"""

MIT License, for detailes please read LICENSE.TXT

Copyright (c) 2022 Mahtab Sharifi

Created on Thu Nov 25 13:59:39 2021

"""

import glob
import xml.etree.ElementTree as ETree
import pandas as pd
import numpy as np
from shapely import geometry
import geopandas as gpd
import geopy.distance as geoDistance
from scipy import spatial
import General_Utils as gu


''' 
    ***************************************************************************
    
    general Part
    
    ***************************************************************************
'''

def getAllFilesNames(pPath, pFilePattern = '*.*', pRecursive=False):
    tout = glob.glob(pPath+pFilePattern, recursive=pRecursive)
    return tout


def strList2Flot(pList):
    tlst = [float(i) for i in pList]
    return tlst


def getGeoStr2Flot(pGeoPoint, sep=','):
    tlst = pGeoPoint.split(sep)
    return strList2Flot(tlst)


def getShapeStr2Flot(pShapeStr, sep=' '):
    shpPointList = pShapeStr.split(sep)
    tout = [getGeoStr2Flot(pnt) for pnt in shpPointList]
    return np.array(tout)


def getShapeStr2polygon(pShapeStr, pOffset, sep=' '):
    t_taz = getShapeStr2Flot(pShapeStr, sep)
    t_taz -= pOffset
    tpolygon = geometry.Polygon(t_taz)
    return tpolygon


def readSimpleXml2Df(pFname): 
    xedg = pd.read_xml(pFname)
    return xedg.convert_dtypes()


def getLatLong(pPoint, pReverse=False):
    x,y = pPoint.xy
    long, lat = x[0], y[0]
    if pReverse:
        return [long, lat]
    return [lat, long]


def getPointXY(pPoint, pReverse=False, pRound=6):
    xy = pPoint.xy
    x, y = round(xy[0][0], pRound), round(xy[1][0], pRound)
    if pReverse:
        return [y,x]
    return [x,y]


def get_gDistrictsDf(pFname, prefix = 'tz_'):
    tdf = readSimpleXml2Df(pFname)
    cols = [prefix + x for x in tdf.columns]
    tdf.columns = cols
    return tdf


def get_taz_polygons(pDf, pOffset):
    t_polys = pDf.tz_shape.apply(getShapeStr2polygon, pOffset=pOffset)
    t_ids = pDf.tz_id.to_list()
    return [t_ids, t_polys]

    
def getSumoLocation(pFname):   
    prstree = ETree.parse(pFname)
    root = prstree.getroot()
    location = { 
            'netOffset': strList2Flot(root[0].attrib.get('netOffset').split(',')),
            'convBoundary': strList2Flot(root[0].attrib.get('convBoundary').split(',')),
            'origBoundary': strList2Flot(root[0].attrib.get('origBoundary').split(',')),
            'projParameter': root[0].attrib.get('projParameter'),
            }
    return location


def getLocationZone(pLocation):
    tlist = pLocation['projParameter'].split('+')
    for t in tlist:
        key_val = t.split('=')
        if len(key_val) == 2:
            key, val = key_val
            if key == 'zone':
                return val.strip()
    return ''
        
    
def getTazCentroids(pSumoNet, pTazGrid, pUtmZone=''):
    location = getSumoLocation(pSumoNet)
    taz_df = get_gDistrictsDf(pTazGrid)
    tz_ids, tz_polys = get_taz_polygons(taz_df, location['netOffset'])
    cntrs = tz_polys.apply(lambda c: c.centroid)

    if len(pUtmZone) == 0: pUtmZone = getLocationZone(location)

    utm_zone = 'EPSG:326{}'.format(pUtmZone)


    geo_df = gpd.GeoDataFrame(cntrs, geometry='tz_shape', crs = utm_zone)
    
    geo_dfll = geo_df.to_crs('EPSG:4326')
    
    return geo_dfll


'''
    ***************************************************************************
    
    split and join taz rectangles Part
    
    ***************************************************************************
'''

def getSumoShapeXmlStr(pStrPolyList):
    header = '<additional xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://sumo.dlr.de/xsd/additional_file.xsd">\n'
    footer = '</additional>\n'
    out = ''.join(pStrPolyList)
    return header + out + footer

    
