# -*- coding: utf-8 -*-
# Created by mjribeiro at 9-6-2020
import random
import osmnx as ox
# ox.config(log_console=True, use_cache=True)
import numpy as np
from time import strftime, gmtime
import bluesky as bs
import itertools

from bluesky.tools import geo
from bluesky.tools.aero import kts, nm
import os
import sys
from itertools import count, filterfalse
# np.set_printoptions(threshold=sys.maxsize)
# np.set_printoptions(threshold=np.inf, suppress=True, formatter={'float_kind': '{:0.0f}'.format})

# import matplotlib.pyplot as plt

COMPUTE_NEW_ROUTES = False
FINISHED_READING_DATA = False
import math

N_WAYPOINTS = 3

EDGE_WIDTH = 110 / nm  # 110 msd
MIN_SEPARATION = 50  # m
TURNRAD = 3 / nm  # 2 meter

TOTAL_DURATION = 60 # minutes
TIME_DIRECTION_CHANGE = 1200 # every 10 minutes

# for this UAV model, more than 15kts will result in a too big turn radius
# smaller speeds are good for turns
TAS_MAX = 30.0  # kts - 15.4m/s
TAS_AVG = 21  # kts - 10.83m/s
TAS_CURVE = 10.0  # kts -5.14m/s
# TOTAL_AIRCRAFT = 700
DENSITIES = [100]



# 0 - north, 1-bottom, 2-east, 3-west

#40 different routes in total
ORIGIN_SIDES = [0, 1, 2, 3]
# positive indicates a tunr to the right, neagtive is a turn to the left
ROUTE_TURNS = [0, 1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6] # 13 route types in total
ROUTE_TURNS = [0, 2, 4, 6, -2, -4, -6] # 7 route in total
ROUTE_TURNS = [0, 2, 4, -2, -4] # 7 route in total
COMBINATION_SCENARIOS = list(itertools.combinations(ROUTE_TURNS, 4))

GRAPH = None
EDGES_PER_NODE = dict()
BOUNDARIES = None
EDGES = set()
EDGES_BOUNDARY = [[] for x in range(4)]
NODES = dict()
REPLACED_NODES = []
# nodes to be used only to transition between layers
TRANSITION_NODES = []

GOING_OUT = 1
GOING_IN = 0
BOTH_DIRECTIONS = 2



def getCurrentBorder(node, G):
    for it in range(len(EDGES_BOUNDARY)):
        if node in EDGES_BOUNDARY[it]:
            return it


def getborderposition(side, G):
    # 0 - east, 1-bottom, 2-west, 3-top
    node = random.choice(EDGES_BOUNDARY[side])
    tempLat, tempLon = getLatLonnode(G, node)

    return tempLat, tempLon, node

def checknodeproximity(newnode, previousnodes):
    getnewnode = False

    ind = np.where(previousnodes == newnode)[0]
    if len(ind) > 0:
        if len(ind) > 1:
            ind = ind[-1]

        if abs(len(previousnodes) - ind) > 15:
            getnewnode = False
    else:
        getnewnode = False

    return getnewnode

def getOrigin(G, originsnode, possible_origins):
    # get origin
    getneworigin = True

    if type(possible_origins) in [int, np.int32, np.int64]:
        originside = possible_origins
    else:
        originside = random.choice(possible_origins)

    cycle_number = 0
    while getneworigin:
        # get the start node on the map
        originLat, originLon, orig_node = getborderposition(originside, G)
        getneworigin = checknodeproximity(orig_node, originsnode)
        cycle_number += 1
        if cycle_number > 100:
            print('getneworigin reached a limit.')
            getneworigin = False

    return originLat, originLon, orig_node, originside

def getLatLonnode(G, node):
    if NODES.get(node) is not None:
        return NODES[node]
    else:
        return [G.nodes()[node]['y'], G.nodes()[node]['x']]


def getclosestnode(newlats, newlons, lat_nodes, lon_nodes, node_ids):
    newLats_extend = np.zeros(len(lat_nodes))
    newLons_extend = np.zeros(len(lat_nodes))

    newLats_extend[0:2] = newlats
    newLons_extend[0:2] = newlons

    qdr, dist = geo.kwikqdrdist_matrix(np.asmatrix(newLats_extend), np.asmatrix(newLons_extend), \
                                       np.asmatrix(lat_nodes), np.asmatrix(lon_nodes))

    ind0 = np.argmin(dist[0])
    ind1 = np.argmin(dist[1])

    return node_ids[ind0], node_ids[ind1]


def getnextnodeswithpreference(allnodes, min_value_ind):
    min_alt_var = min([row[min_value_ind] for row in list(allnodes.values())])
    possible_cur_nodes = dict()
    for node, values in allnodes.items():
        if values[min_value_ind] == min_alt_var:
            possible_cur_nodes[node] = values

    return possible_cur_nodes


