# -*- coding: utf-8 -*-
"""

MIT License, for detailes please read LICENSE.TXT

Copyright (c) 2022 Mahtab Sharifi

Created on Tue Apr 26 18:40:45 2022

"""

import os
import numpy as np
import pandas as pd

from fnmatch import fnmatch
from bisect import bisect_left
from fastnumbers import fast_float

import General_Utils as gu
from pytictoc import TicToc


''' 
    ***************************************************************************
    
    init Part
    
    ***************************************************************************
'''

import base_parameters 
prms_dict = base_parameters.prms_dict


e2e_s_dict = {
        'merge_gain_threshold' : 100000,   
        'edges_count' : 5024,
        'chunk_size' : 200, 
        
        # 'avg_speed' : 43.2, # km/h equal to 12m/s'
        'avg_speed_cms' : 1200, #43.2*1000/3600 * 100

        }

e2e_s_dict.update(prms_dict)

#       [ scn_path , run_path     , rout_file              , aggregate_path   ],
sim_scen = base_parameters.sim_scen


net_path = os.path.join(e2e_s_dict['base_path'], e2e_s_dict['net_path'])
net_file = e2e_s_dict['net_file']
edg_ids_file = e2e_s_dict['edg_ids_file']

def getNetPath(pPrmsDict):
    return os.path.join(pPrmsDict['base_path'], pPrmsDict['net_path'])

def getE2Epath(pPrmsDict):
    return os.path.join(pPrmsDict['base_path'], pPrmsDict['net_path'], pPrmsDict['e2e_path'])

def readEdges(pPrmsDict):
    print('Loading edge_ids ...')
    fn_file = os.path.join(getE2Epath(pPrmsDict), pPrmsDict['edg_ids_file'])
    tdf = gu.read_DF_csv(fn_file, pHeader = None)
    return tdf[0].tolist()

if not 'edge_ids' in vars():
    fn_file = os.path.join(getE2Epath(e2e_s_dict), e2e_s_dict['edg_ids_file'])
    if os.path.exists(fn_file) and os.path.isfile(fn_file):
        edge_ids = readEdges(e2e_s_dict)
    else:
        raise('Can not load edg_ids ...')

def readEdgeLengths(pPrmsDict):
    print('Loading edge_lens ...')
    fn_file = os.path.join(getE2Epath(pPrmsDict), pPrmsDict['edg_len_file'])
    tdf = gu.read_DF_csv(fn_file, pHeader = None)
    return tdf[0].tolist()

if not 'edge_lens' in vars():
    fn_file = os.path.join(getE2Epath(e2e_s_dict), e2e_s_dict['edg_len_file'])
    if os.path.exists(fn_file) and os.path.isfile(fn_file):
        edge_lens = readEdgeLengths(e2e_s_dict)
        edge_lens = np.array(edge_lens, dtype=np.int32)
    else:
        raise('Can not load edg_lens ...')

''' 
    ***************************************************************************
    
    edge2edge chunk Part
    
    ***************************************************************************
'''

def getChunkStEn(pSize, pCunkSize, pThreshold=0.25, pAddRemains=False):
    if pSize <= pCunkSize:
        return[[0, pSize]]
        
    chunkCnt = int(pSize/pCunkSize)
    remainder = pSize - chunkCnt * pCunkSize
    tmp = np.zeros(chunkCnt, dtype=np.int32)
    
    if remainder / pCunkSize > pThreshold and not pAddRemains:
        tmp = tmp[::-1]
        for i in range(remainder):
            j = i % chunkCnt
            tmp[j] += 1
        tmp = tmp[::-1]
    
    out , st = [] , 0
    for i in range(chunkCnt):
        en = st + pCunkSize + tmp[i]
        out.append([st, en])
        st = en
    if remainder > 0 and pAddRemains:
        out = out + [[out[-1][1], pSize]]
    else:
        out[-1][1] = pSize
    return out # , tmp

def scanE2Efolder(pPrmsDict):
    pBasePath = getE2Epath(pPrmsDict)
    pPrefix = pPrmsDict['e2e_file_prefix']
    pSuffix = pPrmsDict['e2e_file_suffix']
    out = []

    e2e_match = (pPrefix+'*'+pSuffix).upper()
    for f in os.listdir(pBasePath):
        if fnmatch(f.upper(), e2e_match):
            out.append(f)
    return out

def setE2Efilenames(pPrmsDict, pChunksStEn):
    cnt = len(str(pChunksStEn[-1][1]))
    p_fix = pPrmsDict['e2e_file_prefix']
    s_fix = pPrmsDict['e2e_file_suffix']
    out = [[x[0],x[1], False, '{}{:0{w}}_{:0{w}}{}'.format(p_fix,x[0],x[1],s_fix, w=cnt)] for x in pChunksStEn]
    return out

def createE2Elist(pPrmsDict, pChunksStEn):
    t_out = setE2Efilenames(pPrmsDict, pChunksStEn)
    t_files = scanE2Efolder(pPrmsDict)
    for l in t_out:
        if l[3] in t_files:
            l[2] = True
    return t_out

def vstackChunks():
    global edge_ids, e2e_s_dict
    print('Building edg2edg stack ...')
    t_size = e2e_s_dict['edges_count']
    t_e2e = np.zeros((t_size,t_size), dtype=np.int32)    
#    chunks is [start_row, end_row, processed, file_name]
    t_chunks = getChunkStEn(e2e_s_dict['edges_count'], e2e_s_dict['chunk_size'],pThreshold=0.75, pAddRemains=False)
    t_chunks = createE2Elist(e2e_s_dict, t_chunks)
    t_path = getE2Epath(e2e_s_dict)
    for chunk in t_chunks:
        row_st, row_en, _, fn_chunk = chunk
        fpfn_chunk = os.path.join(t_path, fn_chunk)
        t_arr = np.fromfile(fpfn_chunk, dtype=np.int32).reshape(row_en - row_st, t_size)
        t_e2e[row_st:row_en,:] = t_arr
    return t_e2e

