/**
 * @brief Class representing a Ray for simulating particle movement in a geometric domain.
 * 
 * The Ray class models the behavior of particles moving in space, handling properties such as
 * velocity, mass flux, and collision detection.
 * 
 * @tparam T The data type used for computations (e.g., float, double).
 */
template <typename T>
class Ray {

    private:
        Matrix<T> origin;           ///< Origin point of the ray.
        Matrix<T> velocity;         ///< Velocity vector of the ray.
        T mass_flux;                ///< Mass flux carried by the ray.
        unsigned int children;      ///< Number of child rays that this ray can spawn upon collision.
        int spawn_vertex_index;     ///< Index of the vertex where this ray was spawned.
        bool failed = false;        ///< Flag indicating if the ray has failed.

    public:
        /**
         * @brief Default constructor for the Ray class.
         * 
         * Initializes the ray with default values for origin, velocity, and other properties.
         */
        Ray();

        /**
         * @brief Parameterized constructor for the Ray class.
         * 
         * Initializes the ray with specific origin, velocity, mass flux, number of children, and spawn vertex index.
         * 
         * @param origin Origin point of the ray.
         * @param velocity Velocity vector of the ray.
         * @param m_dot Mass flux carried by the ray.
         * @param children Number of child rays this ray can spawn.
         * @param spawn_vertex_index Index of the vertex where this ray was spawned.
         */
        Ray(Matrix<T>& origin, Matrix<T>& velocity, T m_dot, unsigned int children, int spawn_vertex_index);

        /**
         * @brief Copy constructor for the Ray class.
         * 
         * Initializes a new ray by copying the properties of another ray.
         * 
         * @param other The ray object to copy from.
         */
        Ray(const Ray<T>& other);

        /**
         * @brief Move constructor for the Ray class.
         * 
         * Initializes a new ray by moving the properties of another ray.
         * 
         * @param other The ray object to move from.
         */
        Ray(Ray<T>&& other);

        /**
         * @brief Destructor for the Ray class.
         * 
         * Resets the properties of the ray to default values.
         */
        ~Ray();

        /**
         * @brief Copy assignment operator for the Ray class.
         * 
         * Copies the properties from another ray.
         * 
         * @param other The ray object to copy from.
         * @return Ray& Reference to the current ray object.
         */
        Ray& operator=(const Ray& other);

        /**
         * @brief Move assignment operator for the Ray class.
         * 
         * Moves the properties from another ray, resetting the source ray.
         * 
         * @param other The ray object to move from.
         * @return Ray& Reference to the current ray object.
         */
        Ray& operator=(Ray&& other);

        /**
         * @brief Gets the velocity of the ray.
         * 
         * @return Matrix<T>& The current velocity vector of the ray.
         */
        Matrix<T>& get_velocity() { return this->velocity; }

        /**
         * @brief Sets the velocity of the ray.
         * 
         * @param new_velocity The new velocity vector to set.
         */
        void set_velocity(Matrix<T> new_velocity) { this->velocity = new_velocity; }

        /**
         * @brief Gets the origin of the ray.
         * 
         * @return Matrix<T>& The current origin point of the ray.
         */
        Matrix<T>& get_origin() { return this->origin; }

        /**
         * @brief Sets the origin of the ray.
         * 
         * @param new_origin The new origin point to set.
         */
        void set_origin(Matrix<T> new_origin) { this->origin = new_origin; }

        /**
         * @brief Gets the mass flux of the ray.
         * 
         * @return T The current mass flux of the ray.
         */
        T get_mass_flux() { return mass_flux; }

        /**
         * @brief Sets the mass flux of the ray.
         * 
         * @param mass_flux The new mass flux to set.
         */
        void set_mass_flux(T mass_flux) { this->mass_flux = mass_flux; }

        /**
         * @brief Checks if the ray collides with a triangle defined by three vertices.
         * 
         * This function checks if the ray intersects with the triangle defined by points p1, p2, and p3,
         * and updates the collision position and distance if a collision occurs.
         * 
         * @param p1 The first vertex of the triangle.
         * @param p2 The second vertex of the triangle.
         * @param p3 The third vertex of the triangle.
         * @param normal The normal vector of the triangle.
         * @param col_position The position of the collision, if it occurs.
         * @param distance The distance from the ray origin to the collision point.
         * @return true If a collision occurs, false otherwise.
         */
        bool check_vertex_collision(Matrix<T>& p1, Matrix<T>& p2, Matrix<T>& p3, Matrix<T>& normal, Matrix<T>& col_position, T& distance);

