# -*- coding: utf-8 -*-
"""

MIT License, for detailes please read LICENSE.TXT

Copyright (c) 2022 Mahtab Sharifi

Created on Thu Mar 24 16:54:04 2022

"""

print("Initializing ...")

import os
from multiprocessing import Pool
from datetime import timedelta

import numpy as np
import pandas as pd

from fnmatch import fnmatch
from bisect import bisect_left

import mht_sumolib as mtSumoLib
import General_Utils as gu
from pytictoc import TicToc

''' 
    ***************************************************************************
    
    init part
    
    ***************************************************************************
'''

import base_parameters 
prms_dict = base_parameters.prms_dict


e2e_s_dict = {
       
        'edg2edg_file' : 'rtm_edg_ids.txt',        
        'edges_count' : 5024,
        'chunk_size' : 200,
        
        'sleep_threshold' : 300, # seconds
        'sleep_duration' : 60, # seconds
        }

e2e_s_dict.update(prms_dict)


net_path = os.path.join(e2e_s_dict['base_path'], e2e_s_dict['net_path'])
net_file = e2e_s_dict['net_file']
edg2edg_file = e2e_s_dict['edg2edg_file']
sumo_net = mtSumoLib.SumoNetLib(os.path.join(net_path, net_file))

#if not 'sumo_net' in vars():
#    print('Loading Sumo Net...')
#    import mht_sumolib as mtSumoLib
#    sumo_net = mtSumoLib.SumoNetLib(os.path.join(net_path, net_file))
#    e2e_s_dict['sumo_net'] = sumo_net
##else:
##    e2e_s_dict['sumo_net'] = sumo_net

def getNetPath(pPrmsDict):
    return os.path.join(pPrmsDict['base_path'], pPrmsDict['net_path'])

''' 
    ***************************************************************************
    
    chunk & os  Part
    
    ***************************************************************************
'''

def getChunkStEn(pSize, pCunkSize, pThreshold=0.25, pAddRemains=False):
    if pSize <= pCunkSize:
        return[[0, pSize]]
        
    chunkCnt = int(pSize/pCunkSize)
    remainder = pSize - chunkCnt * pCunkSize
    tmp = np.zeros(chunkCnt, dtype=np.int32)
    
    if remainder / pCunkSize > pThreshold and not pAddRemains:
        tmp = tmp[::-1]
        for i in range(remainder):
            j = i % chunkCnt
            tmp[j] += 1
        tmp = tmp[::-1]
    
    out , st = [] , 0
    for i in range(chunkCnt):
        en = st + pCunkSize + tmp[i]
        out.append([st, en])
        st = en
    if remainder > 0 and pAddRemains:
        out = out + [[out[-1][1], pSize]]
    else:
        out[-1][1] = pSize
    return out # , tmp

def getE2Epath(pPrmsDict):
    return os.path.join(pPrmsDict['base_path'], pPrmsDict['net_path'], pPrmsDict['e2e_path'])

def makeE2Efolder(pPrmsDict, verbose=False):
    t_dir = getE2Epath(pPrmsDict)
    if not os.path.exists(t_dir):
        os.mkdir(t_dir)
        print('Directory created: {}'.format(t_dir))
    elif os.path.exists(t_dir) and os.path.isfile(t_dir):
        raise Exception("Can't create: {}".format(t_dir))
    return t_dir

def scanE2Efolder(pPrmsDict):
    pBasePath = getE2Epath(pPrmsDict)
    pPrefix = pPrmsDict['e2e_file_prefix']
    pSuffix = pPrmsDict['e2e_file_suffix']
    out = []

    e2e_match = (pPrefix+'*'+pSuffix).upper()
    for f in os.listdir(pBasePath):
        if fnmatch(f.upper(), e2e_match):
            out.append(f)
    return out

def setE2Efilenames(pPrmsDict, pChunksStEn):
    cnt = len(str(pChunksStEn[-1][1]))
    p_fix = pPrmsDict['e2e_file_prefix']
    s_fix = pPrmsDict['e2e_file_suffix']
    out = [[x[0],x[1], False, '{}{:0{w}}_{:0{w}}{}'.format(p_fix,x[0],x[1],s_fix, w=cnt)] for x in pChunksStEn]
    return out

