#include <gecode/brancher/data-collector.hh>
#include <gecode/brancher/ml-utils.cpp>
#include <math.h>

DataCollector* DataCollector::instance = NULL;

DataCollector::DataCollector() {}

DataCollector* DataCollector::getInstance() {
    if (instance == NULL) {
        instance = new DataCollector();
    }
    return instance;
}

vector<MLTree*> DataCollector::getTrees() {
    return trees;
}

void DataCollector::init() {
    tiebreaks = 0;
    nodes_searched = 0;
    number_of_predictions = 0;
}

void DataCollector::incrementTiebreaks() {
    tiebreaks++;
}

int DataCollector::getNumberOfTiebreaks() {
    return tiebreaks;
}

void DataCollector::incrementNodesSearched() {
    nodes_searched++;
}

int DataCollector::getNodesSearched() {
    return nodes_searched;
}

vector<int> DataCollector::getRootDomain() {
    return root_values;
}

vector<int> DataCollector::getRootVars() {
    return root_vars;
}

void DataCollector::setRootDomain(ViewArray<Int::IntView> y){
    //cout << "Root vals: ";
    for (Int::ViewValues<Int::IntView> j(y[0]); j(); ++j) {
        //cout << j.val() << " ";
        root_values.push_back(j.val());
    }
    //cout << endl;
}

void DataCollector::setRootVars(int size){
    for(int i = 0; i < size; i++){
        root_vars.push_back(i);
    }
}

void DataCollector::removeChecked(int checked){
    root_values.erase(remove(root_values.begin(), root_values.end(), checked), root_values.end());
}

void DataCollector::removeCheckedVars(int checked){
    root_vars.erase(remove(root_vars.begin(), root_vars.end(), checked), root_vars.end());
}

// Heuristic function
string DataCollector::getFunction() {
    return func;
}

void DataCollector::setFunction(string f) {
    func = f;
}

void DataCollector::storeMLNode(vector<int> assigned, int domain_size, long degree, int var_order, int pos, int val, int val_pos, int max_val_dom, 
                                int min_val_dom, int regret_min, int regret_max) {    
    MLNode* node = new MLNode(assigned, domain_size, degree, var_order, pos, val, val_pos, max_val_dom, min_val_dom, regret_min, regret_max);
    bool node_stored = false;

    // Check where the node belongs in the trees or whether a new one should be created
    if(node->getAssigned().size() > 1) {
        for(int i = trees.size() - 1; i > -1; --i){
            node_stored = trees.at(i)->storeNode(node);
            //cout << "Did the node get stored: " << node_stored << endl; 
            if (node_stored)
                break;
        }
    }

    if(!node_stored) {
        MLTree* new_tree = new MLTree(node);
        trees.push_back(new_tree);
        //cout << "NEW TREE CREATED with assignment size: " << assigned.size() << endl;
    }
}

void DataCollector::computeDeepImpact(int depth) {
    int leafs = 0;
    double avg_height = 0.0;
    //cout << "Using height of tree as depth " << endl;
    for(int i = 0; i < trees.size(); ++i) {
        trees.at(i)->computeDeepImpact(depth);
        leafs += trees.at(i)->countLeafs();
        avg_height += trees.at(i)->getHeight();
        //trees.at(i)->toString();
    }
    cout << "Total leafs: " << leafs << endl;
    cout << "AVG Height per tree: " << avg_height / (double)(trees.size()) << endl;
}

void DataCollector::computeDeepSmallest(int depth) {
    int leafs = 0;
    double avg_height = 0.0;
    //cout << "Using height of tree as depth " << endl;
    for(int i = 0; i < trees.size(); ++i) {
        trees.at(i)->computeDeepSmallest(depth);
        leafs += trees.at(i)->countLeafs();
        avg_height += trees.at(i)->getHeight();
        //trees.at(i)->toString();
    }
    cout << "Total leafs: " << leafs << endl;
    cout << "AVG Height per tree: " << avg_height / (double)(trees.size()) << endl;
}

