# -*- coding: utf-8 -*-
# Created by mjribeiro at 28-1-2021
import bluesky as bs
import numpy as np
from bluesky.tools import geo
from bluesky.tools.aero import ft

import copy

NUMBER_LAYERS = 6

# A - 0     1   2   3    4   0
# B - 4     1   2   3    4   0
# C - 1+2   3+4
# D - 1+4   2+3

# Structure A - 0
# Structure B - 1
# Structure C - 2
# Structure D - 3
# LIST_STRUCTURES = [[0, 1, 2, 3, 4, 0],  #0
#                    #[4, 1, 2, 3, 4, 1],  #1
#                    # [[1, 2], [3,4], [1, 2], [3,4], [1, 2], [3,4]],  #2
#                    # [[1, 4], [2, 3], [1, 4], [2, 3], [1, 4], [2, 3]],  #3
#
#                    # [0 ,[1, 2], [3, 4], [1, 2], [3, 4], 0],  # 4
#                    # [0, [1, 4], [2, 3], [1, 4], [2, 3], 0],  # 5
#                    # [0, 0, [3, 4], [1, 2], 0, 0],  # 6
#                    # [0, 0, [2, 3], [1, 4], 0, 0],  # 7
#                    [[1, 2], [3, 4], [1, 2], [3, 4]],  # 1
#                    [[1, 4], [2, 3], [1, 4], [2, 3]],  # 2
#                    [0, [3, 4], [1, 2], 0],  # 3
#                    [0, [2, 3], [1, 4], 0],  # 4
#                    ]

LIST_STRUCTURES = [[0, 1, 2, 3, 4,0],  #0
                   [4, 1, 2, 3, 4, 1],  #1
                   [[1, 2], [3,4], [1, 2], [3,4], [1, 2], [3,4]],  #2
                   [[1, 4], [2, 3], [1, 4], [2, 3], [1, 4], [2, 3]],  #3
                   [0 ,[1, 2], [3, 4], [1, 2], [3, 4], 0],  # 4
                   [0, [1, 4], [2, 3], [1, 4], [2, 3], 0],  # 5
                   [0, 0, [3, 4], [1, 2], 0, 0],  # 6
                   [0, 0, [2, 3], [1, 4], 0, 0],  # 7
                   ]


# START_STRUCTURE =   np.array([0, 0, 0,     1, 1, 1])
# END_STRUCTURE =     np.array([1, 2, 3,     0, 2, 3])
# TRANSITION =        np.array([1, 6, 8,     5, 6, 7])