def getshortestpath(startnode, lastnode):
    # shortest paths is a dict of nodes, whose value is a tuple of:
    # (previous node, hdg_previous_node, path_length until now, turns until now, altitude variations until now)
    shortest_paths = {startnode: (None, None, 0, 0, 0)}
    dist_ind = 2
    turns_ind = 3
    alt_var_int = 4
    current_node = startnode
    visited = set()

    # while we haven't reached our destination
    while current_node != lastnode:
        visited.add(current_node)
        destinations = EDGES_PER_NODE[current_node]
        # remove the nodes that are just to be used for transitioning between layers not to travel through
        for destinations_it in range(len(destinations)):
            if destinations[destinations_it] is not None and destinations[destinations_it][0] in TRANSITION_NODES:
                destinations[destinations_it] = None
        weight_to_current_node = shortest_paths[current_node]

        edge_hdg = 1
        for next_node in destinations:
            if next_node is not None:
                distance_to_next_node = next_node[2]
                # add path length to new node
                w_distance = distance_to_next_node + weight_to_current_node[dist_ind]
                # add altitude variations to new node
                alt_var = 0
                w_alt_var = weight_to_current_node[alt_var_int]
                # add number turns variations to new node
                w_turns = weight_to_current_node[turns_ind]
                save_edge_hdg = edge_hdg
                if weight_to_current_node[1] is not None and weight_to_current_node[1] != edge_hdg:
                    alt_var = abs(edge_hdg - weight_to_current_node[1])
                    # because of the auxiliary layers
                    if alt_var == 3 and 1 <= weight_to_current_node[1] <= 4:
                        alt_var = 1
                        if edge_hdg == 4:
                            save_edge_hdg = 0
                        elif edge_hdg == 1:
                            save_edge_hdg = 5
                    w_alt_var = weight_to_current_node[alt_var_int] + alt_var
                    w_turns = weight_to_current_node[turns_ind] + 1
                # is this a new node?
                if next_node[0] not in shortest_paths:
                    shortest_paths[next_node[0]] = (current_node, save_edge_hdg, w_distance, w_turns, w_alt_var)
                # is we already have this node, is this a better way to get there?
                else:
                    cur_best_weight = shortest_paths[next_node[0]]
                    # first prioritize altitude variations, and then turns, and then path length
                    if w_alt_var < cur_best_weight[alt_var_int] or \
                            (w_alt_var == cur_best_weight[alt_var_int] and w_turns < cur_best_weight[turns_ind]) or \
                            (w_alt_var == cur_best_weight[alt_var_int] and w_turns == cur_best_weight[
                                turns_ind] and w_distance < cur_best_weight[dist_ind]):
                        shortest_paths[next_node[0]] = (current_node, save_edge_hdg, w_distance, w_turns, w_alt_var)
            edge_hdg += 1

        next_destinations = {node: shortest_paths[node] for node in shortest_paths if node not in visited}
        # next node is the destination with the lowest weight
        # first we prioritize alt variations, and then path length
        new_next_nodes = getnextnodeswithpreference(next_destinations, alt_var_int)
        if len(new_next_nodes) > 1:
            new_next_nodes = getnextnodeswithpreference(new_next_nodes, turns_ind)
            if len(new_next_nodes) > 1:
                new_next_nodes = getnextnodeswithpreference(new_next_nodes, dist_ind)
        current_node = list(new_next_nodes.keys())[0]

    # Work back through destinations in shortest path
    path = []
    pathlength = shortest_paths[current_node][dist_ind]
    path_turns = shortest_paths[current_node][turns_ind]
    path_alt_variatons = shortest_paths[current_node][alt_var_int]
    while current_node is not None:
        path.append(current_node)
        next_node = shortest_paths[current_node][0]
        current_node = next_node
    # Reverse path
    path = path[::-1]

    return path, pathlength, path_alt_variatons, path_turns


def getLatLonAllNodes(G):
    node_ids = list(EDGES_PER_NODE.keys())
    lat_nodes = np.array([])
    lon_nodes = np.array([])
    for node_id in node_ids:
        lat0, lon0 = getLatLonnode(G, node_id)
        lat_nodes = np.append(lat_nodes, lat0)
        lon_nodes = np.append(lon_nodes, lon0)

    return lat_nodes, lon_nodes, node_ids

def getDestSide(origin_side,routes_number_turns ):
    if np.sign(routes_number_turns) > 0: #turning right
        new_side = ([3, 0, 1, 2])[origin_side]
        new_origin_side = ([1, 2, 3, 0])[origin_side]
    else:  #turning left
        new_side = ([1, 2, 3, 0])[origin_side]
        new_origin_side = ([3, 0, 1, 2])[origin_side]

    return new_side, new_origin_side

def insideMap(G, node):
    lat, lon = getLatLonnode(G, node)

    return ((BOUNDARIES[1] <= lat) & (lat <= BOUNDARIES[0])) & \
           ((BOUNDARIES[3] <= lon) & (lon <= BOUNDARIES[2]))

