# -*- coding: utf-8 -*-
"""

MIT License, for detailes please read LICENSE.TXT

Copyright (c) 2022 Mahtab Sharifi

Created on Fri Jul 22 08:06:56 2022

"""


import os, subprocess
import General_Utils as gu
import pandas as pd
import numpy as np


''' 
    ***************************************************************************
    
    init Part
    
    ***************************************************************************
'''

import base_parameters 
prms_dict = base_parameters.prms_dict

sim_scen = base_parameters.sim_scen

def getNetFullPath(pPrmsDict):
    t_path = os.path.join(pPrmsDict['base_path'],pPrmsDict['net_path'])
    return t_path

def getRunStr(pStr):
    ndx =  pStr.rfind('\\') + 1
    return pStr[ndx:]


def getScnFolders(pPrmsDict, pSim_Scn, pSimNdx):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    pBasePath, pSt, pEn = pPrmsDict['base_path'],pPrmsDict['run_start'],pPrmsDict['run_end']
    t_out = []
    for i in range(pSt, pEn+1):
        t_dir = '{}{:0>3d}'.format(run_path,i)
        path = os.path.join(pBasePath, scn_path, t_dir)
        t_out.append(path)
    return t_out

def getCsvDf(pFilePath, sep=','):
    csv_df = pd.read_csv(pFilePath,sep=sep)
    csv_df = csv_df.convert_dtypes()
    return csv_df

''' 
    ***************************************************************************
    
    emission XML to csv Part
    
    ***************************************************************************
'''

def emmXml2Csv(pPrmsDict, pSim_Scn, pSimNdx):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    t_run_dirs = getScnFolders(pPrmsDict, pSim_Scn, pSimNdx)
    t_out = []
    my_env = os.environ.copy()
    py_cmd = os.path.join(my_env['SUMO_HOME'],'tools\\xml', 'xml2csv.py')
    for o_dir in t_run_dirs:
        in_file = os.path.join(o_dir, 'o_emissions.xml')
        if os.path.exists(in_file) and os.path.isfile(in_file):
            print('emm xml to csv: {}'.format(getRunStr(o_dir)))
            t_parm = '"' + os.path.join(o_dir, 'o_emissions.xml') + '"' + \
                    ' -o ' + '"' + os.path.join(o_dir, 'o_emissions.csv') + '"'
            my_command = 'python ' + py_cmd + ' ' + t_parm
            t_out.append(my_command)
            _ = subprocess.call(my_command)
    return t_out


''' 
    ***************************************************************************
    
    edge to TAZ share
    
    ***************************************************************************
'''
    
def calc_edge_taz_share():
    from mht_edge_taz_share import getEdgeTazShare
    fn_edges = os.path.join(getNetFullPath(prms_dict), prms_dict['fn_edges'])
    fn_nodes = os.path.join(getNetFullPath(prms_dict), prms_dict['fn_nodes'])
    fn_taz_gd = os.path.join(getNetFullPath(prms_dict), prms_dict['fn_py_grid'])
    taz_ids, edg_taz_df = getEdgeTazShare(fn_edges, fn_nodes, fn_taz_gd)
    return [taz_ids, edg_taz_df]



''' 
    ***************************************************************************
    
    emission aggregation
    aggregating the total emission per simulation run per TAZ
    
    ***************************************************************************
'''

def getTazEmmRun(pTazIds, pEdegTazDf, pEmmEdgDf):
    t_arr = np.zeros(len(pTazIds), dtype=np.float64)
    emm_df = pEmmEdgDf[['edge_id','edge_CO2_abs']]
    t_edgs = pd.merge(left = pEdegTazDf, right = emm_df, left_on='edge_id', right_on='edge_id')
    for e_ndx, row in t_edgs.iterrows():
#        print('==================\n',row.edge_id, row.edge_CO2_abs, row.taz_ids, row.taz_share)
        for l_ndx, tz_id in enumerate(row.taz_ids):
            tz_ndx = pTazIds.index(tz_id)
            t_arr[tz_ndx] += row.taz_share[l_ndx] * row.edge_CO2_abs
#            print(tz_ndx, tz_id, row.taz_share[l_ndx] * row.edge_CO2_abs)
    return t_arr

 
def emmRunAggTaz(pPrmsDict, pSim_Scn, pSimNdx, pTazIds, pEdegTazDf):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    t_run_dirs = getScnFolders(pPrmsDict, pSim_Scn, pSimNdx)
    odf = pd.DataFrame(pTazIds, columns=['taz'])
    for o_dir in t_run_dirs:
        in_file = os.path.join(o_dir, 'o_emissions.csv')
        if os.path.exists(in_file) and os.path.isfile(in_file):
            print('emm csv: {}'.format(getRunStr(o_dir)))
            emm_run_edg = getCsvDf(in_file,sep=';')
            emm_run_taz = getTazEmmRun(pTazIds, pEdegTazDf, emm_run_edg)
            odf[getRunStr(o_dir)] = emm_run_taz
    return odf


def calcEmmTazStatistics(pTazRunEmmDf):
    taz_emm_statistics =  pTazRunEmmDf[['taz']].copy()
