// Copyright (c) 2023 Dominic Masters
// 
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

#include "Ray3D.hpp"

using namespace Dawn;

bool_t Dawn::raytestSphere(
  struct Ray3D ray,
  struct PhysicsSphere sphere,
  glm::vec3 *hit,
  glm::vec3 *normal,
  float_t *distance
) {
  float_t a = glm::dot(ray.direction, ray.direction);
  float_t b = 2.0f * glm::dot(ray.direction, ray.origin - sphere.center);
  float_t c = glm::dot(ray.origin - sphere.center, ray.origin - sphere.center);
  c -= sphere.radius * sphere.radius;

  float_t dt = b * b - 4.0f * a * c;
  if(dt < 0.0f) return false;

  float_t t0 = (-b - sqrtf(dt)) / (a * 2.0f);
  if(t0 < 0.0f) return false;

  *hit = ray.origin + t0 * ray.direction;
  *normal = glm::normalize(*hit - sphere.center);
  *distance = t0;

  return true;
}

bool_t Dawn::raytestTriangle(
  struct Ray3D ray,
  struct PhysicsTriangle triangle,
  glm::vec3 *hitPoint,
  glm::vec3 *hitNormal,
  float_t *hitDistance
) {
  assertNotNull(hitPoint, "Ray3D::raytestTriangle: hitPoint cannot be null");
  assertNotNull(hitNormal, "Ray3D::raytestTriangle: hitNormal cannot be null");
  assertNotNull(hitDistance, "Ray3D::raytestTriangle: hitDistance cannot be null");

  // Calculate the normal of the triangle
  glm::vec3 e0 = triangle.v1 - triangle.v0;
  glm::vec3 e1 = triangle.v2 - triangle.v0;
  glm::vec3 normal = glm::normalize(glm::cross(e0, e1));

  // Calculate the denominator of the ray-triangle intersection formula
  float_t denominator = glm::dot(normal, ray.direction);

  // If the denominator is zero, the ray and triangle are parallel and there is no intersection
  if(denominator == 0) return -1;

  // Calculate the distance from the ray origin to the plane of the triangle
  float_t d = glm::dot(triangle.v0 - ray.origin, normal) / denominator;

  // If the distance is negative, the intersection point is behind the ray origin and there is no intersection
  if(d < 0) return -1;

  // Calculate the intersection point
  glm::vec3 intersectionPoint = ray.origin + d * ray.direction;

  // Check if the intersection point is inside the triangle
  glm::vec3 edge0 = triangle.v1 - triangle.v0;
  glm::vec3 edge1 = triangle.v2 - triangle.v1;
  glm::vec3 edge2 = triangle.v0 - triangle.v2;
  glm::vec3 c0 = intersectionPoint - triangle.v0;
  glm::vec3 c1 = intersectionPoint - triangle.v1;
  glm::vec3 c2 = intersectionPoint - triangle.v2;
  glm::vec3 n0 = glm::cross(edge0, c0);
  glm::vec3 n1 = glm::cross(edge1, c1);
  glm::vec3 n2 = glm::cross(edge2, c2);
  if(glm::dot(n0, normal) >= 0 && glm::dot(n1, normal) >= 0 && glm::dot(n2, normal) >= 0) {
    // If the intersection point is inside the triangle, set the hit point, normal and distance
    *hitPoint = intersectionPoint;
    *hitNormal = normal;
    *hitDistance = d;
    return true;
  }

  // If the intersection point is outside the triangle, there is no intersection
  return false;
}

bool_t Dawn::raytestAABB(
  struct Ray3D ray,
  struct AABB3D box,
  glm::vec3 *point,
  glm::vec3 *normal,
  float_t *distance
) {
  assertNotNull(point, "Ray3D::raytestAABB: point cannot be null");
  assertNotNull(normal, "Ray3D::raytestAABB: normal cannot be null");
  assertNotNull(distance, "Ray3D::raytestAABB: distance cannot be null");

  // Compute the inverse direction of the ray, for numerical stability
  glm::vec3 invDir(1.0f / ray.direction.x, 1.0f / ray.direction.y, 1.0f / ray.direction.z);

  // Compute the t-values for the two intersection candidates
  glm::vec3 tMin = (box.min - ray.origin) * invDir;
  glm::vec3 tMax = (box.max - ray.origin) * invDir;

  // Make sure tMin is less than or equal to tMax for all components
  glm::vec3 t1 = glm::min(tMin, tMax);
  glm::vec3 t2 = glm::max(tMin, tMax);
  float tNear = glm::compMax(t1);
  float tFar = glm::compMin(t2);

  // If tNear is greater than or equal to tFar, there is no intersection
  if(tNear >= tFar) return false;

  // If tFar is negative, the ray is pointing away from the box
  if(tFar < 0.0f) return false;

  // Compute the hit point and normal
  glm::vec3 hitPoint = ray.origin + tNear * ray.direction;

  *point = hitPoint;
  *distance = tNear;

  // Small value to account for floating point imprecision
  const float epsilon = 0.001f; 
  if(std::abs(hitPoint.x - box.min.x) < epsilon) {
    *normal = glm::vec3(-1, 0, 0);
  } else if(std::abs(hitPoint.x - box.max.x) < epsilon) {
    *normal = glm::vec3(1, 0, 0);
  } else if(std::abs(hitPoint.y - box.min.y) < epsilon) {
    *normal = glm::vec3(0, -1, 0);
  } else if(std::abs(hitPoint.y - box.max.y) < epsilon) {
    *normal = glm::vec3(0, 1, 0);
  } else if(std::abs(hitPoint.z - box.min.z) < epsilon) {
    *normal = glm::vec3(0, 0, -1);
  } else if(std::abs(hitPoint.z - box.max.z) < epsilon) {
    *normal = glm::vec3(0, 0, 1);
  }

  return true;
}

