# -*- coding: utf-8 -*-
# Created by mjribeiro at 9-10-2020
from bluesky.tools.simtime import timed_function
import bluesky as bs
import numpy as np
import bluesky.tools.cityboundaries as cb
from bluesky import stack
from bluesky.tools.aero import ft

ALTS = [65, 105, 145, 185, 225, 265, 305, 345, 385, 425, 465]
SPEEDS = [10, 15, 20, 25, 30]  # kts
DEFAULT_MAX_SPEED = 30  # kts
STATE_DIM = 4
ACTION_DURATION = 60  # seconds
PENALTY_LOS = 10
PENALTY_NEAR_LOS = 4
PENALTY_CONFLICT = 2


# STATE
# 0 - Aircraft entering the Entrance/Exit Section
# 1 - Aircraft leaving the Entrance/Exit Section
# 2 - Traffic density which will go from the Detection Area and Controlled Section into the Entrance/Exit Section
# 3 - Maximum speed in the Detection Section
# 4 - conflicts in the Detection Section (try later)

def init_plugin():
    # Addtional initilisation code
    global VSL
    VSL = VSL()

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'VSL',

        # The type of this plugin. For now, only simulation plugins are possible.
        'plugin_type': 'sim',

        # The update function is called after traffic is updated.
        'update': VSL.update,

        # The reset function
        'reset': VSL.reset
    }
    stackfunctions = {}

    # init_plugin() should always return these two dicts.
    return config, stackfunctions


