import random
import numpy as np
from bluesky.tools.aero import nm, ft
import bluesky as bs
from bluesky.tools import geo
import bluesky.tools.cityboundaries as cb
from itertools import count, filterfalse

# changing height
ALT_LAYER_SEP = 20  # ftp
MIN_ALT = 57.5  # ft

NUMBER_LAYERS = 6

# altitude layers (some may be active, some may not)
# 0 - traffic layer
# 1 - transition layer - fast
# 2 - transition layer - slow
# 3 - traffic layer
# 4 - transition layerr - fast
# 5 - transition layer - slow
# 6 - traffic layer
# 7 - transition layerr - fast
# 8 - transition layer - slow
# 9 - traffic layer
# 10 - transition layerr - fast
# 11 - transition layer - slow
# 12 - traffic layer
# 13 - transition layerr - fast
# 14 - transition layer - slow
# 15 - traffic layer
# 16 - transition layerr - fast

TRAFFIC_LAYERS = [0, 3, 6, 9, 12, 15]

def get_path(aicraftname, route, heading_per_layer, active_layers, aicraft_starting_per_layer):
    alts = np.array([])
    #print(aicraftname)
    # nodes from the osmnx package are int64 not int
    route = [x for x in route if x not in cb.REPLACED_NODES]
    route = np.array(route).astype(np.int64)
    route_initial = np.copy(route)

    # remove repeated nodesa
    it = 1
    while it < len(route):
        if route[it] == route[it - 1]:
            route = np.delete(route, it)
        else:
            it += 1

    #print(heading_per_layer, bs.sim.simt)
    for it in range(1, len(route)):
        alt, qdr, lat0, lon0, lat1, lon1, aicraft_starting_per_layer = getAltitudeLayer(route[it - 1], route[it], heading_per_layer, active_layers, alts, aicraft_starting_per_layer)
        alts = np.append(alts, alt)

    alts_modified = np.copy(alts)
    # set initial point so it does not start immediately on the point, it will "ascend" from the transition layer
    alts = np.insert(alts, 0, alts[0])
    alts_modified = np.insert(alts_modified, 0, alts_modified[0])# - 1)
    vars = np.zeros(len(alts_modified))

    # find altitude changes
    it = 1
    while it < len(alts_modified) - 1:
        if abs(alts_modified[it + 1] - alts_modified[it]) <= 2:
            it += 1
        else:
            # is this a temporary transition?
            temp_transition, alts_modified, it, vars = do_temp_trans(it, alts_modified, vars)
            # it's not
            if not temp_transition:
                var_alt = abs(alts_modified[it + 1] - alts_modified[it])

                # exception - we ended up in a transition layer
                if alts_modified[it] not in TRAFFIC_LAYERS:
                    alts_modified, route, vars, it = handle_exception_transition_layer(alts_modified, route, it, vars,
                                                                                       heading_per_layer, active_layers)

                var_alt = abs(alts_modified[it + 1] - alts_modified[it])
                if var_alt < 4:
                    alts_modified, it, vars = do_1step_trans(it, alts_modified, vars ,heading_per_layer)
                else:
                    alts_modified, it, route, vars = do_multi_step_trans(it, alts_modified, route, vars, heading_per_layer)

    # final check
    to_remove = []
    for it in range(len(route) - 1):
        isok = False
        edges = cb.EDGES_PER_NODE[route[it]]
        for edge in edges:
            if edge is not None and edge[0] == route[it + 1]:
                isok = True
        if not isok:
            if route[it] == route[it + 1] and alts_modified[it] == alts_modified[it + 1]:
                to_remove.append(it)
            else:
                break

    for it_remove in to_remove:
        alts_modified = np.delete(alts_modified, it_remove)
        route = np.delete(route, it_remove)
        vars = np.delete(vars, it_remove)

    for it in range(len(route)-2):
        if alts_modified[it] == alts_modified[it+1] == alts_modified[it+2]:
            lat0, lon0 = cb.getLatLonnode(cb.GRAPH, route[it])
            lat1, lon1= cb.getLatLonnode(cb.GRAPH, route[it+1])
            lat2, lon2 = cb.getLatLonnode(cb.GRAPH, route[it+2])
            qdr1, dummy = geo.kwikqdrdist(lat0, lon0 , lat1, lon1)
            qdr2, dummy = geo.kwikqdrdist( lat1, lon1,  lat2, lon2 )
            if abs(qdr1- qdr2)>5:
                vars[it+2] = True

    # #vars = np.zeros(len(alts))
    vars = np.zeros(len(alts))
    for it in range(1, len(alts)):
        if alts[it-1] != alts[it]:
            vars[it] = 1

    return alts, route_initial, vars, aicraft_starting_per_layer