void DataCollector::computeDeepMaxRegret(int depth) {
    int leafs = 0;
    double avg_height = 0.0;
    //cout << "Using height of tree as depth " << endl;
    for(int i = 0; i < trees.size(); ++i) {
        trees.at(i)->computeDeepMaxRegret(depth);
        leafs += trees.at(i)->countLeafs();
        avg_height += trees.at(i)->getHeight();
        //trees.at(i)->toString();
    }
    cout << "Total leafs: " << leafs << endl;
    cout << "AVG Height per tree: " << avg_height / (double)(trees.size()) << endl;
}

void DataCollector::computeDeepAntiFF(int depth) {
    int leafs = 0;
    double avg_height = 0.0;
    //cout << "Using height of tree as depth " << endl;
    for(int i = 0; i < trees.size(); ++i) {
        trees.at(i)->computeAntiFF(depth);
        leafs += trees.at(i)->countLeafs();
        avg_height += trees.at(i)->getHeight();
        //trees.at(i)->toString();
    }
    cout << "Total leafs: " << leafs << endl;
    cout << "AVG Height per tree: " << avg_height / (double)(trees.size()) << endl;
}

int DataCollector::getNumberOfPredictions() {
    return number_of_predictions;
}

int DataCollector::getNumberOfGetNodeCalls() {
    return number_of_get_node_calls;
}

void DataCollector::addVar(int var){
    var_order_indices.push_back(var);
}

vector<int> DataCollector::varsSelected(){
    return var_order_indices;
}

void DataCollector::addVarOrdering(vector<int> var_order){
    var_orders.push_back(var_order);
}

vector<vector<int>> DataCollector::getVarOrderings(){
    return var_orders;
}

void DataCollector::resetVarOrderings(){
    var_orders.clear();
    var_order_indices.clear();
}