def createE2Elist(pPrmsDict, pChunksStEn):
    t_out = setE2Efilenames(pPrmsDict, pChunksStEn)
    t_files = scanE2Efolder(pPrmsDict)
    for l in t_out:
        if l[3] in t_files:
            l[2] = True
    return t_out

''' 
    ***************************************************************************
    
    creating / loading edge_ids - edge_ids is a sorted list for binary search < 1 micro-second
    
    ***************************************************************************
'''
def createEdges(pPrmsDict):
    print('Creating edg_ids ...')
    global sumo_net
    t_net = sumo_net
    t0 = [x.getID() for x in t_net.net.getEdges(withInternal=False)]
    t1 = np.array(t0)
    t2 = t1.argsort()
    t3 = pd.DataFrame(t1[t2])
    return t3[0].tolist()

def writeEdges(pPrmsDict):
    t2 = createEdges(pPrmsDict)
    t3 = pd.DataFrame(t2)
    fn_file = os.path.join(getE2Epath(pPrmsDict), pPrmsDict['edg2edg_file'])
    gu.write_DF_csv(fn_file, t3, pHeader = None)
    return t3[0].tolist()

def readEdges(pPrmsDict):
    print('Loading edge_ids ...')
    fn_file = os.path.join(getE2Epath(pPrmsDict), pPrmsDict['edg2edg_file'])
    tdf = gu.read_DF_csv(fn_file, pHeader = None)
    return tdf[0].tolist()

#if not 'edge_ids' in vars():
fn_file = os.path.join(getE2Epath(e2e_s_dict), e2e_s_dict['edg2edg_file'])
if os.path.exists(fn_file) and os.path.isfile(fn_file):
    edge_ids = readEdges(e2e_s_dict)
else:
    makeE2Efolder(e2e_s_dict)
    print('Loading edg_ids ...')
    edge_ids = writeEdges(e2e_s_dict)


''' 
    ***************************************************************************
    
    edge_ndx to edge_id Part
    
    ***************************************************************************
'''
def getEdgNdx(pEdgSortedList, pEdgID):
    'Locate the leftmost value exactly equal to x'
    i = bisect_left(pEdgSortedList, pEdgID)
    if i != len(pEdgSortedList) and pEdgSortedList[i] ==  pEdgID:
        return i
    raise ValueError
    
    
''' 
    ***************************************************************************
    
    creating / loading edge_lenghts - edge_lenghts are based on edge_ids
    
    ***************************************************************************
'''
def createEdgeLengths(pPrmsDict):
    print('Creating edg_lens ...')
    global sumo_net, edge_ids
    t_net = sumo_net # pPrmsDict['sumo_net']
    t_out = [int(t_net.getShortestRoute(x,x)[1] + 0.5) * 100  for x in edge_ids]
    return t_out

def writeEdgeLengths(pPrmsDict):
    t2 = createEdgeLengths(pPrmsDict)
    t3 = pd.DataFrame(t2)
    fn_file = os.path.join(getE2Epath(pPrmsDict), pPrmsDict['edg_len_file'])
    gu.write_DF_csv(fn_file, t3, pHeader = None)
    return t3[0].tolist()

def readEdgeLengths(pPrmsDict):
    print('Loading edge_lens ...')
    fn_file = os.path.join(getE2Epath(pPrmsDict), pPrmsDict['edg_len_file'])
    tdf = gu.read_DF_csv(fn_file, pHeader = None)
    return tdf[0].tolist()

if not 'edge_lens' in vars():
    fn_file = os.path.join(getE2Epath(e2e_s_dict), e2e_s_dict['edg_len_file'])
    if os.path.exists(fn_file) and os.path.isfile(fn_file):
        edge_lens = readEdgeLengths(e2e_s_dict)
    else:
        makeE2Efolder(e2e_s_dict)
        print('Loading edg_lens ...')
        edge_lens   = writeEdgeLengths(e2e_s_dict)


