import itertools
import numpy as np
import gurobipy as gp
from gurobipy import GRB,quicksum
import math
import pandas as pd
import scipy.stats as sps
import sys

N_draws = int(sys.argv[2])
R=range(N_draws)

MU_b_c_inter = 2.40
SIGMA_b_c_inter = 0.618

R_b_c_inter = []

sysEpsilons = []

seedH=int(sys.argv[1])

np.random.seed(seedH)

for r in R:
#     syspoint=(r+1)/(N_draws+1)
    sysEpsilons.append(sps.gumbel_r.rvs())
#     syspointR=sps.uniform.ppf(syspoint, loc=0.01, scale=0.98)
#     sysEpsilons.append(0)
#         R_b_c_inter.append(-5.76)
    R_b_c_inter.append(-np.exp(MU_b_c_inter + SIGMA_b_c_inter * sps.norm.rvs()))


Eps_IC=np.random.permutation(sysEpsilons)
Eps_IWT=np.random.permutation(sysEpsilons)
Eps_ROAD=np.random.permutation(sysEpsilons)
Eps_RAIL=np.random.permutation(sysEpsilons)


# Define all vessel related parameters

vesseltypes=["M8","M11"]
V_fleet=[24,18] # Fleet per vessel type
T_operate=120 # Maximal time a vessel can operate per week
U_cap=[180,300] # Maximal capacity per vessel type
M=100000

# Define all demand related parameters

terminals = ["RTM","DUI","BON"]
#OD_matrix=[[0,4700,1100],[6900,0,100],[700,200,0]] # Weekly demand for IWT
OD_matrix_allM = [[0,6500,1900],[8400,0,6700],[1500,6500,0]] # Weekly demand for all modes
# for i in range(len(terminals)):
#     for j in range(len(terminals)):
#         OD_matrix_allM[i][j]=OD_matrix_allM[i][j]*4

# Define all cost related parameters

Ocosts= [[[0,2200,2700],[4700,0,1100],[5100,1100,0]],[[0,3600,4200],[6100,0,2000],[6600,2000,0]]] # Operational costs per vessel type for each OD
Vcosts= [[[0,24,31],[13,0,8],[18,8,0]],[[0,21,27],[10,0,6],[12,6,0]]] # Voyage costs/TEU per vessel type for each OD (=Voyagecosts/(80%Ucap))

# Define all time related parameters

Twaitport = [30,1,1] # Waiting time at each terminal
Thandport = [round(value * 0.024) for value in U_cap] # Handling time to load/unload a vessel type
Tsailod = [[[0,15,19],[13,0,4],[16,4,0]],[[0,19,25],[15,0,5],[19,5,0]]] # Sailing time per vessel type for each OD (sail+lock time)

#Define utility related parameters
a_road=2.35
a_rail=0.816
b_port=1.60
b_freq_inter=0.0262
b_acc_road=0.0506
b_acc_inter=0.173
b_c_road=-8.73

# Define all parameters for competing modes

costsROAD=[[0,252,317],[251,0,136],[315,136,0]]
costsRAIL=[[0,203,214],[203,0,152],[214,152,0]]
# maxiCOSTS=[[0,252,317],[251,0,152],[315,152,0]]
accROAD=[[0,45,42],[45,0,53],[42,53,0]]
accRAIL=[[0,15,9],[15,0,14],[9,14,0]]
freqRAIL=[[0,35,5],[35,0,2],[5,2,0]]
VOTcostROAD=[[0,4,5],[4,0,2],[5,2,0]]
VOTcostRAIL=[[0,6,7],[6,0,3],[7,3,0]]
VOTcostIWT=[[0,32,39],[62,0,17],[67,17,0]]

# Define mode parameters for IWT
freqIC=[[0,22,4],[19,0,6],[6,5,0]]
costsIC=[[0,68,76],[69,0,46],[74,46,0]]#[[0,35,44],[36,0,13],[42,13,0]]#[[0,68,76],[69,0,46],[74,46,0]]#[[0,123,146],[156,0,69],[180,69,0]]
acc=[[0,17,8],[17,0,15],[8,15,0]]
portif=[[0,1,1],[1,0,0],[1,0,0]]