def handle_exception_transition_layer(alts_modified, route, it, vars, heading_per_layer, active_layers):
    # were we going up or going down?
    var_dif = -1
    if alts_modified[it + 1] > alts_modified[it]:
        var_dif = 1  # move past the fast layer and move on to the slow layer

    if var_dif > 0:
        alts_modified[it] += 1
    else:
        alts_modified[it] -= 2

    return alts_modified, route, vars, it


def check_aux_layers(it, alts_modified):
    it2begin = it + 1
    it2end = it2begin + 1
    while it2end < len(alts_modified) and alts_modified[it2end] == alts_modified[it2begin]:
        it2end += 1

    if alts_modified[it] == 2:
        alts_modified[it2begin:it2end] = 0
    else:
        alts_modified[it2begin:it2end] = 10

    var_alt = 2
    return alts_modified, var_alt


# temporary transition jut go on up
# alts_modified[it], alts_modified[it + 2], alts_modified[it + 3]
# alt0, alt2, alt3,
def do_temp_trans(it, alts_modified, vars):
    is_temp = True
    transition_alt = alts_modified[it] + 1

    # maximum altitude
    if alts_modified[it] == 10:
        transition_alt = alts_modified[it] - 1

    # we've reached the end of the route
    if it + 2 >= len(alts_modified):
        alts_modified[it + 1] = transition_alt
        it = len(alts_modified)
        vars[-2] = True
    else:
        if alts_modified[it + 2] == alts_modified[it]:
            for it2 in range(it, it + 2):
                alts_modified[it2] = transition_alt
                vars[it2] = True
            it += 2
        # elif it + 3 < len(alts_modified) and alts_modified[it + 3] == alts_modified[it]:
        #     for it2 in range(it, it + 3):
        #         alts_modified[it2] = transition_alt
        #         vars[it2] = True
        #     it += 3
        else:
            is_temp = False

    return is_temp, alts_modified, it, vars


# just going to upper or below altitude layer
# alts_modified[it + 1], alts_modified[it], it, alts_modified
def do_1step_trans(it, alts_modified, vars, heading_per_layer):
    # climb
    if alts_modified[it + 1] > alts_modified[it]:
        # pass the fast transition layer, and move on to the slow layer
        alts_modified[it] += 2
    # descent
    else:
        alts_modified[it] -= 1

    vars[it] = True
    it += 1
    return alts_modified, it, vars


def addnewpoint(alts_modified, it, alt_layer, node, route, step, vars):
    alts_modified = np.insert(alts_modified, it, alt_layer)
    route = np.insert(route, it, node)
    if step == 3:  # it is just going forward in the transition layer
        vars = np.insert(vars, it - 1, False)
    else:  # it is turning in the transition layer
        vars = np.insert(vars, it - 1, True)

    return alts_modified, route, vars


