import random
import numpy as np
import gurobipy as gu
from gurobipy import quicksum as summ
from itertools import product
import networkx as nx

class Instance:

    def __init__(self):
        self.rp_edges = None
        self._choiceList = None


    def read_from_file(self,file):
        self.T = range(500)

        self.P = []
        with open(file) as f:
            lines = f.readlines()

        data = [[int(number) if number.lstrip('-').isdigit() else float(number) for number in line.split()] for line in
                lines]

        #first line: nodes and resources
        self.N = range(data[0][0])
        self.T_n = [self.T for i in self.N]

        self.Rr = set(range(data[0][1]))
        self.Rc = set(range(data[0][1],data[0][1]+data[0][2]))
        self.R = self.Rr | self.Rc

        #second line: resource availabilities:
        self.lmbda = data[1]
        self.d = []
        self.k = np.zeros((len(self.Rr),len(self.N)))
        self.k_cons = SparseDict()
        self.k_prod = SparseDict()
        self.G = []
        self.a_g = []
        self.S_g = []
        self.n0 = 0

        n = 0
        for line in range(3,len(data),4):
            self.d.append(data[line][0])

            #first line per node: resources
            for r in self.Rr:
                self.k[r,n] = data[line][r+1]

            for r in self.Rc:
                self.k_cons[r,n] = data[line][2*r-len(self.Rr)+1]
                self.k_prod[r,n] = data[line][2*r-len(self.Rr)+2]


            #second line per node: Or groups
            or_groups_num = data[line+1][0]
            i = 1
            group_size_left = 0
            groups_created = 0

            while i < len(data[line+1]) and or_groups_num > 0:
                if group_size_left == 0:
                    group_size_left = data[line+1][i]
                    current_group = len(self.G)
                    self.G.append(current_group)
                    self.a_g.append(n)
                    self.S_g.append(set())
                    groups_created += 1
                else:
                    self.S_g[current_group].add(data[line+1][i])
                    group_size_left -= 1

                i += 1

            errorStr = "Mismatch in or-groups of node {0}. {1} created, but {2} expected \n {3}".format(n,groups_created,or_groups_num,data[line])
            assert groups_created == or_groups_num,errorStr


            #Third line, time precedence
            time_prec_num = data[line+2][0]
            assert time_prec_num == len(data[line+2][1:]),"Mismatsch in time prec of node {0}. {1} created, but {2} expected".format(n,len(data[line+2][1:]),time_prec_num)
            self.P.extend([(n,j) for j in data[line+2][1:]])

            n+= 1

        self.ne = n-1


        self.timeprec_i = [{j for j, k in self.P if i == k} for i in self.N]  #all time preceding of i
        self.timesucc_i = [{k for j, k in self.P if i == j} for i in self.N]  #all time preceding of i

        self.full_timeprec = [set() for n in self.N]
        open_nodes = {self.n0}
        while len(open_nodes) > 0:
            i = open_nodes.pop()
            for j in self.timesucc_i[i]:
                self.full_timeprec[j] |= {i} | self.full_timeprec[i]
                open_nodes.add(j)

        self.full_timesucc = [set() for n in self.N]
        for j in self.N:
            for i in self.full_timeprec[j]:
                self.full_timesucc[i].add(j)


        self.node2activators = [set() for i in self.N]
        self.activator2groups = [set() for i in self.N]
        for g in self.G:
            for i in self.S_g[g]:
                self.node2activators[i].add(self.a_g[g])

            for i in self.N:
                if self.a_g[g] == i:
                    self.activator2groups[i].add(g)
    @property
    def choiceList(self):
        if self._choiceList is None:
            M = 1000
            D_g = [[] for g in self.G]
            for g in self.G:
                disable_nodes = {self.a_g[g]}
                if len(self.S_g[g]) > 1:
                    disable_nodes |= self.S_g[g]

                D_g[g] = disable_nodes

            m = gu.Model('mip')

            m.setParam('OutputFlag', 0)

            x = m.addVars([(g,i) for g in self.G for i in self.S_g[g]],name='x',vtype=gu.GRB.BINARY)
            y = m.addVars([g for g in self.G])

            #Each group has exactly one x arc 1
            for g in self.G:
                m.addConstr( summ( x[g,i] for i in self.S_g[g]) == 1)


            #If x = 1, then all delay-related should be later
            for g in self.G:
                for i in self.S_g[g]:
                    for h in self.G:
                        if i in D_g[h] and h != g:
                            m.addConstr( M * (x[g,i] - 1) +  y[h] >= y[g] + 1)

            m.optimize()
            y_val = m.getAttr('x',y)

            g_sorted = self.G[:]
            g_sorted.sort(key=lambda x : y_val[x])
            self._choiceList = g_sorted
        return self._choiceList


    def get_nees(self,candidates = None,executed = None):
        if candidates is None:
            candidates = []

        if self.n0 in candidates:
            return {self.n0}


        #Get all N which are considered
        if executed is None:
            N = self.N
        else:
            N = {executed}
            edges = self.get_rp_edges()
            for i in self.N:
                if (i,executed) in edges or (executed,i) in edges:
                    N.add(i)

        #Get all G
        if executed is None:
            G = self.G
        else:
            G = set()
            for g in self.G:
                if self.a_g[g] in N and any(j in N for j in self.S_g[g]) :
                    G.add(g)

        m = gu.Model('nees')
        m.setParam('OutputFlag',0)
        x =  m.addVars([g for g in G], vtype=gu.GRB.BINARY)
        y = m.addVars([i for i in candidates],vtype=gu.GRB.BINARY)

        m.setObjective(summ(y[i] for i in candidates))



        m.addConstr( summ(x[g] for g in self.activator2groups[self.n0] if g in G) >= 1)

        for g in G:
            for i in self.S_g[g]:
                if i in N:
                    groups = [g for g in self.activator2groups[i] if g in G]
                    m.addConstr(x[g] <= summ(x[g] for g in groups)  + (y[i] if i in candidates else 0))

        m.optimize()
        if m.status == 2:
            return get_gu_indexset(y)
        else:
            return None

    def get_rp_edges(self):

        if self.rp_edges is not None:
            return self.rp_edges

        def dirEdge(i, j):
            return (i, j) if i < j else (j, i)

        def get_paths():
            """
            Gets undirected edges i,j if there is a path from i to j or j to i, with i < j
            """
            G = self.get_graph()

            edges = set()
            for i, j in product(self.N, self.N):
                if i != j:
                    if nx.has_path(G, i, j):
                        edges.add(dirEdge(i, j))

            return edges

        def get_dir_paths():
            G = self.get_graph()

            edges = set()
            for i, j in product(self.N, self.N):
                if i != j:
                    if nx.has_path(G, i, j):
                        edges.add((i, j))
            return edges

        def hasCommonGroup(i, j):
            for g in self.G:
                if i in self.S_g[g] and j in self.S_g[g]:
                    return True
            return False

        paths = get_paths()
        dir_paths = get_dir_paths()

        # Do nodes first
        edges = set()
        for i in self.N:
            # add edges between all combinations
            groups = self.activator2groups[i]
            for g1, g2 in product(groups, groups):
                if g1 < g2:
                    for u, v in product(self.S_g[g1], self.S_g[g2]):
                        if u != v and not hasCommonGroup( u, v):
                            edges.add(dirEdge(u, v))

        active_edges = edges.copy()

        while True:
            new_edges = set()

            while len(active_edges) > 0:
                edge = active_edges.pop()

                # Expand along i, if successor i' of i is not j and there is no path from j to i'
                for i, j in [edge, edge[::-1]]:
                    groups = self.activator2groups[i]
                    for g in groups:
                        for i_successor in self.S_g[g]:
                            if j != i_successor and not (j, i_successor) in dir_paths:
                                new_edge = dirEdge(j, i_successor)
                                new_edges.add(new_edge)

                # Add common successors
                i_successors = set().union(*[self.S_g[g] for g in self.activator2groups[i]])
                j_successors = set().union(*[self.S_g[g] for g in self.activator2groups[j]])

                common_successors = i_successors.intersection(j_successors)

                for u, v in product(common_successors, common_successors):
                    if u < v:
                        if not hasCommonGroup( u, v):
                            new_edges.add(dirEdge(u, v))

                # Add to definite edges
                edges.add(edge)

            # Update edges for next loop
            new_edges -= edges | active_edges
            active_edges |= new_edges

            if len(new_edges) == 0:
                break

        edges |= paths
        self.rp_edges = edges
        return edges

    def get_graph(self):
        """
        Gets networkx graph (or-groups are not present in this)
        """

        # Create network
        G = nx.DiGraph()
        G.add_nodes_from(self.N)

        for g in self.G:
            for j in self.S_g[g]:
                G.add_edge(self.a_g[g], j, style='solid')
        return G

    def calculate_resource_needed(self):
        self.resource_needed = []
        for n in self.N:
            resource_needed = np.zeros((len(self.R),len(self.T)))
            for r in self.R:
                if r in self.Rr:
                    resource_needed[r,0:self.d[n]] = self.k[r,n]
                else:
                    resource_needed[r,0:self.d[n]] = self.k_cons[r,n]
                    resource_needed[r,self.d[n]:] = self.k_cons[r,n]-self.k_prod[r,n]

            self.resource_needed.append(resource_needed)

    def write_to_file(self,fileName=None):
        text = ""
        text += "{0} {1} {2} \n".format(len(self.N),len(self.Rr),len(self.Rc)) #Number of activities - Number of renewable resources - number of cumulative resources

        #resources
        text += "".join(["{0} ".format(self.lmbda[r]) for r in self.R]) +"\n\n"

        for n in self.N:

            #resource usage
            text += '{0} '.format(self.d[n]) +' '.join([str(self.k[r,n]) for r in self.Rr]) + ' ' + ' '.join([f'{self.k_cons[r,n]} {self.k_prod[r,n]}' for r in self.Rc]) +"\n"


            #or groups
            groups = [g for g in self.G if self.a_g[g] == n]
            text += str(len(groups)) + " "
            for g in groups:
                text += "{0} {1} ".format(len(self.S_g[g])," ".join(str(succ) for succ in self.S_g[g]))
            text += "\n"

            #Time successors
            successors = [j for i,j in self.P if i == n]
            text += "{0} {1}\n\n".format(len(successors),' '.join([str(node) for node in successors]))

        if fileName is not None:
            f = open(fileName,'w')
            f.write(text)
            f.close()

        return text