#CURRENT_SRUCTURE_ID = None
#
ALL_POSSIBILITIES = [[[0, 1, 2, 3, 4, 0], [0, 0, 1, 2, 3, 4],[1, 2, 3, 4, 0, 0],[0, 4, 3, 2, 1, 0], [0, 0, 4, 3, 2, 1], [4, 3, 2, 1, 0, 0]], #0

                     [[4, 1, 2, 3, 4, 1], [1, 4, 3, 2, 1, 4], [1, 2, 3, 4, 1, 2], [2, 1, 4, 3, 2, 1], [2, 3, 4, 1, 2, 3], [3, 2, 1, 4, 3, 2], [3, 4, 1, 2, 3, 4], [4, 3, 2, 1, 4, 3]],  #1

                     [[[1, 2], [3,4], [1, 2], [3,4], [1, 2], [3,4]], [[3, 4], [1, 2], [3, 4], [1, 2], [3, 4], [1, 2]]],  #2

                     [[[1, 4], [2, 3], [1, 4], [2, 3], [1, 4], [2, 3]], [[2, 3], [1, 4], [2, 3], [1, 4], [2, 3], [1, 4]]    ], #3

                     [  [0 ,[1, 2], [3, 4], [1, 2], [3, 4], 0], [[1, 2], [3, 4], [1, 2], [3, 4], 0, 0], [0, [1, 2], [3, 4], [1, 2], [3, 4], 0], [0, 0, [1, 2], [3, 4], [1, 2], [3, 4]],
                      [[1, 2], [3,4], [3,4], [1, 2], 0, 0], [0, [1, 2], [3,4], [3,4], [1, 2], 0],
                      [0, 0, [1, 2], [3,4], [3,4], [1, 2]],  [[3,4], [1, 2], [1, 2], [3,4], 0, 0], [0, [3,4], [1, 2], [1, 2], [3,4], 0],
                      [0, 0, [3,4], [1, 2], [1, 2], [3,4]], [[3, 4], [1, 2], [3, 4], [1, 2], 0, 0], [0, [3, 4], [1, 2], [3, 4], [1, 2], 0], [0, 0, [3, 4], [1, 2], [3, 4], [1, 2]]], #4

                     [ [0, [1, 4], [2, 3], [1, 4], [2, 3], 0],  [[1, 4], [2, 3], [1, 4], [2, 3], 0, 0], [0, [1, 4], [2, 3], [1, 4], [2, 3], 0],
                      [0, 0, [1, 4], [2, 3], [1, 4], [2, 3]],  [[1, 4], [2, 3], [2, 3], [1, 4], 0, 0], [0, [1, 4], [2, 3], [2, 3], [1, 4], 0],
                      [0, 0, [1, 4], [2, 3], [2, 3], [1, 4]], [[2, 3], [1, 4], [1, 4], [2, 3], 0, 0], [0, [2, 3], [1, 4], [1, 4], [2, 3], 0],
                      [0, 0, [2, 3], [1, 4], [1, 4], [2, 3]], [[2, 3], [1, 4], [2, 3], [1, 4], 0, 0], [0, [2, 3], [1, 4], [2, 3], [1, 4], 0],
                      [0, 0, [2, 3], [1, 4], [2, 3], [1, 4]], ], #5

                     [[0, 0, [3, 4], [1, 2], 0, 0], [0, 0, [1, 2], [3, 4], 0, 0], [[1, 2], [3,4], 0, 0, 0, 0],  [[3, 4], [1,2], 0, 0, 0, 0],
                      [ 0, 0, 0, [1, 2], [3,4], 0],[ 0, 0, 0, [3, 4], [1,2], 0], [0, 0, 0, 0, [1, 2], [3,4]], [0, 0, 0, 0, [3, 4], [1,2]]], #6

                    [ [0, 0, [2, 3], [1, 4], 0, 0],   [0, [2, 3], [1, 4], 0, 0, 0],  [0, 0, 0, [2, 3], [1, 4], 0], [0, 0, 0, 0, [2, 3], [1, 4]],
                      [[1, 4], [2, 3], 0, 0, 0, 0], [0, [1, 4], [2, 3], 0, 0, 0], [0, 0, [1, 4], [2, 3], 0, 0], [0, 0, 0, [1, 4], [2, 3], 0],
                      [0, 0, 0, 0, [1, 4], [2, 3]]]  #7

                     ]


ALL_POSSIBILITIES = [[[1, 2, 3, 4], [2, 3, 4, 1], [3, 2, 1, 4], [4, 3, 2, 1]],  #0
                     [ [[1, 2], [3, 4], [1, 2], [3, 4]],  [[3,4], [1, 2], [1, 2], [3,4]],[[3,4], [1, 2], [3,4],[1, 2]], [ [1, 2], [3, 4], [3, 4],[1, 2]]],  #1
                     [ [[1, 4], [2, 3], [1, 4], [2, 3]],  [[2, 3], [1, 4], [1, 4],[2, 3]],[[2, 3], [1, 4], [2, 3],[1, 4]], [ [1, 4], [2, 3], [2, 3],[1, 4]]],  #2

                     [[0, [3, 4], [1, 2], 0], [0, [1, 2], [3, 4], 0], [[1, 2], [3,4], 0, 0],  [[3, 4], [1,2], 0, 0],
                      [ 0, 0, [1, 2], [3,4]],[ 0, 0, [3, 4], [1,2]]],  #6

                     [[0, [1, 4], [2, 3], 0], [0, [2, 3], [1, 4], 0], [[1, 4], [2, 3], 0, 0], [[2, 3], [1, 4], 0, 0],
                      [0, 0, [2, 3], [1, 4]], [0, 0, [1, 4], [2, 3]]],  # 6
                     ]