#    function_dict = {"emm_sum": "sum", "emm_mean": "mean", "emm_std": "std"}    
#    taz_emm_statistics = pTazRunEmmDf.aggregate(func=function_dict, numeric_only=True, axis=1)
    taz_emm_statistics['emm_sum'] = pTazRunEmmDf.sum(numeric_only=True, axis=1)
    taz_emm_statistics['emm_avg'] = pTazRunEmmDf.mean(numeric_only=True, axis=1)
    taz_emm_statistics['emm_std'] = pTazRunEmmDf.std(numeric_only=True, axis=1)
    taz_emm_statistics['emm_min'] = pTazRunEmmDf.min(numeric_only=True, axis=1)
    taz_emm_statistics['emm_max'] = pTazRunEmmDf.max(numeric_only=True, axis=1)
    taz_emm_statistics['emm_delta'] = taz_emm_statistics['emm_max']-taz_emm_statistics['emm_min']
    
    return taz_emm_statistics


def saveTazEmm(pPrmsDict, pSim_Scn, pSimNdx, pTazRunEmmDf, pTazStatEmmDf):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
#    t_run_dirs = getScnFolders(pPrmsDict, pSim_Scn, pSimNdx)
    pBasePath = pPrmsDict['base_path']
    out_path = os.path.join(pBasePath, scn_path, aggregate_path)
    if not os.path.exists(out_path):
        os.mkdir(out_path)

    fn_taz_agg = os.path.join(out_path, 'taz_emm_run.csv')
    pTazRunEmmDf.to_csv(fn_taz_agg, sep='\t', index_label='row', float_format='%.5f')

    fn_taz_stat = os.path.join(out_path, 'taz_emm_statistics.csv')
    pTazStatEmmDf.to_csv(fn_taz_stat, sep='\t', index_label='row', float_format='%.5f')
    
def tripOutputAgg(pPrmsDict, pSim_Scn, pSimNdx):
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    t_run_dirs = getScnFolders(pPrmsDict, pSim_Scn, pSimNdx)
    
    scn_path, run_path, rout_file, aggregate_path = pSim_Scn[pSimNdx]
    t_run_dirs = getScnFolders(pPrmsDict, pSim_Scn, pSimNdx)
    division = list(range(0,3600,60))
    cols = ['d_'+str(x).zfill(4) for x in division]
    out = []
    for o_dir in t_run_dirs:
        in_file = os.path.join(o_dir, 'o_rtm.trip_output.xml')
        if os.path.exists(in_file) and os.path.isfile(in_file):
            print('reading trip_output: {}'.format(getRunStr(o_dir)))
            tdf = pd.read_xml(in_file)
            tdf = tdf[['id', 'duration', 'routeLength', 'timeLoss']].copy()
            count, _ = np.histogram(tdf.timeLoss, bins=division)
            t_stat = tdf.timeLoss.agg(['sum', 'mean', 'std']).round(4)
            t_speed = (tdf.routeLength/tdf.duration).agg(['mean', 'std']).round(4)
            out.append([getRunStr(o_dir)] + t_stat.tolist() + t_speed.tolist() + count.tolist())
    odf = pd.DataFrame(out)
    odf.columns = ['run_id','r_sum', 'r_mean', 'r_std', 'sp_mean', 'sp_std'] + cols[1:]
    
    out_path = os.path.join(pPrmsDict['base_path'], scn_path, aggregate_path)
    if not os.path.exists(out_path):
        os.mkdir(out_path)
    out_file = os.path.join(out_path, 'timeloss_statistics.csv')
    odf.to_csv(out_file, sep='\t', index_label='row')
    
    return odf


def emission_agg(prms_dict, sim_scen, taz_ids, edg_taz_df):
    
    for scn_ndx in range (len(sim_scen)):
        taz_emm_agg = emmRunAggTaz(prms_dict, sim_scen, scn_ndx, taz_ids, edg_taz_df)
        taz_emm_stat = calcEmmTazStatistics(taz_emm_agg)
        saveTazEmm(prms_dict, sim_scen, scn_ndx, taz_emm_agg, taz_emm_stat)
        t_agg = tripOutputAgg(prms_dict, sim_scen, scn_ndx)
        
    return t_agg
    

if __name__ == '__main__':
    #setp: o_emission xml2csv
    for scn_ndx in range(len(sim_scen)):
        t_x2c = emmXml2Csv(prms_dict, sim_scen, scn_ndx)
    
    #setp: edge to taz share - is not required to be run since the files are already created
    #creating taz_id and edg_taz_df
    from mht_edge_taz_share import getEdgeTazShare
    fn_edges = os.path.join(getNetFullPath(prms_dict), prms_dict['fn_edges'])
    fn_nodes = os.path.join(getNetFullPath(prms_dict), prms_dict['fn_nodes'])
    fn_taz_gd = os.path.join(getNetFullPath(prms_dict), prms_dict['fn_py_grid'])
    taz_ids, edg_taz_df = getEdgeTazShare(fn_edges, fn_nodes, fn_taz_gd)
    
    #setp: emission taz run aggregation & statistics
    a0_agg = emission_agg(prms_dict, sim_scen, taz_ids, edg_taz_df)
    