class Solution_al(object):
    """
    Activity list solution representation

    self.al = [n0,n1,...nm] - List of all activities, where ni is scheduled before nj if i < j
    self.executed  (List of booleans) [i] = True iff activity i is executed
    self.al_executed = [n0,n1,...nm] - List of EXECUTED activities, where ni is scheduled before nj if i < j
    """
    def __init__(self,ins):
        self.ins = ins


    def get_result(self):
        """
        serial Generation Scheme
        """
        ins = self.ins
        unscheduled = [i for i in self.al if self.executed[i]]

        scheduled = set()
        start_times = [0 for n in ins.N]
        resource_levels = np.tile(np.array([ins.lmbda]).T,(len(ins.T)))


        prec_penalty = 0
        change_times = {0}
        executed_set = set(i for i in self.ins.N if self.executed[i])
        al_executed = [i for i in self.al if self.executed[i]]

        selection_penalty = 0
        for g in ins.G:
            if self.executed[ins.a_g[g]]:
                successors = [i for i in ins.S_g[g] if self.executed[i]]
                selection_penalty += max(1 if len(successors) == 0 else 0, len(successors)-1)

        for n in unscheduled:

            #check earliest starting time on precedence
            preceeding = ins.timeprec_i[n] & scheduled
            prec_penalty += len((ins.timeprec_i[n] & executed_set) - preceeding)
            if len(preceeding) == 0:
                min_t = 0
            else:
                min_t = max(start_times[m]+ins.d[m] for m in preceeding)


            #Check resource feasibility
            finish_time_last = 0 if len(scheduled) == 0 else max(start_times[m]+ins.d[m] for m in scheduled)
            resource_needed = ins.resource_needed[n]
            scheduled_time = None


            check_times_all = np.array([_t for _t in change_times])

            for t in range(min_t,finish_time_last+1):
                if t not in change_times:
                    continue

                check_times = check_times_all[check_times_all >= t]

                modified = np.maximum(resource_levels[:,check_times],np.zeros( resource_levels[:,check_times].shape))-resource_needed[:,check_times-t]

                if not np.any(modified < 0):
                    scheduled_time = t
                    break
                elif scheduled_time is None:
                    if not np.any(modified[list(ins.Rr),:] < 0):
                        scheduled_time = t


            #Schedule
            t = scheduled_time
            start_times[n] = t
            scheduled.add(n)
            change_times |= {t,t+ins.d[n]}

            if t+ins.d[n] + max(ins.d) > len(ins.T):
                step_size = 50+ins.d[n]
                ins.T = range(len(ins.T)+step_size)
                #Resource requirements
                for i in ins.N:
                    ins.resource_needed[i] = np.hstack((ins.resource_needed[i], np.tile(ins.resource_needed[i][:, [-1]], step_size)))

                #resource levels
                resource_levels = np.hstack((resource_levels, np.tile(resource_levels[:, [-1]], step_size)))


            resource_levels[:,t:len(ins.T)] = resource_levels[:,t:len(ins.T)]-ins.resource_needed[n][:,:len(ins.T)-t]

        end_times = [start_times[n]+ins.d[n] for n in al_executed]
        makespan = max(end_times)


        #get res penalties
        res_penalties = resource_levels.copy()
        res_penalties[res_penalties > 0] = 0
        res_penalty = np.sum(res_penalties*-1)


        result = {
            'makespan' : makespan,
            'res_penalty' : res_penalty,
            'prec_penalty' : prec_penalty,
            'selection_penalty' : selection_penalty,
            'times' : start_times,
            'end_times' : end_times,
            'resources' : resource_levels
        }
        return result