# Determine number of nodes, max number of legs of services and numbre of vessel types

N_nodes=len(terminals)
N=range(N_nodes)
nodes = list(N)
maxLegs=2*(N_nodes-1)
L=range(maxLegs)
N_Vtypes=len(vesseltypes)
K=range(N_Vtypes)

# Determine the considered services together with their legs

legs_matrix=[]
services=[]
service_list=[]

for P in range(2,N_nodes + 1):
    for subset in itertools.combinations(nodes, P):
        legs=np.zeros((maxLegs,N_nodes,N_nodes))
        l=0
        for seq in range(1,len(subset)):
            legs[l][subset[seq-1]][subset[seq]]=1
            l+=1
        for revseq in reversed(range(1,len(subset))):
            legs[l][subset[revseq]][subset[revseq-1]]=1
            l+=1
        legs_matrix.append(legs)
        services.append(subset)
        node_list=[]
        for w in subset:
            node_list.append(terminals[w])
        service_list.append(node_list)

N_services=len(services)
S=range(N_services)

# Determine what OD demand will use each leg of each service 

usage_matrix=np.zeros((N_services,maxLegs,N_nodes,N_nodes))

for s in S:
    for l in range(len(services[s])-1):
        for k in range(l+1):
            for q in range(l+1,len(services[s])):
                usage_matrix[s][l][services[s][k]][services[s][q]]=1
    lt=0
    for r in reversed(range(l+1,2*(l+1))):
        usage_matrix[s][r]=np.transpose(usage_matrix[s][lt])
        lt+=1

# Determine fixed costs of each service for each vessel type

fixed_costs=np.zeros((N_Vtypes,N_services))
for k in K:
    for s in S:
        fcost=0
        for leg in legs_matrix[s]:
            for i in N:
                for j in N:
                    fcost+=leg[i][j]*Ocosts[k][i][j]
        fixed_costs[k][s] = fcost

# Determine variable costs of each service for each vessel type

var_costs=np.zeros((N_Vtypes,N_services,N_nodes,N_nodes))

for k in K:
    for i in N:
        for j in N:
            if i!=j:
                for s in S:
                    vcost=0
                    for l in range(2*(len(services[s])-1)):
                        vcost+=usage_matrix[s][l][i][j]*sum(legs_matrix[s][l][a][b]*Vcosts[k][a][b] for a in N for b in N)
                    if vcost==0:
                        var_costs[k][s][i][j]=M
                    else:
                        var_costs[k][s][i][j]=vcost

# Determine sailing and port times of each service for each vessel type

t_sail=np.zeros((N_Vtypes,N_services))
t_port=np.zeros((N_Vtypes,N_services))

for k in K:
    for s in S:
        tstemp=0
        tptemp=0
        for l in range(2*(len(services[s])-1)):
            tstemp+=sum(legs_matrix[s][l][a][b]*Tsailod[k][a][b] for a in N for b in N)
            tptemp+=sum(legs_matrix[s][l][a][b]*(Twaitport[b]+2*Thandport[k]) for a in N for b in N)
        t_sail[k][s] = tstemp
        t_port[k][s] = tptemp

# Determine max. number of cycles per vessel for each service for each vessel type

maxCycles=np.zeros((N_Vtypes,N_services), dtype=int)

for k in K:
    for s in S:
        maxCycles[k][s]=math.floor(T_operate/(t_sail[k][s]+t_port[k][s]))


# Determine the ODs that can be served by a service

servOD_matrix=np.zeros((N_services,N_nodes,N_nodes))

for i in N:
    for j in N:
        if i!=j:
            for s in S:
                if var_costs[0][s][i][j]!=M:
                    servOD_matrix[s][i][j]=1


# Determine utilities for competing modes

Vroad=np.zeros((N_nodes,N_nodes,N_draws))