# change to a referential where node is going 'north'
def put_in_node_ref(step, current_layer, previous_node, current_node):
    lat_previous_node, lon_previous_node = cb.getLatLonnode(cb.GRAPH, previous_node)
    lat_current_node, lon_current_node = cb.getLatLonnode(cb.GRAPH, current_node)
    qdr, dist = geo.kwikqdrdist(lat_previous_node, lon_previous_node,  lat_current_node, lon_current_node)

    right_qdr = (qdr + 90)%360
    possible_directions = [86.8, 176.8, 266.8, 356.8]
    closest_heading = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - right_qdr))

    # edges = [right, down, left, up]
    edges = [closest_heading, (closest_heading+1)%4, (closest_heading+2)%4, (closest_heading+3)%4]
    return edges[step]

def getOppositeStep(step):
    return (step + 2) % 4

def getNewNode(previous_node, step, edges):
    lat_node, lon_node = cb.getLatLonnode(cb.GRAPH, previous_node)
    possible_directions = [86.8, 176.8, 266.8, 356.8]
    distances = [94.4, 207.6, 94.4, 207.6]  # meters
    lat0, lon0 = geo.kwikpos(lat_node, lon_node, possible_directions[step], distances[step] / nm)

    # check if we ended up in an existing node
    lat_nodes, lon_nodes, node_ids = cb.getLatLonAllNodes(cb.GRAPH)
    closest_node, dummy = cb.getclosestnode([lat0], [lon0], lat_nodes, lon_nodes, node_ids)
    lat_closest_node, lon_closest_node = cb.getLatLonnode(cb.GRAPH, closest_node)
    qdr, dist = geo.kwikqdrdist(lat0, lon0, lat_closest_node, lon_closest_node)
    if closest_node != previous_node and dist * nm < 15:  # lets use this existent node
        lat0, lon0 = lat_closest_node, lon_closest_node
        edges[step] = (closest_node, qdr, dist * nm)
        cb.EDGES_PER_NODE[previous_node] = edges
    # create a new node
    else:
        existent_nodes = list(cb.NODES.keys())
        # find the smallest positive number not in list
        new_nodes_key = next(filterfalse(set(existent_nodes).__contains__, count(1)))
        cb.addtoNodes(lat0, lon0, new_nodes_key, True)
        if edges[step] is None:
            edges[step] = (new_nodes_key, possible_directions[step], distances[step])
            cb.EDGES_PER_NODE[previous_node] = edges
        # create edges for the new node, create only the ones that we are going to need
        new_nodes_edges = [None, None, None, None]
        opposite_step = getOppositeStep(step)
        new_nodes_edges[opposite_step] = (previous_node, possible_directions[opposite_step], distances[step])
        cb.EDGES_PER_NODE[new_nodes_key] = new_nodes_edges
        closest_node = new_nodes_key

    return lat0, lon0, edges, closest_node


# create new points to allow A/C to transition altitude layers safely
def create_point_transition(node, step, alt_layer, mov_layer, it, alts_modified, route, vars, heading_per_layer):
    edges = cb.EDGES_PER_NODE.get(node)
    step_in_ref = put_in_node_ref(step, mov_layer, route[it-2], node)
    if edges[step_in_ref] is None:
        # add new auxiliary node to the map
        lat0, lon0, edges, new_node = getNewNode(node, step_in_ref, edges)

    #print('create new point', edges, edges[step_in_ref][0])
    alts_modified, route, vars = addnewpoint(alts_modified, it, alt_layer, edges[step_in_ref][0], route, step, vars)

    return alts_modified, route, edges[step_in_ref][0], vars

