#include <gecode/brancher/ml-tree.hh>
#include <gecode/brancher/ml-utils.cpp>

using namespace std;

MLTree::MLTree(MLNode* root_node) {
    root = root_node;
    root->setPositionInTree(0);
    nodes.push_back(root_node);
    height = 1;
    lastAssignmentSize = 1;
    last_node = root_node;
    smallest_nodes_counted = 0;
    regret_nodes_counted = 0;
    dom_size_nodes_counted = 0;
    current_pos = 0;
}

vector<MLNode*> MLTree::getNodes() {
    return nodes;
}

int MLTree::getHeight() {
    return height;
}

void MLTree::setHeight(int tree_height) {
    height = tree_height;
}

int MLTree::getLastAssignmentSize() {
    return lastAssignmentSize;
}

void MLTree::setLastAssignmentSize(int size) {
    lastAssignmentSize = size;
}

MLNode* MLTree::getLastNode() {
    return last_node;
}

bool MLTree::checkNodeByAssignment(vector<int> assignment) {
    for(int i = 0; i < nodes.size(); ++i){
        if(nodes.at(i)->getAssigned() == assignment){
            return true;
        }
    }
    return false;
}

MLNode* MLTree::getNodeByAssignment(vector<int> assignment) {
    for(int i = 0; i < nodes.size(); ++i){
        if(nodes.at(i)->getAssigned() == assignment){
            return nodes.at(i);
        }
    }
    cerr << "Node could not be found by assignment" << endl;
    return nullptr;
}

void MLTree::setLastNode(MLNode* node) {
    last_node = node;
}

bool MLTree::storeNode(MLNode* node) {
    // First check if node can be normally stored as a child in the tree
    bool ret = recursiveNodeStore(root, node);
    // Check for corner cases where values were skipped
    if(!ret && !nodes.empty()){      
        int size = 0;
        if(node->getAssigned().size() < getLastNode()->getAssigned().size()){
            size = node->getAssigned().size();
        } else {
            size = getLastNode()->getAssigned().size();
        }
        if(node->getAssigned() == root->getAssigned())
            return false;
        
        vector<int> assignments_iter;
        for(int i = 0; i < size; ++i){
            assignments_iter.push_back(getLastNode()->getAssigned().at(i));
            if(getLastNode()->getAssigned().at(i) != node->getAssigned().at(i)){
                // A new tree should be created
                if(i == 0){
                    return false;
                }

                // If assignment sizes are the same and only the last value differs then simply add as child of parent of lastNode
                if(i == getLastNode()->getAssigned().size() - 1 && getLastNode()->getAssigned().size() == node->getAssigned().size()){
                    // The last node was the root
                    if(getLastNode()->getParent() == nullptr)
                        return false;    
                    storeInTree(getLastNode()->getParent(), node);
                    return true;
                } 
                if(!checkNodeByAssignment(assignments_iter))
                    return false;
                MLNode* nod = getNodeByAssignment(assignments_iter);
                if(nod->getAssigned() == root->getAssigned())
                    return false;
                storeInTree(nod->getParent(), node);
                return true;
            }
        }

        // Store node having in mind that values were skipped by solver
        if(node->getAssigned().size() > getLastNode()->getAssigned().size()){
            storeInTree(getLastNode(), node);
            return true;
        }
    }
    return ret;
}

bool MLTree::recursiveNodeStore(MLNode* node2check, MLNode* node) {
    // Check for parent-child relation.
    vector<int> old_assigned(node->getAssigned());
    old_assigned.pop_back();
    
    // If the assignment array is the same there is a child-parent relation
    if(node2check->getAssigned() == old_assigned)
        return storeAsChild(node2check, node);

    // Recursive call
    vector<MLNode*> children = node2check->getChildren();
    for(int i = 0; i < children.size(); ++i) {
        if(recursiveNodeStore(children.at(i), node))
            return true;
    }
    
    return false;   
}