for i in N:
    for j in N:
        if i!=j:
            for r in R:
                Vroad[i][j][r]=a_road+b_c_road*(costsROAD[i][j]+VOTcostROAD[i][j])/1000+b_acc_road*accROAD[i][j]+Eps_ROAD[r]

# # Determine upper bounds for frequency

# Bound_f=np.zeros((N_nodes,N_nodes), dtype=int)

# for i in N:
    # for j in N:
        # Bound_f[i][j]=int(math.ceil(math.log2(35)))

Vrail=np.zeros((N_nodes,N_nodes,N_draws))
Vic=np.zeros((N_nodes,N_nodes,N_draws))

for i in N:
    for j in N:
        if i!=j:
            for r in R:
                Vrail[i][j][r]=a_rail+R_b_c_inter[r]*(costsRAIL[i][j]+VOTcostRAIL[i][j])/1000+b_acc_inter*accRAIL[i][j]+b_freq_inter*freqRAIL[i][j]+Eps_RAIL[r]
                Vic[i][j][r]=R_b_c_inter[r]*(costsIC[i][j]+VOTcostIWT[i][j])/1000+b_acc_inter*acc[i][j]+b_freq_inter*freqIC[i][j]+b_port*portif[i][j]+Eps_IC[r]

maxV=np.zeros((N_nodes,N_nodes,N_draws))
maxALT=np.zeros((N_nodes,N_nodes,N_draws))
for i in N:
    for j in N:
        if i!=j:
            for r in R:
                maxV[i][j][r]=max(Vroad[i][j][r],Vrail[i][j][r],Vic[i][j][r])
                maxALT[i][j][r]=np.argmax([Vroad[i][j][r],Vrail[i][j][r],Vic[i][j][r]])
                
N_prices=50
P=range(N_prices)
N_freq=36
F=range(N_freq)

Viwtf=np.zeros((N_freq,N_nodes,N_nodes,N_draws,N_prices))
Diwtf=np.zeros((N_prices,N_nodes,N_nodes,N_freq))
profitf=np.zeros((N_freq,N_freq,N_nodes,N_nodes,N_prices))
maxPROFf=np.zeros((N_nodes,N_nodes,N_freq,N_freq))
maxPROFfij=np.zeros((N_nodes,N_nodes,N_freq,N_freq))
optPf=np.zeros((N_nodes,N_nodes,N_freq,N_freq))