# change to a referential where node is going 'north'
def put_in_node_ref(G, step, previous_node, current_node):
    lat_previous_node, lon_previous_node = getLatLonnode(G, previous_node)
    lat_current_node, lon_current_node = getLatLonnode(G, current_node)
    qdr, dist = geo.kwikqdrdist(lat_previous_node, lon_previous_node,  lat_current_node, lon_current_node)

    right_qdr = (qdr + 90)%360
    possible_directions = [86.8, 176.8, 266.8, 356.8]
    closest_heading = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - right_qdr))

    # edges = [right, down, left, up]
    edges = [closest_heading, (closest_heading+1)%4, (closest_heading+2)%4, (closest_heading+3)%4]
    return edges[step]

def computeRoutes(G, density, route_type, possible_origins):
    originsnode = np.array([])
    destnodes = np.array([])
    routelength = np.array([])
    destinationsides = np.array([])

    nodes_stops = []
    routes = []
    nacTotal = np.inf
    timePassed = 0

    #origin_side_it = 0
    #routes_number_turns = random.choice(ROUTE_TURNS)
    routes_number_turns = ROUTE_TURNS[route_type]
    #indice_COMBINATION_SCENARIOS = 0
    #routes_number_turns = COMBINATION_SCENARIOS[route_type][indice_COMBINATION_SCENARIOS]
    #print('file_number', file_number, 'time', 0, 'choice', routes_number_turns)

    last_change_routes = 0
    while len(routes) < np.ceil(nacTotal):
        # if timePassed - last_change_routes >= TIME_DIRECTION_CHANGE:
        #     last_change_routes = timePassed
        #     indice_COMBINATION_SCENARIOS += 1
        #     #print(indice_COMBINATION_SCENARIOS,  timePassed)
        #     print(route_type, route_type, routes_number_turns, timePassed)
        #     routes_number_turns = COMBINATION_SCENARIOS[route_type][indice_COMBINATION_SCENARIOS]

        route_length = 0
        nodes_before_turn = []
        originLat, originLon, originnode, origin_side = getOrigin(G, originsnode, possible_origins)
        originsnode = np.append(originsnode, originLon)

        dest_side = ([2,3, 0, 1])[origin_side]

        make_turn_every_number_nodes = 20

        #calculate the route
        new_node = EDGES_PER_NODE[originnode][dest_side][0]
        last_turn = 0
        number_turns_done = 0
        route = np.array([originnode])
        #print('new route-----------------------')
        sucessive_turns = 0
        while insideMap(G, new_node):
            previous_node = new_node
            if number_turns_done < abs(routes_number_turns) and len(route-1) - last_turn >= make_turn_every_number_nodes:
                # add the node just before the turn
                nodes_before_turn.append(new_node)
                # now turn
                new_dest_side, new_origin_side = getDestSide(origin_side,routes_number_turns )
                new_node = EDGES_PER_NODE[previous_node][new_dest_side]
                # is this node None and just following the previous route wasnt?
                if new_node is None and EDGES_PER_NODE[previous_node][dest_side] is not None:
                    new_node = EDGES_PER_NODE[previous_node][dest_side]
                else:
                    origin_side = new_origin_side
                    dest_side = new_dest_side
                    last_turn = len(route)
                    number_turns_done += 1
                    sucessive_turns += 1
                    make_turn_every_number_nodes = random.choice(np.arange(2,5))
            else:
                sucessive_turns = 0
                new_node = EDGES_PER_NODE[new_node][dest_side]

            # does this exist?
            if new_node is None:
                start_lat, end_lat, end_lon, start_lon = BOUNDARIES
                lat_node, lon_node = getLatLonnode(G, previous_node)
                # if we left the map just finish route
                if number_turns_done < abs(routes_number_turns) and sucessive_turns < 2:
                    if new_node is not None:
                        nodes_before_turn.append(new_node)
                    else:
                        nodes_before_turn.append(previous_node)
                    dest_side, origin_side = getDestSide(origin_side, routes_number_turns)
                    new_node = EDGES_PER_NODE[previous_node][dest_side]
                    if new_node is None:
                        break
                    last_turn = len(route)
                    number_turns_done += 1
                    sucessive_turns += 1
                    make_turn_every_number_nodes = random.choice(np.arange(5, 10))
                elif geo.kwikdist(lat_node, lon_node, lat_node, end_lon) * nm < 110:
                    dest_side = 0
                    break
                elif geo.kwikdist(lat_node, lon_node, end_lat, lon_node) * nm < 160:
                    dest_side = 1
                    break
                elif geo.kwikdist(lat_node, lon_node, start_lat, lon_node) * nm < 210:
                    dest_side = 3
                    break
                elif geo.kwikdist(lat_node, lon_node, lat_node, start_lon) * nm < 360:
                    dest_side = 2
                    break
                else:
                    break

            route_length += new_node[1]
            new_node = new_node[0]
            route = np.append(route, new_node)
            # lat_previous, lon_previous = getLatLonnode(G, previous_node)
            lat_now, lon_now = getLatLonnode(G, new_node)
            # hdg, dist = geo.kwikqdrdist(lat_previous, lon_previous , lat_now, lon_now )
            # print(hdg, dist)


        if len(route)>10 and number_turns_done >= abs(routes_number_turns) and \
                (lon_now > -112.478 or lon_now < -122.505 or lat_now > 37.760 or lat_now < 37.740):
            destnodes = np.append(destnodes, new_node)
            # route has all the nodes
            if route[-1] != new_node and new_node is not None:
                route = np.append(route, new_node)
            routelength = np.append(routelength, route_length)
            routes.append(route)
            # nodes stops only has the nodes needed to turn + final destination node
            nodes_before_turn.append( route[-1])
            nodes_stops.append(tuple(nodes_before_turn))
            destinationsides = np.append(destinationsides, dest_side)

            distAvg = np.average(routelength)
            spawnRate = (density * TAS_MAX * kts) / (distAvg)
            spawnInterval = 1.0 / spawnRate  # seconds

            #if len(routes) * spawnInterval - timePassed > 30*60:
            timePassed = len(routes) * spawnInterval
            #print('timePassed', timePassed)

            # Total number of aircraft in scenario for the total scenario duration
            nacTotal = np.ceil(TOTAL_DURATION * 60.0 / spawnInterval)
            #origin_side_it = (file + 1) % len(ORIGIN_SIDES)

    return routes[:int(np.ceil(nacTotal))], spawnInterval, routelength, distAvg, nodes_stops, destinationsides