# more than 1 and up to 3 altitude layers
# alts_modified[it + 1], alts_modified[it]
# alt1, alt0,
def do_multi_step_trans(it, alts_modified, route, vars, heading_per_layer):
    # lets move forward a bit so we can then backtrack and join the correct lane
    #var_alt = abs(alts_modified[it + 1] - alts_modified[it])
    # at every step we move one or one down depending if it's a climb or descent
    var_dif = -1
    if alts_modified[it + 1] > alts_modified[it]:
        var_dif = 1# move past the fast layer and move on to the slow layer

    lat0, lon0 = cb.getLatLonnode(cb.GRAPH, route[it-1])
    lat1, lon1 = cb.getLatLonnode(cb.GRAPH, route[it])
    lat2, lon2 = cb.getLatLonnode(cb.GRAPH, route[it+1])
    now_hdg, dummt = geo.kwikqdrdist(lat0, lon0, lat1, lon1 )
    targe_hdg, dummy = geo.kwikqdrdist( lat1, lon1,   lat2, lon2)

    angle_dif = targe_hdg - now_hdg
    if angle_dif < 0:
        angle_dif += 360
    # if angle_dif > 180:  # it its 180 we need to check the layers the aircraft
    #     going_right = False  # left turn
    # else:
    #     going_right = True  # right turn

    if var_dif > 0: # going_up
        alts_modified[it] = alts_modified[it+1]  - 1 # turn in the transition layer
    else: # going down:
        alts_modified[it] = alts_modified[it + 1] + 2# turn in the transition layer

    # turn 180
    # 3 - continue forward, 0 - turn right, 2 - turn left
    # route[20] = route[18]
    # angle_dif = 180
    node = route[it]
    if abs(angle_dif - 180)< 2: # we need to add some space for this 180 degrees turn
        movement_var = [0, 0] # turning right
        alts_layers = [alts_modified[it] , alts_modified[it]]
        movement_layers = [alts_modified[it] , alts_modified[it]]
        it2 = it + 1

        for step, alt_layer, mov_layer in zip(movement_var, alts_layers, movement_layers):
            alts_modified, route, node, vars = \
            create_point_transition(node, step, alt_layer, mov_layer, it2, alts_modified, route, vars, heading_per_layer)
            it2 += 1

        for it3 in range(it+1, it2):
            vars[it3] = True
        it = it2 - 1

    # to set up the turn
    it = it + 1
    vars[it] = True

    # check if the last node of the transition is connect to the next node in the route
    # due to imperfections in the map this may not be the case
    isok = False
    edges_last_tran = cb.EDGES_PER_NODE[route[it]]
    for edge in edges_last_tran:
        if edge is not None and edge[0] == route[it+1]:
            isok = True
            break

    if not isok:  # and route[it2 - 1] not in cb.TRANSITION_NODES:
        #print('ROUTE is not ok!', route[it2 - 1], route[it2])
        try:
            new_route, dummy, dummy, dummmy = cb.getshortestpath(route[it - 1], route[it])
        except:
            new_route = createPath(route[it - 1], route[it])
        for it3 in range(1, len(new_route) - 1):
            route = np.insert(route, it, new_route[it3])
            # just keep going in the transition layer
            alts_modified = np.insert(alts_modified, it, alts_modified[it - 1])
            vars = np.insert(vars, it, False)
            it += 1

    # it += int(var_alt) - 1
    return alts_modified, it, route, vars

def createPath(node0, node1):
    new_route = [node0]
    lat0, lon0 = cb.getLatLonnode(cb.GRAPH, node0)
    lat1, lon1 = cb.getLatLonnode(cb.GRAPH, node1)
    qdr, dist = geo.kwikqdrdist(lat0, lon0, lat1, lon1)
    possible_directions = [86.8, 176.8, 266.8, 356.8]
    distances = [94.4, 207.6, 94.4, 207.6]  # meters
    closest_heading = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - qdr))
    dist = dist * nm
    current_node = node0
    while dist > distances[closest_heading]:
        edges_current_node = cb.EDGES_PER_NODE[current_node]
        lat0, lon0, edges_current_node, new_node = getNewNode(current_node, closest_heading, edges_current_node)
        new_route.append(new_node)
        qdr, dist = geo.kwikqdrdist(lat0, lon0, lat1, lon1)
        dist = dist * nm
        closest_heading = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - qdr))
        current_node = new_node

    if new_route[-1] != node1:
        if cb.EDGES_PER_NODE[new_route[-1]][closest_heading] is None:
            cb.EDGES_PER_NODE[new_route[-1]][closest_heading] = (node1, qdr, dist)
        new_route.append(node1)
    return new_route