def splitRectangleNp(pPoly, pRows, pCols):
    n_rows, n_cols = pRows, pCols
    minx, miny, maxx, maxy = pPoly.bounds
    tm = 2

    w = maxx - minx
    h = maxy - miny
    d_x = w / n_cols
    d_y = h / n_rows
    
    rectangles = []
    for x in np.linspace(minx, maxx, n_cols+1).tolist()[:-1]:
        for y in np.linspace(miny, maxy, n_rows+1).tolist()[:-1]:
            x1 = x + tm
            x2 = x + d_x - tm
            y1 = y + tm
            y2 = y + d_y - tm 
            rectangles.append([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)])
    return rectangles


def splitGdTaz2Poly(pDfTzGrid, pRows, pCols):
    
    t_line = '\t<poly id="{}" color="{}" fill="{}" layer="{}" shape="{}"/>\n'
    t_color = 'red'; t_fill = '0'; t_layer = '1.00'
    tz_ids, tz_polys = get_taz_polygons(pDfTzGrid, [0,0])
    t_out = []
    for i in range(len(tz_polys)):
        tz_poly = tz_polys[i]
        tz_id = tz_ids[i]
        tz_dvd = splitRectangleNp(tz_poly, pRows, pCols)
        for j in range(len(tz_dvd)):
            tz_shp = ' '.join([str(elem) for elem in tz_dvd[j]])
            tz_shp = tz_shp.replace('(','').replace(')','').replace(', ', ',')
            t_id = tz_id+'_'+str(j)
            t_out.append(t_line.format(t_id, t_color,t_fill, t_layer, tz_shp))
    xml_out = getSumoShapeXmlStr(t_out)
    return xml_out
    

def getSPS(x,y):
    return '{:0.6f},{:0.6f}'.format(x,y)


def getPoly4Points(pDf):
    t_shp = pDf['tz_shape'].str.cat(sep=' ')
    t_shp = t_shp.replace(' ',',')
    t_arr = np.fromstring(t_shp, sep=',')
    t_arr = t_arr.reshape((int(len(t_arr)/2),2))
    xmin = t_arr[:,0].min()
    xmax = t_arr[:,0].max()
    ymin = t_arr[:,1].min()
    ymax = t_arr[:,1].max()
    t_out = '{} {} {} {} {}'.format(getSPS(xmin,ymin), getSPS(xmin,ymax),
                                    getSPS(xmax,ymax), getSPS(xmax,ymin),
                                    getSPS(xmin,ymin))
    return t_out


def joinGdPoly(pDfPolyGrid):
    t_line = '\t<poly id="{}" color="{}" fill="{}" layer="{}" shape="{}"/>\n'
    t_color = 'red'; t_fill = '0'; t_layer = '1.00'
    tz_df = pDfPolyGrid.copy()
    tz_df['tz_id'] = tz_df['tz_id'].apply(lambda x: x[:x.rfind('_')])
    tz_ids = tz_df['tz_id'].unique()
    t_out = []
    for t_id in tz_ids:
        mask = tz_df['tz_id'] == t_id
        t_shp = tz_df[['tz_shape']][mask]
        tz_shp = getPoly4Points(t_shp)
        t_out.append(t_line.format(t_id, t_color,t_fill, t_layer, tz_shp))
    
    xml_out = getSumoShapeXmlStr(t_out)
    return xml_out


''' 
    ***************************************************************************
    
    count per taz Part
    
    ***************************************************************************
'''



def getPoiCountsFiles(pOsmKeys, pPath, pSumoNet, pTazGrid, pMinDist = 1.0, cbCopy = False):
    osm_keys = [x[0] for x in pOsmKeys]
    osm_files = [x[1] for x in pOsmKeys]
    location = getSumoLocation(pSumoNet)
    taz_df = get_gDistrictsDf(pTazGrid)
    tz_ids, tz_polys = get_taz_polygons(taz_df, location['netOffset'])
    tz_len = range(len(tz_polys))
    tcnt = np.zeros((len(tz_polys), len(osm_files)), dtype=int)
    
    for i in range(len(osm_files)):
        geodf = gpd.read_file(path + osm_files[i] +'.geojson')
        utmdf = geodf.to_crs(location['projParameter'])
        print('{} \t {}'.format(osm_keys[i], len(utmdf)))
        for j in range(len(utmdf)):
            tpoint = utmdf.geometry.iloc[j]
            for k in tz_len:
                tpolygon = tz_polys[k]
                if geometry.Point(tpoint).distance(tpolygon) < pMinDist:
                    tcnt[k,i] += 1
                    break

    tout = pd.DataFrame(tcnt)
    tout.columns = osm_keys
    tout['taz_id'] = tz_ids
    outcols = ['taz_id'] + osm_keys
    if cbCopy:
        gu.copy(tout[outcols].to_csv(sep='\t'))

    return tout[outcols]


