import networkx as nx
import random
import statistics as st
from networkx.algorithms.dag import topological_sort,descendants
import numpy as np
from scheduling.networks import Network_patterson
from itertools import combinations
import gurobipy as gu
from gurobipy import quicksum as summ
import scipy as sp
from scipy.optimize import least_squares



"""

PARAMETERS
 - random_delete : Fraction of edges from input network which is deleted randomly
 - max_diff: Maximum relative workload difference per selection group
 - successors_max: Max amount of successors per group
 - distance_target: target fraction of edges either in prec graph or select graph, but not in both
 - groups_target: Target fraction of nodes for which a selection group is generated
 - i2_target: Target for i2, see RanGen for more information

DESCRIPTION
- Create precedence input network from patterson
- Create full network for selection network
- remove fraction random_delete of edges randomly
- Greedily remove edges until distance and i2 cannot be improved
- Remove triangles in precedence graph
- get candidate groups 
    - Select node randomly
    - pick n (= rand(2, successors_max)) random successors candidates from all descendants in selection graph
    - For each subset S \subseteq successor candidates: (shuffled)
        - Get duration until they come together again.
        - Delete successor if duration is too large (based on max_diff)
        - if |S| > 1, add group to candidate groups, add edges to forbidden edges and break
        
- select candidate groups with lowest stdev in normalized duration (stdev/mean)
- scale work content by durations

"""


