#include <cstdio>
#include <iostream>
#include <fstream>
#include <cmath>
#include <random>
#include <string>
#include <mpi.h>
#include <iomanip>
#include <sys/stat.h>
#include <sys/types.h>
#include "Constants.h"
#include "Matrix.h"
#include "Hermite_Tools.h"
#include "Local_Kernels.h"
#include "Random_Sampler.h"
#include "Geometry.h"
#include "Gas.h"
#include "Ray.h"
#include "Surface.h"
#include "PG_Kernel.h"
#include "Raytracer.h"
#include "Importer.h"


int main(int argc, char* argv[]) {

    bool batch = false;
    
    std::string project_name;

    for(auto i = 0; i < argc; i ++) {
        std::string arg = argv[i];
        if(arg == "--batch") batch = true;
        if(arg.find("--project=") != std::string::npos) {
            project_name = arg.substr(10, arg.length());
        }
    }

    int rank, size, retval_rank, retval_size;
    double start_time, end_time;
    MPI_Init(&argc, &argv);
    retval_rank = MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    retval_size = MPI_Comm_size(MPI_COMM_WORLD , &size);

    if( batch == false) {
        std::tuple<PG_Kernel<double>, Raytracer<double>> project_tuple = import_single<double>(project_name);

        PG_Kernel<double> kernel = std::get<0>(project_tuple);
        Raytracer<double> raytracer = std::get<1>(project_tuple);

        kernel.set_num_particles(kernel.get_num_particles() / size);
        raytracer.set_num_particles(raytracer.get_num_particles() / size);

        start_time = MPI_Wtime();
        std::vector<trajectory<double>> trajectories_kernel = kernel.sample_batch();
        end_time = MPI_Wtime();

        if(rank == 0) std::cout<<"Kernel simulation finished in " << end_time - start_time <<" seconds!\n";

        start_time = MPI_Wtime();
        std::vector<trajectory<double>> trajectories_raytracer = raytracer.simulate();
        end_time = MPI_Wtime();

        if(rank == 0) std::cout<<"Raytracer simulation finished in " << end_time - start_time <<" seconds!\n";

        kernel.save(trajectories_kernel, project_name + "/results/kernel_data_" + std::to_string(rank) + ".dat");

        if(rank == 0) std::cout<<"Kernel results saved!\n";

        raytracer.save(trajectories_raytracer, project_name + "/results/raytracer_data_" + std::to_string(rank) + ".dat");

        if(rank == 0) std::cout<<"Raytracer results saved!\n";

        MPI_Barrier(MPI_COMM_WORLD);

        if (rank == 0) {
            std::cout<<"Combining files ...\n";

            kernel.combine_files(project_name + "/results/kernel_data", size);
            raytracer.combine_files(project_name + "/results/raytracer_data", size);
        }

        if(rank == 0) std::cout<<"Results have been generated!\n";
    }
    else {

        std::tuple<std::vector<PG_Kernel<double>>, std::vector<Raytracer<double>>> project_tuple = import_batch<double>(project_name);

        std::vector<PG_Kernel<double>> kernels = std::get<0>(project_tuple);
        std::vector<Raytracer<double>> raytracers = std::get<1>(project_tuple);

        for(long i = 0; i < kernels.size(); i ++)  {

            kernels[i].set_num_particles(kernels[i].get_num_particles() / size);
            raytracers[i].set_num_particles(raytracers[i].get_num_particles() / size);

            start_time = MPI_Wtime();
            std::vector<trajectory<double>> trajectories_kernel = kernels[i].sample_batch();
            end_time = MPI_Wtime();

            if(rank == 0) std::cout<<"Kernel simulation with "<<kernels[i].get_sim_name()<<" finished in " << end_time - start_time <<" seconds!\n";

            start_time = MPI_Wtime();
            std::vector<trajectory<double>> trajectories_raytracer = raytracers[i].simulate();
            end_time = MPI_Wtime();

            if(rank == 0) std::cout<<"Raytracer simulation with "<<raytracers[i].get_sim_name()<<" finished in " << end_time - start_time <<" seconds!\n\n";

            kernels[i].save(trajectories_kernel, project_name + "/results/kernel_data_" + kernels[i].get_sim_name() + "_" + std::to_string(rank) + ".dat");
            raytracers[i].save(trajectories_raytracer, project_name + "/results/raytracer_data_" + raytracers[i].get_sim_name() + "_" + std::to_string(rank) + ".dat");

             MPI_Barrier(MPI_COMM_WORLD);

            if (rank == 0) {
                kernels[i].combine_files(project_name + "/results/kernel_data_" + kernels[i].get_sim_name(), size);
                raytracers[i].combine_files(project_name + "/results/raytracer_data_" + raytracers[i].get_sim_name(), size);
            }
        }
    }


    MPI_Finalize();

    return 0;
}