import numpy as np
import scipy.stats as sps
from utility import Utility, ODpair

class Shipment:
    def __init__(self, quantity, od):
        self.quantity = quantity
        self.od = od

class Shipper:
    def __init__(self, shipment, util, MU_b_c_inter, seedgen = 0, SIG_b_c_inter = 0):
        self.shipment = shipment
        self.util = util
        self.seed = seedgen
        if self.util.type == 'Deter':
            self.b_c_inter = MU_b_c_inter
            self.Eps_IWT1 = 0
            self.Eps_IWT2 = 0
            self.Eps_ROAD = 0
            self.Eps_RAIL = 0
        elif self.util.type == 'MNL':
            self.b_c_inter = MU_b_c_inter
            np.random.seed(seedgen)
            self.Eps_IWT1 = sps.gumbel_r.rvs()
            self.Eps_IWT2 = sps.gumbel_r.rvs()
            self.Eps_ROAD = sps.gumbel_r.rvs()
            self.Eps_RAIL = sps.gumbel_r.rvs()
        elif self.util.type == 'Mixed':
            np.random.seed(seedgen)
            self.b_c_inter = -np.exp(MU_b_c_inter + SIG_b_c_inter * np.random.normal(0,1))
            self.Eps_IWT1 = sps.gumbel_r.rvs()
            self.Eps_IWT2 = sps.gumbel_r.rvs()
            self.Eps_ROAD = sps.gumbel_r.rvs()
            self.Eps_RAIL = sps.gumbel_r.rvs()
        
    def getUroad(self):
        return self.util.a_road + self.util.b_c_road*(self.shipment.od.cRoad+self.shipment.od.VOTcRoad)/1000+self.util.b_acc_road*self.shipment.od.accRoad+self.Eps_ROAD 
        
    def getUrail(self):
        return self.util.a_rail + self.b_c_inter*(self.shipment.od.cRail+self.shipment.od.VOTcRail)/1000+self.util.b_acc_inter*self.shipment.od.accRail+self.util.b_freq_inter*self.shipment.od.fRail+self.Eps_RAIL
        
    def getUiwt1(self, cIWT1, fIWT1):
        if fIWT1>0:
            return self.b_c_inter*(cIWT1+self.shipment.od.VOTcIWT)/1000+self.util.b_acc_inter*self.shipment.od.accIWT+self.util.b_port*self.shipment.od.portIWT+self.util.b_freq_inter*fIWT1+self.Eps_IWT1
        else:
            return -10000000

    def getUiwt2(self, cIWT2, fIWT2):
        if fIWT2>0:
            return self.b_c_inter*(cIWT2+self.shipment.od.VOTcIWT)/1000+self.util.b_acc_inter*self.shipment.od.accIWT+self.util.b_port*self.shipment.od.portIWT+self.util.b_freq_inter*fIWT2+self.Eps_IWT2
        else:
            return -10000000

    def modeChoice(self, cIWT1, cIWT2, fIWT1, fIWT2):
        utilities={"IWT1":self.getUiwt1(cIWT1, fIWT1), "IWT2":self.getUiwt2(cIWT2, fIWT2), "ROAD":self.getUroad(), "RAIL":self.getUrail()}
        return max(utilities, key=utilities.get)
        
class Population:
    def __init__(self, network, util, MU_b_c_inter, seed_gen, n_shippersperOD, SIG_b_c_inter = 0, shippers_list = []):
        self.network = network
        self.shipperslist = shippers_list
        if not self.shipperslist:
            N_nodes=len(self.network.listTerminals)
            N = range(N_nodes)
            np.random.seed(seed_gen)
            for i in N:
                for j in N:
                    if i!=j:
                        odij=self.network.ODs[self.network.listTerminals[i]+self.network.listTerminals[j]]
                        for r in range(n_shippersperOD[i][j]):
                            myshipment=Shipment(odij.totDemand/n_shippersperOD[i][j], odij)
                            myshipper = Shipper(myshipment, util, MU_b_c_inter, np.random.randint(1,999999999), SIG_b_c_inter)
                            self.shipperslist.append(myshipper)
  
    def assignCargo(self, cIWT1, cIWT2, fIWT1, fIWT2, capIWT1, capIWT2):
        N_nodes=len(self.network.listTerminals)
        N = range(N_nodes)
        costsIWT={"IWT1":cIWT1, "IWT2":cIWT2}
        capacities={"IWT1":capIWT1, "IWT2":capIWT2, "ROAD":np.full((N_nodes,N_nodes),1000000000000), "RAIL":np.full((N_nodes,N_nodes),1000000000000)}
        assignment={"IWT1":np.zeros((N_nodes,N_nodes)), "IWT2":np.zeros((N_nodes,N_nodes)), "ROAD":np.zeros((N_nodes,N_nodes)), "RAIL":np.zeros((N_nodes,N_nodes))}
        for shipperi in self.shipperslist:
            indexo=self.network.listTerminals.index(shipperi.shipment.od.origin)
            indexd=self.network.listTerminals.index(shipperi.shipment.od.destination)
            chosen_mode=shipperi.modeChoice(costsIWT["IWT1"][indexo][indexd], costsIWT["IWT2"][indexo][indexd], fIWT1[indexo][indexd], fIWT2[indexo][indexd])
            containers=shipperi.shipment.quantity
            while containers>capacities[chosen_mode][indexo][indexd]:
                costsIWT[chosen_mode][indexo][indexd]=1000000000000
                chosen_mode=shipperi.modeChoice(costsIWT["IWT1"][indexo][indexd], costsIWT["IWT2"][indexo][indexd], fIWT1[indexo][indexd], fIWT2[indexo][indexd])
            assignment[chosen_mode][indexo][indexd]+=containers
            capacities[chosen_mode][indexo][indexd]-=containers
        return assignment        
    


  