LAST_TIME_STATE = None
CURRENT_STRUCTURE = None

def getCurrentStructure():
    return CURRENT_STRUCTURE, [True, True, True, True, True, True]

def findLayersHeading(ALTS):
    layers_heading = [[] for i in range(NUMBER_LAYERS)]
    hdgs_now, alts_now, vs_now = getCurrentHeadings()
    mid_hdg_ranges = [86.8, 176.8, 266.8, 356.8]

    if len(hdgs_now) > 0:
        for it in range(0, len(hdgs_now)):
            # aircraft might be climbing/descending - take that into consideration
            iactwp = bs.traf.ap.route[it].iactwp
            while iactwp < len(bs.traf.ap.route[it].wpalt) and bs.traf.ap.route[it].wpalt[iactwp]/ft not in ALTS:
                iactwp += 1
            if iactwp >= len(bs.traf.ap.route[it].wpalt): # this aircraft is about to end its route
                continue
            best_alt_altitude = np.where(ALTS == bs.traf.ap.route[it].wpalt[iactwp]/ft)[0][0]
            # find heading
            if iactwp > 1 and bs.traf.ap.route[it].wpalt[iactwp] == bs.traf.ap.route[it].wpalt[iactwp-1]:
                hdg_layer, dummy = geo.kwikqdrdist(bs.traf.ap.route[it].wplat[iactwp-1],  bs.traf.ap.route[it].wplon[iactwp-1],
                                             bs.traf.ap.route[it].wplat[iactwp], bs.traf.ap.route[it].wplon[iactwp])
            else:
                hdg_layer, dummy = geo.kwikqdrdist(bs.traf.ap.route[it].wplat[iactwp],  bs.traf.ap.route[it].wplon[iactwp],
                                             bs.traf.ap.route[it].wplat[iactwp+1], bs.traf.ap.route[it].wplon[iactwp+1])
            hdg_range = min(range(len(mid_hdg_ranges)), key=lambda i: abs(mid_hdg_ranges[i] - hdg_layer))
            if hdg_range + 1 not in layers_heading[best_alt_altitude]:
                layers_heading[best_alt_altitude].append(hdg_range + 1)
                #print(it, 'layers', best_alt_altitude, 'alt', alts_now[it]/ft, 'hdg',hdgs_now[it], hdg_range + 1, vs_now[it], ALTS[best_alt_altitude])
        layers_heading.reverse()
    else:
        layers_heading = np.zeros(NUMBER_LAYERS)

    return layers_heading

def checkNoLayersPreset(layers_hdq):
    for it in range(0, len(layers_hdq)):
        if (type(layers_hdq[it]) not in [int, np.int32, np.float64] and len(layers_hdq[it])>0) or \
                (type(layers_hdq[it]) in [int, np.int32, np.float64] and layers_hdq[it] > 0):
            return False

    return True

def setNewStructure(structure):
    global CURRENT_STRUCTURE # global declaration
    CURRENT_STRUCTURE = structure

# def setNewStructure(new_layer_structure, first_structure):
#     global CURRENT_STRUCTURE # global declaration
#
#     #is this the first structure to be set?batc
#     if first_structure or CURRENT_STRUCTURE is None:# checkNoLayersPreset(layers_hdq):
#         CURRENT_STRUCTURE = LIST_STRUCTURES[new_layer_structure]
#         return CURRENT_STRUCTURE
#
#     # find all possible combinations
#     print('previous structure', CURRENT_STRUCTURE)
#     CURRENT_STRUCTURE, number_ver_trans = findBestTransition(CURRENT_STRUCTURE, new_layer_structure)
#     print('new structure', CURRENT_STRUCTURE, 'number transitions', number_ver_trans)
#     return CURRENT_STRUCTURE

def checkActiveLayers(structure):
    active_layers = np.ones(6, dtype=bool)

    for it in range(0, len(structure)):
        if type(structure[it]) in [int, np.int32, np.int64] and structure[it] == 0:
            active_layers[it] = False

    return active_layers