class Instance_creator:
    def __init__(self):
        self.parameters = {
            'random_delete' : 0.05,
            'max_diff' : 0.25,
            'distance_target' : 0.05,
            'i2_target' : 0.5,
            'groups_target' : 0.2,
            'successors_max' : 5,
        }
        self.verbose = True




    def create_instance(self,networkfile,params = None):
        if params is None:
            params = {}

        #Add default parameters
        for key,val in self.parameters.items():
            if key not in params:
                params[key] = val

        result = {}

        #Create precedence network
        network = Network_patterson()
        network.read_patterson(networkfile)
        G_p = network.G

        N_num = len(G_p.nodes)
        N_list = list(range(N_num))
        N = set(range(N_num))
        total_edges = (N_num**2-N_num)/2
        random_delete = round(params['random_delete']*(N_num**2-N_num)/2)
        groups_num = round(params['groups_target']*N_num)


        #Create full network
        topsort_pos = [0 for i in N]
        topsort = list(topological_sort(G_p))
        for p,i in enumerate(topological_sort(G_p)):
            topsort_pos[i] = p
        G_s = nx.DiGraph()
        G_s.add_nodes_from(N)
        G_s.add_edges_from([(i,j) for i in N for j in N if topsort_pos[i] < topsort_pos[j]])


        #Remove random edges
        for c in range(random_delete):
            candidate_edges = [(i,j) for i in N for j in N if (i,j) in G_s.edges and G_s.out_degree[i] > 1 and G_s.in_degree[j] > 1]
            if len(candidate_edges) == 0:
                break
            else:
                edge= random.choice(candidate_edges)
                G_s.remove_edge(*edge)



        ## Remove edges until distance and i2 cannot be improved anymore
        while True:
            candidate_edges = [(i,j) for i in N for j in N if (i,j) in G_s.edges and G_s.out_degree[i] > 1 and G_s.in_degree[j] > 1]
            if self.verbose:
                print('\t candidate edges',len(candidate_edges),end='\r')

            pl,i2 = self.get_i2(G_s,topsort)
            distance =  self.get_distance(G_s,G_p)

            new_scores = {}
            for i,j in candidate_edges:
                #Calculate new distance
                distance_new = distance
                if (i,j) in G_p.edges:
                    distance_new  += 1/total_edges
                else:
                    distance_new  -= 1/total_edges

                i2_new = self.get_updated_i2(G_s,pl,(i,j))

                #Get distance
                score =  (abs(i2-params['i2_target']) - abs(i2_new-params['i2_target']) ,abs(distance-params['distance_target'])- abs(distance_new-params['distance_target']) )
                if (score[0] >= 0 and score[1] >= 0):
                    new_scores[i,j] = sum(score)

            if len(new_scores) > 0:
                #select edge
                edge = max(new_scores,key=new_scores.get)
                G_s.remove_edge(*edge)
            else:
                break


        #Get full successors per node, sorted by topsort
        full_successors = [sorted(list(descendants(G_s,i)),key=lambda x: topsort_pos[x]) for i in N]

        #Remove triangles
        while True:
            to_remove = None
            for i,j in G_s.edges:
                other_successors = [k for k in G_s.successors(i) if k != j]
                if any(j in descendants(G_s,k) for k in other_successors):
                    #print('remove ',(i,j),'successors',other_successors)
                    to_remove = (i,j)
                    break
            if to_remove is None:
                break


            G_s.remove_edge(*to_remove)

        ## Get candidate groups
        candidate_groups = []
        forbidden_edges = set()
        node_candidates = N_list[:-1]



        while len(candidate_groups) < N_num:
            if len(node_candidates) == 0:
                break


            i = random.choice(node_candidates)
            all_successors = [j for j in G_s.successors(i) if (i,j) not in forbidden_edges]
            if len(all_successors) <= 1:
                node_candidates.remove(i)
                continue


            group_descendants = self.get_unique_descendants(all_successors,full_successors)
            while len(group_descendants) > 0 and any(len(desc) == 0 for desc in group_descendants.values()):
                to_delete = [node for node,desc in group_descendants.items() if len(desc) == 0]
                for j in to_delete:
                    #del group_descendants[j]
                    all_successors.remove(j)
                group_descendants = self.get_unique_descendants(all_successors,full_successors)


            if len(group_descendants) <= 1:
                node_candidates.remove(i)
                continue

            successors_combos = []
            for successors_num in range(2,params['successors_max']+1):
                successors_combos.extend(list(combinations(all_successors,successors_num)))

            random.shuffle(successors_combos)

            for successors in successors_combos:
                successors = list(successors)
                while len(successors) > 1:
                    node_descendants = self.get_group_descendants(successors,full_successors)
                    durations = [sum([network.durations[j] for j in node_descendants[succ]]) for succ in successors]
                    duration_mean = st.mean(durations)

                    to_remove = []
                    for idx,succ in enumerate(successors):
                        diff = abs((durations[idx]-duration_mean)/duration_mean) / len(durations)
                        if diff > params['max_diff']:
                            to_remove.append(idx)

                    to_remove.sort(reverse=True)
                    if len(to_remove) == 0:
                        break
                    for idx in to_remove:
                        del successors[idx]



                if len(successors) > 1:
                    candidate_groups.append((i,successors))
                    for j in successors:
                        forbidden_edges.add((i,j))
                    break
            else:
                node_candidates.remove(i)
                continue



        result['group_candidates']  = candidate_groups.copy()


        groups = []
        while len(candidate_groups) > 0 and len(groups) < groups_num:

            scores = []
            for group_idx,(activator,successors) in enumerate(candidate_groups):
                #Get score for group
                node_descendants = self.get_group_descendants(successors,full_successors)
                group_scores = []
                durations = []
                for i in successors:
                    duration = sum([network.durations[j] for j in node_descendants[i]])
                    durations.append(duration)
                    score = 1
                    group_scores.append(score)

                #Remove successors
                scores.append(st.stdev(durations)/st.mean(durations))

            group_idx = np.argmin(scores)
            if self.is_feasible(G_s,groups + [candidate_groups[group_idx]]):
                groups.append(candidate_groups[group_idx])
            del candidate_groups[group_idx]

        result['groups'] = groups


        # Scale work content
        G = list(range(len(groups)))
        group_nodes = []
        for group in groups:
            desc = self.get_group_descendants(group[1],full_successors)
            group_nodes.append([desc[succ] for succ in group[1]])


        branch_durations = [ [sum(network.durations[node] for node in nodes) for nodes in group_nodes[g]] for g in G]



        S_g = [group[1] for group in groups]
        branch2coeff = {}
        coeff2branch = {}
        coeff = 0
        for g in G:
            for b in range(len(S_g[g])):
                branch2coeff[g,b] = coeff
                coeff2branch[coeff] = (g,b)
                coeff += 1

        node_coeffs = []
        for i in N:
            node_coeffs.append([])
            for coeff,(g,b) in coeff2branch.items():
                if i in group_nodes[g][b]:
                    node_coeffs[-1].append(coeff)




        for r in network.resources:
            upperbounds = [0 for coeff in coeff2branch]
            for coeff,(g,b) in coeff2branch.items():
                upperbounds[coeff] = min(network.resource_availability[r]/network.resource_usage[r,n] for n in group_nodes[g][b])



            def equations(coeffs):
                eqs = []

                total_usage = sum( sp.product([coeffs[coeff] for coeff in node_coeffs[i]]) * network.resource_usage[r,i] for i in N)
                eqs.append(len(G)*10 * (total_usage- sum([network.resource_usage[r,i] for i in N])))


                for g in G:
                    branches_num = len(S_g[g])
                    for branch1 in range(branches_num):
                        for branch2 in range(branches_num):
                            if branch1 < branch2:
                                #Combination of branch
                                res_final_1 = sum( sp.product([coeffs[coeff] for coeff in node_coeffs[i]]) * network.resource_usage[r,i] for i in group_nodes[g][branch1])
                                res_final_2 = sum( sp.product([coeffs[coeff] for coeff in node_coeffs[i]]) * network.resource_usage[r,i] for i in group_nodes[g][branch2])
                                eqs.append(res_final_1 * branch_durations[g][branch1] - res_final_2 * branch_durations[g][branch2])

                return eqs

            sol = least_squares(equations,[1 for i in coeff2branch])

            for i in N:
                for coeff in node_coeffs[i]:
                    network.resource_usage[r,i] = min(network.resource_usage[r,i]*sol['x'][coeff],network.resource_availability[r])



            for i in N:
                network.resource_usage[r,i] = round(network.resource_usage[r,i])





        #Create ins
        ins = Instance()
        ins.N = set(G_s.nodes)
        ins.P = list(G_p.edges)

        added_edges = set()
        for g,group in enumerate(groups):
            ins.G.append(g)
            ins.a_g.append(group[0])
            ins.S_g.append(set(group[1]))
            added_edges |= {(group[0],j) for j in group[1]}

        for edge in G_s.edges:
            if edge not in added_edges:
                g = len(ins.G)
                added_edges.add(edge)
                ins.G.append(g)
                ins.a_g.append(edge[0])
                ins.S_g.append({edge[1]})


        for duration in network.durations:
            ins.d.append(duration)

        ins.n0 = 0
        ins.ne = len(ins.N)-1


        #Set resources
        ins.Rr = network.resources
        ins.r = network.resource_usage
        ins.lmbda = [network.resource_availability[r] for r in ins.R]

        ins.k = SparseDict()
        for (r,n),usage in network.resource_usage.items():
            ins.k[r,n] = usage


        ins.init()
        result['ins'] = ins

        return result



    def get_i2(self,G,topsort,N = None):
        if N is None:
            N  = G.nodes

        if len(N) == 1:
            return 1
        #Get PL
        pl = [1 for i in G.nodes]
        for i in topsort:
            predecessors = set(G.predecessors(i)) & set(N)
            if len(predecessors) > 0:
                pl[i] = max(pl[j] for j in predecessors)+1
        n = len(N)
        if len([pl[i] for i in N]) == 0:
            m = 1
        else:
            m = max([pl[i] for i in N])
        return pl,(m-1)/(n-1)

    def get_updated_i2(self,G,pl,edge):
        pl = pl.copy()
        N = G.nodes

        nodes = {edge[1]}
        while len(nodes) > 0:
            node = nodes.pop()
            if node == edge[1]:
                predecessors = {i for i in G.predecessors(node) if i != edge[0]}
            else:
                predecessors = set(G.predecessors(node))

            pl_new = max(pl[j] for j in predecessors)+1
            if pl_new < pl[node]:
                pl[node] = pl_new
                nodes |= set(G.successors(node))

        if len([pl[i] for i in N]) == 0:
            m = 1
        else:
            m = max([pl[i] for i in N])
        n = len(N)
        return (m-1)/(n-1)


    def get_distance(self,G_s,G_p,N = None):
        if N is None:
            N = G_s.nodes
        nodes_num = len(N)

        if nodes_num <= 1:
            return 0

        edges_to_create = {(i,j) for i,j in G_s.edges if (i,j) not in G_p.edges and i in N and j in N}
        edges_to_delete = {(i,j) for i,j in G_p.edges if (i,j) not in G_s.edges and i in N and j in N}
        total_edges = (nodes_num**2-nodes_num)/2
        distance = len(edges_to_create) + len(edges_to_delete)

        normalized_distance = distance/total_edges
        return normalized_distance

    def get_group_descendants(self,nodes,descendants):
        unique_descendants = {}
        for j in nodes:
            descendants_j = []
            for k in [j] + descendants[j]:
                if all(k in descendants[other_successor] for other_successor in nodes if other_successor != j):
                    break
                descendants_j.append(k)
            unique_descendants[j] = descendants_j
        return unique_descendants


    def get_unique_descendants(self,nodes,descendants):
        unique_descendants = {}
        for j in nodes:
            descendants_j = []
            for k in [j] + descendants[j]:
                if any(k in descendants[other_successor] for other_successor in nodes if other_successor != j):
                    break
                descendants_j.append(k)
            unique_descendants[j] = descendants_j
        return unique_descendants

    def is_feasible(self,graph,groups):
        G = []
        a_g = []
        S_g = []
        edges_used = set()
        for g,(act,succ) in enumerate(groups):
            G.append(g)
            a_g.append(act)
            S_g.append(succ)
            edges_used |= {(act,j) for j in succ}

        for i,j in graph.edges:
            if (i,j) not in edges_used:
                g = len(G)
                G.append(g)
                a_g.append(i)
                S_g.append({j})

        #Get feasibility
        m = gu.Model('mip')
        x = m.addVars([i for i in graph.nodes],name='x',vtype=gu.GRB.BINARY)
        m.addConstr(x[0] == 1)

        for g in G:
            m.addConstr( x[a_g[g]] <= summ(x[j] for j in S_g[g]) )
            m.addConstr( summ(x[j] for j in S_g[g]) <= 1)

        m.setParam('OutputFlag',0)
        m.optimize()
        return m.status == 2


