# usage: python mapmatching.py
# This script relies on CSV files, please edit the file paths in the main function below.
# Make sure to set the MAPBOX_KEY variable in a .env file in the same directory as this script.

import requests
import numpy as np
import pandas as pd
from pathlib import Path
from dotenv import load_dotenv
import os

def lonlat2str(lon,lat):
    # format lat and lon so they can be used in URL for API request: lat1,lon1;lat2,lon2;...
    lat = list(lat)
    lon = list(lon)
    lonlat = ''
    for i in range(len(lat)):
        lonlat += str(lon[i]) + ',' + str(lat[i]) + ';'
    return lonlat[:-1] # remove last ';'

def timestamps2str(unix_timestamps):
    # format timestamps so they can be used in URL for API request: t1;t2;t3;...
    timestamps = list(unix_timestamps)
    ts = ''
    for i in range(len(timestamps)):
        ts += str(timestamps[i]) + ';'
    return ts[:-1] # remove last ';'

def mapbox_request(lon,lat,timestamps):
    # request mapbox API for map matching with speed limit request
    load_dotenv()
    mapbox_key = os.getenv("MAPBOX_KEY")
    profile = 'mapbox/driving'
    geom = 'geojson' # or polyline
    annotations = 'maxspeed'
    coordinates = lonlat2str(lon,lat) #"-122.39636,37.79129;-122.39732,37.79283;-122.39606,37.79349"
    timestamps = timestamps2str(timestamps) # "1589580000;1589580001;1589580002"
    url = f"https://api.mapbox.com/matching/v5/{profile}/{coordinates}?annotations={annotations}&tidy=true&overview=full&geometries={geom}&timestamps={timestamps}&access_token={mapbox_key}"
    r = requests.get(url)
    return r.json()

def parse_response(response):
    waypoint_ids = []
    matchings_ids = []
    lon_mm = []
    lat_mm = []

    for x in response['tracepoints']:
        try:
            waypoint_ids.append(x['waypoint_index'])
            matchings_ids.append(x['matchings_index'])
            lon_mm.append(x['location'][0])
            lat_mm.append(x['location'][1])
        except:
            waypoint_ids.append(np.nan)
            matchings_ids.append(np.nan)
            lon_mm.append(np.nan)
            lat_mm.append(np.nan)

    response_df = pd.DataFrame({'waypoint_index':waypoint_ids, 'matchings_index':matchings_ids, 'lon_mm':lon_mm, 'lat_mm':lat_mm})

    # lookup correct matchings
    for i, row in response_df.iterrows():
        if not np.isnan(row['matchings_index']):
            matchings_index = int(row['matchings_index'])
            response_df.loc[i, 'confidence'] = response['matchings'][matchings_index]['confidence']
            matchings = response['matchings'][matchings_index]

            # n_legs = len(matchings['legs'])
            # if row['waypoint_index'] >= n_legs:
                # print('waypoint index too high:', row['waypoint_index'], 'max:', n_legs-1)
            
            try:
                response_df.loc[i, 'street_name_mm'] = matchings['legs'][max(int(row['waypoint_index']),0)]['summary']
            except:
                response_df.loc[i, 'street_name_mm'] = np.nan

            try:
                speedlimit_unparsed = matchings['legs'][int(row['waypoint_index'])]['annotation']['maxspeed'][0]['speed']
                assert type(speedlimit_unparsed) == int or type(speedlimit_unparsed) == float 
                response_df.loc[i, 'speed_limit_mm'] = speedlimit_unparsed
            except:
                response_df.loc[i, 'speed_limit_mm'] = np.nan
            # print(matchings['legs'][int(row['waypoint_index'])]['annotation']['maxspeed'])
        else:
            response_df.loc[i, 'street_name_mm'] = np.nan
            response_df.loc[i, 'speed_limit_mm'] = np.nan
    return response_df


def df_chunker(df, chunk_size):
    '''returns a list of dataframes of size chunk_size. Last dataframe may be smaller but at least 2 rows.
    If len(last_chunk) == 1, the last chunk gets one row from the previous chunk.'''
    chunks = []

    n_normal_chunks = len(df)//chunk_size

    for i in range(n_normal_chunks+1):
        chunks.append(df.iloc[i*chunk_size:(i+1)*chunk_size])

    if len(chunks[-1]) == 0:
        chunks = chunks[:-1]
    if len(chunks[-1]) == 1:
        row_to_prepend = chunks[-2].iloc[-1]
        ix_to_prepend = row_to_prepend.name
        cols = chunks[-1].columns
        
        chunks[-1].loc[ix_to_prepend,cols] = row_to_prepend.values # add last row of previous chunk to last chunk,
        chunks[-1].sort_index(inplace=True)
       
        chunks[-2] = chunks[-2].iloc[:-1]

    return chunks

