# -*- coding: utf-8 -*-
# Created by mjribeiro at 28-1-2021
from bluesky.tools.simtime import timed_function
import bluesky as bs
import numpy as np
from bluesky.tools.aero import ft
import bluesky.tools.structurefinder as sf
import os
from datetime import datetime
import bluesky.tools.cityboundaries as cb
from bluesky.tools import geo
from bluesky.tools import geo
from bluesky.tools.aero import nm

NUMBER_LAYERS = 6
ACTION_DIM = 24
STATE_DIM = 22

LAYER_SEP = 5  # 5 ft
ASAS_PZH = bs.settings.asas_pzh  # ft
ALT_MIN = 65  # ft

# ALTS = [65, 105, 145, 185, 225, 265, 305, 345, 385, 425, 465]
# ALTS_all = np.arange(ALT_MIN, ALT_MIN + (LAYER_SEP + ASAS_PZH)*17, LAYER_SEP + ASAS_PZH) # in ft
ALTS_all = np.arange(ALT_MIN, ALT_MIN + (LAYER_SEP + ASAS_PZH) * 17, LAYER_SEP + ASAS_PZH)  # in ft
ALTS_traffic = np.arange(ALT_MIN, 400, ASAS_PZH * 3 + LAYER_SEP * 3)

LIST_STRUCTURES =[[3,3,3,3,3,[1, 2, 4]], [4,2,1,1,1,3], [3,3,3,3,3,[1, 2, 4]], [3,3,3,3,3,[1, 2, 4]], [3,1,3,3,3,[2, 4]],
                  [3,1,3,3,1,[2, 4]], [1,3,2,3,3,4], [3,3,3,1,1,[2, 4]], [1,2,2,2,1,[3, 4]], [3,3,2,3,1,4], [1,1,2,3,4,1],
                  [4,4,2,4,4,[1, 3]], [3,4,4,3,3,[1, 2]], [2,4,4,2,4,[1, 3]], [1,1,3,3,4,2], [1,1,3,3,1,[2, 4]], [3,3,3,1,1,[2, 4]],
                  [3,1,3,3,3,[2, 4]], [4,3,1,3,3,2], [3,4,1,3,3,2], [1,2,4,1,4,3], [3,3,3,1,1,[2, 4]], [3,1,3,3,1,[2, 4]],
                  [3,3,3,1,1,[2, 4]], [3,3,3,3,1,[2, 4]], [3,3,3,3,3,[1, 2, 4]], [2,2,4,2,2,[1, 3]], [3,3,3,3,3,[1, 2, 4]],
                  [2,4,4,2,4,[1, 3]], [3,1,3,3,3,[2, 4]], [1,2,1,2,2,[3, 4]], [3,1,3,3,1,[2, 4]], [3,3,3,1,1,[2, 4]],
                  [3,3,3,3,3,[1, 2, 4]], [3,3,3,3,1,[2, 4]], [2,3,3,1,1,4], [3,2,4,2,4,1], [3,3,2,3,3,[1, 4]],
                  [2,2,4,2,4,[1, 3]], [2,1,3,1,1,4], [1,1,3,3,1,[2, 4]], [3,3,3,1,1,[2, 4]], [3,3,3,1,1,[2, 4]],
                  [3,3,3,1,1,[2, 4]], [1,1,3,3,1,[2, 4]], [1,1,3,3,1,[2, 4]], [2,2,4,2,4,[1, 3]], [3,1,3,1,1,[2, 4]],
                  [3,3,2,2,4,1], [1,1,3,3,1,[2, 4]], [3,3,3,3,3,[1, 2, 4]], [2,2,4,2,4,[1, 3]], [3,1,3,3,1,[2, 4]],
                  [2,2,4,2,4,[1, 3]], [3,1,3,3,1,[2, 4]], [3,3,3,3,3,[1, 2, 4]], [3,3,3,3,1,[2, 4]], [3,1,3,3,1,[2, 4]],
                  [3,3,3,1,1,[2, 4]], [3,3,3,3,1,[2, 4]], [3,3,3,1,3,[2, 4]], [2,2,4,2,4,[1, 3]], [3,3,3,1,1,[2, 4]],
                  [2,2,4,2,4,[1, 3]], [3,1,3,3,1,[2, 4]], [3,3,3,1,1,[2, 4]], [3,1,3,1,1,[2, 4]], [3,3,3,1,1,[2, 4]],
                  [3,3,3,1,1,[2, 4]], [3,3,3,1,1,[2, 4]], [3,3,3,3,1,[2, 4]], [3,3,3,1,1,[2, 4]], [1,3,1,3,3,[2, 4]],
                  [3,3,3,1,1,[2, 4]], [3,3,3,1,1,[2, 4]]]