for f in F:
    if f>0:
        print(f)
        for i in N:
            for j in N:
                if i<j:
                    for fsmall in range(V_fleet[0]*maxCycles[0][i+j-1]+1):
                        for fbig in range(V_fleet[1]*maxCycles[1][i+j-1]+1):
                            if fsmall+fbig==f:
                                for p in P:
                                    D=[0,0]
                                    for r in R:
                                        Viwtf[f][i][j][r][p]=R_b_c_inter[r]*(p*10+VOTcostIWT[i][j])/1000+b_freq_inter*f+b_acc_inter*acc[i][j]+b_port*portif[i][j]+Eps_IWT[r]
                                        Viwtf[f][j][i][r][p]=R_b_c_inter[r]*(p*10+VOTcostIWT[j][i])/1000+b_freq_inter*f+b_acc_inter*acc[j][i]+b_port*portif[j][i]+Eps_IWT[r] 
                                        if Viwtf[f][i][j][r][p]>=maxV[i][j][r]:
                                            D[0]+=OD_matrix_allM[i][j]/N_draws
                                        if Viwtf[f][j][i][r][p]>=maxV[j][i][r]:
                                            D[1]+=OD_matrix_allM[j][i]/N_draws 
                                    Diwtf[p][i][j][f]=D[0]
                                    Diwtf[p][j][i][f]=D[1] 
                                    Ksmall=fsmall*U_cap[0]
                                    Kbig=fbig*U_cap[1]
                                    dsmall=[0,0]
                                    dbig=[0,0] 
                                    if D[0] >= Ksmall+Kbig:
                                        D[0]=Ksmall+Kbig
                                        dsmall[0]=Ksmall
                                        dbig[0]=Kbig
                                    elif D[0] >= Kbig:
                                        dbig[0]=Kbig
                                        dsmall[0]=D[0]-dbig[0]
                                    else:
                                        dbig[0]=D[0]
                                        dsmall[0]=0
                                    if D[1] >= Ksmall+Kbig:
                                        D[1]=Ksmall+Kbig
                                        dsmall[1]=Ksmall
                                        dbig[1]=Kbig
                                    elif D[1] >= Kbig:
                                        dbig[1]=Kbig
                                        dsmall[1]=D[1]-dbig[1]
                                    else:
                                        dbig[1]=D[1]
                                        dsmall[1]=0
                                    profitf[fsmall][fbig][i][j][p]=D[0]*p*10-(fsmall*fixed_costs[0][i+j-1]/2+fbig*fixed_costs[1][i+j-1]/2)-(dsmall[0]*var_costs[0][i+j-1][i][j]+dbig[0]*var_costs[1][i+j-1][i][j])
                                    profitf[fsmall][fbig][j][i][p]=D[1]*p*10-(fsmall*fixed_costs[0][i+j-1]/2+fbig*fixed_costs[1][i+j-1]/2)-(dsmall[1]*var_costs[0][i+j-1][j][i]+dbig[1]*var_costs[1][i+j-1][j][i])
                                optPf[i][j][fsmall][fbig]=10*np.argmax(profitf[fsmall][fbig][i][j])
                                optPf[j][i][fsmall][fbig]=10*np.argmax(profitf[fsmall][fbig][j][i])
                                maxPROFf[i][j][fsmall][fbig]=np.max(profitf[fsmall][fbig][i][j])
                                maxPROFf[j][i][fsmall][fbig]=np.max(profitf[fsmall][fbig][j][i])
                                maxPROFfij[i][j][fsmall][fbig]=maxPROFf[i][j][fsmall][fbig]+maxPROFf[j][i][fsmall][fbig]

price1=np.zeros((N_nodes,N_nodes))
f1=np.zeros((N_services,N_Vtypes))


price2=np.zeros((N_nodes,N_nodes))
f2=np.zeros((N_services,N_Vtypes))

for i in N:
    for j in N:
        if i<j:
            imax=np.unravel_index(np.argmax(maxPROFfij[i][j]), maxPROFfij[i][j].shape)
            price2[i][j]=optPf[i][j][imax[0]][imax[1]]
            price2[j][i]=optPf[j][i][imax[0]][imax[1]]
			
Flist=list(F)

# del fix_costsH
# del var_costsH
# del maxV
# del Viwtf
# del profitf
# gc.collect()

pvisit = []
fvisit = []