def findBestTransition(layers_hdq, index_end_str):
    all_possibilities = ALL_POSSIBILITIES[index_end_str]
    number_alt_changes = []
    final_transitions = []

    for possibility in all_possibilities:
        final_structure = copy.deepcopy(possibility)
        # are there non enabled traffic layers that we can use to create more possibilities
        for sub_possibility in getSubPossibilities(final_structure):
            final_transitions.append(sub_possibility)
            number_trans = calculateAltTransitions(layers_hdq, sub_possibility)
            number_alt_changes.append(number_trans)
            if number_trans == 0: # no need to continue
                break

    return final_transitions[np.argmin(number_alt_changes)], np.min(number_alt_changes)

def getSubPossibilities(main_structure):
    sub_possibilities = []
    all_pos_transition_layer = [1, 2, 3, 4, [1,2], [2,3], [1,4], [3,4]]

    empty_layers_found = False
    for layer in main_structure:
        if type(layer) in [int, np.int32, np.int64]:
            if layer == 0:
                empty_layers_found = True
                break
        elif 0 in layer:
            empty_layers_found = True
            break

    if not empty_layers_found:
        return [main_structure]

    temporary_layers = np.where(np.array(main_structure) == 0)[0]
    options = np.zeros(len(temporary_layers))
    while True:
        aux = copy.deepcopy(main_structure)
        for it_layer in range(0, len(temporary_layers)):
            aux[temporary_layers[it_layer]] = all_pos_transition_layer[int(options[it_layer])]
        sub_possibilities.append(aux)

        # next option
        it2 = len(options)-1
        found_next_option = False
        while it2 >= 0:
            if options[it2] < len(all_pos_transition_layer) - 1:
                options[it2] += 1
                found_next_option = True
                break
            else:
                options[it2:] = 0
                it2 -= 1

        if not found_next_option:
            break

    return sub_possibilities

def calculateAltTransitions(start_st, final_structure):
    number_alt_transitions = 0
    #find which layers have each heading range
    hdgs = [[],[],[],[]]

    # fins out in which layers you can travel with each heading
    for layer_future in range(0, len(final_structure)):
        for hdg in [1, 2, 3 ,4]:
            if checkHdgsMatch(final_structure[layer_future], hdg):
                hdgs[hdg-1].append(layer_future)

    # calculate alt transactions based on to which layer aircraft must go to continue with their heading
    for layer_start_it in range(0, len(start_st)):
        if type(start_st[layer_start_it]) in [int, np.int32, np.int64]:
            if start_st[layer_start_it] > 0:
                hdg = np.array(hdgs[start_st[layer_start_it]-1])
                number_alt_transitions += min(abs(hdg - layer_start_it))
        else:
            for hqd_layer in start_st[layer_start_it]:
                number_alt_transitions += min(abs(np.array(hdgs[hqd_layer-1]) - layer_start_it))

    return number_alt_transitions

def getNumberHeadings(layer_headings):
    if type(layer_headings)in [int, np.int32, np.int64]:
        return 1
    return len(layer_headings)

def checkHdgsMatch(start_layer, end_layer):
    len_start_layer = getNumberHeadings(start_layer)
    len_end_layer = getNumberHeadings(end_layer)

    if len_start_layer > len_end_layer:
        return end_layer in start_layer
    if len_end_layer > len_start_layer:
        return start_layer in end_layer

    # layers have the same heading length (they each have one or two headings)
    if len_end_layer == 1:
        return start_layer == end_layer
    else:
        return start_layer[0] in end_layer, start_layer[1] in end_layer

# def saveTransition(start, end, transition, active_layers):
#     global START_STRUCTURE
#     global END_STRUCTURE
#     global TRANSITION
#
#     START_STRUCTURE = np.append(START_STRUCTURE, start)
#     END_STRUCTURE = np.append(END_STRUCTURE, end)
#     index_transition = getStructureId(transition)
#     ACTIVE_LAYERS.append(active_layers)
#     TRANSITION = np.append(TRANSITION, index_transition)

def getStructureId(structure):
    if structure in LIST_STRUCTURES:
        return LIST_STRUCTURES.index(structure)
    else: # if structure is not saved yet, then add it
        LIST_STRUCTURES.append(structure)
        return len(LIST_STRUCTURES)-1