        /**
         * @brief Checks if the ray collides with a bounding box.
         * 
         * @param bb The bounding box to check collision with.
         * @return true If the ray collides with the bounding box, false otherwise.
         */
        bool check_bb_collision(Matrix<T>& bb);

        /**
         * @brief Propagates the ray through the geometry and handles collision and reflection.
         * 
         * This function propagates the ray, checking for collisions with the geometry,
         * and applies a kernel function to compute the reflected ray properties.
         * 
         * @param geometry The geometry through which the ray propagates.
         * @param gas The gas properties influencing the ray's behavior.
         * @param kernel The function used to compute reflection or scattering after collision.
         * @param collision_index Pointer to store the index of the vertex where the collision occurred.
         * @return std::vector<Ray<T>> A vector of child rays spawned after collision.
         */
        std::vector<Ray<T>> propagate(Geometry<T>& geometry, const Gas<T>& gas, void (*kernel)(T*, T*, T*, T*, T*), unsigned int* collision_index);
};

/**
 * @brief Default constructor for the Ray class.
 * 
 * Initializes the ray with default values for origin, velocity, mass flux, and other properties.
 */
template <typename T>
Ray<T>::Ray() {
    origin = Matrix<T>(3, 1, {0.0, 0.0, 0.0});
    velocity = Matrix<T>(3, 1, {0.0, 0.0, 0.0});
    mass_flux = 0.0;
    children = 0;
    spawn_vertex_index = -1;
}

/**
 * @brief Parameterized constructor for the Ray class.
 * 
 * Initializes the ray with specific origin, velocity, mass flux, number of children, and spawn vertex index.
 * 
 * @param origin Origin point of the ray.
 * @param velocity Velocity vector of the ray.
 * @param m_dot Mass flux carried by the ray.
 * @param children Number of child rays this ray can spawn.
 * @param spawn_vertex_index Index of the vertex where this ray was spawned.
 */
template <typename T>
Ray<T>::Ray(Matrix<T>& origin, Matrix<T>& velocity, T m_dot, unsigned int children, int spawn_vertex_index)
    : origin(origin), velocity(velocity), mass_flux(m_dot), children(children), spawn_vertex_index(spawn_vertex_index) {}

/**
 * @brief Copy constructor for the Ray class.
 * 
 * Initializes a new ray by copying the properties of another ray.
 * 
 * @param other The ray object to copy from.
 */
template <typename T>
Ray<T>::Ray(const Ray<T>& other) 
    : origin(other.origin), velocity(other.velocity), children(other.children), mass_flux(other.mass_flux), spawn_vertex_index(other.spawn_vertex_index) {}

/**
 * @brief Move constructor for the Ray class.
 * 
 * Initializes a new ray by moving the properties of another ray.
 * 
 * @param other The ray object to move from.
 */
template <typename T>
Ray<T>::Ray(Ray<T>&& other)
    : origin(other.origin), velocity(other.velocity), children(other.children), mass_flux(other.mass_flux), spawn_vertex_index(other.spawn_vertex_index) {
    other.children = 0;
    other.mass_flux = 0.0;
    other.spawn_vertex_index = 0;
}

/**
 * @brief Destructor for the Ray class.
 * 
 * Resets the properties of the ray to default values.
 */
template <typename T>
Ray<T>::~Ray() {
    children = 0;
    mass_flux = 0.0;
}

/**
 * @brief Copy assignment operator for the Ray class.
 * 
 * Copies the properties from another ray.
 * 
 * @param other The ray object to copy from.
 * @return Ray& Reference to the current ray object.
 */
template <typename T>
Ray<T>& Ray<T>::operator=(const Ray<T>& other) {
    origin = other.origin;
    velocity = other.velocity;
    mass_flux = other.mass_flux;
    children = other.children;
    spawn_vertex_index = other.spawn_vertex_index;
    return *this;
}

/**
 * @brief Move assignment operator for the Ray class.
 * 
 * Moves the properties from another ray, resetting the source ray's properties.
 * 
 * @param other The ray object to move from.
 * @return Ray& Reference to the current ray object.
 */
template <typename T>
Ray<T>& Ray<T>::operator=(Ray<T>&& other) {
    origin = other.origin;
    velocity = other.velocity;
    mass_flux = other.mass_flux;
    children = other.children;
    spawn_vertex_index = other.spawn_vertex_index;

    // Reset the properties of the moved ray
    other.children = 0;
    other.mass_flux = 0.0;
    other.spawn_vertex_index = 0;
    return *this;
}