bool MLTree::storeAsChild(MLNode* parent, MLNode* child) {
    vector<MLNode*> children = parent->getChildren();

    // Check whether child was already stored in the tree
    for(int i = 0; i < children.size(); ++i){
        if (children.at(i)->getAssigned() == child->getAssigned() && children.at(i)->getValue() == child->getValue())
            return false;
    }
    storeInTree(parent, child);

    return true;
}

void MLTree::storeInTree(MLNode* parent, MLNode* child) {
    // Store height in node and check for max height tree
    int pos = parent->getPositionInTree() + 1;
    child->setPositionInTree(pos);
    if((pos + 1) > getHeight())
        setHeight(pos + 1);

    child->linkChildToParent(parent);
    nodes.push_back(child);
    setLastNode(child);
    setLastAssignmentSize(child->getAssigned().size());
}

/// Compute Deep Impact

void MLTree::computeDeepImpact(int depth) {
    // Compute deep impact for each node in the tree 
    computeAllImpact(root, depth);
}

void MLTree::computeAllImpact(MLNode* node, int depth) {
    // Call actual Impact function (naive implementation) for the node
    double impact = round(recursiveImpact(node, depth, 0) * 10000.0) / 10000.0;
    //cout << "Computed impact score of: " << impact << endl;
    node->setNodeScore( impact );
    //cout << "This node score was set: " << node->getNodeScore() << endl;
    
    for(int i = 0; i < node->getChildren().size(); ++i){
        computeAllImpact(node->getChildren().at(i), depth);
    }
}

// Deep Impact
//
// Opt(x,v,k) = 1/k * I_{y,v} + 1 / |M| * SUM_{v in M} ( Opt(y,v,k+1) )
//
// Where 
//  x : The variable to assign a value to
//  k : current level in the tree
//  y : the next available unassigned variable
//  M : set of explored assignments
//  v : value to assign to the variable
//
// Dynamic programming (memoization) could be used here
double MLTree::recursiveImpact(MLNode* node, int depth, int cur_depth) {
    double impact = 0.0;

    // Stop when we reach depth
    if(cur_depth > depth)
        return 0.0;

    if(cur_depth != 0) {
        // Compute impact score between parent and child
        impact = MLUtils::computeImpact(node->getParent()->getDegree(), node->getDegree());
        // Discount factor
        impact *= (1.0 / (double)cur_depth);
        //cout << " * " << (1.0 / (double)cur_depth) << " = " << impact << endl;
    }

    double children_impact = 0.0;
    int number_of_children = node->getChildren().size();
    if (number_of_children == 0)
        return impact;

    for(int i = 0; i < number_of_children; ++i) {
        // Recursive call to next level in the tree
        children_impact += recursiveImpact(node->getChildren().at(i), depth, cur_depth + 1);
        //cout << "Computed children impact, should not be nan: " << children_impact << endl;
    }

    // Return node's impact + average of impact of children
    //cout << "Computed partly impact, should not be nan: " << ((double)children_impact / (double)number_of_children) << endl;
    return impact + (children_impact / (double)number_of_children);
}

// Smallest value in domain

void MLTree::computeDeepSmallest(int depth) {
    // Compute deep impact for each node in the tree 
    computeAllSmallest(root, depth);
}

void MLTree::computeAllSmallest(MLNode* node, int depth) {
    
    if(height - node->getPositionInTree() < depth){
        node->setNodeScore(0);
    } else {
        //cout << "Tree height: " << height << " pos in tree: " << node->getPositionInTree() << endl; 
        // Call actual Impact function (naive implementation) for the node
        smallest_nodes_counted = 0;
        current_pos = 0;
        double smallest = round(recursiveSmallest(node, depth, 0)/(double)smallest_nodes_counted * 10000.0) / 10000.0;
        //cout << "Node min val: " << node->getMinValDom() << " Deep Smallest: " << smallest << "Number of nodes counted: " << smallest_nodes_counted << endl;
        if(current_pos < depth){
            node->setNodeScore( 0 );
        } else {
            node->setNodeScore( smallest );
        }
        
        //cout << "This node score was set: " << node->getNodeScore() << endl;
        
        for(int i = 0; i < node->getChildren().size(); ++i){
            computeAllSmallest(node->getChildren().at(i), depth);
        }
    }
}