if not 'edg2edg' in vars():
    try:
        edg2edg = vstackChunks()
    except:
        raise('Can not builde edge2edge ...')

'''
    ***************************************************************************
    
    gui part
    
    ***************************************************************************
'''

def getRouteFrTo(pSumoNet, pEdgO, pEdgD, pToList= False, cost_limit=np.inf, pFrToCost=True):
    sumo_route, cost = pSumoNet.getShortestRoute(pEdgO, pEdgD, cost_limit=cost_limit, includeFrToCost=pFrToCost)
    if pToList:
        route = [x.getID() for x in sumo_route]
    else:
        route = sumo_route
    return [route, cost]

def getTripRouteNdx(pSumoNet, pTripsDF, pNdx):
    t_orgn, t_dest = pTripsDF[['from', 'to']].iloc[pNdx]
    route, cost = getRouteFrTo(pSumoNet, t_orgn, t_dest, pToList=True)
    return [route, cost]

def getTripRouteId(pSumoNet, pTripsDF, pId):
    t_orgn, t_dest = pTripsDF[['from', 'to']][pTripsDF.id == pId].values.tolist()[0]
    route, cost = getRouteFrTo(pSumoNet, t_orgn, t_dest, pToList=True)
    return [route, cost]

def getRouteFrToMrg(pSumoNet, pEdg1, pEdg2, pEdg3, pEdg4, cost_limit=np.inf):
    route1, cost1 = pSumoNet.getShortestRoute(pEdg1, pEdg2, cost_limit=cost_limit, includeFrToCost=True)
    route2, cost2 = pSumoNet.getShortestRoute(pEdg2, pEdg3, cost_limit=cost_limit, includeFrToCost=False)
    route3, cost3 = pSumoNet.getShortestRoute(pEdg3, pEdg4, cost_limit=cost_limit, includeFrToCost=True)

    route_l1 = [x.getID() for x in route1]
    route_l2 = [x.getID() for x in route2]
    route_l3 = [x.getID() for x in route3]
    route = route_l1 + route_l2 + route_l3
    cost = cost1 + cost2 + cost3
    return [route, cost]

def printMergeData(pTripsDF, pTripId1, pTripId2, pCombNdx):
    trip_1 = pTripsDF[pTripsDF.id == pTripId1].iloc[0]
    trip_2 = pTripsDF[pTripsDF.id == pTripId2].iloc[0]
    
    a1,b1,a1t,b1t = trip_1['from'], trip_1['to'], trip_1['depart'], trip_1['arrive']
    a2,b2,a2t,b2t = trip_2['from'], trip_2['to'], trip_2['depart'], trip_2['arrive']
    
    combs = [[a1 , a2 , b1 , b2],
             [a1 , a2 , b2 , b1],
             [a2 , a1 , b1 , b2],
             [a2 , a1 , b2 , b1],
             [a1 , b1 , a2 , b2],
             ]

    time_combs = [[a1t , a2t , b1t , b2t],
                  [a1t , a2t , b2t , b1t],
                  [a2t , a1t , b1t , b2t],
                  [a2t , a1t , b2t , b1t],
                  [a1t , b1t , a2t , b2t],
             ]

    prn_str = 'trip1 - fr: {}, to: {}, depart: {}, arrive: {}, cost: {}'
#    'id', 'depart', 'from', 'to', 'fromTaz', 'toTaz', 'processed', 'cost', 'arrive', 'frto_taz'
    print('--------------------')
    print(prn_str.format(*trip_1[['from', 'to','depart','arrive', 'cost']].to_list()))
    print(prn_str.format(*trip_2[['from', 'to','depart','arrive', 'cost']].to_list()))
    
    rout_costs = getRouteCostId(*combs[pCombNdx])
    rout_times = np.array(getRouteTimes(np.array(rout_costs))) + time_combs[pCombNdx][0]
    rout_costs = [0] + rout_costs
    print('Timing:\t {:>8}\t {:>8}\t {}\t {}'.format('edge', 'dist', 'arrive', 'delay'))
    for i in range(len(rout_times)):
        t_edge, t_dist = combs[pCombNdx][i], rout_costs[i]
        t_arrive, t_delay = rout_times[i], abs(time_combs[pCombNdx][i] - rout_times[i])
        print('\t {:>8}\t {:>8.0f}\t {}\t {}'.format(t_edge, t_dist, t_arrive, t_delay))
    
    gain = (trip_1.cost + trip_2.cost - sum(rout_costs)) / 100
    print('gain: {:,.0f} m'.format(gain))
    
    return combs[pCombNdx]

''' 
    ***************************************************************************
    
    general Part
    
    ***************************************************************************
'''

def get_mod(pN, pK = 20):
    d = pN / pK
    if d >= 1000:
        r = 3
    elif d >= 100:
        r = 2
    elif d >= 10:
        r = 1
    else:
        r = 0
    o = int(round(d, -r))
    if o == 0: o = 1
    return o

def getEdgeMidPos(pEdgeId):
    global edge_ids, edge_lens
    ndx = getEdgNdx(pEdgeId)
    mid_pos = edge_lens[ndx] / 200
    return mid_pos

def getRunStr(pStr):
    ndx =  pStr.rfind('\\') + 1
    return pStr[ndx:]


def getScnDirs(pPrmsDict, pSim_Scn, pSimNdx):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    pBasePath, pSt, pEn = pPrmsDict['base_path'],pPrmsDict['run_start'],pPrmsDict['run_end']
    t_out = []
    for i in range(pSt, pEn+1):
        t_dir = '{}{:0>3d}'.format(run_path,i)
        path = os.path.join(pBasePath, scn_path, t_dir)
        t_out.append(path)
    return t_out


def make_get_ScnDirs(pPrmsDict, pSim_Scn, pSimNdx, verbose=False):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    t_out = getScnDirs(pPrmsDict, pSim_Scn, pSimNdx)
    for i, t_path in enumerate(t_out):
        if os.path.exists(t_path) and os.path.isdir(t_path):
            continue
        os.makedirs(t_path,exist_ok=False)
        if verbose:
            print('Directory created: {}'.format(getRunStr(t_path)))
    return t_out