int DataCollector::runML() {
    //cout << "This is my data size per asset when started the ML: " << dom_size_data.at(0).size() << endl;
    PyObject *pModule, *pName, *pFunc;
    PyObject *pArgs, *p_list;

    number_of_predictions = 0;

    Py_Initialize();
    PyRun_SimpleString("import sys");
    PyRun_SimpleString("sys.path.append(\".\")");
    // Execute python file with machine learning
    pName = PyUnicode_DecodeFSDefault("ml_model");

    pModule = PyImport_Import(pName);
    Py_DECREF(pName);

    if (pModule != NULL) {
        // Execute function 'printTest'
        pFunc = PyObject_GetAttrString(pModule, "supportVectorRegression");
        /* pFunc is a new reference */

        if (pFunc && PyCallable_Check(pFunc)) {
            pArgs = PyTuple_New(12);

            PyObject *pDomSizeList = PyList_New(0);
            PyObject *pDegreeList = PyList_New(0);
            PyObject *pVarOrderList = PyList_New(0);
            PyObject *pVarList = PyList_New(0);
            PyObject *pValList = PyList_New(0);
            PyObject *pValPosList = PyList_New(0);
            PyObject *pMinValDomList = PyList_New(0);
            PyObject *pMaxValDomList = PyList_New(0);
            PyObject *pRegretMinList = PyList_New(0);
            PyObject *pRegretMaxList = PyList_New(0);
            PyObject *pNodeScoreList = PyList_New(0);

            int reverse_ranking = 0;
            string func = DataCollector::getInstance()->getFunction();
            if(func == "impact" || func == "anti_ff")
                reverse_ranking = 1;
            

            MLTree* cur_tree;
            MLNode* cur_node;
            for(int i = 0; i < trees.size(); ++i) {
                cur_tree = trees.at(i);
                for(int j = 0; j < cur_tree->getNodes().size(); ++j) {
                    cur_node = cur_tree->getNodes().at(j);
                    PyList_Append(pDomSizeList,         PyLong_FromLong(    cur_node->getDomainSize()           ));
                    PyList_Append(pDegreeList,          PyLong_FromLong(    cur_node->getDegree()               ));
                    PyList_Append(pVarOrderList,        PyLong_FromLong(    cur_node->getVarOrder()             ));
                    PyList_Append(pVarList,             PyLong_FromLong(    cur_node->getPosition()             ));
                    PyList_Append(pValList,             PyLong_FromLong(    cur_node->getValue()                ));
                    PyList_Append(pValPosList,          PyLong_FromLong(    cur_node->getValuePosition()        ));
                    PyList_Append(pMinValDomList,       PyLong_FromLong(    cur_node->getMinValDom()            ));
                    PyList_Append(pMaxValDomList,       PyLong_FromLong(    cur_node->getMaxValDom()            ));
                    PyList_Append(pRegretMinList,       PyLong_FromLong(    cur_node->getRegretMin()            ));
                    PyList_Append(pRegretMaxList,       PyLong_FromLong(    cur_node->getRegretMax()            ));
                    PyList_Append(pNodeScoreList,       PyFloat_FromDouble( cur_node->getNodeScore()            ));
                }
            }

            PyTuple_SetItem(pArgs, 0, pDomSizeList);
            PyTuple_SetItem(pArgs, 1, pDegreeList);
            PyTuple_SetItem(pArgs, 2, pVarOrderList);
            PyTuple_SetItem(pArgs, 3, pVarList);
            PyTuple_SetItem(pArgs, 4, pValList);
            PyTuple_SetItem(pArgs, 5, pValPosList);
            PyTuple_SetItem(pArgs, 6, pMinValDomList);
            PyTuple_SetItem(pArgs, 7, pMaxValDomList);
            PyTuple_SetItem(pArgs, 8, pRegretMinList);
            PyTuple_SetItem(pArgs, 9, pRegretMaxList);
            PyTuple_SetItem(pArgs, 10, pNodeScoreList);
            PyTuple_SetItem(pArgs, 11, PyLong_FromLong(reverse_ranking));

            PyObject *p_list = PyObject_CallObject(pFunc, pArgs);
            p_ml_model = PyList_GetItem(p_list, 0);
            Py_INCREF(p_list);
            p_ml_scaler = PyList_GetItem(p_list, 1);
            Py_INCREF(p_list);
            cout << "Cosine simularity: " << PyFloat_AsDouble(PyList_GetItem(p_list, 2)) << endl;
            Py_INCREF(p_list);
            cout << "Spearman correlation: " << PyFloat_AsDouble(PyList_GetItem(p_list, 3)) << endl;
            Py_INCREF(p_list);
            cout << "Spearman top 10 correlation: " << PyFloat_AsDouble(PyList_GetItem(p_list, 4)) << endl;
            Py_INCREF(p_list);
            cout << "Max score: " << PyFloat_AsDouble(PyList_GetItem(p_list, 5)) << endl;
            Py_INCREF(p_list);
            cout << "R2: " << PyFloat_AsDouble(PyList_GetItem(p_list, 6)) << endl;
            Py_INCREF(p_list);
            cout << "R2 @10: " << PyFloat_AsDouble(PyList_GetItem(p_list, 7)) << endl;
            Py_INCREF(p_list);
            cout << "First pred place: " << PyFloat_AsDouble(PyList_GetItem(p_list, 8)) << endl;
            Py_INCREF(p_list);
            cout << "First test place: " << PyFloat_AsDouble(PyList_GetItem(p_list, 9)) << endl;
            Py_INCREF(p_list);
            cout << "Second pred place: " << PyFloat_AsDouble(PyList_GetItem(p_list, 10)) << endl;
            Py_INCREF(p_list);
            cout << "Second test place: " << PyFloat_AsDouble(PyList_GetItem(p_list, 11)) << endl;
            Py_INCREF(p_list);
            cout << "Third pred place: " << PyFloat_AsDouble(PyList_GetItem(p_list, 12)) << endl;
            Py_INCREF(p_list);
            cout << "Third test place: " << PyFloat_AsDouble(PyList_GetItem(p_list, 13)) << endl;
            Py_INCREF(p_list);
            cout << "Fitted data set size: " << PyFloat_AsDouble(PyList_GetItem(p_list, 14)) << endl;


            // Decrementing reference counters
            Py_DECREF(pArgs);
            Py_DECREF(p_list);
            Py_DECREF(pDomSizeList);
            Py_DECREF(pDegreeList);
            Py_DECREF(pVarOrderList);
            Py_DECREF(pVarList);
            Py_DECREF(pValList);
            Py_DECREF(pValPosList);
            Py_DECREF(pMinValDomList);
            Py_DECREF(pMaxValDomList);
            Py_DECREF(pRegretMinList);
            Py_DECREF(pRegretMaxList);
            Py_DECREF(pNodeScoreList);
            
            if (p_ml_model == NULL) {
                //Py_DECREF(p_ml_model);
                //Py_DECREF(pModule);
                Py_DECREF(pFunc);
                PyErr_Print();
                fprintf(stderr,"Call failed\n");
                return 1;
            }
        }
        Py_XDECREF(pFunc);
        Py_DECREF(pModule);
    } else {
        PyErr_Print();
        exit(EXIT_FAILURE);
    }
    return 0;
}