// Deep Smallest
//
// Opt(x,v,k) = ( S_d + SUM_{v in M} ( Opt(y,v,k+1) ) ) / #nodes_counted
//
// Where 
//  x : The variable to assign a value to
//  k : current level in the tree
//  y : the next available unassigned variable
//  M : set of explored assignments
//  v : value to assign to the variable
//  S_d : smallest value in the domain
//
// Dynamic programming (memoization) could be used here
double MLTree::recursiveSmallest(MLNode* node, int depth, int cur_depth) {
    if(current_pos < cur_depth)
        current_pos = cur_depth;
    //cout << "Been here " << smallest_nodes_counted << endl;
    double smallest = 0.0;

    // Stop when we reach depth
    if(cur_depth == depth)
        return 0.0;

    smallest = (double)node->getMinValDom();
    smallest_nodes_counted++;
    

    double children_smallest = 0.0;
    int number_of_children = node->getChildren().size();
    if (number_of_children == 0)
        return smallest;

    for(int i = 0; i < number_of_children; ++i) {
        // Recursive call to next level in the tree
        children_smallest += recursiveSmallest(node->getChildren().at(i), depth, cur_depth + 1);
        //cout << "Computed children impact, should not be nan: " << children_impact << endl;
    }

    // Return node's impact + average of impact of children
    //cout << "Computed partly impact, should not be nan: " << ((double)children_impact / (double)number_of_children) << endl;
    return smallest + children_smallest;
}

// Max regret value in domain

void MLTree::computeDeepMaxRegret(int depth) {
    // Compute deep impact for each node in the tree 
    computeAllMaxRegret(root, depth);
}

void MLTree::computeAllMaxRegret(MLNode* node, int depth) {
    
    if(height - node->getPositionInTree() < depth){
        node->setNodeScore(0);
    } else {
        //cout << "Tree height: " << height << " pos in tree: " << node->getPositionInTree() << endl; 
        // Call actual Impact function (naive implementation) for the node
        regret_nodes_counted = 0;
        current_pos = 0;
        double regret = round(recursiveMaxRegret(node, depth, 0)/(double)regret_nodes_counted * 10000.0) / 10000.0;
        //cout << "Node min val: " << node->getMinValDom() << " Deep Smallest: " << smallest << "Number of nodes counted: " << smallest_nodes_counted << endl;
        if(current_pos < depth){
            node->setNodeScore( 0 );
        } else {
            node->setNodeScore( regret );
        }
        
        //cout << "This node score was set: " << node->getNodeScore() << endl;
        
        for(int i = 0; i < node->getChildren().size(); ++i){
            computeAllMaxRegret(node->getChildren().at(i), depth);
        }
    }
}

// Deep Maximum regret
//
// Opt(x,v,k) = ( S_d + SUM_{v in M} ( Opt(y,v,k+1) ) ) / #nodes_counted
//
// Where 
//  x : The variable to assign a value to
//  k : current level in the tree
//  y : the next available unassigned variable
//  M : set of explored assignments
//  v : value to assign to the variable
//  S_d : Largest difference between two smallest values
//
// Dynamic programming (memoization) could be used here
double MLTree::recursiveMaxRegret(MLNode* node, int depth, int cur_depth) {
    if(current_pos < cur_depth)
        current_pos = cur_depth;
    //cout << "Been here " << smallest_nodes_counted << endl;
    double smallest = 0.0;
    double second_smallest = 0.0;

    // Stop when we reach depth
    if(cur_depth == depth)
        return 0.0;

    double max_regret = (double)node->getRegretMin();
    regret_nodes_counted++;
    

    double children_regret = 0.0;
    int number_of_children = node->getChildren().size();
    if (number_of_children == 0)
        return max_regret;

    for(int i = 0; i < number_of_children; ++i) {
        // Recursive call to next level in the tree
        children_regret += recursiveMaxRegret(node->getChildren().at(i), depth, cur_depth + 1);
        //cout << "Computed children impact, should not be nan: " << children_impact << endl;
    }

    // Return node's impact + average of impact of children
    //cout << "Computed partly impact, should not be nan: " << ((double)children_impact / (double)number_of_children) << endl;
    return max_regret + children_regret;
}