bool_t Dawn::raytestCube(
  struct Ray3D ray,
  struct AABB3D box,
  glm::mat4 transform,
  glm::vec3 *point,
  glm::vec3 *normal,
  float_t *distance
) {
  // Compute the inverse transformation matrix
  glm::mat4 inverseTransform = glm::inverse(transform);

  // Transform the ray into model space
  struct Ray3D localRay;
  localRay.origin = glm::vec3(inverseTransform * glm::vec4(ray.origin, 1.0f));
  localRay.direction = glm::normalize(glm::vec3(inverseTransform * glm::vec4(ray.direction, 0.0f)));

  // Call raytestAABB with the transformed ray and cube
  bool_t hit = raytestAABB(localRay, box, point, normal, distance);
  if(!hit) return false;

  // Transform the hit point and normal back into world space
  *point = glm::vec3(transform * glm::vec4(*point, 1.0f));
  *normal = glm::normalize(glm::vec3(glm::transpose(inverseTransform) * glm::vec4(*normal, 0.0f)));

  return true;
}

bool_t Dawn::raytestQuad(
  struct Ray3D ray,
  glm::vec2 min,
  glm::vec2 max,
  glm::mat4 transform,
  glm::vec3 *point,
  glm::vec3 *normal,
  float_t *distance
) {
  assertNotNull(point, "Ray3D::raytestQuad: point cannot be null");
  assertNotNull(normal, "Ray3D::raytestQuad: normal cannot be null");
  assertNotNull(distance, "Ray3D::raytestQuad: distance cannot be null");
  
  // transform ray into local space of the quad
  glm::mat4 inverseTransform = glm::inverse(transform);
  glm::vec3 localRayOrigin = glm::vec3(inverseTransform * glm::vec4(ray.origin, 1.0f));
  glm::vec3 localRayDirection = glm::vec3(inverseTransform * glm::vec4(ray.direction, 0.0f));

  // perform ray-quad intersection test
  float_t t = -localRayOrigin.z / localRayDirection.z; // intersection distance along ray
  if(t < 0) return false; // intersection is behind the ray origin
  glm::vec2 intersectionPoint = glm::vec2(localRayOrigin) + t * glm::vec2(localRayDirection);
  if(
    glm::any(glm::lessThan(intersectionPoint, min)) ||
    glm::any(glm::greaterThan(intersectionPoint, max))
  ) {
    return false; // intersection is outside the quad
  }
  *distance = t;

  // compute point and normal of intersection in world space
  glm::vec3 localIntersectionPoint = glm::vec3(intersectionPoint, 0.0f);
  *point = glm::vec3(transform * glm::vec4(localIntersectionPoint, 1.0f));
  *normal = glm::normalize(glm::vec3(transform * glm::vec4(0.0f, 0.0f, 1.0f, 0.0f)));

  return true; // intersection found
}

bool_t Dawn::raytestCapsule(
  struct Ray3D ray,
  struct PhysicsCapsule capsule,
  glm::vec3 *point,
  glm::vec3 *normal,
  float_t *distance
) {
  // Calculate the axis of the capsule
  glm::vec3 capsuleAxis = glm::normalize(ray.direction);
  glm::vec3 capsuleP0 = capsule.origin;
  glm::vec3 capsuleP1 = capsule.origin + capsule.height * capsuleAxis;

  // Calculate the sphere centers and radii of the capsule end-caps
  glm::vec3 sphereP0 = capsule.origin;
  glm::vec3 sphereP1 = capsule.origin + capsule.height * capsuleAxis;
  float_t sphereR = capsule.radius;

  // Calculate the closest points on the capsule axis and the ray
  glm::vec3 closestPointRay, closestPointAxis;
  if(glm::distance(ray.origin, capsuleP0) < glm::distance(ray.origin, capsuleP1)) {
    closestPointAxis = glm::clamp(glm::dot(ray.origin - capsuleP0, capsuleAxis), 0.0f, capsule.height) * capsuleAxis + capsuleP0;
  } else {
    closestPointAxis = glm::clamp(glm::dot(ray.origin - capsuleP1, -capsuleAxis), 0.0f, capsule.height) * -capsuleAxis + capsuleP1;
  }

  closestPointRay = glm::clamp(
    glm::dot(closestPointAxis - ray.origin, ray.direction),
    0.0f, glm::length(ray.direction)
  ) * ray.direction + ray.origin;

  // Calculate the distance between the closest points on the ray and the axis
  glm::vec3 temp = (closestPointRay - closestPointAxis);
  float_t distanceSquared = glm::dot(temp, temp);

  // Check if the ray intersects the end-caps of the capsule
  if(
    raytestSphere(ray, { .center = sphereP0, .radius = sphereR }, point, normal, distance) ||
    raytestSphere(ray, { .center = sphereP1, .radius = sphereR }, point, normal, distance)
  ) {
    *normal = glm::normalize(*point - sphereP0);
    return true;
  }

  // Check if the ray intersects the cylinder part of the capsule
  if(distanceSquared > sphereR * sphereR) return false;
  *distance = glm::distance(ray.origin, closestPointRay);
  *point = closestPointRay;
  *normal = glm::normalize(*point - closestPointAxis);
  return true;
}