def create_instance(source_file,target_file = None):
    #Set parameters
    resource_lower = 5
    resource_upper = 10
    resource_factor = 0.2
    resource_amount = 2
    iterations = 10000


    ins = Instance()
    ins.read_from_file(source_file)


    candidate_pairs = get_candidate_pairs(ins)
    resource_pairs = round(resource_factor*len(candidate_pairs))
    new_resources = list(range(len(ins.Rr) ,len(ins.Rr) + resource_amount))

    print('Got pairs')

    for r in new_resources:
        pairs = random.sample(candidate_pairs,resource_pairs)

        for selected_pair in pairs:
            pair_resource_amount = random.randint(resource_lower,resource_upper)
            ins.k_cons[r,selected_pair[0]] += pair_resource_amount
            ins.k_prod[r,selected_pair[1]] += pair_resource_amount

        ins.lmbda.append(0)
        ins.Rc.add(r)
        ins.R.add(r)

    print('Allocated pairs')
    ins.calculate_resource_needed()

    #Check randomly
    resource_needed = np.zeros((iterations,resource_amount))

    #calculate resource requirements on [iterations] different solutions and pick minimal
    for iteration in range(iterations):
        #Calculate randomly executed
        executed  = {ins.n0}
        forbidden = set()
        for g in ins.choiceList:

            a = ins.a_g[g]
            S = {i for i in ins.S_g[g]}
            if a not in executed:
                continue

            #Check if one is chosen already
            if len(executed & S) > 0:
                continue

            #Remove forbidden nodes
            S = S-forbidden

            successor = random.choice(list(S))
            executed.add(successor)
            forbidden |= S-{successor}


        #Get al
        al = [ins.n0]
        nodes_unscheduled = set(ins.N) - {ins.n0}

        while len(nodes_unscheduled) > 0:

            candidates = []
            for node in nodes_unscheduled:
                if len(ins.timeprec_i[node] & nodes_unscheduled) == 0:
                    candidates.append(node)
            node_selected = random.choice(candidates)
            al.append(node_selected)
            nodes_unscheduled.remove(node_selected)



        sol = Solution_al(ins)
        sol.al = al
        sol.executed = [i in executed for i in ins.N]
        result = sol.get_result()

        resource_needed[iteration,:] = -np.min(result['resources'],axis=1)[new_resources].T

    capacities = resource_needed[resource_needed.sum(axis=1).argmin(),:].astype(int)


    #Set capacities
    for r in new_resources:
        ins.lmbda[r] = capacities[r - len(ins.Rr)]

    print('got capacities')
    if target_file is not None:
        ins.write_to_file(target_file)
        print('Wrote to file')