''' 
    ***************************************************************************
    
    sumo net getShortestRoute Part
    
    ***************************************************************************
'''
def getCostFrTo(pEdgODndx):
    global sumo_net, edge_ids
    pSumoNet, pEdgIds = sumo_net, edge_ids
    pEdgOndx, pEdgDndx = pEdgODndx
    if pEdgOndx == pEdgDndx:
        return 0
    try:
        t_orgn = pEdgIds[pEdgOndx]
        t_dest = pEdgIds[pEdgDndx]
        _ , cost = pSumoNet.getShortestRoute(t_orgn, t_dest)
    except : 
        return -1
    return int(cost+0.5) * 100


''' 
    ***************************************************************************
    
    calc chunk & write to file Part
    
    ***************************************************************************
'''
def getChunkNdxTask(pChunks):
    try:
        ndx = [x[2] for x in pChunks].index(False)
    except : # ValueError as e:
        ndx = -1
    return ndx        

def getODlist(pOndx,pSize):
    return [[pOndx, x] for x in range(pSize)]

def printProgress(pStart, pEnd, pCurrent, pTicToc):
    t_elapsed, t_last, t_average = pTicToc.tocstate()
    t_remains = (pEnd - pCurrent) * t_average
    t_tmp = [str(timedelta(seconds=x))[:-4] for x in [t_elapsed, t_last, t_average, t_remains]]
    s_elapsed, s_last, s_average, s_remains = t_tmp
#    print ('{} - {}, Elapsed: {}, Last:{}, Avg: {}, ETR: {}'.format(pCurrent, pEnd, *t_tmp ))
    print ('{} - {}, Elapsed: {}, Avg: {}, ETR: {}'.format(pCurrent, pEnd, s_elapsed, s_average, s_remains))



def calcChunkRows(pPrmsDict, pChunk):
    row_st, row_en, _, fn_chunk = pChunk
    print('Starting calculation rows: {} - {}'.format(row_st, row_en))
    t_size = pPrmsDict['edges_count']
    edg2edg = np.zeros((t_size,t_size), dtype=np.int32)

    try:
        t_pool = Pool(os.cpu_count())
        tic_toc = TicToc() #create instance of class
        tic_toc.tic() #Start timer
        t_sleep_cntr = 0
        for row_ndx in range(row_st, row_en):
            todo_row = getODlist(row_ndx,t_size)
            t_e2d_row = t_pool.map(getCostFrTo, todo_row)
            edg2edg[row_ndx] = list(t_e2d_row)
            printProgress(row_st, row_en, row_ndx, tic_toc)
            if tic_toc.tocvalue() - t_sleep_cntr >= pPrmsDict['sleep_threshold']:
                gu.msgSleep('Cooling down for ', pWait = pPrmsDict['sleep_duration'], pVariance = 0.55)
                t_sleep_cntr = tic_toc.tocvalue()
    except:
        print("mht_edge2edge_turbo.py > calcChunkRows() : Something went wrong.")

    finally:
        t_pool.close()

    fpfn_chunk = os.path.join(getE2Epath(e2e_s_dict), fn_chunk)
    edg2edg[row_st:row_en,:].tofile(fpfn_chunk)
    
    print('File {} is created.'.format(fn_chunk), end='\t')
    tic_toc.toc() #Stop timer
    return edg2edg[row_st:row_en,:]

''' 
    ***************************************************************************
    
    run Part
    
    ***************************************************************************
'''
    
def main_run():
    global sumo_net, edge_ids, e2e_s_dict
#    chunks is [start_row, end_row, processed, file_name]
    t_chunks = getChunkStEn(e2e_s_dict['edges_count'], e2e_s_dict['chunk_size'],pThreshold=0.75, pAddRemains=False)
    t_chunks = createE2Elist(e2e_s_dict, t_chunks)
    chunk_ndx = getChunkNdxTask(t_chunks)
    if chunk_ndx == -1:
        print('All Chunks are done. Noting to do...')
        return
    
    _ = calcChunkRows(e2e_s_dict, t_chunks[chunk_ndx])

if __name__ == '__main__':
    main_run()

    