def addedgetonode(G, node, edge, lat_node, lon_node):
    if EDGES_PER_NODE.get(node) is None:
        EDGES_PER_NODE[node] = list()

    other_end_edge = edge[0]
    if other_end_edge == node:
        other_end_edge = edge[1]

    lat_otherend, lon_otherend = getLatLonnode(G, other_end_edge)
    qdr, dist = geo.kwikqdrdist(lat_node, lon_node, lat_otherend, lon_otherend)

    EDGES_PER_NODE[node].append((other_end_edge, qdr, dist * nm))


def addtoEDGES(node0, node1):
    EDGES.add((min(node0, node1), max(node0, node1)))


def addtoNodes(lat, lon, NODES_it, onlyTransitionNode=False):
    key_with_value = None
    for key, value in NODES.items():
        if value == (lat, lon):
            key_with_value = key
            break

    if (lat, lon) not in NODES.values():
        NODES[NODES_it] = (lat, lon)
        if onlyTransitionNode:
            TRANSITION_NODES.append(NODES_it)
        return NODES_it + 1, NODES_it
    else:
        return NODES_it, key_with_value


def checkEdgeOk(heading, expected_headings):
    for hdg in expected_headings:
        if abs(heading - hdg) < 3:
            return True

    return False


def replaceedgeinnode(node, new_other_edge, old_other_edge, G):
    new_edges = []
    for edge in EDGES_PER_NODE[node]:
        if edge is not None and edge[0] == old_other_edge:
            lat0, lon0 = getLatLonnode(G, node)
            lat1, lon1 = getLatLonnode(G, new_other_edge)
            qdr, dist = geo.kwikqdrdist(lat0, lon0, lat1, lon1)
            new_edges.append((new_other_edge, qdr, dist * nm))
        else:
            new_edges.append(edge)

    EDGES_PER_NODE[node] = new_edges


# make sure we only have the 4 expected edges (with expected heading)
def removeEdge(node, edges, G):
    expected_headings = [86, 176, 266, 356]
    correct_headings = [None, None, None, None]

    # look for the expected headings
    for it in range(len(expected_headings)):
        for edge in edges:
            if abs(expected_headings[it] - edge[1]) < 5:
                correct_headings[it] = edge

    closest_hdg = 10000
    missing_edges = sum(x is None for x in correct_headings)
    if missing_edges > len(expected_headings) - len(edges):
        for edge in edges:
            if edge not in correct_headings:
                closest_heading = min(range(len(expected_headings)), key=lambda i: abs(expected_headings[i] - edge[1]))
                dig_hdg = abs(edge[1] - expected_headings[closest_heading])
                if dig_hdg < closest_hdg:
                    closest_hdg = dig_hdg
                if correct_headings[closest_heading] is None and dig_hdg < 10:
                    correct_headings[closest_heading] = edge
                elif (min(node, edge[0]), max(node, edge[0])) in EDGES:
                    EDGES.remove((min(node, edge[0]), max(node, edge[0])))

    missing_edges = sum(x is None for x in correct_headings)
    # we dont need this node, its just in the way
    if missing_edges == 2:
        edges_delete = []
        if correct_headings[0] is not None and correct_headings[2] is not None:
            edges_delete.append(correct_headings[0][0])
            edges_delete.append(correct_headings[2][0])
        if correct_headings[1] is not None and correct_headings[3] is not None:
            edges_delete.append(correct_headings[1][0])
            edges_delete.append(correct_headings[3][0])
        if len(edges_delete) > 0:
            if (min(node, edges_delete[1]), max(node, edges_delete[1])) in EDGES:
                EDGES.remove((min(node, edges_delete[1]), max(node, edges_delete[1])))
            if (min(edges_delete[0], node), max(edges_delete[0], node)) in EDGES:
                EDGES.remove((min(edges_delete[0], node), max(edges_delete[0], node)))
            addtoEDGES(edges_delete[0], edges_delete[1])
            replaceedgeinnode(edges_delete[0], edges_delete[1], node, G)
            replaceedgeinnode(edges_delete[1], edges_delete[0], node, G)
        correct_headings = [None, None, None, None]

    EDGES_PER_NODE[node] = correct_headings
    return correct_headings