''' 
    ***************************************************************************
    
    route Part - python version
    
    ***************************************************************************
'''

def getEdgNdx(pEdgID):
    'Locate the leftmost value exactly equal to x'
    global edge_ids
    pEdgSortedList = edge_ids
    i = bisect_left(pEdgSortedList, pEdgID)
    if i != len(pEdgSortedList) and pEdgSortedList[i] ==  pEdgID:
        return i
    raise ValueError

def getFrToCostId(pOrgnId,pDestId, pIncludeFT=True):
    global edg2edg, edge_ids, edge_lens
    orgn_ndx, dest_ndx = getEdgNdx(pOrgnId),getEdgNdx(pDestId)
    cost = edg2edg[orgn_ndx, dest_ndx]
    if pIncludeFT:
        return cost
    else:
        return cost - edge_lens[orgn_ndx] - edge_lens[dest_ndx]
#%timeit _ = getFrToCostId(e1,e2, True)
#2.32 µs ± 2.99 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


def getFrToCostNdx(pOrgnNdx,pDestNdx, pIncludeFT=True):
    global edg2edg, edge_lens
    cost = edg2edg[pOrgnNdx,pDestNdx]
    if pIncludeFT:
        return cost
    else:
        return cost - edge_lens[pOrgnNdx] - edge_lens[pDestNdx]
#%timeit _ = getFrToCostNdx(10,1000, True)
#332 ns ± 0.0925 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

def getFrToInc(pOrgnNdx,pDestNdx):
    global edg2edg, edge_lens
    cost = edg2edg[pOrgnNdx,pDestNdx]
    return cost

def getFrToNot(pOrgnNdx,pDestNdx):
    global edg2edg, edge_lens
    cost = edg2edg[pOrgnNdx,pDestNdx]
    return cost - edge_lens[pOrgnNdx] - edge_lens[pDestNdx]

def getRouteCostId(e1 , e2 , e3 , e4):
    c0,c1,c2 = getFrToCostId(e1,e2) , getFrToCostId(e2,e3,False) , getFrToCostId(e3,e4)
    return [c0,c1,c2]
#%timeit _ = getRouteCostId(a1,b1,a2,b2)
#12.8 µs ± 15.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


def getRouteCostNdx(e1 , e2 , e3 , e4):
    c0,c1,c2 = getFrToInc(e1,e2) , getFrToNot(e2,e3) , getFrToInc(e3,e4)
    return [c0,c1,c2]
#%timeit _ = getRouteCostNdx(na1,nb1,na2,nb2)
#6.45 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

def calcMixGainNdx(na1,nb1,na2,nb2, cost1, cost2):
    l_costs = [getRouteCostNdx(na1 , na2 , nb1 , nb2),
               getRouteCostNdx(na1 , na2 , nb2 , nb1),
               getRouteCostNdx(na2 , na1 , nb1 , nb2),
               getRouteCostNdx(na2 , na1 , nb2 , nb1),
               [cost1 ,0, cost2],               
             ]
    
    costs = [sum(x) for x in l_costs]
    t_ndx = costs.index(min(costs))
    
    if t_ndx == 4:
        gain = 0
    else:
        gain = costs[4] - costs[t_ndx]
    return [gain, t_ndx]

#%timeit _ = calcMixGainNdx(na1,nb1,na2,nb2,cost1,cost2)
#57.3 µs ± 101 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
#%timeit _ = calcMixGain(trip1, trip2)
#122 µs ± 155 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

''' 
    ***************************************************************************
    
    route Part - compiled version
    
    ***************************************************************************
'''

from numba import njit

@njit
def getFrToCostNdxNb(pOrgnNdx,pDestNdx, pIncludeFT=True):
    global edg2edg, edge_lens
    cost = edg2edg[pOrgnNdx,pDestNdx]
    if pIncludeFT:
        return cost
    else:
        return cost - edge_lens[pOrgnNdx] - edge_lens[pDestNdx]
#%timeit _ = getFrToCostNdxNb(na1,nb1, False)
#535 ns ± 0.31 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#%timeit _ = getFrToCostNdxNb(na1,nb1, True)
#522 ns ± 0.244 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

@njit
def getFrToNotNb(pOrgnNdx,pDestNdx):
    global edg2edg, edge_lens
    cost = edg2edg[pOrgnNdx,pDestNdx]
    return cost - edge_lens[pOrgnNdx] - edge_lens[pDestNdx]
#%timeit _ = getFrToNotNb(na1,nb1)
#480 ns ± 7.87 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


@njit
def getFrToIncNb(pOrgnNdx,pDestNdx):
    global edg2edg, edge_lens
    cost = edg2edg[pOrgnNdx,pDestNdx]
    return cost
#%timeit _ = getFrToIncNb(na1,nb1)
#469 ns ± 0.18 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

@njit
def getRouteCostNdxNb(e1 , e2 , e3 , e4):
    c0,c1,c2 = getFrToIncNb(e1,e2) , getFrToNotNb(e2,e3) , getFrToIncNb(e3,e4)
    return [c0,c1,c2]
#%timeit _ = getRouteCostNdx(na1,nb1,na2,nb2)
#6.45 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#%timeit _ = getRouteCostNdxNb(na1,nb1,na2,nb2)
#1.02 µs ± 8.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

@njit
def getRouteTimes(pCosts):
#    global e2e_s_dict
    c1 , c2 , c3 = pCosts
    t_speed = 1200 #e2e_s_dict['avg_speed_cms']
    t1,t2,t3 = c1/t_speed, c2/t_speed, c3/t_speed
    t2 += t1
    t3 += t2
    return [0,int(t1+0.5),int(t2+0.5),int(t3+0.5)]

