# -*- coding: utf-8 -*-
"""

MIT License, for detailes please read LICENSE.TXT

Copyright (c) 2022 Mahtab Sharifi

Created on Thu Mar 10 05:16:44 2022

"""
import xml.etree.ElementTree as ETree
import pandas as pd
import numpy as np
from shapely import geometry
import matplotlib.pyplot as plt


''' 
    ***************************************************************************
    
    init Part
    
    ***************************************************************************
'''
def readSimpleXml2Df(pFname): 
    xedg = pd.read_xml(pFname)
    return xedg.convert_dtypes()

def getRoot(fnXML):
    if len(fnXML) == 0: return []
    prstree = ETree.parse(fnXML)
    xml_root = prstree.getroot()   
    return xml_root

def strList2Flot(pList):
    tlst = [float(i) for i in pList]
    return tlst

def getGeoStr2Flot(pGeoPoint, sep=','):
    tlst = pGeoPoint.split(sep)
    return strList2Flot(tlst)

def getShapeStr2Flot(pShapeStr, sep=' '):
    shpPointList = pShapeStr.split(sep)
    tout = [getGeoStr2Flot(pnt) for pnt in shpPointList]
    return np.array(tout)

def getShapeStr2line(pShapeStr, pOffset = 0, sep=' '):
    t_tmp = getShapeStr2Flot(pShapeStr, sep)
    t_tmp -= pOffset
    t_line = geometry.LineString(t_tmp)
    return t_line

def getShapeStr2polygon(pShapeStr, pOffset, sep=' '):
    t_taz = getShapeStr2Flot(pShapeStr, sep)
    t_taz -= pOffset
    tpolygon = geometry.Polygon(t_taz)
    return tpolygon

def get_gDistrictsDf(pFname, prefix = 'tz_'):
    tdf = readSimpleXml2Df(pFname)
    cols = [prefix + x for x in tdf.columns]
    tdf.columns = cols
    return tdf

def get_taz_polygons(pDf, pOffset):
    t_polys = pDf.tz_shape.apply(getShapeStr2polygon, pOffset=pOffset)
    t_ids = pDf.tz_id.to_list()
    return [t_ids, t_polys]

def getTazCorners(pTaz):
    xs,ys = pTaz.exterior.xy
    min_x, min_y, max_x, max_y = min(xs), min(ys), max(xs), max(ys)
    c_dict = {
            'bl' : [min_x, min_y],
            'br' : [max_x, min_y],
            'tr' : [max_x, max_y],
            'tl' : [min_x, max_y],
            }
    return c_dict

def offsetTaz(pTazPolys, pOffset):
    out = []
    for ndx, taz_poly in enumerate(pTazPolys):
        xs,ys = taz_poly.exterior.xy
        min_x, min_y, max_x, max_y = min(xs), min(ys), max(xs), max(ys)
        min_x -= pOffset
        min_y -= pOffset
        max_x += pOffset
        max_y += pOffset
        new_poly = geometry.Polygon([(min_x,min_y), (max_x,min_y), (max_x,max_y), (min_x,max_y), (min_x,min_y)])
        out.append(new_poly)

    return out


''' 
    ***************************************************************************
    
    Calculating share Part
    
    ***************************************************************************
'''
def getEdges(fnEdg, fnNod):
    edg_root = getRoot(fnEdg)
    nod_root = getRoot(fnNod)
    out = []
    for child in edg_root:
        tag = child.tag
        if tag == 'edge' :
            e_dict = child.attrib
            if "type" in e_dict:
                t_out = [e_dict['id'], e_dict['from'], e_dict['to'], e_dict['priority'], e_dict['type'], e_dict['numLanes'], e_dict['speed']]
            else:
                t_out = [e_dict['id'], e_dict['from'], e_dict['to'], e_dict['priority'], 'None', e_dict['numLanes'], e_dict['speed']]
            if "shape" in e_dict:
                out.append(t_out +[e_dict['shape']]) #  n_dict['length']
            else:
                t_st = e_dict['from']
                t_en = e_dict['to']

                e_xpath = './node[@id="{}"]'.format(t_st)
                t_st_elem = nod_root.find(e_xpath)
                x1, y1 = t_st_elem.attrib['x'], t_st_elem.attrib['y']
                
                e_xpath = './node[@id="{}"]'.format(t_en)
                t_en_elem = nod_root.find(e_xpath)
                x2, y2 = t_en_elem.attrib['x'], t_en_elem.attrib['y']
                
                t_shape ='{},{} {},{}'.format(x1, y1, x2, y2)
                out.append(t_out +[t_shape])
                

    edge_columns = 'edge_id,edge_from,edge_to,edge_priority,edge_type,edge_numLanes,edge_speed,edge_shape'.split(',')
    tdf = pd.DataFrame(out, columns=edge_columns)
    tdf['edge_length'] = [round(x.length, 3) for x in tdf['edge_shape'].apply(getShapeStr2line)]
    tdf = tdf.convert_dtypes()
    print(tdf.dtypes)
    return tdf