PENALTY_LOS = 20
PENALTY_NEAR_LOS = 2  # 4
PENALTY_CONFLICT = 1  # 2

LOOK_AHEAD_TIME = 4000  # seconds - 20 minutes
STATIC_STRUCT = True
FINAL_TIME = 1200
# FINAL_TIME = 3600
TAS_AVG = 21  # kts - 10.83m/s

HDG_RANGES_AREAS = [45, 135, 225, 315]


def init_plugin():
    # Addtional initilisation codear
    global structure_RL
    structure_RL = structure_RL()

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'structure_RL',

        # The type of this plugin. For now, only simulation plug8ins are possible.
        'plugin_type': 'sim',

        # The update function is called after traffic is updated.
        'update': structure_RL.update,

        # The reset function
        'reset': structure_RL.reset
    }
    stackfunctions = {}

    # init_plugin() should always return these two dicts.
    return config, stackfunctions


class structure_RL():

    def __init__(self):
        super(structure_RL, self).__init__()

    def reset(self):
        # all layers start as inactive. state, action
        self.actions_DDPG = [[] for i in range(2)]
        # no structure initially set
        self.actions_DDPG[1] = [-1]
        self.structure_set = False
        self.last_structure_set_time = None
        self.aircraftRoutes_times = []
        self.aircraftRoutes_hdgs = []
        self.previous_read_points = None
        self.prevlospairs = 0
        self.prevconfpairs = 0
        self.expected_traffic = 0
        self.waiting_response = False
        self.history_structures = []
        self.FIRST_STRUCTURE = True

    @timed_function('structure_RL', dt=1.0)
    def update(self, dt):
        if self.previous_read_points != self.getPointFileName():
            self.readDeliveryPoints()
            # self.expected_traffic = int(self.previous_read_points.split('-')[-2][-2:])

        if not self.structure_set or \
                (not STATIC_STRUCT and not self.waiting_response and self.last_structure_set_time is not None \
                 and (bs.sim.simt - self.last_structure_set_time) >= LOOK_AHEAD_TIME):

            self.setDDPGaction(None)
            self.structure_set = True
            self.waiting_response = True
            return

            state = self.getCurrentState()
            state = state.tolist()
            #print('state', state)
            # state.append(int(np.argmax(self.actions_DDPG[1])))
            # if self.structure_set:
            #    self.setDDPGreward(state)

            if bs.sim.isRLcreated():
                print('request action DDPG')
                bs.sim.RLaction(bs.stack.get_scenname(), state)
            else:
                print('create DDPG')
                bs.sim.RLcreate(bs.stack.get_scenname(), state)

            # sf.setNewStructure(0, ALTS)
            # sf.setNewStructure(3, ALTS)
            # sf.setNewStructure(2, ALTS)
            self.structure_set = True
            self.waiting_response = True

        if bs.sim.simt == FINAL_TIME:
            return
            if self.structure_set:
                self.setDDPGreward(self.getFinalState().tolist())
            # bs.sim.DDPGsave(bs.stack.get_scenname())
            # print('end save')

    def getFinalState(self, reward=False):
        possible_directions = [86.8, 176.8, 266.8, 356.8]
        state = np.zeros(STATE_DIM)

        for aircraft in range(0, bs.traf.ntraf):
            direction = min(range(len(possible_directions)),
                            key=lambda i: abs(possible_directions[i] - bs.traf.hdg[aircraft]))
            state[direction + 16] += 1

        return state

    def getCurrentState(self, reward=False):
        # State:
        # Number of aircraft now with Heading 1
        # Number of aircraft now with Heading 2i
        # Number of aircraft now with Heading 3
        # Number of aircraft now with Heading 4
        # Number of aircraft future with Heading 1
        # Number of aircraft future with Heading 2
        # Number of aircraft future with Heading 3
        # Number of aircraft future with Heading 4
        # Previous structure (=0 if it is the first structure)
        if cb.FINISHED_READING_DATA:  # we can do anything yet, if we still do not have the information needed to buid the city
            state = sf.getTrajectoryInformation(self.aircraftRoutes_times, self.aircraftRoutes_hdgs, LOOK_AHEAD_TIME,
                                                STATE_DIM, reward)
            return state
        return None

    def getAircraftBetween(self, begin, end, hdgs):
        if begin < end:
            return len(np.where(np.logical_and(begin < hdgs, hdgs <= end))[0])
        else:
            return len(np.where(np.logical_or(hdgs <= end, hdgs > begin))[0])

    def setDDPGaction(self, action):
        # state (dim = 9) + actions(sim=4)

        # self.waiting_response = False
        # different_structure = False
        # if len(self.actions_DDPG[0])== 0 or np.argmax(action[STATE_DIM:]) != np.argmax(self.actions_DDPG[0]):
        #     different_structure = True
        # self.actions_DDPG[0] = action[:STATE_DIM]
        # self.actions_DDPG[1] = action[STATE_DIM:]

        # take into account the LOS and conflicts until now, we want to ount with the LOS/
        # conflict from on with the reward
        self.prevconfpairs = len(bs.traf.cd.confpairs_all)
        self.prevlospairs = len(bs.traf.cd.lospairs_all)

        # new_structure = np.argmax(action[STATE_DIM:])
        #new_structure = self.getStructureFromAction(action[STATE_DIM:])
        # new_structure = [4,1,2,3,4,1]
        scenarioName = bs.stack.get_scenname()
        repetition = scenarioName.split('Rep')[1]
        repetition = int(repetition.split('.')[0])
        #new_structure = LIST_STRUCTURES[repetition]
        new_structure = [4, 1, 2, 3, 4, 1]
        #print(bs.stack.get_scenname(), 'new structure', new_structure, 'at', bs.sim.simt)
        if True:
            new_layer_structure = sf.setNewStructure(new_structure)  # , self.FIRST_STRUCTURE)
            self.history_structures.append([bs.sim.simt, self.actions_DDPG[0], new_structure, new_layer_structure])
        else:
            self.history_structures.append([bs.sim.simt, self.actions_DDPG[0], new_structure, "SAME"])

        self.last_structure_set_time = bs.sim.simt
        self.FIRST_STRUCTURE = False

    def getStructureFromAction(self, action):
        heading_applied = set()
        structure = [None] * NUMBER_LAYERS
        action = np.array(action)

        for it in range(0, NUMBER_LAYERS):
            structure[it] = int(np.argmax(action[it * 4: it * 4 + 4]) + 1)
            heading_applied.add(structure[it])

        if len(heading_applied) < 4:

            heading_applied = set()
            for it in range(0, NUMBER_LAYERS - 1):
                heading_applied.add(structure[it])

            add_hdgs = []
            for hdg in [1, 2, 3, 4]:
                if hdg not in heading_applied:
                    add_hdgs.append(hdg)
            if len(add_hdgs) == 1:
                structure[NUMBER_LAYERS - 1] = add_hdgs[0]
            else:
                structure[NUMBER_LAYERS - 1] = add_hdgs

        return structure

    # def getStructureFromAction(self, action):
    #     structure = [None] * NUMBER_LAYERS
    #     heading_applied = set()
    #
    #     current_hdg = 0
    #     sum_so_far = 0
    #
    #     for it in range(NUMBER_LAYERS):
    #
    #         if it == NUMBER_LAYERS - 1 and len(heading_applied) < 4:
    #             add_hdgs = []
    #             for hdg in [1,2,3,4]:
    #                 if hdg not in heading_applied:
    #                     add_hdgs.append(hdg)
    #             if len(add_hdgs) == 1:
    #                 structure[it] = add_hdgs[0]
    #             else:
    #                 structure[it] = add_hdgs
    #             continue
    #
    #         portion_angle = action[it]
    #         portion_angle = max(portion_angled, 0)
    #         sum_so_far += portion_angle
    #         hdgs_this_layer = []
    #
    #         if sum_so_far > 0.75:
    #             final_hdg = 4
    #         elif sum_so_far > 0.5:
    #             final_hdg = 3
    #         elif sum_so_far > 0.25:
    #             final_hdg = 2
    #         else:
    #             final_hdg = 1
    #
    #         for hdg in range(current_hdg, final_hdg):
    #             current_hdg += 1
    #             heading_applied.add(current_hdg)
    #             hdgs_this_layer.append(current_hdg)
    #
    #         if len(hdgs_this_layer) == 0:
    #             hdgs_this_layer.append(current_hdg)
    #
    #         if len(hdgs_this_layer) == 1:
    #             structure[it] = hdgs_this_layer[0]
    #         else:
    #             structure[it] = hdgs_this_layer
    #
    #     print(action, structure)
    #     return structure

    # def getStructureFromAction(self, action):
    #     structure = []
    #     heading_applied = set()
    #
    #     action = np.array(action)*6
    #
    #     for it in np.arange(0, ACTION_DIM, 4):
    #         L = np.argsort(action[it:it+4])
    #         if it == (ACTION_DIM-4) and len(heading_applied) < 4:
    #             add_hdgs = []
    #             for hdg in [1,2,3,4]:
    #                 if hdg not in heading_applied:
    #                     add_hdgs.append(hdg)
    #             if len(add_hdgs) == 1:
    #                 structure.append(add_hdgs[0])
    #             else:
    #                 structure.append(add_hdgs)
    #             continue
    #
    #         if action[it+L[-1]] > 0.5:
    #             structure.append(L[-1]+1)
    #             heading_applied.add(L[-1]+1)
    #         elif action[it+L[-1]] + action[it+L[-2]] > 0.5:
    #             structure.append([L[-1]+1,L[-2]+1])
    #             heading_applied.add(L[-1] + 1)
    #             heading_applied.add(L[-2] + 1)
    #         elif action[it+L[-1]]  + action[it+L[-2]] + action[it+L[-3]] > 0.5:
    #             structure.append([L[-1]+1,L[-2]+1,L[-3]+1])
    #             heading_applied.add(L[-1] + 1)
    #             heading_applied.add(L[-2] + 1)
    #             heading_applied.add(L[-3] + 1)
    #         else:
    #             structure.append([1,2,3,4])
    #             heading_applied.add(1)
    #             heading_applied.add(2)
    #             heading_applied.add(3)
    #             heading_applied.add(4)
    #
    #     return structure

    # def getStructureFromAction(self, action):
    #     structure = []
    #     heading_applied = set()
    #
    #     for it in np.arange(0, ACTION_DIM, 4):
    #         L = np.argsort(action[it:it+4])
    #         if it == (ACTION_DIM-4) and len(heading_applied) < 4:
    #             add_hdgs = []
    #             for hdg in [1,2,3,4]:
    #                 if hdg not in heading_applied:
    #                     add_hdgs.append(hdg)
    #             structure.append(add_hdgs)
    #             continue
    #
    #         if np.max(action[it:it+4]) > 0.6:
    #             structure.append(L[-1]+1)
    #             heading_applied.add(L[-1]+1)
    #         elif np.max(action[it:it+4]) > 0.4:
    #             structure.append([L[-1]+1,L[-2]+1])
    #             heading_applied.add(L[-1] + 1)
    #             heading_applied.add(L[-2] + 1)
    #         elif np.max(action[it:it+4]) > 0.3:
    #             structure.append([L[-1]+1,L[-2]+1,L[-3]+1])
    #             heading_applied.add(L[-1] + 1)
    #             heading_applied.add(L[-2] + 1)
    #             heading_applied.add(L[-3] + 1)
    #         else:
    #             structure.append([1,2,3,4])
    #
    #     return structure

    def setDDPGreward(self, next_state):
        reward = 0

        # reward based on the number of LoSs
        reward = self.getLoss(reward)
        # reward based on traffic inflow vs outflow
        # reward = self.getTrafficOutflow(reward)
        # reward based on the near misses (almost LoS)
        # reward = self.getNearMisses(reward)

        # scename, states, nextstates, rewards, actions
        # print('FINAL REWARD2', bs.stack.get_scenname(), reward)
        if len(self.actions_DDPG[0]) > 0:
            bs.sim.RLresult(bs.stack.get_scenname(), self.actions_DDPG[0], next_state, reward, self.actions_DDPG[1])

        # outflow-inflow
        # inflow: a/c passing through the first node of the controlled sectionba
        # outflow: a/c passing through the last node of the entrance/exit section

    def getTrafficOutflow(self, rewards):
        rewards -= abs(bs.traf.ntraf - self.expected_traffic) / 10
        # print('reward traffic flow', abs(bs.traf.ntraf - self.expected_traffic)/10)
        return rewards

    def getNearMisses(self, rewards):
        confpairs_new = len(bs.traf.cd.confpairs_all) - self.prevconfpairs
        rewards -= confpairs_new
        # print('reward new conflicts', confpairs_new)
        # it = 0
        # if confpairs_new > 0:
        #     timeLos = bs.traf.cd.conftimeLoS[-confpairs_new:]
        #     confpairs_new = bs.traf.cd.confpairs_all[-confpairs_new:]
        #     for pair in [tuple(x) for x in confpairs_new]:
        #         rewards = self.addTlos(timeLos[it], rewards)
        #         it += 1
        self.prevconfpairs = len(bs.traf.cd.confpairs_all)
        return rewards

    def addTlos(self, timeLos, rewards):
        if timeLos < 10:
            rewards -= PENALTY_NEAR_LOS
        else:
            rewards -= PENALTY_CONFLICT
        return rewards

    def getLoss(self, rewards):
        lospairs_new = len(bs.traf.cd.lospairs_all) - self.prevlospairs
        # print('reward new los', lospairs_new)
        rewards -= lospairs_new
        self.prevlospairs = len(bs.traf.cd.lospairs_all)
        return rewards

    def findclosestfrontiervalue(self, frontier_values, value):
        return min(range(len(frontier_values)), key=lambda i: abs(frontier_values[i] - value))

    def readDeliveryPoints(self):
        scenarioname = bs.stack.get_scenname()
        aircraft_id = 0
        if len(scenarioname) > 0:
            file_name = self.getPointFileName()
            self.previous_read_points = file_name
            with open(os.path.join(bs.settings.scenario_path, file_name), "r") as reader:
                lines = reader.readlines()
                for line in lines:
                    if not line.startswith("#"):
                        line1 = line.split(',')
                        line3 = line.split('[')
                        time = datetime.strptime(line1[1].split(',')[0], '%H:%M:%S.%f')
                        generation_time = time.hour * 3600 + time.minute * 60 + time.second
                        times = np.array([])
                        aircraftRoute = [np.int64(i) for i in line3[1][:-2].replace(' ', '').split(',')]
                        hdgs = np.array([])
                        for it in range(1, len(aircraftRoute)):
                            lat0, lon0 = cb.getLatLonnode(cb.GRAPH, aircraftRoute[it - 1])
                            lat1, lon1 = cb.getLatLonnode(cb.GRAPH, aircraftRoute[it])
                            qdr, dist = geo.kwikqdrdist(lat0, lon0, lat1, lon1)
                            dist = dist * nm  # pass to meters
                            hdgs = np.append(hdgs, qdr)
                            if len(times) > 0:
                                times = np.append(times, times[-1] + dist / TAS_AVG)
                            else:
                                times = np.append(times, generation_time + dist / TAS_AVG)
                        self.aircraftRoutes_times.append(times)
                        self.aircraftRoutes_hdgs.append(hdgs)
                        aircraft_id += 1

    def getPointFileName(self):
        scenarioname = bs.stack.get_scenname()
        if 'SSD' in scenarioname:
            repetition = scenarioname.split('-')[-2]
            instance_name = scenarioname.split('-')[-3]
        else:
            # instance_name = scenarioname.split('-')[2]
            # instance_name = instance_name.split('Inst')[1]
            # repetition = scenarioname.split('-')[-1]
            # repetition = repetition.split('_')[0]
            repetition = scenarioname.split('Rep')[1]
            repetition = repetition.split('.')[0]
        repetition = int(repetition) % 5
        return 'point-test-Rep' + str(0) + '.scn'