@njit
def calcMixGainNdxNb(na1,nb1,na2,nb2, cost1, cost2):
    l_costs = [getRouteCostNdxNb(na1 , na2 , nb1 , nb2),
               getRouteCostNdxNb(na1 , na2 , nb2 , nb1),
               getRouteCostNdxNb(na2 , na1 , nb1 , nb2),
               getRouteCostNdxNb(na2 , na1 , nb2 , nb1),
               [cost1 ,0, cost2],               
             ]
    
    l_costs = np.array(l_costs)
    costs = l_costs.sum(axis=1) 
    t_ndx = costs.argmin()
    if t_ndx == 4:
        gain = 0
        t0,t1,t2,t3 = 0,0,0,0
    else:
        gain = costs[4] - costs[t_ndx]
        t0,t1,t2,t3 = getRouteTimes(l_costs[t_ndx])
    return [gain, t_ndx, t1, t2, t3]
#%timeit _ = calcMixGainNdxNb(na1,nb1,na2,nb2,cost1,cost2)
#3.51 µs ± 5.33 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@njit
def mapRouteTimes(pA1t, pA2t, pT1, pT2, pT3, pCombNdx):
#    [a1,b1, a2, b2]
#    out = [0,0,0,0]
    
    if pCombNdx == 0:
#        pT0, pT2, pT1, pT3
#        na1, nb1, na2, nb2
        out = pA1t, pA1t+pT2, pA1t+pT1, pA1t+pT3
    elif pCombNdx == 1:
#        pT0, pT3, pT1, pT2
#        na1, nb1, na2, nb2
        out = pA1t, pA1t+pT3, pA1t+pT1, pA1t+pT2
    elif pCombNdx == 2:
#        pT1, pT2, pT0, pT3
#        na1, nb1, na2, nb2
        out = pA2t+pT1, pA2t+pT2, pA2t, pA2t+pT3
    elif pCombNdx == 3:
#        pT1, pT3, pT0, pT2
#        na1, nb1, na2, nb2
        out = pA2t+pT1, pA2t+pT3, pA2t, pA2t+pT2
    elif pCombNdx == 4:
        out = pA1t, pA1t+pT1, pA2t, pA2t+pT3
    return  out


@njit
def isValidMixed(pA1t, pB1t, pA2t, pB2t, pT1, pT2, pT3, pCombNdx):
    if pCombNdx > 3:
        return False
    delay, tot_delay = 600, 900 #e2e_s_dict['alwd_delay'], e2e_s_dict['alwd_tot_delay']
    
    mix_times = mapRouteTimes(pA1t, pA2t, pT1, pT2, pT3, pCombNdx)
    a1_delay, b1_delay, a2_delay, b2_delay = [abs(x - y) for x,y in zip(mix_times,[pA1t, pB1t, pA2t, pB2t])]
    t1_total, t2_total = a1_delay + b1_delay , a2_delay + b2_delay
#    print(a1_delay,b1_delay,a2_delay,b2_delay,t1_total,t2_total)
    if a1_delay <= delay and  b1_delay <= delay and  a2_delay <= delay and  b2_delay <= delay and \
        t1_total  <= tot_delay and  t2_total <= tot_delay:
        out_flg = True
    else:
        out_flg = False
    return out_flg

    
@njit
def getBestGain(pPrimTrip, pLookupTrips):
#   tmp_cols = [' 0', '     1', '        2', '       3', '       4', '   5', '     6', '     7']
#   tmp_cols = ['id', 'depart', 'processed', 'orgn_ndx', 'dest_ndx', 'cost', 'arrive', 'df_ndx']
    tr_id1, depart1, processed1, ndx_a1, ndx_b1, cost1, arrive1, df_ndx1 = pPrimTrip
    max_gain = max_cmb = max_row = 0
    for i in range(len(pLookupTrips)):
        tr_id2, depart2, processed2, ndx_a2, ndx_b2, cost2, arrive2, df_ndx2 = pLookupTrips[i]
        gain, com_ndx, t1, t2, t3 = calcMixGainNdxNb(ndx_a1,ndx_b1,ndx_a2, ndx_b2, cost1, cost2)
        t_flag = isValidMixed(depart1, arrive1, depart2, arrive2, t1, t2, t3, com_ndx)
        if gain > max_gain and t_flag:
            max_gain = gain
            max_cmb = com_ndx
            max_row = i
    return [max_gain, max_cmb, max_row]

''' 
    ***************************************************************************
    
    route Part - reading route file
    
    ***************************************************************************
'''

def readRouteFile(pPath, pRouteFile):
#    Index(['id', 'depart', 'from', 'to', 'fromTaz', 'toTaz', 'processed', 'cost','arrive', 'frto_taz']
#    ['id', 'depart', 'from', 'to', 'fromTaz', 'toTaz', 'processed', 'orgn_ndx', 'dest_ndx', 'cost', 'arrive', 'frto_taz']
    global e2e_s_dict
    tdf = pd.read_xml(os.path.join(pPath, pRouteFile))
    tdf['processed'] = False
    to_delete = ['departLane','departSpeed']
    tdf.drop(to_delete, inplace=True, axis=1)
    tdf['orgn_ndx'] = tdf['from'].apply(getEdgNdx).astype(int)
    tdf['dest_ndx'] = tdf['to'].apply(getEdgNdx).astype(int)
    t_c0 = [getFrToCostNdx(row['orgn_ndx'], row['dest_ndx']) for ndx,row in tdf.iterrows()]
    tdf['cost'] = t_c0
    t_speed =  e2e_s_dict['avg_speed_cms']
    tdf['arrive'] = (tdf['depart'] + tdf['cost'] / t_speed + 0.5).astype(int)
    tdf['depart'] = (tdf['depart'] + 0.5).astype(int)    
    tdf['id'] = tdf['id'].astype(int)
    tdf = tdf.convert_dtypes()

    tdf['frto_taz'] = tdf['fromTaz'] + '-' + tdf['toTaz']
    tdf['df_ndx'] = tdf.index
    return tdf

# run_path = os.path.join(e2e_s_dict['base_path'], 'run_009_mht\\')
# rou_file = 'rtm_batch.rou.xml'