vector<double> DataCollector::predictML(vector<int> dom_sizes, vector<long> degrees, vector<int> var_orders, vector<int> var_posses, vector<int> values, 
                                vector<int> val_posses, vector<int> min_vals, vector<int> max_vals, vector<int> regret_min_vals,
                                vector<int> regret_max_vals) {
    number_of_predictions++;
    vector<double> predictions;
    PyObject *pFunc;
    PyObject *pArgs, *pPredictionList;
    PyObject *pModule;

    PyRun_SimpleString("import sys");
    PyRun_SimpleString("sys.path.append(\".\")");
    pModule = PyImport_Import(PyUnicode_DecodeFSDefault("ml_model"));

    if (pModule != NULL) {
        pFunc = PyObject_GetAttrString(pModule, "predictML");

        if (pFunc && PyCallable_Check(pFunc)) {

            // Prep tuples of data
            PyObject *pDomSizeTuple = PyTuple_New(dom_sizes.size());
            PyObject *pDegreeTuple = PyTuple_New(degrees.size());
            PyObject *pVarOrdersTuple = PyTuple_New(var_orders.size());
            PyObject *pVarPosTuple = PyTuple_New(var_posses.size());
            PyObject *pValuesTuple = PyTuple_New(values.size());
            PyObject *pValPosTuple = PyTuple_New(val_posses.size());
            PyObject *pMinValsTuple = PyTuple_New(min_vals.size());
            PyObject *pMaxValsTuple = PyTuple_New(max_vals.size());
            PyObject *pRegretMinTuple = PyTuple_New(regret_min_vals.size());
            PyObject *pRegretMaxTuple = PyTuple_New(regret_max_vals.size());

            for(int i = 0; i < dom_sizes.size(); ++i){
                PyTuple_SetItem(pDomSizeTuple, i, PyLong_FromLong(dom_sizes.at(i)));
                PyTuple_SetItem(pDegreeTuple, i, PyLong_FromLong(degrees.at(i)));
                PyTuple_SetItem(pVarOrdersTuple, i, PyLong_FromLong(var_orders.at(i)));
                PyTuple_SetItem(pVarPosTuple, i, PyLong_FromLong(var_posses.at(i)));
                PyTuple_SetItem(pValuesTuple, i, PyLong_FromLong(values.at(i)));
                PyTuple_SetItem(pValPosTuple, i, PyLong_FromLong(val_posses.at(i)));
                PyTuple_SetItem(pMinValsTuple, i, PyLong_FromLong(min_vals.at(i)));
                PyTuple_SetItem(pMaxValsTuple, i, PyLong_FromLong(max_vals.at(i)));
                PyTuple_SetItem(pRegretMinTuple, i, PyLong_FromLong(regret_min_vals.at(i)));
                PyTuple_SetItem(pRegretMaxTuple, i, PyLong_FromLong(regret_max_vals.at(i)));
            }

            pArgs = PyTuple_New(12);
            Py_INCREF(p_ml_model);
            Py_INCREF(p_ml_scaler);
            PyTuple_SetItem(pArgs, 0, p_ml_model);
            PyTuple_SetItem(pArgs, 1, p_ml_scaler);
            PyTuple_SetItem(pArgs, 2, pDomSizeTuple);
            PyTuple_SetItem(pArgs, 3, pDegreeTuple);
            PyTuple_SetItem(pArgs, 4, pVarOrdersTuple);
            PyTuple_SetItem(pArgs, 5, pVarPosTuple);
            PyTuple_SetItem(pArgs, 6, pValuesTuple);
            PyTuple_SetItem(pArgs, 7, pValPosTuple);
            PyTuple_SetItem(pArgs, 8, pMinValsTuple);
            PyTuple_SetItem(pArgs, 9, pMaxValsTuple);
            PyTuple_SetItem(pArgs, 10, pRegretMinTuple);
            PyTuple_SetItem(pArgs, 11, pRegretMaxTuple);


            pPredictionList = PyObject_CallObject(pFunc, pArgs);

            for(int i = 0; i < dom_sizes.size(); ++i){
                predictions.push_back(PyFloat_AsDouble(PyList_GetItem(pPredictionList, i)));
                Py_INCREF(pPredictionList);
            }

            // Decrementing reference counters
            Py_DECREF(pArgs);
            Py_DECREF(pPredictionList);
            // Py_DECREF(pDomSizeTuple);
            // Py_DECREF(pVarPosTuple);
            // Py_DECREF(pValuesTuple);
            // Py_DECREF(pValPosTuple);
            // Py_DECREF(pMinValsTuple);
            // Py_DECREF(pMaxValsTuple);
            
            if (pPredictionList == NULL) {
                Py_DECREF(pFunc);
                Py_DECREF(pModule);
                PyErr_Print();
                fprintf(stderr,"Call failed\n");
                return predictions;
            }
        }
        Py_DECREF(pModule);
        Py_XDECREF(pFunc);
    } else {
        PyErr_Print();
        exit(EXIT_FAILURE);
    }
    return predictions;

}