// Max regret value in domain

void MLTree::computeAntiFF(int depth) {
    // Compute deep impact for each node in the tree 
    computeAllAntiFF(root, depth);
}

void MLTree::computeAllAntiFF(MLNode* node, int depth) {
    
    if(height - node->getPositionInTree() < depth){
        node->setNodeScore(0);
    } else {
        //cout << "Tree height: " << height << " pos in tree: " << node->getPositionInTree() << endl; 
        // Call actual Impact function (naive implementation) for the node
        dom_size_nodes_counted = 0;
        current_pos = 0;
        double regret = round(recursiveAntiFF(node, depth, 0)/(double)dom_size_nodes_counted * 10000.0) / 10000.0;
        //cout << "Node min val: " << node->getMinValDom() << " Deep Smallest: " << smallest << "Number of nodes counted: " << smallest_nodes_counted << endl;
        if(current_pos < depth){
            node->setNodeScore( 0 );
        } else {
            node->setNodeScore( regret );
        }
        
        //cout << "This node score was set: " << node->getNodeScore() << endl;
        
        for(int i = 0; i < node->getChildren().size(); ++i){
            computeAllAntiFF(node->getChildren().at(i), depth);
        }
    }
}

// Deep Anti First Fail
//
// Opt(x,v,k) = ( D_s + SUM_{v in M} ( Opt(y,v,k+1) ) ) / #nodes_counted
//
// Where 
//  x : The variable to assign a value to
//  k : current level in the tree
//  y : the next available unassigned variable
//  M : set of explored assignments
//  v : value to assign to the variable
//  D_s : Domain size
//
// Dynamic programming (memoization) could be used here
double MLTree::recursiveAntiFF(MLNode* node, int depth, int cur_depth) {
    if(current_pos < cur_depth)
        current_pos = cur_depth;
    double dom_size = 0.0;

    // Stop when we reach depth
    if(cur_depth == depth)
        return 0.0;

    dom_size = (double)node->getDomainSize();
    dom_size_nodes_counted++;

    double children_dom_size = 0.0;
    int number_of_children = node->getChildren().size();
    if (number_of_children == 0)
        return dom_size;

    for(int i = 0; i < number_of_children; ++i) {
        // Recursive call to next level in the tree
        children_dom_size += recursiveAntiFF(node->getChildren().at(i), depth, cur_depth + 1);
        //cout << "Computed children impact, should not be nan: " << children_impact << endl;
    }

    // Return node's impact + average of impact of children
    //cout << "Computed partly impact, should not be nan: " << ((double)children_impact / (double)number_of_children) << endl;
    return dom_size + children_dom_size;
}

int MLTree::countLeafs() {
    int leaf_count = 0;

    leaf_count += countLeafsRec(root);

    return leaf_count;
}

int MLTree::countLeafsRec(MLNode* node) {
    if(node->isLeaf()){
        return 1;
    } else {
        int leafSum = 0;
        for(int i = 0; i < node->getChildren().size(); ++i) {
            leafSum += countLeafsRec(node->getChildren().at(i));
        }
        return leafSum;
    }
}

void MLTree::toString() {
    cout << "Tree h=(" << getHeight() << ") s=(" << nodes.size() << "): " << root->getValue();
    toStringRec(root);
    cout << endl;
}

void MLTree::toStringRec(MLNode* node) {
    if(!node->getChildren().empty()){
        cout << " {";
        for(int i = 0; i < node->getChildren().size(); ++i) {
            MLNode* child = node->getChildren().at(i);
            cout << " " << child->getValue() << " ";
            toStringRec(child);
        }
        cout << "} ";
    }
}