# def getTrajectoryInformation(times, hdgs, look_ahead_time):
#     possible_directions = [86.8, 176.8, 266.8, 356.8]
#     state = np.zeros(4)
#
#     global LAST_TIME_STATE
#     now = LAST_TIME_STATE
#     if LAST_TIME_STATE is None:
#         now = 0
#
#     for it in range(0, len(times)): # per aircraft
#         already_added = np.zeros(4)
#         for it2 in range(0, len(times[it])):
#             timestamp = times[it][it2]
#             if now <= timestamp <= now + look_ahead_time:
#                 index = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - hdgs[it][it2]))
#                 if timestamp < now + 300 and already_added[index] == 0:
#                     state[index] += 1
#                     already_added[index] += 1
#                 elif timestamp < now + 600 and already_added[4 +index] == 0:
#                     state[4 +index] += 1
#                     already_added[4+index] += 1
#                 elif timestamp < now + 900 and already_added[8 + index] == 0:
#                     state[8 + index] += 1
#                     already_added[8 + index] += 1
#                 elif timestamp < now + 1200 and already_added[12 + index] == 0:
#                     state[12 + index] += 1
#                     already_added[12 + index] += 1
#                 elif timestamp < now + 1500 and already_added[16 + index] == 0:
#                     state[16 + index] += 1
#                     already_added[16 + index] += 1
#                 elif timestamp < now + 1800 and already_added[20 + index] == 0:
#                     state[20 + index] += 1
#                     already_added[20 + index] += 1
#
#     for aircraft_idx in range(bs.traf.ntraf):
#         index = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - bs.traf.hdg[aircraft_idx]))
#         state[index] += 1
#
#     LAST_TIME_STATE = now + look_ahead_time
#     return state

# def getTrajectoryInformation(times, hdgs, look_ahead_time, STATE_DIM, reward):
#     possible_directions = [86.8, 176.8, 266.8, 356.8]
#     state = np.zeros(STATE_DIM)
#
#     now = bs.sim.simt
#
#     for it in range(0, len(times)):
#         already_added = np.zeros(STATE_DIM)
#         for it2 in range(0, len(times[it])):
#             timestamp = times[it][it2]
#             if now <= timestamp <= now + look_ahead_time:
#                 index = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - hdgs[it][it2]))
#                 if index != index_previous:
#                     state[index] += 1
#                 index = index_previous
#
#     if reward:
#         state = np.zeros(STATE_DIM)
#
#     for aircraft_idx in range(bs.traf.ntraf):
#         index = min(range(len(possible_directions)),
#                     key=lambda i: abs(possible_directions[i] - bs.traf.hdg[aircraft_idx]))
#         state[index] += 1

def getTrajectoryInformation(times, hdgs, look_ahead_time, STATE_DIM, reward):
    possible_directions = [86.8, 176.8, 266.8, 356.8]
    state = np.zeros(STATE_DIM)

    time_step = 5*60 # 5 minutes

    for it in range(0, len(times)):
        if times[it][0] < 5*60:
            time_index = 0
        elif times[it][0] < 10*60:
            time_index = 1
        elif times[it][0] < 15*60:
            time_index = 2
        elif times[it][0] < 20 * 60:
            time_index = 3
        else:
            break

        for it2 in range(0, len(times[it])): # go to aircraft
            timestamp = times[it][it2]
            if time_index < 4 and (it2 == 0 or (it2 > 0 and timestamp > time_step*time_index and timestamp < time_step*time_index)):
                direction = min(range(len(possible_directions)), key=lambda i: abs(possible_directions[i] - hdgs[it][it2]))
                state[4 * time_index + direction] += 1
                time_index += 1

    return state

def getCurrentHeadings():
    hdgs_now = np.array([])
    alts_now = np.array([])
    vs_now = np.array([])
    for ac in range(0, bs.traf.ntraf):
        hdgs_now = np.append(hdgs_now, bs.traf.hdg[ac])
        alts_now = np.append(alts_now, bs.traf.alt[ac])
        vs_now = np.append(vs_now, bs.traf.vs[ac])

    return hdgs_now, alts_now, vs_now