def drawedges(G):
    G = ox.utils_graph.remove_isolated_nodes(G)

    NODES_it = 0

    number = 0
    for edge in list(G.edges(data=True)):
        # G.edges return duplicates edges: edge(A-B), edges(B-A) - we just want to save each edge once,
        # we will later decide the traffic direction in this edge
        # maybe the line is not a straight line
        if 'geometry' in edge[2]:
            lines = edge[2]['geometry']
            new_nodes = []
            for line_it in range(1, len(lines.xy[1]) - 1):
                NODES_it, node_id = addtoNodes(lines.xy[1][line_it], lines.xy[0][line_it], NODES_it)
                new_nodes.append(node_id)
            addtoEDGES(edge[0], new_nodes[0])
            addtoEDGES(edge[1], new_nodes[-1])
            for it in range(new_nodes[0], new_nodes[-1]):
                addtoEDGES(it, it + 1)
        else:
            addtoEDGES(edge[0], edge[1])

    for edge in EDGES:
        lat0, lon0 = getLatLonnode(G, edge[0])
        lat1, lon1 = getLatLonnode(G, edge[1])
        addedgetonode(G, edge[0], edge, lat0, lon0)
        addedgetonode(G, edge[1], edge, lat1, lon1)

    global EDGES_PER_NODE
    nodes_remove = []
    for node, edges in EDGES_PER_NODE.items():
        new_edges = removeEdge(node, edges, G)
        ok_edges = sum(x is not None for x in new_edges)
        if ok_edges < 2:
            nodes_remove.append(node)

    while len(nodes_remove) > 0:
        for node_remove in nodes_remove:
            edges = EDGES_PER_NODE[node_remove]
            for edge in edges:
                if edge is not None and (min(node_remove, edge[0]), max(node_remove, edge[0])) in EDGES:
                    EDGES.remove((min(node_remove, edge[0]), max(node_remove, edge[0])))
            del EDGES_PER_NODE[node_remove]

        nodes_remove_new = []
        for node, edges in EDGES_PER_NODE.items():
            new_edges = []
            for edge in edges:
                if edge is not None and edge[0] not in nodes_remove:
                    new_edges.append(edge)
                else:
                    new_edges.append(None)
            ok_edges = sum(x is not None for x in new_edges)
            if ok_edges > 1:
                EDGES_PER_NODE[node] = new_edges
            else:
                nodes_remove_new.append(node)
        nodes_remove = nodes_remove_new

    # check that every node has the correct connections
    for node, edges in EDGES_PER_NODE.items():
        for it in range(len(edges)):
            if edges[it] is not None:
                other_node = edges[it][0]
                if other_node == node:
                    edges[it] = None
                    EDGES_PER_NODE[node] = edges
                else:
                    opposite_edge = (it + 2) % 4
                    edges_other_node = EDGES_PER_NODE[other_node]
                    edge_other_node = edges_other_node[opposite_edge]
                    if edge_other_node is not None and edge_other_node[0] != node:
                        print('??')
                    elif edge_other_node is None:
                        lat0, lon0 = getLatLonnode(G, node)
                        lat1, lon1 = getLatLonnode(G, other_node)
                        qdr, dist = geo.kwikqdrdist(lat1, lon1, lat0, lon0)
                        edges_other_node[opposite_edge] = (node, qdr, dist * nm)
                        EDGES_PER_NODE[other_node] = edges_other_node

    # remove nodes in the middle too close to other nodes
    nodes_remove = []
    for node, edges in EDGES_PER_NODE.items():
        copy_edges = np.copy(edges)
        for edge_it in range(len(edges)):
            edge = edges[edge_it]
            if edge is not None:
                distance = edge[2]
                node_to_remove = edge[0]
                edges_deleted_node = EDGES_PER_NODE[node_to_remove]
                if sum(x is None for x in edges_deleted_node) >= 2 and \
                        (edges_deleted_node[3] == edges_deleted_node[1] == None or edges_deleted_node[0] ==
                         edges_deleted_node[2] == None):
                    if edges_deleted_node[edge_it] is not None:
                        REPLACED_NODES.append(node_to_remove)
                        # print('deleted', node_to_remove, distance, edges_deleted_node, edge[1])
                        for edges_deleted_node_it in range(len(edges_deleted_node)):
                            if (edges[edges_deleted_node_it] is None and edges_deleted_node[
                                edges_deleted_node_it] is not None) \
                                    or edges_deleted_node_it == edge_it:
                                new_node_to_connect_to = edges_deleted_node[edges_deleted_node_it][0]
                                lat0, lon0 = getLatLonnode(G, node)
                                lat1, lon1 = getLatLonnode(G, new_node_to_connect_to)
                                qdr, dist = geo.kwikqdrdist(lat0, lon0, lat1, lon1)
                                copy_edges[edges_deleted_node_it] = (new_node_to_connect_to, qdr, dist * nm)
                                edges_node_other_side = EDGES_PER_NODE[new_node_to_connect_to]
                                opposite_step = (edges_deleted_node_it + 2) % 4
                                edges_node_other_side[opposite_step] = (node, (qdr + 180) % 360, dist * nm)
                                EDGES_PER_NODE[new_node_to_connect_to] = edges_node_other_side
                        EDGES_PER_NODE[node] = copy_edges
                        nodes_remove.append(node_to_remove)

    for node in nodes_remove:
        del EDGES_PER_NODE[node]

    start_lat, end_lat, end_lon, start_lon = BOUNDARIES
    for node in list(EDGES_PER_NODE.keys()):
        # remove duplicates and order per edge qdr
        # EDGES_PER_NODE[node] = sorted(list(set(EDGES_PER_NODE[node])), key=lambda x: x[1])
        edges = EDGES_PER_NODE[node]
        lat_node, lon_node = getLatLonnode(G, node)
        if edges[0] is None and geo.kwikdist(lat_node, lon_node, lat_node, end_lon) * nm < 100:  # right
            EDGES_BOUNDARY[0].append(node)
        elif edges[1] is None and geo.kwikdist(lat_node, lon_node, end_lat, lon_node) * nm < 150:  # bottom
            EDGES_BOUNDARY[1].append(node)
        elif edges[-1] is None and geo.kwikdist(lat_node, lon_node, start_lat, lon_node) * nm < 200:  # top
            EDGES_BOUNDARY[3].append(node)
        elif edges[2] is None and geo.kwikdist(lat_node, lon_node, lat_node, start_lon) * nm < 350:  # left
            EDGES_BOUNDARY[2].append(node)

    for edge in EDGES:
        lat1, lon1 = getLatLonnode(G, edge[0])
        lat2, lon2 = getLatLonnode(G, edge[1])
        NODES[edge[0]] = (lat1, lon1)
        NODES[edge[1]] = (lat2, lon2)
        # qdr, dist = geo.kwikqdrdist(lat1, lon1, lat2, lon2)
        # print(edge[0], lat1, lon1, edge[1], lat2, lon2, qdr)
        bs.scr.objappend('LINE', "line_" + str(number), np.append((lat1, lon1), (lat2, lon2)))
        number += 1

    # fig, axes = plt.subplots()
    # fig.set_size_inches(20, 20)
    # plt.grid(True)
    # axes.set_xlim([start_lon - 0.001, end_lon + 0.001])
    # axes.set_ylim([end_lat - 0.001, start_lat + 0.001])
    # color = ['red', 'black', 'blue', 'green']
    # for it in range(len(EDGES_BOUNDARY)):
    #     for it2 in range(len(EDGES_BOUNDARY[it])):
    #         lat, lon = getLatLonnode(G, EDGES_BOUNDARY[it][it2])
    #         axes.plot(lon, lat, marker='o', color=color[it])
    # # fig.savefig(os.getcwd() + "\\testfiles\\map_borders.png", bbox_inches='tight')
    #
    # # fig, axes = plt.subplots()
    # # fig.set_size_inches(20, 20)
    # # plt.grid(True)
    # # axes.set_xlim([start_lon - 0.001, end_lon + 0.001])
    # # axes.set_ylim([end_lat - 0.001, start_lat + 0.001])
    # for edge in EDGES:
    #     lat0, lon0 = getLatLonnode(G, edge[0])
    #     lat1, lon1 = getLatLonnode(G, edge[1])
    #     axes.plot([lon0, lon1], [lat0, lat1], color='b', linewidth=1)
    # fig.savefig(os.getcwd() + "\\testfiles\\map.png", bbox_inches='tight')