/**
 * @brief Checks if the ray collides with a triangle defined by three vertices.
 * 
 * This function determines if the ray intersects with the triangle formed by 
 * vertices p1, p2, and p3. If a collision occurs, the position of the collision 
 * and the distance to the collision point are updated.
 * 
 * @param p1 The first vertex of the triangle.
 * @param p2 The second vertex of the triangle.
 * @param p3 The third vertex of the triangle.
 * @param normal The normal vector of the triangle.
 * @param col_position The position of the collision (output parameter).
 * @param distance The distance from the ray origin to the collision point (output parameter).
 * @return true If a collision occurs, false otherwise.
 */
template <typename T>
bool Ray<T>::check_vertex_collision(Matrix<T>& p1, Matrix<T>& p2, Matrix<T>& p3, Matrix<T>& normal, Matrix<T>& col_position, T& distance) {
    
    const T EPSILON = 1e-15; // Small tolerance for floating-point comparisons
    Matrix<T> edge1(3, 1), edge2(3, 1), h(3, 1), s(3, 1), q(3, 1);
    Matrix<T> direction(3, 1);
    direction = velocity / velocity.norm(); 
    T a, f, u, v;

    // Compute the two edge vectors of the triangle
    edge1 = p2 - p1;
    edge2 = p3 - p1;

    // Calculate the determinant (cross product of ray direction and edge2)
    h = direction.cross(edge2);
    a = edge1.tr().dot(h)(0, 0);

    // Check if the ray is parallel to the triangle
    if (a > -EPSILON && a < EPSILON) {
        return false; // Parallel ray
    }

    // Calculate the ray-triangle intersection using barycentric coordinates
    f = 1.0 / a;
    s = (origin - p1);
    u = f * s.tr().dot(h)(0, 0);

    if (u < -EPSILON || u > 1.0 + EPSILON) {
        return false; // No intersection
    }

    q = s.cross(edge1);
    v = f * direction.tr().dot(q)(0, 0);

    if (v < -EPSILON || u + v > 1.0 + EPSILON) {
        return false; // No intersection
    }

    // Calculate the intersection point
    T t = f * edge2.tr().dot(q)(0, 0);

    if (t > EPSILON) {
        // Ray intersects the triangle
        distance = fabs(t);
        col_position = origin + direction * t;
        return true;
    } else {
        return false; // Intersection occurs, but not in the direction of the ray
    }
}

/**
 * @brief Propagates the ray through the geometry and handles collision and reflection.
 * 
 * This function propagates the ray through the given geometry, checking for collisions with
 * the vertices of the geometry. If a collision occurs, the position, distance, and other 
 * properties of the collision are computed. The kernel function is used to calculate the 
 * reflected velocity of the ray based on the gas and surface properties at the collision point.
 * After collision, child rays are generated and returned.
 * 
 * @param geometry The geometry through which the ray propagates.
 * @param gas The gas properties that affect the ray's behavior during propagation.
 * @param kernel The kernel function used to compute reflection or scattering after collision.
 * @param collision_index Pointer to store the index of the vertex where the collision occurred.
 * @return std::vector<Ray<T>> A vector containing child rays generated after the collision.
 */