'''
    ***************************************************************************
    
    fused trips to xml part
    
    ***************************************************************************
'''

def fusedTrips2XmlSeries(pTrip1,pTrip2, pMode):
    'mode','t_id','depart','from_edg','to_edg','via_edg','end_pos'
    t_id = '{}{:0>6}-{:0>6}'.format(pMode, pTrip1['id'], pTrip2['id'])
    t_depart = pTrip1['depart'] + round(np.random.rand(), 2)
    via_edg = end_pos = ''

    a1,b1 = pTrip1['from'], pTrip1['to']
    a2,b2 = pTrip2['from'], pTrip2['to']

    combs = [[a1 , a2 , b1 , b2],
             [a1 , a2 , b2 , b1],
             [a2 , a1 , b1 , b2],
             [a2 , a1 , b2 , b1],
             [a1 , b1 , a2 , b2],
             ]
    
    if pMode == 5:
        t_from, t_to = pTrip1['from'], pTrip1['to']
    elif 0 <= pMode < 4 :
        t_from, t_to = combs[pMode][0], combs[pMode][3]
        via_edg = '{}_{}'.format(combs[pMode][1], combs[pMode][2])
        end_pos = '{:.1f}_{:.1f}'.format(getEdgeMidPos(combs[pMode][1]), getEdgeMidPos(combs[pMode][2]))
    else :
        raise('mht_trip_fusion3 > getTripXmlStr(...) > not suported pMode: {}'.format(pMode))
    
    t_out = [pMode, t_id, t_depart, t_from, t_to, via_edg, end_pos]
    return t_out

def fusedTrips2XmlDf(pDfInTrips, pDfFt2):
#    ['trip1_ndx', 'trip2_ndx', 'trip1_id', 'trip2_id', 'gain', 'comb_ndx']
    print('Merging trips ...')
    cnt, cnt_len, cnt_mod = 0, len(pDfFt2), get_mod(len(pDfFt2), 5) #round(len(pDf) / 100, -2)
    t_out = []
    processed_list = []
    for i in range(len(pDfFt2)):
        if cnt % cnt_mod == 0:
            print('{}\t{}'.format(cnt, cnt_len))
        cnt += 1

        f_trips = pDfFt2.iloc[i]
        trip1 = pDfInTrips.iloc[f_trips['trip1_ndx']]
        trip2 = pDfInTrips.iloc[f_trips['trip2_ndx']]
        t_mode = f_trips['comb_ndx']
        t_row = fusedTrips2XmlSeries(trip1, trip2, t_mode)
        processed_list.extend([f_trips['trip1_ndx'], f_trips['trip2_ndx']])
        t_out.append(t_row)
    
    pDfInTrips.loc[processed_list, 'processed'] = True


#    ['id', 'depart', 'from', 'to', 'fromTaz', 'toTaz', 'processed', 'cost', 'arrive', 'frto_taz']
    print('Processing untouched trips ...')
    mask = pDfInTrips['processed'] == False
    np_trips = pDfInTrips[mask]
    cnt, cnt_len, cnt_mod = 0, len(np_trips), get_mod(len(np_trips), 5) #round(len(pDf) / 100, -2)    
    for i in range(len(np_trips)):
        if cnt % cnt_mod == 0:
            print('{}\t{}'.format(cnt, cnt_len))
        cnt += 1

        o_trip = np_trips.iloc[i]
        t_id, depart, from_edg, to_edg = o_trip[['id', 'depart', 'from', 'to']].tolist()
        t_id, depart = '6{:0>6}-000000'.format(t_id), depart + round(np.random.rand(), 2)
        t_list = [6, t_id, depart, from_edg, to_edg, '', '']
        t_out.append(t_list)

    df_out = pd.DataFrame(t_out).convert_dtypes()
    df_out.columns = ['mode','t_id','depart','from_edg','to_edg','via_edg','end_pos']
    df_out.sort_values('depart', ascending=True, inplace=True, ignore_index=True)
    return df_out

def getTripStops(pViaEdge, pViaPos):
    edg_lst = pViaEdge.split('_')
    pos_lst = pViaPos.split('_')
    pos0, pos1 = fast_float(pos_lst[0]), fast_float(pos_lst[1])
    if pos0 > 5:
        stop0 = '<stop edge="{}" endPos="{:.0f}" duration="30"/>'.format(edg_lst[0], pos0)
    else:
        stop0 = '<stop edge="{}" duration="30"/>'.format(edg_lst[0])
    
    if pos1 > 5:
        stop1 = '<stop edge="{}" endPos="{:.0f}" duration="30"/>'.format(edg_lst[1], pos1)
    else:
        stop1 = '<stop edge="{}" duration="30"/>'.format(edg_lst[1])
    return stop0 + stop1

def getTripXmlStr(pTripXmlSr):
#    ['mode', 't_id', 'depart', 'from_edg', 'to_edg', 'via_edg', 'end_pos']
    mode, t_id, depart, from_edg, to_edg, via_edg, end_pos = pTripXmlSr
    if mode == 5 or mode == 6:
        out_str = '<trip id="{}" depart="{}" from="{}" to="{}" departLane="free" departSpeed="max"/>'
        return out_str.format(t_id, depart, from_edg, to_edg)
    elif 0 <= mode < 4 :
        out_str0 = '<trip id="{}" depart="{}" from="{}" to="{}" via="{}" departLane="free" departSpeed="max">'
        out_str1 = getTripStops(via_edg, end_pos)
        out_str2 = '</trip>'
        out_str0 = out_str0.format(t_id, depart, from_edg, to_edg, via_edg.replace('_', ' '))
        return out_str0 + out_str1 + out_str2
    else :
        raise('mht_trip_fusion3 > getTripXmlStr(...) > not suported pMode: {}'.format(mode))