class VSL():

    def __init__(self):
        super(VSL, self).__init__()

    def reset(self):
        self.current_states = dict()
        self.max_speeds = dict()
        self.prevconfpairs = 0
        self.prevlospairs = 0
        self.timestamps = dict()
        self.actions_DDPG = [[] for i in range(3)]
        if len(stack.get_scenname()) > 0:
            bs.sim.RLsave(stack.get_scenname())
        # clear the values from the previous episode
        bs.traf.ap.clearMaxSpeeds()
        bs.traf.ap.clearCrossings()

    @timed_function('VSL', dt=10.0)
    def update(self, dt):
        #print(bs.sim.simt)
        if bs.sim.isRLcreated():
            # get the active waypoints for all aircraft
            next_wpt = np.array([])
            for route in bs.traf.ap.route:
                next_wpt = np.append(next_wpt, route.iactwp)

            next_wpt = next_wpt.astype(int)
            entranceexit_sections = self.get_entrance_exit_sections(bs.traf.ap.route, next_wpt)
            entranceexit_sections = self.join_entrance_exit_sections(entranceexit_sections)
            # fill the traffic density in the state
            entranceexit_sections = self.find_traffic_density(entranceexit_sections, bs.traf.ap.route, next_wpt)
            # check max speed and add it to the state
            entranceexit_sections = self.checkmaxspeeds(entranceexit_sections)
            # new conflicts - test with this later
            # self.checkNewConflicts(current_alt, next_alt, next_wpt, bs.traf.ap.route)
            if len(self.timestamps) > 0:
                self.check_actions_finished(bs.traf.ap.route, next_wpt)

            # compute new max speed for the entrance/exit actions
            # add only new sections
            new_sections = {k: entranceexit_sections[k] for k in set(entranceexit_sections) - set(self.current_states)}
            if len(new_sections) > 0:
                # extra check: it may be that a section with a max-speed just got enlarged
                # in this case, we keep the same speed that it already has
                new_sections = self.check_smaller_section_exists(new_sections)
                new_sections = self.check_bigger_section_exists(new_sections)
                new_sections = self.removeoneacstates(new_sections)

                bs.sim.RLaction(bs.stack.get_scenname(), new_sections)
            self.current_states.update(new_sections)
        else:
            bs.sim.RLcreate(bs.stack.get_scenname())

    # remove section that only have one aircraft, there's no point
    def removeoneacstates(self, new_sections):
        remove_sections = []
        for key, states in new_sections.items():
            if sum(states[0:3]) < 2:
                remove_sections.append(key)

        for remove_section in remove_sections:
            if remove_section in new_sections:
                del new_sections[remove_section]

        return new_sections

    # remove section if it exits
    def remove_section(self, shorter_section):
        index = None
        keys = list(self.current_states)
        for it in range(len(keys)):
            if keys[it] == shorter_section:
                index = it
                break
        if index is not None:
            # remove the previous section
            del self.current_states[shorter_section]
            if shorter_section in self.timestamps:
                del self.timestamps[shorter_section]

    # remove the previous version and use the new one
    def check_smaller_section_exists(self, new_sections):
        for section in new_sections.keys():
            if len(section) > 4:  # meaning its not normal size, more nodes were added to this section
                for it in range(4, len(section)):
                    extra = 0
                    while it + extra <= len(section):
                        shorter_section = tuple([section[0]] + list(section[1 + extra:it + extra]))  # dfault size
                        self.remove_section(shorter_section)
                        extra += 1
        return new_sections

    # keep the bigger section and do not consider this one
    def check_bigger_section_exists(self, new_sections):
        remove_sections = []
        for section in new_sections.keys():
            keys = list(self.current_states)
            for it in range(len(keys)):
                size_section = len(keys[it])
                if size_section > len(section):
                    if all(x in keys[it] for x in section):
                        remove_sections.append(section)

        for remove_section in remove_sections:
            if remove_section in new_sections:
                del new_sections[remove_section]

        return new_sections

    def check_actions_finished(self, routes, next_wpt):
        actions = [[] for i in range(3)]
        for key, timestamp in self.timestamps.items():
            # has it passed 60 seconds?
            if bs.sim.simt - timestamp > ACTION_DURATION:
                index = self.actions_DDPG[0].index(key)
                for it2 in range(0, 3):
                    actions[it2].append(self.actions_DDPG[it2][index])
                    del self.actions_DDPG[it2][index]

        # delete from timestamps
        for action in actions[0]:
            del self.timestamps[action]

        if len(actions[0]) > 0:
            # do reward to the DDPG model
            self.checkrewards(actions)

    def checkrewards(self, actions):
        rewards = np.zeros(len(actions[0]))
        # reward based on the number of LoSs
        rewards = self.getLoss(actions[0], rewards)
        # reward based on traffic inflow vs outflow
        rewards = self.getTrafficOutflow(actions[0], actions[1], rewards)
        # reward based on the near misses (almost LoS)
        rewards = self.getNearMisses(actions[0], rewards)
        # reward based on allowing for a 'fast' speed
        rewards = self.getSpeedBonus(actions[2], rewards)
        # reset
        # bs.traf.ap.removeCrossings(actions[0])
        bs.traf.ap.removeMaxSpeeds(actions[0])

        next_states = []
        # create next states
        for key in actions[0]:
            new_state = self.current_states.get(tuple(key))
            if new_state is None:
                next_states.append((0, 0, 0, 0))
            else:
                next_states.append(new_state)
                # print('done', key)
                del self.current_states[tuple(key)]

        # scename, states, nextstates, rewards, actions
        # print(list(rewards))
        bs.sim.RLresult(bs.stack.get_scenname(), actions[1], next_states, list(rewards), actions[2])

    def getSpeedBonus(self, actions, rewards):
        for it in range(len(actions)):
            rewards[it] += np.argmax(actions[it])
        return rewards

    def getNearMisses(self, keys, rewards):
        confpairs_new = len(bs.traf.cd.confpairs_all) - self.prevconfpairs

        if confpairs_new > 0:
            nodes0 = bs.traf.cd.confnodes0[-confpairs_new:]
            nodes1 = bs.traf.cd.confnodes1[-confpairs_new:]
            alts = bs.traf.cd.confalt[-confpairs_new * 2:]
            timeLos = bs.traf.cd.conftimeLoS[-confpairs_new:]
            confpairs_new = bs.traf.cd.confpairs_all[-confpairs_new:]
            it_nodes = 0
            it_alts = 0
            for pair in [tuple(x) for x in confpairs_new]:
                alt0 = self.findclosestfrontiervalue(ALTS, alts[it_alts] / ft)
                alt1 = self.findclosestfrontiervalue(ALTS, alts[it_alts + 1] / ft)
                if (alt0 != alt1 or nodes0[it_nodes] != nodes1[it_nodes]) and timeLos[it_nodes] > 0:
                    rewards = self.addTlos(alt1, nodes1[it_nodes], rewards, keys, timeLos[it_nodes])
                it_nodes += 1
                it_alts += 2

        self.prevconfpairs = len(bs.traf.cd.confpairs_all)
        return rewards

    def addTlos(self, alt, node, rewards, keys, timeLos):
        for it in range(len(keys)):
            if keys[it][0] == alt and node in keys[it][2:]:
                if timeLos < 10:
                    #print('near Los')
                    rewards[it] -=PENALTY_NEAR_LOS
                else:
                    #print('conflict')
                    rewards[it] -= PENALTY_CONFLICT
        return rewards

    def getLoss(self, keys, rewards):
        lospairs_new = len(bs.traf.cd.lospairs_all) - self.prevlospairs
        if lospairs_new > 0:
            nodes0 = bs.traf.cd.intrusionsnodes0[-lospairs_new:]
            nodes1 = bs.traf.cd.intrusionsnodes1[-lospairs_new:]
            alts = bs.traf.cd.intrusionsalt[-lospairs_new * 2:]
            lospairs_new = bs.traf.cd.lospairs_all[-lospairs_new:]
            it_nodes = 0
            it_alts = 0
            for pair in [tuple(x) for x in lospairs_new]:
                alt0 = self.findclosestfrontiervalue(ALTS, alts[it_alts] / ft)
                alt1 = self.findclosestfrontiervalue(ALTS, alts[it_alts + 1] / ft)
                if it_nodes < len(nodes0):
                    rewards = self.addlos(alt0, nodes0[it_nodes], rewards, keys)
                    # do not count the same LoS twice
                    if it_nodes < len(nodes1):
                        if alt0 != alt1 or nodes0[it_nodes] != nodes1[it_nodes]:
                            rewards = self.addlos(alt1, nodes1[it_nodes], rewards, keys)
                it_nodes += 1
                it_alts += 2

        self.prevlospairs = len(bs.traf.cd.lospairs_all)
        return rewards

    # see if the LoSs was within one of the section that we are considering
    def addlos(self, alt, node, rewards, keys):
        for it in range(len(keys)):
            if keys[it][0] == alt and node in keys[it][2:]:
                #print('los')
                rewards[it] -= PENALTY_LOS
        return rewards

    def setDDPGaction(self, actions):
        # print('got actions', actions)
        for it in range(len(actions[0])):
            key = actions[0][it]
            # maybe we've just cancelled this section for a bigger on
            if tuple(key) in self.current_states:
                for it2 in range(3):
                    self.actions_DDPG[it2].append(tuple(actions[it2][it]))

                self.timestamps[tuple(key)] = bs.sim.simt
                action = np.argmax(actions[2][it])
                for node in key[2:]:
                    self.max_speeds[(key[0], node)] = SPEEDS[action]
                alt = key[0]
                end_ee_section = key[-1]
                bs.traf.ap.countCrossings(alt, end_ee_section)
            # else:
            #     print(key, 'was cancelled')

        bs.traf.ap.setMaximumSpeedNode(self.max_speeds)

    def findcrossings(self, alt, last_node, crossing_alts, crossing_nodes, crossing_total):
        index_node = np.logical_and(crossing_alts == alt, crossing_nodes == last_node)
        if index_node.any():
            index_node = np.where(index_node == True)[0]
            crossings = np.array(crossing_total[index_node.item(0)])
            crossings = crossings[crossings > bs.sim.simt - 60]
            return len(crossings)
        return 0

    # outflow-inflow
    # inflow: a/c passing through the first node of the controlled sectionba
    # outflow: a/c passing through the last node of the entrance/exit section
    def getTrafficOutflow(self, keys, states, rewards):
        crossing_alts, crossing_nodes, crossing_total = bs.traf.ap.getCrossings()
        if len(crossing_total) > 0:
            for it in range(len(keys)):
                alt = keys[it][0]
                # aircraft that leave the area (either by changing altitude or by crossing the last node)
                # outflow = self.findcrossings(alt, keys[it][-1], crossing_alts, crossing_nodes, crossing_total) + \
                #           states[it][1]
                # aircraft already in the path plus the ones coming up
                # inflow = states[it][0] + states[it][1] + states[it][2]
                expected_outflow = states[it][0] + states[it][2]
                final_outflow = self.findcrossings(alt, keys[it][-1], crossing_alts, crossing_nodes, crossing_total)
                # it does not make sense for this value to be bigger than zero
                # it will only happen when an aircraft was created after the states for the next 60 seconds was
                # calculated
                # we don't want this edge case to have a positive influence on the reward
                rewards[it] -= expected_outflow - final_outflow
                #print('reward', rewards[it], expected_outflow, final_outflow)
        return rewards

    def addToState(self, dict, key, entering, leaving, staying, max_speed=None):
        state = dict.get(key)
        if state is not None:
            if max_speed is None:
                max_speed = state[3]
            new_state = (state[0] + entering, state[1] + leaving, state[2] + staying, max_speed)
        else:
            if max_speed is None:
                max_speed = 0
            # the last two indices will be filled up later
            new_state = (entering, leaving, staying, max_speed)

        dict[key] = new_state
        return dict

    def getOppositeStep(self, step):
        return (step + 2) % 4

    def find_next_and_previous_node(self, start_node, mid_node):
        edge_it = -1
        next_node = None
        previous_node = None

        edges_start_node = cb.EDGES_PER_NODE[start_node]
        for it in range(len(edges_start_node)):
            if edges_start_node[it] is not None:
                if edges_start_node[it][0] == mid_node:
                    edge_it = it

        if edge_it >= 0:
            next_node = cb.EDGES_PER_NODE[mid_node][edge_it]
            if next_node is not None:
                next_node = next_node[0]
            opposite_edge_it = self.getOppositeStep(edge_it)
            previous_node = cb.EDGES_PER_NODE[start_node][opposite_edge_it]
            if previous_node is not None:
                previous_node = previous_node[0]

        return next_node, previous_node

    def join_entrance_exit_sections(self, pre_sections):
        sections_keys = list(pre_sections.keys())
        new_sections = dict()
        previous_section = None
        it = 0
        while it < len(sections_keys):
            if previous_section == sections_keys[it]:
                it += 1
                if it >= len(sections_keys):
                    break
            previous_section = sections_keys[it]
            altitude = sections_keys[it][0]
            # start_c_section = sections_keys[it][1]
            start_ee_section = sections_keys[it][2]
            end_ee_section = sections_keys[it][3]
            checking_for_next_node = True
            while checking_for_next_node:
                after_end_node, dummy = self.find_next_and_previous_node(start_ee_section, end_ee_section)
                # do we have another section just right after this one? we need at least one edge separation
                adjacent_section = (altitude, start_ee_section, end_ee_section, after_end_node)
                if after_end_node is not None and adjacent_section in sections_keys:
                    state0 = pre_sections[sections_keys[it]]
                    state1 = pre_sections[adjacent_section]
                    aux = list(sections_keys[it])
                    # add the new ending
                    aux.append(int(after_end_node))
                    new_sections[tuple(aux)] = tuple([i + y for i, y in zip(state0, state1)])
                    # print(sections_keys[it], adjacent_section, 'joined')
                    sections_keys.remove(sections_keys[it])
                    sections_keys.remove(adjacent_section)
                    # continue looking for possible next sections
                    end_ee_section = after_end_node
                else:
                    checking_for_next_node = False

        for section in sections_keys:
            new_sections[section] = pre_sections[section]

        return new_sections

    def findclosestfrontiervalue(self, frontier_values, value):
        return min(range(len(frontier_values)), key=lambda i: abs(frontier_values[i] - value))

    def getPathInNextTimeStamp(self, route, next_wpt):
        time = 0
        while time < ACTION_DURATION and next_wpt < len(route.wptnodes) - 1:
            next_node = route.wptnodes[next_wpt + 1]
            connecting_nodes = cb.EDGES_PER_NODE[route.wptnodes[next_wpt]]
            for connecting_node in connecting_nodes:
                if connecting_node is not None and connecting_node[0] == next_node:
                    speed = route.wpspd[next_wpt + 1]
                    time += connecting_node[2] / speed
                    break
            next_wpt += 1

        return next_wpt - 1

    def get_entrance_exit_sections(self, routes, next_wpt):
        sections = dict()
        for it in range(len(routes)):
            max_wpt = self.getPathInNextTimeStamp(routes[it], next_wpt[it])
            # print('sections', next_wpt[it], max_wpt)
            for it2 in range(next_wpt[it], max_wpt):
                alt0 = self.findclosestfrontiervalue(ALTS, routes[it].wpalt[it2] / ft)
                alt1 = self.findclosestfrontiervalue(ALTS, routes[it].wpalt[it2 + 1] / ft)

                # altitude layer transition?
                if alt0 != alt1:
                    start_ee_section = routes[it].wptnodes[it2]
                    end_ee_section = routes[it].wptnodes[it2 + 1]
                    dummy, control_node = self.find_next_and_previous_node(start_ee_section, end_ee_section)
                    if control_node is not None:
                        control_node = int(control_node)

                    alt_layer = alt0
                    entering = False
                    if alt1 % 2 == 0:
                        alt_layer = alt1
                        entering = True

                    new_key = (int(alt_layer), control_node, int(start_ee_section), int(end_ee_section))
                    # entering
                    if entering:
                        # print('aircraft', bs.traf.id[it], 'entering', new_key)
                        sections = self.addToState(sections, new_key, 1, 0, 0)
                    # leaving
                    else:
                    #    # print('aircraft', bs.traf.id[it], 'leaving', new_key)
                        sections = self.addToState(sections, new_key, 0, 1, 0)

        return sections

    def find_traffic_density(self, sections, routes, next_wpt):
        for it_routes in range(len(routes)):
            travel_nodes = routes[it_routes].wptnodes
            for key in list(sections.keys()):
                if all(x in travel_nodes for x in key[2:]):
                    # make sure they are in order
                    index_begin = (np.where(travel_nodes == key[2])[0]).item(0)
                    index_end = (np.where(travel_nodes == key[-1])[0]).item(0)
                    max_wpt = self.getPathInNextTimeStamp(routes[it_routes], next_wpt[it_routes])
                    if index_end - index_begin == len(key) - 3 and next_wpt[it_routes] <= index_end and index_begin < \
                            next_wpt[it_routes] + max_wpt:
                        alts_travel = routes[it_routes].wpalt[index_begin: index_end + 1]
                        # go from altitude to altitude layers
                        for it in range(len(alts_travel)):
                            alts_travel[it] = self.findclosestfrontiervalue(ALTS, alts_travel[it] / ft)

                        # are we at the correct altitude?
                        if all(alt == key[0] for alt in alts_travel):
                            # print('aircraft', bs.traf.id[it_routes], 'staying', key)
                            sections = self.addToState(sections, key, 0, 0, 1)

        return sections

    def checkmaxspeeds(self, sections):
        for key, state in sections.items():
            # the first node of the entrance/exit section is also the last node of the controlled section, where
            # the max speed is set
            alt = key[0]
            for node in key[2:]:
                max_speed = self.max_speeds.get((alt, node))
                if max_speed is None:
                    max_speed = DEFAULT_MAX_SPEED
                sections = self.addToState(sections, key, 0, 0, 0, max_speed)
        return sections