class SparseDict(dict):
    def __init__(self,default=0):
        self.default = default

    def __getitem__(self,key):
        item =  self.get(key)
        return self.default if item is None else item


    def __setitem__(self,key,value):
        if value != self.default:
            super(SparseDict, self).__setitem__(key, value)
        elif self.get(key) is not None:
            del self[key]


class Instance:
    def __init__(self):
        self.solution = None

        #network
        self.N = [] #nodes/activities
        self.n0 = None #starting node
        self.ne = None #ending node
        self.P = None #Precedence relationships. Set of tuples
        self.G = [] #or groups (g)
        self.a_g = []#activator nodes (indexed by g)
        self.S_g = []#Activated nodes. List of sets (indexed by g)


        #resources
        self.Rr = set() #Renewable resources
        self.Rc = set() #Cumulative resources
        self.lmbda = [] #resource capacity (indexed by resources r)
        self.k = SparseDict() #resource usage (indexed by (r,n))

        #Time
        self._T = range(0) #All time periods
        self.T_n = [] #Time periods for activity n (indexed by activities n)
        self.d = [] #durations (indexed by n)

        #plotting
        self._pos = None

        #other
        self._choiceList = None
        self.cuts_verbose = False #Whether to show gurobi output for creating cuts
        self.verbose = False
        self.preprocess_timehorizon = True #Whether T_n is automatically set if T is set
        self.cut_objectives = {}
        self.max_cut_iterations= {'combined' : 100,'combined_nochange' : 10}
        self.cut_tolerance = 0.0001
        self.use_choicelist = False #If a choice list is used for a heuristic
        self.has_choicelist = None # If a choicelist was successfully calculated
        self.combined_cuts_groups = None #Can be set to calculate the combined cutting planes earlier, with calculate_combined_cuts()
        self.combined_cutting_planes = ['lastnode','time']
        self.rp_edges = None #storage placeholder
        self._inv_rp_edges = None
        self.ub_milp = None #Storage for ub milp



    def init(self):
        """
        Inits multiple parameters.

        Call this function after setting the environment variables

        """
        topsort = list(self.topsort_prec())
        self.n0 = topsort[0]
        self.ne = topsort[-1]

        self.timeprec_i = [{j for j, k in self.P if i == k} for i in self.N]  #all time preceding of i
        self.timesucc_i = [{k for j, k in self.P if i == j} for i in self.N]  #all time preceding of i

        #Create logical successor maps
        self.activator2successors = [set() for i in self.N]
        for g in self.G:
            self.activator2successors[self.a_g[g]] |= self.S_g[g]


        #Create various mappings: input2output[input] = output
        self.node2activators = [set() for i in self.N]
        self.activator2groups = [set() for i in self.N]
        for g in self.G:
            for i in self.S_g[g]:
                self.node2activators[i].add(self.a_g[g])

            for i in self.N:
                if self.a_g[g] == i:
                    self.activator2groups[i].add(g)

        #Get succeeding groups
        self.G_succ = [g for g in self.G if all([(self.a_g[g],j) in self.P for j in self.S_g[g]])]

        self.edge2groups = {(i,j) : set() for i in self.N for j in self.N}
        self.successor2groups = {(i): set() for i in self.N}
        for g in self.G:
            for j in self.S_g[g]:
                self.edge2groups[self.a_g[g],j].add(g)
                self.successor2groups[j].add(g)


        #Full time predecessors
        self.full_timeprec = [set() for n in self.N]
        open_nodes = {self.n0}
        while len(open_nodes) > 0:
            i = open_nodes.pop()
            for j in self.timesucc_i[i]:
                self.full_timeprec[j] |= {i} | self.full_timeprec[i]
                open_nodes.add(j)

        #get full time successors
        self.full_timesucc = [set() for n in self.N]
        for j in self.N:
            for i in self.full_timeprec[j]:
                self.full_timesucc[i].add(j)

        #Get full selection-predecessors
        self.full_selpred = [set() for n in self.N]
        open_nodes = {self.n0}
        while len(open_nodes) > 0:
            i = open_nodes.pop()
            for j in self.activator2successors[i]:
                self.full_selpred[j] |= {i} | self.full_selpred[i]
                open_nodes.add(j)

        #Get full selection successors
        self.full_selsucc = [set() for n in self.N]
        for j in self.N:
            for i in self.full_selpred[j]:
                self.full_selsucc[i].add(j)

    def writeToFile(self,fileName=None):
        text = ""
        text += "{0} {1} {2} \n".format(len(self.N),len(self.Rr),len(self.Rc)) #Number of activities - Number of renewable resources - number of cumulative resources

        #resources
        text += "".join(["{0} ".format(self.lmbda[r]) for r in self.R]) +"\n\n"

        for n in self.N:

            #resource usage
            text += '{0} '.format(self.d[n]) +' '.join([str(self.k[r,n]) for r in self.R]) + "\n"


            #or groups
            groups = [g for g in self.G if self.a_g[g] == n]
            text += str(len(groups)) + " "
            for g in groups:
                text += "{0} {1} ".format(len(self.S_g[g])," ".join(str(succ) for succ in self.S_g[g]))
            text += "\n"

            #Time successors
            successors = [j for i,j in self.P if i == n]
            text += "{0} {1}\n\n".format(len(successors),' '.join([str(node) for node in successors]))

        if fileName is not None:
            f = open(fileName,'w')
            f.write(text)
            f.close()

        return text


if __name__ == '__main__':
    #Example
    nodes = 30
    rc = 0.5 #Resource constrainedness, used in RanGen
    i2 = 0.4 #i2, used in Rangen


    #Get patterson file, created with Rangen2
    filename = 'pat_networks/Pat_{0}_{1}_{2}_{3}.rcp'.format(nodes,4,int(rc*100),int(i2*100))
    ic = Instance_creator()
    params = {}
    params['i2_target'] = i2

    result = ic.create_instance(filename,params)

    #save to file
    result['ins'].writeToFile('instances/test.txt')