def fusedTrips2XmlStr(pDfTripsXml):
    out = ['<?xml version="1.0" encoding="UTF-8"?>']
    time_stamp = gu.getCurrentDateTime().replace('T', ' ')
    out.append('<!-- generated on {} by Heuristic Trip Fusion Algorithm Version 4.0 -->'.format(time_stamp))
    out.append('<routes xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://sumo.dlr.de/xsd/routes_file.xsd">')
    for i in range(len(pDfTripsXml)):
        xml_trip = getTripXmlStr(pDfTripsXml.iloc[i])
        out.append(xml_trip)
    out.append('</routes>')
    out_str = '\n'.join(out)
    out_str = out_str.replace('<trip id', '\t<trip id').replace('<stop edge', '\n\t\t<stop edge').replace('</trip>', '\n\t</trip>')
    return out_str + '\n'

'''
    ***************************************************************************
    
    fusion algorithm 1 - time sorted lookup part
    
    ***************************************************************************
'''

def getCandidTripsAlg1(pDf, pPrimeTrip):
    row = pPrimeTrip
    p_id, depart,arrive = row['id'], row['depart'], row['arrive']
    mask = (pDf['processed'] == False) & (pDf['depart'] >= depart) & (pDf['depart'] < arrive) & (pDf['id'] != p_id)
    tdf_lookup = pDf[mask]
    return tdf_lookup
#%timeit _ = getCandidTripsAlg1(a0_trips, a0_prime)
#5.13 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

def getCandidTripsAlg1Np(pArr, pPrimeTrip):
#    ['id', 'depart', 'processed', 'orgn_ndx', 'dest_ndx', 'cost', 'arrive', 'df_ndx']
#    pPrimeTrip = pArr[pPrimeNdx]
    p_id, depart,arrive = pPrimeTrip[0], pPrimeTrip[1], pPrimeTrip[6]
    mask = (pArr[:,2] == False) & (pArr[:,1] >= depart) & (pArr[:,1] < arrive) & (pArr[:,0] != p_id)
    tdf_lookup = pArr[mask]
    return tdf_lookup

#%timeit _ = getCandidTripsAlg1Numpy(a0_arr, a0_arr[88000])
#979 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

def mixTimeSortedTrips(pDf): # algorithm 1
#    columns: ['id', 'depart', 'from', 'to', 'fromTaz', 'toTaz', 'processed', 'cost', 'arrive', 'frto_taz', 'df_ndx']
    global e2e_s_dict
    tdf = pDf
    out, total_gain, gain_threshold = [], 0.0, e2e_s_dict['merge_gain_threshold']
    cnt, cnt_len, cnt_mod = 0, len(tdf), get_mod(len(tdf), 20) #round(len(pDf) / 100, -2)
    tic_toc = TicToc(); tic_toc.tic()
#   tmp_cols = [' 0', '     1', '        2', '       3', '       4', '   5', '     6', '     7']
    tmp_cols = ['id', 'depart', 'processed', 'orgn_ndx', 'dest_ndx', 'cost', 'arrive', 'df_ndx']
    np_tdf = tdf[tmp_cols].to_numpy(dtype=np.int32)
    for i in range(len(tdf)):
        if cnt % cnt_mod == 0:
            print('{}\t{}\t{:,.0f}'.format(cnt, cnt_len, total_gain/100000), end='\t')
            tic_toc.toc()
        cnt += 1

        p_row = np_tdf[i]
        p_ndx = p_row[7]
        if p_row[2]:
            continue
        np_lookup = getCandidTripsAlg1Np(np_tdf, p_row)

        if len(np_lookup) == 0:
            continue
        
        t_gain, t_comb, t_row = getBestGain(p_row, np_lookup)
        if t_gain > gain_threshold:
            s_trip = np_lookup[t_row]
            s_ndx, s_id = s_trip[7], s_trip[0]
            out.append([p_ndx, s_ndx, p_row[0], s_id, t_gain, t_comb])
            np_tdf[[p_ndx, s_ndx], 2] = True
            total_gain += t_gain
    tic_toc.toc()
    return out


def mergeTripsAlg1(pSim_Scn, pInScnNdx, pOutScnNdx):
    global e2e_s_dict
    
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pOutScnNdx]
    base_scn_path, base_run_path, base_rout_file, base_aggregate_path = pSim_Scn[pInScnNdx]
    
    pInFolders = getScnDirs(e2e_s_dict, pSim_Scn, pInScnNdx)
    pOutFolders = make_get_ScnDirs(e2e_s_dict, pSim_Scn, pOutScnNdx, verbose= 1)

    in_file_name = base_rout_file
    out_file_name = rout_file
    out_stat = []
    for i in range(len(pInFolders)):
        i_dir, o_dir = pInFolders[i], pOutFolders[i]
        print('Processing: {}'.format(getRunStr(i_dir)))
        in_trips = readRouteFile(i_dir, in_file_name)
        out_mixed = mixTimeSortedTrips(in_trips)
        if len(out_mixed)>0:
            out_df = pd.DataFrame(out_mixed)
            out_df.columns = ['trip1_ndx','trip2_ndx','trip1_id','trip2_id','gain','comb_ndx']
            fn_intermediate = os.path.join(o_dir, 'intermediate_data.csv')
            out_df.to_csv(fn_intermediate, sep='\t', index_label='row')
            
        else:
            continue
        
        print('Converting to XML ...')
        out_3 = fusedTrips2XmlDf(in_trips, out_df)
        out_4 = fusedTrips2XmlStr(out_3)

        print('--------------------')
        t_gain = (out_df['gain'].sum()/ 100000).sum()
        t_cost = (in_trips['cost']  / 100000).sum()
        t_gain_prcnt = round(t_gain / t_cost * 100, 2)
        prn_str = 'in_trips: {}, out_trips: {}, total_cost: {:,.0f}, gain: {:,.0f}  %{:.1f}'
        print(prn_str.format(len(in_trips) , len(out_3), t_cost, t_gain, t_gain_prcnt))
        print('--------------------')
        out_stat.append([len(in_trips) , len(out_df), t_cost, t_gain, t_gain_prcnt])
        print('Writing: {}'.format(getRunStr(o_dir)))
        gu.FileSave(os.path.join(o_dir, out_file_name), out_4, wMode = 'w')
    print(out_stat)
    return out_stat