while (not price2.tolist() in pvisit) or (not f2.tolist() in fvisit):
    print(price2)
    price1=np.copy(price2)
    pvisit.append(price1.tolist())
    print(f2)
    f1=np.copy(f2)
    fvisit.append(f1.tolist())
    
    DemandMaxIWT=np.zeros((N_nodes,N_nodes,N_freq))
    for i in N:
        for j in N:
            DemandMaxIWT[i][j]=Diwtf[int(price1[i][j]/10)][i][j]

    n = gp.Model("SND_SAA.lp")

    freq = []
    bfreq=[]
    Fod=[]
    Dmax=[]
    flow = []
    vessel = []

    for s in S:
        freq.append([])
        vessel.append([])
        for k in K:
            freq[s].append(n.addVar(lb = 0, vtype = GRB.INTEGER, name = 'freq_sk('+str(s)+','+str(k)+')'))
            vessel[s].append(n.addVar(lb = 0, vtype = GRB.INTEGER, name = 'vessel_sk('+str(s)+','+str(k)+')'))

    z = []

    for i in N:
        flow.append([])
        z.append([])
        bfreq.append([])
        Fod.append([])
        Dmax.append([])
        for j in N:
            z[i].append([])
            flow[i].append([])
            bfreq[i].append([])
            Fod[i].append(n.addVar(lb = 0, vtype = GRB.INTEGER, name = 'freq_ij('+str(i)+','+str(j)+')'))
            Dmax[i].append(n.addVar(lb = 0, vtype = GRB.CONTINUOUS, name = 'Dmax_ij('+str(i)+','+str(j)+')'))
            for s in S:
                flow[i][j].append([])
                for k in K:
                    flow[i][j][s].append([])
                    for r in R:
                        flow[i][j][s][k].append(n.addVar(lb = 0, vtype = GRB.CONTINUOUS, name = 'x_ijskr('+str(i)+','+str(j)+','+str(s)+','+str(k)+','+str(r)+')'))
            for r in R:
                z[i][j].append(n.addVar(lb = 0, vtype = GRB.CONTINUOUS, name = 'z1_ijr('+str(i)+','+str(j)+','+str(r)+')'))
            for f in F:
                bfreq[i][j].append(n.addVar(lb = 0, vtype = GRB.BINARY, name = 'bf_ijf('+str(i)+','+str(j)+','+str(f)+')'))

    n.setObjective((1/N_draws)*(quicksum(price1[i][j]*flow[i][j][s][k][r] for i in N for j in N for s in S for k in K for r in R)-quicksum(freq[s][k]*fixed_costs[k][s] for s in S for k in K for r in R)-quicksum(flow[i][j][s][k][r]*(var_costs[k][s][i][j]) for k in K for s in S for i in N for j in N for r in R)))

    n.modelSense = GRB.MAXIMIZE
    n.update()


    for i in N:
        for j in N:
            for s in S:
                for k in K:
                    for r in R:
                        n.addConstr(flow[i][j][s][k][r] <= quicksum(usage_matrix[s][l][i][j] for l in L)*OD_matrix_allM[i][j])

    for k in K:     
        for s in S:
            n.addConstr(freq[s][k]<=maxCycles[k][s]*vessel[s][k])
    #         n.addConstr(freq[s][k]*(t_sail[k][s]+t_port[k][s])/T_operate<=vessel[s][k])
    #         n.addConstr((freq[s][k]*(t_sail[k][s]+t_port[k][s])/T_operate)+0.999>=vessel[s][k])
            for l in L:
                n.addConstr(quicksum((1/N_draws)*flow[i][j][s][k][r]*usage_matrix[s][l][i][j] for i in N for j in N for r in R) <= freq[s][k]*U_cap[k])

    for k in K:
        n.addConstr(quicksum(vessel[s][k] for s in S)<=V_fleet[k])

    for i in N:
        for j in N:
            for r in R:
                n.addConstr(quicksum(flow[i][j][s][k][r] for s in S for k in K)+z[i][j][r]==OD_matrix_allM[i][j])

    for i in N:
        for j in N:
            n.addConstr(quicksum(bfreq[i][j][f] for f in F)<=1)
            n.addConstr(Dmax[i][j]==quicksum(bfreq[i][j][f]*DemandMaxIWT[i][j][f] for f in F))
            n.addConstr(Fod[i][j]==quicksum(bfreq[i][j][f]*Flist[f] for f in F))
            n.addConstr(Fod[i][j]==quicksum(servOD_matrix[s][i][j]*freq[s][k] for s in S for k in K))
            n.addConstr(quicksum((1/N_draws)*flow[i][j][s][k][r] for s in S for k in K for r in R)<=Dmax[i][j])

    n.update()


    n.write('TEST_SNDpaper.lp')
#     n.setParam('OutputFlag', True)
    n.optimize()

    for i in N:
        for j in N:
            if i!=j:
                fs=int(sum(servOD_matrix[s][i][j]*freq[s][0].x for s in S))
                fb=int(sum(servOD_matrix[s][i][j]*freq[s][1].x for s in S))
                price2[i][j]=optPf[i][j][fs][fb]
                
    for s in S:
        for k in K:
            f2[s][k]=freq[s][k].x

                