def impute_speeds(df_orig, df_with_speeds, t_col):
    """The original data was downsampled for better map matching results. This function returns the original dataframe. 
    It looks up the closest timestamp for rows between two map-matched coordinates and imputes the speed limit for those rows
    
        df_orig: original dataframe with timestamps, speed limits, etc
        df: dataframe with timestamps, speed limits, etc, but with missing values
        returns: df with missing values imputed
    """
    df = df_with_speeds
    df_orig['speedlimit_mapbox'] = np.nan
    df_orig['lon_mapbox'] = np.nan
    df_orig['lat_mapbox'] = np.nan


    for i in range(len(df)-1): 
        ix_0 = df['old_index'][i]
        ix_1 = df['old_index'][i+1]

        sl0 = df['speed_limit_mm'][i]
        sl1 = df['speed_limit_mm'][i+1]

        lon0 = df['lon_mm'][i]
        lon1 = df['lon_mm'][i+1]
        lat0 = df['lat_mm'][i]
        lat1 = df['lat_mm'][i+1]

        t0 = df[t_col][i]
        t1 = df[t_col][i+1]

        # print(ix_0, ix_1)
        dist_to_t0 = (df_orig.loc[ix_0:ix_1, t_col] - t0).dt.total_seconds().abs().values # temporal distance to t0
        dist_to_t1 = (df_orig.loc[ix_0:ix_1, t_col] - t1).dt.total_seconds().abs().values # temporal distance to t1
        dist_arr = np.vstack([dist_to_t0, dist_to_t1])
        nearest_ixs = dist_arr.argmin(axis=0) # 0 if closest to t0, 1 if closest to t1
        # replace 0 with ix_0, 1 with ix_1
        speeds_to_input = np.where(nearest_ixs==0, sl0, sl1)
        lons_to_input = np.where(nearest_ixs==0, lon0, lon1)
        lats_to_input = np.where(nearest_ixs==0, lat0, lat1)

        df_orig.loc[ix_0:ix_1, 'speedlimit_mapbox'] = speeds_to_input
        df_orig.loc[ix_0:ix_1, 'lon_mapbox'] = lons_to_input
        df_orig.loc[ix_0:ix_1, 'lat_mapbox'] = lats_to_input

    return df_orig

if __name__ == '__main__':
    files2 = Path(r'..\Preprocessed data\Before matching with GPS azimuth and distance')

    downsample_s = 5
    chunk_size = 100
    lon_col = 'longitude'
    lat_col = 'latitude'
    t_col = 'date'

    todo = ['3.2', '3.4', '3.5', '5.1', '5.2', '5.3', '6.1', '6.2', '6.3'] # runs selected for paper (have gopro gps data etc.)
    files = files2.glob('*gps_gopro.csv')
    for file in list(files):
        
        if file.name[:3] not in todo:
            continue

        print(file.name)
        
        df = pd.read_csv(file, parse_dates=['date'])
        
        if downsample_s:
            df['old_index'] = range(len(df))
            df_orig = df.copy()
            df = df.resample(f'{downsample_s}s', on=t_col).first().reset_index()

        df['t_unix'] = df[t_col].astype(np.int64) // 10**9

        chunks = df_chunker(df, chunk_size) 

        for i, chunk_df in enumerate(chunks):
            print('Chunk', i)
        
            response = mapbox_request(chunk_df[lon_col], chunk_df[lat_col], chunk_df.t_unix)
            
            response_df = parse_response(response)
            
            ixs = chunk_df.index
            df.loc[ixs, 'lon_mm'] = response_df.lon_mm.values
            df.loc[ixs, 'lat_mm'] = response_df.lat_mm.values
            df.loc[ixs, 'speed_limit_mm'] = response_df.speed_limit_mm.values
            # df.loc[ixs, 'street_name_mapbox'] = response_df.street_name_mm.values
            # df.loc[ixs, 'confidence_mm'] = response_df.confidence_mm.values
            
            
        df_final = impute_speeds(df_orig, df, t_col)

        # reformat so column numbers are same as original
        df_final

        # remove speedlimit, speedzone_id, speedzone_name, old_index columns
        df_final.drop(columns=['speedlimit', 'speedzone_id', 'speedzone_name', 'old_index'], inplace=True)
        df_final.to_csv(f'files_out/newfiles/mm_{file.name}', index=False)