def definecity(boundaries):
    global BOUNDARIES
    BOUNDARIES = boundaries
    # Defining the map boundaries
    # north, bottom, east, west
    # boundaries = [37.813, 37.776,-122.38, -122.45] - example
    G = ox.graph_from_bbox(boundaries[0], boundaries[1], boundaries[2], boundaries[3], network_type='drive')
    # ox.plot.plot_graph(G)

    drawedges(G)

    # fig, ax = plt.subplots(nrowds=1)
    # for edge in EDGES:
    #     lat1, lon1 = getLatLonnode(G, edge[0])
    #     lat2, lon2 = getLatLonnode(G, edge[1])batch
    #     ax.plot( [lon1, lon2], [lat1, lat2], color= 'blue')
    # fig.suptitle("Urban Setting Used in this Work")
    # ax.set_ylabel("Latitude [\N{DEGREE SIGN}]")
    # ax.set_xlabel("Longitude [\N{DEGREE SIGN}]")

    global GRAPH
    GRAPH = G

    #5.14^2/(9.8* math.atan(25 * math.pi / 180))

    # OSMNX MUST BE VERSION 0.14.0 - OTHER VERSIONS HAVE DIFFERENT NODES!!
    # VERSION 0.16.0, FOR EXAMPLE, DOES NOT HAVE NODE 65340962!
    # NODES[65340962] = (37.7385227,-122.4936732)
    #
    # edges = EDGES_PER_NODE[258916766]
    # edges[1] = (65340962, , )
    # EDGES_PER_NODE[258916766] = edges
    #
    # edges = EDGES_PER_NODE[65353316]
    # EDGES_PER_NODE[65353316] =i
    #
    # dges = EDGES_PER_NODE[65340959]
    # EDGES_PER_NODE[65340959] =
    #
    # EDGES_PER_NODE[65340962] = ()

    # WE CANNOT START THE SIMULATIONS WHILE THIS DATA IS STILL BEING PROCESSED
    # print('FINISHED PROCESSING MAP DATA', bs.stack.get_scenname(), len(EDGES_PER_NODE))
    # all_nodes = list(EDGES_PER_NODE.keys())
    # print(sorted(all_nodes))

    global FINISHED_READING_DATA
    FINISHED_READING_DATA = True
    file_number = 0

    # there is a bug, I dont know why yet
    edges = EDGES_PER_NODE[65293150]
    edges[-1] = None
    EDGES_PER_NODE[65293150] = edges

    ORIGIN= [0, 1, 2, 3,
             [0,1], [0,2], [0,3], [1,2], [1,3], [2,3],
             [0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2,3],
             [0,1,2,3]]

    print(len(DENSITIES) * len(ORIGIN) * len(ROUTE_TURNS))

    if COMPUTE_NEW_ROUTES:
        for density in DENSITIES:
            for origin_side in ORIGIN:
                for route_type in range(len(ROUTE_TURNS)): ##range(len(ROUTE_TURNS)): #range(len(COMBINATION_SCENARIOS)): #
                    routes, spawnInterval, routelength, distAvg, nodes_stops, destinationsides = computeRoutes(G, density, route_type, origin_side)
                    createfile(routes, density, spawnInterval, routelength, boundaries, route_type, file_number, distAvg, destinationsides, origin_side)
                    createFilePoint(nodes_stops, routes, density, boundaries, route_type, distAvg, spawnInterval, file_number, origin_side)
                    print('file written', file_number)
                    file_number +=1
        print('finished files')