dobj={'Amounts [EUR]': pd.Series([n.ObjVal,sum(price2[i][j]*flow[i][j][s][k][r].x for i in N for j in N for s in S for k in K for r in R)/N_draws,sum(f2[s][k]*fixed_costs[k][s] for s in S for k in K),sum(flow[i][j][s][k][r].x*var_costs[k][s][i][j] for k in K for s in S for i in N for j in N for r in R)/N_draws], index=['Profit','Revenue','FixedCosts','VariableCosts'])}


dfreq={}
dvessel={}
for k in K:
    fff=[]
    vvv=[]
    iii=[]
    for s in S:
        fff.append(freq[s][k].x)
        vvv.append(vessel[s][k].x)
        iii.append("-".join(service_list[s]))
    dfreq[vesseltypes[k]+'freq']=pd.Series(fff, index=iii)
    dvessel[vesseltypes[k]+'vessels']=pd.Series(vvv, index=iii)

dicofv = {**dfreq, **dvessel}


dprice={}
dvol={}
dshare={}

ppp=[]
vvv=[]
sss=[]
iii=[]
for i in N:
    for j in N:
        if i!=j:
            voliter=sum(flow[i][j][s][k][r].x for s in S for k in K for r in R)/N_draws
            if voliter>0:
                ppp.append(price2[i][j])
            else:
                ppp.append(-1000)
            vvv.append(voliter)
            sss.append(voliter/OD_matrix_allM[i][j])
            iii.append(terminals[i]+"-"+terminals[j])
dprice["IWT"]=pd.Series(ppp, index=iii)
dvol["IWTvol"]=pd.Series(vvv, index=iii)
dshare["IWTshare"]=pd.Series(sss, index=iii)

ppp=[]
vvv=[]
sss=[]
for i in N:
    for j in N:
        if i!=j:
            ppp.append(costsROAD[i][j])
            volivola=0
            for r in R:
                if maxALT[i][j][r]==0:
                    volivola+=z[i][j][r].x
            vvv.append(volivola/N_draws)
            sss.append(volivola/N_draws/OD_matrix_allM[i][j])
dprice["Road"]=pd.Series(ppp, index=iii)
dvol["Roadvol"]=pd.Series(vvv, index=iii)
dshare["Roadshare"]=pd.Series(sss, index=iii)

ppp=[]
vvv=[]
sss=[]
for i in N:
    for j in N:
        if i!=j:
            ppp.append(costsRAIL[i][j])
            volivola=0
            for r in R:
                if maxALT[i][j][r]==1:
                    volivola+=z[i][j][r].x
            vvv.append(volivola/N_draws)
            sss.append(volivola/N_draws/OD_matrix_allM[i][j])
dprice["Rail"]=pd.Series(ppp, index=iii)
dvol["Railvol"]=pd.Series(vvv, index=iii)
dshare["Railshare"]=pd.Series(sss, index=iii)

ppp=[]
vvv=[]
sss=[]
for i in N:
    for j in N:
        if i!=j:
            ppp.append(costsIC[i][j])
            volivola=0
            for r in R:
                if maxALT[i][j][r]==2:
                    volivola+=z[i][j][r].x
            vvv.append(volivola/N_draws)
            sss.append(volivola/N_draws/OD_matrix_allM[i][j])
dprice["IC"]=pd.Series(ppp, index=iii)
dvol["ICvol"]=pd.Series(vvv, index=iii)
dshare["ICshare"]=pd.Series(sss, index=iii)

dicovs = {**dvol, **dshare}

with pd.ExcelWriter('heurMIX3n_'+sys.argv[2]+'_'+sys.argv[1]+'.xlsx') as writer:
    pd.DataFrame(dobj).to_excel(writer,sheet_name='ObjectiveFunction')
    pd.DataFrame(dprice).to_excel(writer,sheet_name='Prices')
    pd.DataFrame(dicofv).to_excel(writer,sheet_name='Frequencies')
    pd.DataFrame(dicovs).to_excel(writer,sheet_name='Shares')