'''
    ***************************************************************************
    
    fusion algorithm 2 - lenght sorted lookup part
    
    ***************************************************************************
'''
def readAlg2RouteFile(p_dir, p_file_name):
    tdf = readRouteFile(p_dir, p_file_name)
    tdf.sort_values('cost', ascending=False, ignore_index = True, inplace = True)
    tdf['df_ndx'] = tdf.index
    return tdf
   
def getCandidTripsAlg2Np(pArr, pPrimeTrip):
#    ['id', 'depart', 'processed', 'orgn_ndx', 'dest_ndx', 'cost', 'arrive', 'df_ndx']
#    pPrimeTrip = pArr[pPrimeNdx]
    p_id, depart,arrive = pPrimeTrip[0], pPrimeTrip[1], pPrimeTrip[6]
    mask = (pArr[:,2] == False) & \
        (((pArr[:,1] >= depart) & (pArr[:,1] < arrive)) | ((pArr[:,6] > depart) & (pArr[:,6] <= arrive))) & \
        (pArr[:,0] != p_id)
    tdf_lookup = pArr[mask]
    return tdf_lookup

def mixCostSortedTrips(pDf): # algorithm 2
#    columns: ['id', 'depart', 'from', 'to', 'fromTaz', 'toTaz', 'processed', 'cost', 'arrive', 'frto_taz', 'df_ndx']
    global e2e_s_dict
    tdf = pDf
    out, total_gain, gain_threshold = [], 0.0, e2e_s_dict['merge_gain_threshold']
    cnt, cnt_len, cnt_mod = 0, len(tdf), get_mod(len(tdf), 20) #round(len(pDf) / 100, -2)
    tic_toc = TicToc(); tic_toc.tic()
#   tmp_cols = [' 0', '     1', '        2', '       3', '       4', '   5', '     6', '     7']
    tmp_cols = ['id', 'depart', 'processed', 'orgn_ndx', 'dest_ndx', 'cost', 'arrive', 'df_ndx']
    np_tdf = tdf[tmp_cols].to_numpy(dtype=np.int32)
    for i in range(len(tdf)):
        if cnt % cnt_mod == 0:
            print('{}\t{}\t{:,.0f}'.format(cnt, cnt_len, total_gain/100000), end='\t')
            tic_toc.toc()
        cnt += 1

        p_row = np_tdf[i]
        p_ndx = p_row[7]
        if p_row[2]:
            continue
        
        np_lookup = getCandidTripsAlg2Np(np_tdf, p_row)

        if len(np_lookup) == 0:
            continue
        
        t_gain, t_comb, t_row = getBestGain(p_row, np_lookup)

        if t_gain > gain_threshold:
            s_trip = np_lookup[t_row]
            s_ndx, s_id = s_trip[7], s_trip[0]
            out.append([p_ndx, s_ndx, p_row[0], s_id, t_gain, t_comb])
            np_tdf[[p_ndx, s_ndx], 2] = True
            total_gain += t_gain
    tic_toc.toc()
    return out

def mergeTripsAlg2(pSim_Scn, pInScnNdx, pOutScnNdx):
    
    global e2e_s_dict

    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pOutScnNdx]
    base_scn_path, base_run_path, base_rout_file, base_aggregate_path = pSim_Scn[pInScnNdx]
    
    pInFolders = getScnDirs(e2e_s_dict, pSim_Scn, pInScnNdx)
    pOutFolders = make_get_ScnDirs(e2e_s_dict, pSim_Scn, pOutScnNdx, verbose= 1)
    
    in_file_name = base_rout_file
    out_file_name = rout_file
    out_stat = []
    for i in range(len(pInFolders)):
        i_dir, o_dir = pInFolders[i], pOutFolders[i]
        print('Processing: {}'.format(getRunStr(i_dir)))
        in_trips = readAlg2RouteFile(i_dir, in_file_name)

        out_mixed = mixCostSortedTrips(in_trips)
        if len(out_mixed)>0:
            out_df = pd.DataFrame(out_mixed)
            out_df.columns = ['trip1_ndx','trip2_ndx','trip1_id','trip2_id','gain','comb_ndx']
            fn_intermediate = os.path.join(o_dir, 'intermediate_data.csv')
            out_df.to_csv(fn_intermediate, sep='\t', index_label='row')
        else:
            continue
        
        print('Converting to XML ...')
        out_3 = fusedTrips2XmlDf(in_trips, out_df)
        out_4 = fusedTrips2XmlStr(out_3)

        print('--------------------')
        t_gain = (out_df['gain'].sum()/ 100000).sum()
        t_cost = (in_trips['cost']  / 100000).sum()
        t_gain_prcnt = round(t_gain / t_cost * 100, 2)
        prn_str = 'in_trips: {}, out_trips: {}, total_cost: {:,.0f}, gain: {:,.0f}  %{:.1f}'
        print(prn_str.format(len(in_trips) , len(out_3), t_cost, t_gain, t_gain_prcnt))
        print('--------------------')
        out_stat.append([len(in_trips) , len(out_df), t_cost, t_gain, t_gain_prcnt])
        print('Writing: {}'.format(getRunStr(o_dir)))
        gu.FileSave(os.path.join(o_dir, out_file_name), out_4, wMode = 'w')
    print(out_stat)
    return out_stat


'''
    ***************************************************************************
    
    fusion without stop - from - fusion with stop route_trip files (alg2)
    
    ***************************************************************************
'''

def readFusedRouteFile(pPath, pRouteFile):
#    ['id', 'depart', 'from', 'to', 'via', 'departLane', 'departSpeed', 'stop']
    tdf = pd.read_xml(os.path.join(pPath, pRouteFile))
    tdf = tdf.convert_dtypes()
    return tdf[['id', 'depart', 'from', 'to', 'via', 'departLane', 'departSpeed', 'stop']]