def get_times(ins,nodes):
    times = [0 for i in ins.N]
    nodes_unscheduled = set(nodes) - {ins.n0}

    while len(nodes_unscheduled) > 0:
        to_remove = set()
        for node in nodes_unscheduled:
            if len(ins.timeprec_i[node] & nodes_unscheduled) == 0:
                times[node] = max([times[prec]+ins.d[prec] for prec in ins.timeprec_i[node]])
                to_remove.add(node)
        nodes_unscheduled -= to_remove

    return times


def get_candidate_pairs(ins):

    #Get candidate pairs
    candidate_pairs_list = [[] for i in ins.N]
    candidate_pairs = []
    for i in ins.N:
        if i == ins.n0:
            continue
        for j in ins.N:
            if j == ins.ne:
                continue

            if (i,j) in ins.P and ins.d[i] == 0 and ins.d[j] == 0:
                continue


            if j in ins.full_timesucc[i]:
                nees =  ins.get_nees(candidates={j},executed=i)
                if nees is not None:
                    nees2 = ins.get_nees(candidates={i},executed=j)
                    if nees2 is not None:
                        candidate_pairs_list[i].append(j)
                        candidate_pairs_list[j].append(i)
                        candidate_pairs.append((i,j))


    return candidate_pairs




def get_gu_indexset(x):
    """
    Get index set from gurobi variable
    """
    A = set()
    for key,var in x.items():
        if round(var.x) == 1:
            A.add(key)
    return A

class SparseDict(dict):
    def __init__(self,init_dict=None,default=0):
        self.default = default
        if init_dict is not None:
            for key,val in init_dict.items():
                self[key] = val

    def __getitem__(self,key):
        item =  self.get(key)
        return self.default if item is None else item


    def __setitem__(self,key,value):
        if value != self.default:
            super(SparseDict, self).__setitem__(key, value)
        elif self.get(key) is not None:
            del self[key]

    def copy(self):
        return self.__copy__()

    def __copy__(self):
        sd = SparseDict(default=self.default)
        for key,value in self.items():
            sd[key] = value
        return sd


if __name__ == '__main__':
    source_file = 'sourcefile.txt'
    target_file = 'targetfile.txt'
    create_instance(source_file,target_file)