def getPoiCountsJoined(pOsmKeys, pDfJoined, pSumoNet, pTazGrid, pMinDist = 1.0, cbCopy = False):
    osm_keys = [x[0] for x in pOsmKeys]
    osm_files = [x[1] for x in pOsmKeys]
    location = getSumoLocation(pSumoNet)
    taz_df = get_gDistrictsDf(pTazGrid)
    tz_ids, tz_polys = get_taz_polygons(taz_df, location['netOffset'])
    tz_len = range(len(tz_polys))
    tcnt = np.zeros((len(tz_polys), len(osm_files)), dtype=int)

    geodf = pDfJoined[['osm_file','geometry']]
    utmdf = geodf.to_crs(location['projParameter'])
    for j in range(len(utmdf)):
        if j % 500 == 0: print('{} \t {}'.format(j, len(utmdf)))
        i = osm_files.index(utmdf.osm_file.iloc[j])
        tpoint = utmdf.geometry.iloc[j]
        for k in tz_len:
            tpolygon = tz_polys[k]
            if geometry.Point(tpoint).distance(tpolygon) < pMinDist:
                tcnt[k,i] += 1
                break

    tout = pd.DataFrame(tcnt)
    tout.columns = osm_keys
    
    tout['food_industry']=tout['restaurant,fast_food,cafe,pub,bar'.split(',')].sum(axis=1)
    tout['care_facilties']=tout['pharmacy,kindergarten,dentist,doctors,nursing_home,childcare,clinic'.split(',')].sum(axis=1)
    tout['recreational_facilities']=tout['place_of_worship,social_facility,community_centre,library,arts_centre,cinema,social_centre,studio,conference_centre'.split(',')].sum(axis=1)
    tout['commercial_cmb']=tout['retail,commercial,Shop'.split(',')].sum(axis=1)
    tout['miscellaneous_high_importance']=tout['college,hospital,townhall,marketplace,university,station'.split(',')].sum(axis=1)
  
    tout['taz_id'] = tz_ids
    outcols = ['taz_id'] + osm_out
    if cbCopy:
        gu.copy(tout[outcols].to_csv(sep='\t'))

    return tout[outcols]


dropList =[1,191,1088,1331,1368,1528,1592,1603,1604,1620,1630,1644,1653,1682,1688,1710,1842,1844,1923,2685,
           2805,2806,3061,3074,3179,3182,3212,3256,3373,3420,3423,3439,3545,3546,3617,4002,4396,4509,4660,
           4670,5035,5065,5070,5092,5112,5140,5582,5894,5920,6158,6300,6383,6384,7071,7542,7830,8254,8360,
           8528,8544,8576,8587,8881,9117,9364,9368,9419,]

''' 
    ***************************************************************************
    
    overlap check Part
    
    ***************************************************************************
'''
def dropDuplicates(pDf, pLocList):
    indis = pDf.index[pLocList]
    tdf = pDf.drop(index = indis)
    tdf.reset_index(drop=True, inplace=True)
    return tdf


def joinGeojsonFiles(pOsmKeys, pPath):
    osm_files = [x[1] for x in pOsmKeys]
    geodf = gpd.read_file(path + osm_files[0] +'.geojson')
    geodf['osm_file'] = osm_files[0]

    for i in range(1,len(osm_files)):
        print('{}\t{}'.format(i, osm_files[i]))
        geotmp = gpd.read_file(path + osm_files[i] +'.geojson')
        geotmp['osm_file'] = osm_files[i]
        geodf = geodf.append(geotmp, ignore_index=True)
    
    return geodf


def getDistMeters(p1,p2):
    return geoDistance.geodesic(p1,p2).meters


def findDuplicatesBrute(pGeoDf, pMinDist = 1.0):
    points = pGeoDf.geometry
    ll1 = ll2 = points.apply(getLatLong)
    tlen = len(ll1)
    out = []
    for i in range(tlen-1):
        for j in range(i+1, tlen):
            dist = geoDistance.geodesic(ll1[i],ll2[j]).meters
            if dist < pMinDist:
                print(i,j)
    
    return out


def findUtmNeighibours(pUtmDf, pMinDist = 1.0):
    points = pUtmDf.geometry.apply(getPointXY).to_list()
    tTree = spatial.KDTree(points)
    tDistance, tIndex = tTree.query(points, 2)
    tBall = tTree.query_ball_point(points, pMinDist)
    return [tBall, tDistance, tIndex]