def getNonstopTrips(pTripSr):
#    <trip id="6000055-000000" depart="23.55" from="691267087" to="47271299#0" departLane="free" departSpeed="max"/>
    global e2e_s_dict
    tr_id, tr_depart, tr_from, tr_to, tr_via, departLane, departSpeed, tr_stop = pTripSr
    out = []
    if pd.isna(tr_via):
        out.append([tr_id, tr_depart, tr_from, tr_to, 0, 0, departLane, departSpeed])
    else:
        stops = tr_via.split(' ')
        mid_pos_m = [getEdgeMidPos(stops[0]), getEdgeMidPos(stops[1])]
        mid_pos_cm = [int(mid_pos_m[0] * 100), int(mid_pos_m[1] * 100)]
        t_speed =  e2e_s_dict['avg_speed_cms']
        
        departPos, arrivalPos = 0, mid_pos_m[0]
        trip_0 = [tr_id + '-0', tr_depart, tr_from, stops[0], departPos, arrivalPos, departLane, departSpeed]
        
        arrive_cost = getFrToCostId(tr_from, stops[0], pIncludeFT=True) - mid_pos_cm[0]
        tr_depart1 = round(tr_depart + (arrive_cost / t_speed) + 30, 2)
        departPos, arrivalPos = mid_pos_m
        trip_1 = [tr_id + '-1', tr_depart1, stops[0], stops[1], departPos, arrivalPos, departLane, departSpeed]

        arrive_cost = getFrToCostId(stops[0],stops[1], pIncludeFT=True) - mid_pos_cm[0] - mid_pos_cm[1]
        tr_depart2 = round(tr_depart1 + (arrive_cost / t_speed) + 30, 2)
        departPos, arrivalPos = mid_pos_m[1], 0        
        trip_2 = [tr_id + '-2', tr_depart2, stops[1], tr_to, departPos, arrivalPos, departLane, departSpeed]

        out.extend([trip_0,trip_1,trip_2])
    
    return out
    
def convertFusedRoute2Nonstop(pDf):
#    ['id', 'depart', 'from', 'to', 'via', 'departLane', 'departSpeed', 'stop']
    out = []
    cnt, cnt_len, cnt_mod = 0, len(pDf), get_mod(len(pDf), 10) #round(len(pDf) / 100, -2)
    for i in range(len(pDf)):
        if cnt % cnt_mod == 0:
            print('{}\t{}'.format(cnt, cnt_len))
        cnt += 1
        trip_ns = getNonstopTrips(pDf.iloc[i])
        out.extend(trip_ns)
        
    tdf = pd.DataFrame(out).convert_dtypes()
    tdf.columns = ['id', 'depart', 'from', 'to', 'departPos', 'arrivalPos', 'departLane', 'departSpeed']
    tdf.sort_values('depart', ascending=True, inplace=True, ignore_index=True)
    
    return tdf

def getNsTripXmlStr(pTripNsSr):
    pos_threshold = 5
    tr_id, depart, tr_from, tr_to, departPos, arrivalPos, departLane, departSpeed = pTripNsSr

    if departPos > pos_threshold:
        depPosStr = ' departPos="{}" '.format(departPos)
    else:
        depPosStr = ''

    if arrivalPos > pos_threshold:
        arvPosStr = ' arrivalPos="{}" '.format(arrivalPos)
    else:
        arvPosStr = ''
        
    out_tmp = '<trip id="{}" depart="{}" from="{}" to="{}" {} {} departLane="{}" departSpeed="{}"/>'
    out_str = out_tmp.format(tr_id, depart, tr_from, tr_to, depPosStr, arvPosStr, departLane, departSpeed)
    out_str = gu.removeDblSpaces(out_str)
    return out_str

def nonstopTrips2XmlStr(pDfTripsXml):
    out = ['<?xml version="1.0" encoding="UTF-8"?>']
    time_stamp = gu.getCurrentDateTime().replace('T', ' ')
    out.append('<!-- generated on {} by Heuristic Trip Fusion Algorithm (Via Stops to Nonstop) Version 4.0 -->'.format(time_stamp))
    out.append('<routes xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://sumo.dlr.de/xsd/routes_file.xsd">')
    for i in range(len(pDfTripsXml)):
        xml_trip = getNsTripXmlStr(pDfTripsXml.iloc[i])
        out.append(xml_trip)
    out.append('</routes>')
    out_str = '\n'.join(out)
    out_str = out_str.replace('<trip id', '\t<trip id')
    return out_str + '\n'

def mainFused2NonStop(pSim_Scn, pInScnNdx, pOutScnNdx):
    global e2e_s_dict
    t_in_folders = getScnDirs(e2e_s_dict, pSim_Scn, pInScnNdx)
    t_out_folders = make_get_ScnDirs(e2e_s_dict, pSim_Scn, pOutScnNdx, verbose= 1)
    in_file_name = pSim_Scn[pInScnNdx][2]
    out_file_name = pSim_Scn[pOutScnNdx][2]
    out = []
    for i in range(len(t_in_folders)):
        i_dir, o_dir = t_in_folders[i], t_out_folders[i]
        print('Processing: {}'.format(getRunStr(i_dir)))
        in_trips = readFusedRouteFile(i_dir, in_file_name)
        ns_trips = convertFusedRoute2Nonstop(in_trips)
        out_xml = nonstopTrips2XmlStr(ns_trips)
        out.append([getRunStr(o_dir), len(in_trips), len(ns_trips)])
        print('Writing: {}\t in_trips: {}, out_trips: {}'.format(*out[-1]))
        gu.FileSave(os.path.join(o_dir, out_file_name), out_xml, wMode = 'w')
        
    return out

if __name__ == '__main__':
    # this file is for fusion
    # pattern is: (scenario_list, input_scenario, outpot_scenario)
    
    # algorith 1
    out_stat = mergeTripsAlg1(sim_scen, 0, 1)
    
    # algorith 2
    out_stat = mergeTripsAlg2(sim_scen, 0, 2)

    # algorith 3
    out_stat = mainFused2NonStop(sim_scen, 2, 3)