def getAltitudeLayer(point0, point1, heading_per_layer, active_layers, previous_alts, aicraft_starting_per_layer):
    # set it from bottom to top
    heading_per_layer = np.flip(heading_per_layer)

    #expected_headings = [356, 86, 176, 266]
    heading_transform = [-np.inf, 45, 135, 225, 315]
    expected_headings = []
    for it in range(0, len(heading_per_layer)):
        active = active_layers[it]
        hdg = heading_per_layer[it]
        # do not use this layer, it was just to help at the moment of the transition
        # but it should be used afterwards
        if len(previous_alts)>0 and not active:
            hdg = 0
            expected_headings.append(hdg)
        else:
            if type(hdg)in [int, np.int32, np.int64]:
                expected_headings.append(hdg)
            else:
                aux = []
                for heading in hdg:
                    aux.append(heading)
                expected_headings.append(aux)

    lat0, lon0 = cb.getLatLonnode(cb.GRAPH, point0)
    lat1, lon1 = cb.getLatLonnode(cb.GRAPH, point1)

    qdr, dummy = geo.kwikqdrdist(lat0, lon0, lat1, lon1)
    # altitude = min(range(len(expected_headings)), key=lambda i: abs(expected_headings[i] - qdr))
    # main altitudes start at layer 2
    correct_layer_hdg = min(range(len(heading_transform)), key=lambda i: abs(heading_transform[i] - qdr))
    # options of layers where to put the aircraft in
    options_layers = np.array([])
    for it in range(0, len(expected_headings)):
        if type(expected_headings[it])in [int, np.int32, np.int64]:
            if expected_headings[it] == correct_layer_hdg:
                options_layers = np.append(options_layers, it)
        elif correct_layer_hdg in expected_headings[it]:
            options_layers = np.append(options_layers, it)

    # start of the path or there is only one possibility anyway
    if len(options_layers) == 1:
        altitude = options_layers[0]
    # else what is the closest layer to the one we are now?
    else:
        if len(previous_alts) == 0:
            current_layer = None
        else:
            current_layer = TRAFFIC_LAYERS.index(previous_alts[-1])
        #is it an option to continue in the same layer?
        if current_layer in options_layers:
            altitude = current_layer
        else:
            # get the number of altitude variations that would be necessary to move into each of the valid altitudes
            if current_layer is not None:
                alt_variations = abs(options_layers - current_layer)
            else:
                alt_variations = np.zeros(len(options_layers))
            min_var = min(alt_variations)
            best_options = np.where(alt_variations == min_var)[0]
            options_layers = options_layers[best_options]
            # are there more than one valid layer that results in the same variation?
            if len(best_options) > 1:
                # if more than one way is permitted, try to use the same layers throughout the path
                best_options = -1
                for it in np.arange(0, len(options_layers)):
                    if TRAFFIC_LAYERS[int(options_layers[it])] in previous_alts:
                        best_options = it

                # else use the layer with the fewest number of aircraft
                if best_options < 0:
                    best_options = random.choice(range(0, len(options_layers)))
                    if len(previous_alts) == 0: # if it is the start, lets start with aircraft well divided
                        min_var = min(aicraft_starting_per_layer[options_layers.astype(int)])
                        best_options = np.where(aicraft_starting_per_layer[options_layers.astype(int)] == min_var)[0]
                        if len(best_options) > 1: # well then just pick randomly
                            best_options = random.choice(best_options)

            else:
                best_options = 0
            altitude = options_layers[int(best_options)]

    # turn the traffic layers to their actual number taking into acount transition altitudes
    if len(previous_alts) == 0:
        aicraft_starting_per_layer[int(altitude)] += 1
    altitude = TRAFFIC_LAYERS[int(altitude)]

    return altitude, qdr, lat0, lon0, lat1, lon1, aicraft_starting_per_layer