def makeUnique0(pList):
    tList = pList.copy()
    for pair in pList:
        rvrs = [pair[1], pair[0]]
        if rvrs in tList:
            while rvrs in tList: tList.remove(rvrs)
    return tList


def makeUnique(pArr):
    tList = pArr.copy()
    flags = np.ones(len(tList),dtype=bool)
    for i in range(len(tList) - 1):
        if flags[i]:
            rvrs = [tList[i,1], tList[i,0]]
            for j in range(i + 1, len(tList)):
                if np.equal(tList[j],rvrs).all():
                   flags[j] = False
                   
    return tList[flags] 


def observeNeighbour(pGeoDf, pPairs):
    ndx1, ndx2 = pPairs
    p1 = getLatLong(pGeoDf.geometry.iloc[ndx1])
    p2 = getLatLong(pGeoDf.geometry.iloc[ndx2])
    dist = getDistMeters(p1,p2)
    print("\n"+"="*30)
    print('ndx1: {}\tndx2:{}\tdistance: {:0.2f}'.format(ndx1, ndx2,dist))
    print('google map points: {}\t{}'.format(p1,p2))
    print("-"*30)
    print('ndx: {}\n{}'.format(ndx1,pGeoDf.iloc[ndx1].dropna()))
    print("-"*30)
    print('ndx: {}\n{}'.format(ndx2,pGeoDf.iloc[ndx2].dropna()))
    

''' 
    ***************************************************************************
    
    taz to kml Part
    
    ***************************************************************************
'''
def shapely2np(pShpPolygon):
    tstr = str(pShpPolygon)
    st = tstr.rfind('(')
    en = tstr.find(')')
    if st < 0 or en < 0 :
       raise Exception('shapely2np: pShpStr is not polygon.')
    tstr = tstr[st+1:en-1]
    tlst = tstr.split(',')
    tnar = [np.fromstring(i, dtype=float, sep=' ') for i in tlst]
    return np.array(tnar)


osm_keys = [
    ['restaurant', 'rtm_val_restaurant'],
    ['fast_food', 'rtm_val_fast_food'],
    ['school', 'rtm_val_school'],
    ['cafe', 'rtm_val_cafe'],
    ['pub', 'rtm_val_pub'],
    ['place_of_worship', 'rtm_val_place_of_worship'],
    ['bar', 'rtm_val_bar'],
    ['pharmacy', 'rtm_val_pharmacy'],
    ['kindergarten', 'rtm_val_kindergarten'],
    ['social_facility', 'rtm_val_social_facility'],
    ['college', 'rtm_val_college'],
    ['dentist', 'rtm_val_dentist'],
    ['doctors', 'rtm_val_doctors'],
    ['community_centre', 'rtm_val_community_centre'],
    ['library', 'rtm_val_library'],
    ['nursing_home', 'rtm_val_nursing_home'],
    ['childcare', 'rtm_val_childcare'],
    ['arts_centre', 'rtm_val_arts_centre'],
    ['hospital', 'rtm_val_hospital'],
    ['townhall', 'rtm_val_townhall'],
    ['cinema', 'rtm_val_cinema'],
    ['social_centre', 'rtm_val_social_centre'],
    ['marketplace', 'rtm_val_marketplace'],
    ['university', 'rtm_val_university'],
    ['clinic', 'rtm_val_clinic'],
    ['studio', 'rtm_val_studio'],
    ['conference_centre', 'rtm_val_conference_centre'],
    ['station', 'rtm_val_station'],
    ['retail', 'rtm_val_retail'],
    ['residential', 'rtm_val_residential'],
    ['commercial', 'rtm_val_commercial'],
    ['Shop', 'rtm_key_Shop'],
    ['leisure', 'rtm_key_leisure'],
    ['tourism', 'rtm_key_tourism'],
        ]

osm_out=[
    'food_industry',
    'care_facilties',
    'recreational_facilities',
    'commercial_cmb',
    'miscellaneous_high_importance',
    'school',
    'leisure',
    'tourism',
        ]


taz_gd_fname = '..\\01_Network\\taz_gd_sumo.xml'
sumo_net_fname = '..\\01_Network\\rtm_netconvert.net.xml'
path = "..\\01_Network\\Geojason Exports\\"


# tdf = getPoiCountsFiles(osm_keys, path, sumo_net_fname, taz_gd_fname, cbCopy=1)

geoAllf = joinGeojsonFiles(osm_keys, path)
geoAll = dropDuplicates(geoAllf, dropList)
tdf = getPoiCountsJoined(osm_keys, geoAll, sumo_net_fname, taz_gd_fname, cbCopy=1)