template <typename T>
std::vector<Ray<T>> Ray<T>::propagate(Geometry<T>& geometry, const Gas<T>& gas, void (*kernel)(T*, T*, T*, T*, T*), unsigned int* collision_index) {

    std::vector<Ray<T>> children_rays; // Vector to store the child rays after propagation
    Matrix<T> p1_col(3, 1), p2_col(3, 1), p3_col(3, 1), norm_col(3, 1), pos_col(3, 1); // Collision-related matrices
    Matrix<T> p1_x = geometry.get_p1_x(), p1_y = geometry.get_p1_y(), p1_z = geometry.get_p1_z();
    Matrix<T> p2_x = geometry.get_p2_x(), p2_y = geometry.get_p2_y(), p2_z = geometry.get_p2_z();
    Matrix<T> p3_x = geometry.get_p3_x(), p3_y = geometry.get_p3_y(), p3_z = geometry.get_p3_z();
    Matrix<T> norm_x = geometry.get_norm_x(), norm_y = geometry.get_norm_y(), norm_z = geometry.get_norm_z();
    std::vector<Matrix<T>> GSI_properties(geometry.get_GSI_properties()); // GSI (General Surface Interaction) properties
    std::vector<const char*> GSI_property_names(geometry.get_GSI_property_names()); // Names of GSI properties
    std::vector<Matrix<T>> surface_properties(geometry.get_surface_properties()); // Surface properties at each vertex
    std::vector<const char*> surface_property_names(geometry.get_surface_property_names()); // Names of surface properties

    T* GSI_prop_col = new T[GSI_properties.size() + 1]; // GSI properties at the collision point
    T* surface_prop_col = new T[surface_properties.size() + 1]; // Surface properties at the collision point
    T gas_prop_col[5] = {gas.get_temperature(), gas.get_density(), gas.get_molar_mass(), gas.get_speed()}; // Gas properties

    T min_dist = 1e15; // Initialize minimum distance with a large value
    unsigned int vertex_index = 0; // Vertex index where the collision occurred
    bool bounced = false; // Flag indicating whether the ray has collided with a surface

    // Loop through each vertex in the geometry to check for collisions
    for (auto i = 0; i < geometry.get_num_vertices(); i++) {
        Matrix<T> p1(3, 1, {p1_x(i), p1_y(i), p1_z(i)});
        Matrix<T> p2(3, 1, {p2_x(i), p2_y(i), p2_z(i)});
        Matrix<T> p3(3, 1, {p3_x(i), p3_y(i), p3_z(i)});
        Matrix<T> norm(3, 1, {norm_x(i), norm_y(i), norm_z(i)});
        Matrix<T> pos(3, 1);

        T dist = 0.0;
        bool collided = check_vertex_collision(p1, p2, p3, norm, pos, dist); // Check if the ray intersects the triangle
        if (i != spawn_vertex_index && collided && min_dist > dist) {
            // Update the collision data if this collision is closer than previous ones
            p1_col = p1;
            p2_col = p2;
            p3_col = p3;
            norm_col = norm;
            pos_col = pos;
            vertex_index = i;
            min_dist = dist;
            bounced = true; // Set the flag indicating a collision has occurred
        }
    }

    // If no collision occurred, return an empty vector (no child rays generated)
    if (!bounced) return children_rays;

    // Compute the reflection directions after collision
    Matrix<T> norm_direction = norm_col;
    Matrix<T> tan1_direction = (velocity - norm_direction * velocity.tr().dot(norm_direction)(0, 0)) / (velocity - norm_direction * velocity.tr().dot(norm_direction)(0, 0)).norm();
    Matrix<T> tan2_direction = norm_col.cross(tan1_direction); // Tangential direction orthogonal to the normal

    // Compute the velocity components along the normal and tangential directions
    T vel_in_norm = (velocity.tr().dot(norm_direction))(0, 0);
    T vel_in_tan1 = (velocity.tr().dot(tan1_direction))(0, 0);
    T vel_in_tan2 = (velocity.tr().dot(tan2_direction))(0, 0);

    T vel_in[4] = {vel_in_tan1, vel_in_tan2, vel_in_norm}; // Store input velocities for kernel function
    
    // Extract the GSI and surface properties at the point of collision
    for (auto i = 0; i < GSI_properties.size(); i++) {
        GSI_prop_col[i] = GSI_properties[i](0, vertex_index);
    }

    for (auto i = 0; i < surface_properties.size(); i++) {
        surface_prop_col[i] = surface_properties[i](0, vertex_index);
    }

    // Generate child rays after collision
    for (auto i = 0; i < children; i++) {
        T vel_refl[4] = {0.0, 0.0, 0.0}; // Reflected velocity components
        kernel(vel_in, vel_refl, GSI_prop_col, gas_prop_col, surface_prop_col); // Apply the kernel function to compute reflection

        // Compute the reflected velocity vector for the child ray
        Matrix<T> child_velocity(3, 1);
        child_velocity = norm_direction * (-vel_refl[2]) * sgn<T>(vel_in[2]) + 
                         tan1_direction * vel_refl[0] + 
                         tan2_direction * vel_refl[1];

        // Create the child ray and add it to the list of child rays
        Ray<T> child_ray(pos_col, child_velocity, mass_flux / children * velocity.norm() / child_velocity.norm(), 1, vertex_index);
        children_rays.push_back(child_ray); 
    }

    // Store the index of the vertex where the collision occurred
    *collision_index = vertex_index;

    // Return the generated child rays
    return children_rays;
}