def tim2txt(t):
    """Convert time to timestring: HH:MM:SS.hh"""
    return strftime("%H:%M:%S.", gmtime(t)) + i2txt(int((t - int(t)) * 100.), 2)


def i2txt(i, n):
    """Convert integer to string with leading zeros to make it n chars long"""
    itxt = str(i)
    return "0" * (n - len(itxt)) + itxt


def createfile(routes, density, spawnInterval, routelength, boundaries, route_type, file_number, distAvg, destinationsides, origin_side):
    # Storage list for scenario definition
    lines = []

    lines.append('# ##################################################### #' + "\n")
    lines.append('# Num Instantaneous A/C: ' + str(density) + "\n")
    lines.append('# Spawned A/C: ' + str(len(routes)) + "\n")
    lines.append('# Turns: ' + str(route_type) + "\n")
    lines.append('# Origin Side: ' + str(origin_side) + "\n")
    lines.append('# Boundaries: ' + str(boundaries) + "\n")
    lines.append('# Average Distance: ' + str(distAvg) + " m \n")
    lines.append('# ##################################################### #' + "\n")

    spawn_seconds = 0
    for it in range(len(routes)):
        aircraftname = 'AC' + format(it + 1, '04d')
        # Convert time to hh:mm:ss format
        spawntime = tim2txt(spawn_seconds) + '>'

        lines.append('\n# ' + aircraftname + ' Direct Distance ' + str(routelength[it]) + '\n')

        origin_lat, origin_lon = getLatLonnode(GRAPH, routes[it][0])
        secondpoint_lat, secondpoint_lon = getLatLonnode(GRAPH, routes[it][1])
        thirdpoint_lat, thirdpoint_lon = getLatLonnode(GRAPH, routes[it][2])
        dest_lat, dest_lon = getLatLonnode(GRAPH, routes[it][-1])

        currentheading, dummy = geo.kwikqdrdist(origin_lat, origin_lon, secondpoint_lat, secondpoint_lon)
        next_heading, dummy = geo.kwikqdrdist(secondpoint_lat, secondpoint_lon, thirdpoint_lat, thirdpoint_lon)
        initial_speed = TAS_MAX
        if abs(currentheading - next_heading) > 5:
            initial_speed = TAS_CURVE

        lines.append(
            spawntime + 'CRE,' + aircraftname + ',MAVIC,' + str(origin_lat) + ',' + str(origin_lon) + ',' + str(
                currentheading % 360) + ',50,' + str(initial_speed) + "\n")

        newheading = 0
        if destinationsides[it] == 1:
            newheading = 180
        elif destinationsides[it] == 2:
            newheading = 90
        elif destinationsides[it] == 3:
            newheading = 270

        final_lat, finallon = geo.kwikpos(dest_lat, dest_lon, newheading, 1000 / nm)
        heading_final_point = newheading

        lines.append(spawntime + 'DEST,' + aircraftname + ',' + str(final_lat) + ',' + str(finallon) + "\n")

        for wpy in range(1, len(routes[it]) - 1, 1):
            node_lat, node_lon = getLatLonnode(GRAPH, routes[it][wpy])
            nextnode_lat, nextnode_lon = getLatLonnode(GRAPH, routes[it][wpy + 1])

            newheading, dummy = geo.kwikqdrdist(node_lat, node_lon, nextnode_lat, nextnode_lon)

            speed = str(TAS_MAX)
            if abs(currentheading - newheading) > 5:
                currentheading = newheading
                lines.append(spawntime + 'ADDWPT ' + aircraftname + ' FLYTURN' + "\n")
                lines.append(spawntime + 'ADDWPT ' + aircraftname + ' TURNRAD ' + str(TURNRAD) + "\n")
                speed = str(TAS_CURVE)

            # set two nodes in advance for breaking for turn, its needs it
            if wpy + 2 < len(routes[it]):
                next_next_node_lat, next_next_node_lon = getLatLonnode(GRAPH, routes[it][wpy + 2])
                newheading, dummy = geo.kwikqdrdist(node_lat, node_lon, next_next_node_lat, next_next_node_lon)
                if abs(currentheading - newheading) > 5:
                    speed = str(TAS_CURVE)
            else:  # dest point
                head_dest, dummy = geo.kwikqdrdist(node_lat, node_lon, dest_lat, dest_lon)
                if abs(currentheading - head_dest) > 5:
                    speed = str(TAS_CURVE)

            lines.append(
                spawntime + 'ADDWPT ' + aircraftname + ' ' + str(node_lat) + ',' + str(
                    node_lon) + ", 50, " + speed + "\n")

        speed = str(TAS_MAX)
        if abs(heading_final_point - head_dest) > 5:
            lines.append(spawntime + 'ADDWPT ' + aircraftname + ' FLYTURN' + "\n")
            lines.append(spawntime + 'ADDWPT ' + aircraftname + ' TURNRAD ' + str(TURNRAD) + "\n")
            speed = str(TAS_CURVE)
        lines.append(
            spawntime + 'ADDWPT ' + aircraftname + ' ' + str(dest_lat) + ',' + str(dest_lon) + ", 50, " + speed + "\n")

        lines.append(spawntime + 'ADDWPT ' + aircraftname + ' ' + str(final_lat) + ',' + str(finallon) + ", 50, " + str(
            TAS_MAX) + "\n")
        lines.append(spawntime + 'VNAV ' + aircraftname + ' ON' + "\n")
        spawn_seconds += spawnInterval

    # write lines and close file
    f = open(os.getcwd() + "/testfiles/intent-test-Rep" + str(file_number) + '.scn', "x")
    f.writelines(lines)
    f.close()