int DataCollector::testPythonSum(int a, int b) {
    int sum;
    PyObject *pFunc;
    PyObject *pArgs, *pSum;
    PyObject *pModule;

    cout << "I Was here 1" << endl;
    Py_Initialize();
    PyRun_SimpleString("import sys");
    PyRun_SimpleString("sys.path.append(\".\")");
    pModule = PyImport_Import(PyUnicode_DecodeFSDefault("ml_model"));
    cout << "Did we crash yet" << endl;

    if (pModule != NULL) {
        cout << "I Was here 2" << endl;
        pFunc = PyObject_GetAttrString(pModule, "testSum");

        if (pFunc && PyCallable_Check(pFunc)) {
            pArgs = PyTuple_New(2);
            PyTuple_SetItem(pArgs, 0, PyLong_FromLong(a));
            PyTuple_SetItem(pArgs, 1, PyLong_FromLong(b));    
            cout << "I Was here 3" << endl;         

            pSum = PyObject_CallObject(pFunc, pArgs);
            sum = PyLong_AsLong(pSum);
            cout << "I Was here 4" << endl;

            // Decrementing reference counters
            Py_DECREF(pArgs);
            Py_DECREF(pSum);
            
            if (pSum == NULL) {
                Py_DECREF(pSum);
                Py_DECREF(pFunc);
                PyErr_Print();
                fprintf(stderr,"Call failed\n");
                return sum;
            }
        }
        Py_DECREF(pModule);
        Py_XDECREF(pFunc);
    } else {
        PyErr_Print();
        return -1;
    }
    return sum;
}


int DataCollector::finalizeML() {
    Py_FinalizeEx();
    return 1;
}