#include <iostream>
#include <cmath>
#include <random>
#include <vector>
#include <tuple>
#include <functional>

template <typename T>
T random_sampler_1D(T (*pdf)(T, std::vector<T>), std::vector<T> parameters, T x0, T size, T domain_x[]) {

    unsigned long steps = 100;

    T x = x0;

    std::random_device rd;
    std::seed_seq fullSeed{rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd()};
    std::mt19937 rng(fullSeed);
    std::uniform_real_distribution<T> uniformDist(0.0f, 1.0f);
    std::normal_distribution<T> normDist(0.0f, 1.0f);

    for(auto i = 0; i < steps; i ++) {

        T step = normDist(rng) * size;
        T u_sample = uniformDist(rng);
        T x_p = x + step;
        T acceptance = 0.0;
        if(x_p >= domain_x[0] && x_p <= domain_x[1]) {
            T acceptance = std::min(1.0, *pdf(x_p, parameters) / *pdf(x, parameters));
        }
        if(acceptance >= u_sample) {
            x = x_p;
        }
    }

    return x;
}

template <typename T>
std::tuple<T, T> random_sampler_2D(T (*pdf)(T, T, std::vector<T>), std::vector<T> parameters, T x0, T y0, T size_x, T size_y, T domain_x[], T domain_y[]) {

    unsigned long steps = 100;

    T x = x0;
    T y = y0;

    std::random_device rd;
    std::seed_seq fullSeed{rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd()};
    std::mt19937 rng(fullSeed);
    std::uniform_real_distribution<T> uniformDist(0.0f, 1.0f);
    std::normal_distribution<T> normDist(0.0f, 1.0f);

    for(auto i = 0; i < steps; i ++) {

        T step_x = normDist(rng) * size_x;
        T step_y = normDist(rng) * size_y;
        T u_sample = uniformDist(rng);
        T x_p = x + step_x;
        T y_p = y + step_y;
        T acceptance = 0.0;
        if(x_p >= domain_x[0] && x_p <= domain_x[1] && y_p >= domain_y[0] && y_p <= domain_y[1]) {
            T acceptance = std::min(1.0, *pdf(x_p, y_p, parameters) / *pdf(x, y, parameters));
        }
        if(acceptance >= u_sample) {
            x = x_p;
            y = y_p;
        }
    }

    return std::make_tuple(x, y);
}

template <typename T>
std::tuple<T, T> random_sampler_angles(T (*pdf)(T, T, std::vector<T>), std::vector<T> parameters, T theta_r1, T theta_r2, T size_x, T size_y, T domain_x[], T domain_y[]) {

    unsigned long steps = 100;

    std::random_device rd;
    std::seed_seq fullSeed{rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd()};
    std::mt19937 rng(fullSeed);
    std::uniform_real_distribution<T> uniformDist(0.0f, 1.0f);
    std::normal_distribution<T> normDist(0.0f, 1.0f);

    for(auto i = 0; i < steps; i ++) {

        T step_x = normDist(rng) * size_x;
        T step_y = normDist(rng) * size_y;
        T u_sample = uniformDist(rng);
        T theta_r1_p = theta_r1 + step_x;
        T theta_r2_p = theta_r2 + step_y;
        T acceptance = 0.0;
        if(theta_r1_p < 0.0) {
            theta_r1_p = std::abs(theta_r1_p);
            theta_r2_p += M_PI;
        }
        if(theta_r2_p < 0.0) {
            theta_r2_p += 2.0 * M_PI;
        }
        if(theta_r2_p > 2.0 * M_PI) {
            theta_r2_p -= 2.0 * M_PI;
        }

        if(theta_r1_p < domain_x[1]) {
            T acceptance = std::min(1.0, *pdf(theta_r1_p, theta_r2_p, parameters) / *pdf(theta_r1, theta_r2, parameters));
        }

        if(acceptance >= u_sample) {
            theta_r1 = theta_r1_p;
            theta_r2 = theta_r2_p;
        }
    }

    return std::make_tuple(theta_r1, theta_r2);
}

template <typename T>
Matrix<T> maxwell_sampler(T mol_mass, T temperature) {

    std::random_device rd;
    std::seed_seq fullSeed{rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd()};
    std::mt19937 rng(fullSeed);
    std::normal_distribution<T> normDist(0.0f, 1.0f);
    std::uniform_real_distribution<T> uniformDist(0.0f, 1.0f);

    T vel_wall = sqrt(2.0 * R_gas / mol_mass * temperature);

    T vel_x = sqrt(R_gas * temperature / mol_mass) * normDist(rng); 
    T vel_y = sqrt(R_gas * temperature / mol_mass) * normDist(rng); 
    T vel_z = - vel_wall * sqrt(-log(uniformDist(rng)));

    Matrix<T> velocity_out(3, 1, {vel_x, vel_y, vel_z});

    return velocity_out;

}