def createFilePoint(nodes_stops, routes, density, boundaries, route_type, distAvg, spawnInterval, file_number,origin_side ):
    # Storage list for scenario definition
    lines = []
    np.set_printoptions(threshold=np.inf, suppress=True, formatter={'float_kind': '{:0.0f}'.format})

    lines.append('# ##################################################### #' + "\n")
    lines.append('# Num Instantaneous A/C: ' + str(density) + "\n")
    lines.append('# Spawned A/C: ' + str(len(routes)) + "\n")
    lines.append('# Turns: ' + str(route_type) + "\n")
    lines.append('# Origin Side: ' + str(origin_side) + "\n")
    lines.append('# Boundaries: ' + str(boundaries) + "\n")
    lines.append('# Average Distance: ' + str(distAvg) + " m \n")
    lines.append('# ##################################################### #' + "\n")

    spawn_seconds = 0
    for it in range(len(routes)):
        aircraftname = 'AC' + format(it + 1, '04d')
        # Convert time to hh:mm:ss format
        spawntime = tim2txt(spawn_seconds)

        lines.append(aircraftname + ',' + spawntime + ',' + str(nodes_stops[it]) + ',' +
                     np.array2string(routes[it], precision=0, separator=',', max_line_width=np.inf) + "\n")
        spawn_seconds += spawnInterval

    # write lines and close file
    f = open(os.getcwd() + "/testfiles/point-test-Rep" + str(file_number) + '.scn', "x")
    f.writelines(lines)
    f.close()