def distributeEmmission(pTazIds, pTazPolys, pEmmDf):
    tdf = pEmmDf.copy()
    tdf['edge_line'] = tdf['edge_shape'].apply(getShapeStr2line)
    taz_emm_share = np.zeros((len(pTazIds), 2), dtype=np.float64)
    t_mod = int(len(pTazPolys) / 10)
    
    ofst_polys = offsetTaz(pTazPolys, 1)

    for ndx, taz_poly in enumerate(ofst_polys):
        t_id = pTazIds[ndx]
        if ndx % t_mod == 0:
            print(ndx, t_id)
        for e_ndx, edge in tdf.iterrows():
            t_intersection = edge['edge_line'].intersection(taz_poly)
            if t_intersection.length > 0.5:
                t_share =  t_intersection.length / edge['edge_line'].length 
                co2_abs = edge.edge_CO2_abs * t_share
                co2_nrm = edge.edge_CO2_normed * t_share
                taz_emm_share[ndx] += co2_abs, co2_nrm
    out = pd.DataFrame(taz_emm_share, columns=['co2_abs', 'co2_nrm'])
    out['taz_id'] = pTazIds
    return out


def plotDistributeEdge(pTazIds, pTazPolys, pEdgeSr, pOffset=0):
    edge = pEdgeSr.copy()
    t_edge = getShapeStr2line(edge['edge_shape'])
    out = []
    tfig, taxes = plt.subplots()
    tfig.set_size_inches(6, 6)
    
     
    taxes.plot(*t_edge.xy, color="gray")
    t_title ='edge: {}'.format(edge.edge_id)
    taxes.set_title(t_title, fontsize='large')
    
    ofst_polys = offsetTaz(pTazPolys, pOffset)

    for ndx, taz_poly in enumerate(ofst_polys):
        t_intersection = t_edge.intersection(taz_poly)
        if t_intersection.length > 0.5:
            t_share =  t_intersection.length # / edge['edge_line'].length
            
            taz_crnrs = getTazCorners(taz_poly)
            taxes.plot(*taz_poly.exterior.xy)
            t_str = '\n  id: {}\n  len: {:,.2f}'.format(pTazIds[ndx], t_share)
            taxes.text(taz_crnrs['tl'][0], taz_crnrs['tl'][1], t_str,ha='left', va='top')
            out.append(t_share)
            
    a_out = np.array(out)
    
    taxes.axis('off')
    tfig.tight_layout()

    return a_out

''' 
    ***************************************************************************
    
    Runing for all simulations
    
    ***************************************************************************
'''

def calcEdgeTazShare(pEdgesDf, pTazIds, pTazPolys):
    tdf = pEdgesDf.copy()
    tdf['edge_line'] = tdf['edge_shape'].apply(getShapeStr2line)
    ofst_polys = offsetTaz(pTazPolys, 1)
    t_taz_list, t_share_list = [],[]
    t_mod = int(len(pEdgesDf) / 10)
    for e_ndx, edge in tdf.iterrows():
        e_taz, e_share = [],[]
        if e_ndx % t_mod == 0: print(e_ndx, len(pEdgesDf))
        for t_ndx, taz_poly in enumerate(ofst_polys):
            t_intersection = edge['edge_line'].intersection(taz_poly)
            if t_intersection.length > 0.5:
                e_taz.append(pTazIds[t_ndx])
                t_share =  t_intersection.length / edge['edge_line'].length
                e_share.append(round(t_share,6))
        t_taz_list.append(e_taz)
        t_share_list.append(e_share)
    tdf['taz_ids'] = t_taz_list
    tdf['taz_share'] = t_share_list
    return tdf[['edge_id', 'taz_ids','taz_share']]


def getEdgeTazShare(fn_Edges, fn_Nodes, fn_GridTaz):
    taz_df = get_gDistrictsDf(fn_GridTaz)
    tz_ids, tz_polys = get_taz_polygons(taz_df, 0)
    edgs_df = getEdges(fn_Edges, fn_Nodes)
    tdf = calcEdgeTazShare(edgs_df,tz_ids, tz_polys)
    return [tz_ids, tdf]


''' 
    ***************************************************************************
    
    Visualisation part to check all the TAZ an edge passes through
    See figure 4.17 in Thesis doc
    
    ***************************************************************************
'''

def edge_share_test(edge_id = '723115170'):
    from os import path as Path
    import base_parameters 
    prms_dict = base_parameters.prms_dict
    
    py_edges = Path.join(prms_dict['base_path'], prms_dict['net_path'], prms_dict['fn_edges'])
    py_nodes = Path.join(prms_dict['base_path'], prms_dict['net_path'], prms_dict['fn_nodes'])
    taz_gd_fname= Path.join(prms_dict['base_path'], prms_dict['net_path'], prms_dict['fn_py_grid'])
    
    taz_df = get_gDistrictsDf(taz_gd_fname)
    tz_ids, tz_polys = get_taz_polygons(taz_df, 0)
    edgs_df = getEdges(py_edges, py_nodes)
    edg_data = edgs_df[edgs_df.edge_id == edge_id]
    
    if len(edg_data) > 0:
        t_dist = plotDistributeEdge(tz_ids, tz_polys, edg_data.iloc[0])
        plt.show()

