# -*- coding: utf-8 -*-
# Created by mjribeiro at 7-7-2020

import bluesky.tools.cityboundaries as cb
import numpy as np
from bluesky.tools.simtime import timed_function
import bluesky as bs
from bluesky import stack


def init_plugin():
    # Addtional initilisation code
    global checkcongestion
    checkcongestion = checkcongestion()

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'CHECKCONGESTION',

        # The type of this plugin. For now, only simulation plugins are possible.
        'plugin_type': 'sim',

        # The update function is called after traffic is updated.
        'update': checkcongestion.update,

        # The reset function
        'reset': checkcongestion.reset
    }
    stackfunctions = {}

    # init_plugin() should always return these two dicts.
    return config, stackfunctions


class checkcongestion():

    def __init__(self):
        super(checkcongestion, self).__init__()
        self.nodes_id = np.array([])
        self.nodes_lat = np.array([])
        self.nodes_lon = np.array([])
        self.edges_speeds = dict()
        self.aircraft_routes = dict()
        self.generation_times = dict()
        self.deliverypoints = dict()

    def get_edge(self, node0, node1):
        return (min(node0, node1), max(node0, node1))

    @timed_function('PATHFINDER', dt=15.0)
    def update(self, dt):
        if len(self.nodes_id) == 0 and cb.GRAPH is not None:
            self.state_dim = 6
            self.actiondim = 1
            #bs.sim.DDPGcreate(self.statedim, self.actiondim)

        if self.previous_read_points != self.getPointFileName():
            self.readDeliveryPoints()

    def reset(self):
        self.prevlospairs = set()
        self.prevconfpairs = set()

        self.actions = dict()
        self.aircraft_waiting = dict()

        bs.sim.RLsave(bs.stack.get_scenname())

    def sendAicraftRoutes(self):
        bs.traf.cd.setnotesroute(self.aircraft_routes)

    def getIndPerAngle(self, angle):
        if angle < 90:
            ind_result = 0
        elif angle < 180:
            ind_result = 1
        elif angle < 300:
            ind_result = 2
        else:
            ind_result = 3

        return ind_result

    def getNodeSide(self, aircraft, wp):
        node_int = self.aircraft_routes[bs.traf.id[aircraft]][wp]
        edge_it = bs.traf.trk[aircraft]
        # if the next node is not active, then we get the last one
        if not self.nodes_states[(np.where(self.nodes_id == node_int)[0]).item(0)]:
            node_int = self.aircraft_routes[bs.traf.id[aircraft]][wp - 1]

        edge_it = (edge_it + 180) % 360
        edge_it = self.getIndPerAngle(edge_it)

        return node_int, edge_it

    def addtoRLstate(self, nodes_with_aircraft, node_int, edge_it):
        if nodes_with_aircraft.get(node_int) is None:
            nodes_with_aircraft[node_int] = [0, 0, 0, 0]
        nodes_with_aircraft[node_int][edge_it] += 1

        return nodes_with_aircraft

    def getNextWaypoint(self, aircraft):
        iactwp = bs.traf.ap.route[aircraft].iactwp
        return iactwp

    def getcurrentstate(self):
        new_state = np.zeros(self.state_dim)

        for aircraft in range(bs.traf.ntraf):
            # find what's already at the node
            iactwp = self.getNextWaypoint(aircraft)
            if iactwp < len(self.aircraft_routes[bs.traf.id[aircraft]]):
                node_int = self.aircraft_routes[bs.traf.id[aircraft]][iactwp]
                new_state[self.nodes_groups[node_int]] += 1

        self.nodes_RL = new_state

    def callRL(self):
        bs.sim.RLaction(stack.get_scenname(), self.nodes_RL)


    def getroutesegments(self, route, deliverypoints):
        route_segments = []
        last_id = 0

        for deliverypoint in deliverypoints:
            point_it = np.where(route == deliverypoint)[0]
            if len(point_it) > 0:
                point_it = point_it.item(0)
                route_segments.append(route[last_id:point_it + 1])
                last_id = point_it + 1

        return route_segments

    def checkMatch(self, delivery_segment, edges_changed):
        for edge in edges_changed:
            if edge in delivery_segment:
                return True

        return False

    def set_Maximum_Speeds(self):
        nodes = list(self.actions.keys())

        change_route_aircrafts = []
        nodes_changed = set()

        for node in nodes:
            actions = np.array(list(self.actions[node]))[0]
            if self.action_decision == BOTH_ACTIONS:
                actions = np.array(actions[0::2])
            edges = [edge[0] for edge in cb.EDGES_PER_NODE[node]]
            actions = self.actions_per_edges(edges, actions, node)
            actions = self.adjustDirectionValues(actions, edges, node)

            diffs = []
            for it in range(len(edges)):
                edge = edges[it]
                previous_values = self.edges_directions[(min(node, edge), max(node, edge))]
                if actions[it] != previous_values:
                    diffs.append(it)
                    nodes_changed.add(node)
                    # update the values
                    self.edges_directions[(min(node, edge), max(node, edge))] = actions[it]

        if len(nodes_changed) > 0:
            for aircraft, route in self.aircraft_routes.items():
                route_changed = False
                route_new = np.array([])
                if aircraft in bs.traf.id:  # does aircraft exist?
                    aircraft_it = bs.traf.id2idx(aircraft)
                    # lets not count the past route
                    iactwp = self.getNextWaypoint(aircraft_it)
                    indices_intersection = np.in1d(np.array(route), np.array(list(nodes_changed)))
                    path_changed = np.sum(indices_intersection[iactwp:]) > 0
                    waiting = aircraft in self.aircraft_waiting.keys()
                    if (iactwp < len(route) and path_changed) or waiting:
                        # has anything changed?
                        if waiting:
                            it_start = self.aircraft_waiting[aircraft]
                        else:
                            for it_start in range(len(self.deliverypoints[aircraft])):
                                if self.deliverypoints[aircraft][it_start] in route[iactwp:]:
                                    break
                        route_complete = True
                        route_points = []
                        current_route_point = 0
                        if iactwp < len(route) and route[iactwp] != self.deliverypoints[aircraft][it_start]:
                            route_points = [route[iactwp]]
                            current_route_point = iactwp
                        route_points = route_points + self.deliverypoints[aircraft][it_start::]
                        # print('route_points', route_points)
                        all_new = True
                        last_it = 0
                        for it in range(len(route_points) - 1):
                            # do we need to change to re-check this part of the route
                            # we don't re-check everything otherwise it's too slow

                            # do to weird turns, we alrady passed through there
                            if route_points[it + 1] not in route_new:
                                ind_next_deliv_point = np.where(np.array(route) == route_points[it + 1])[0]
                                if len(ind_next_deliv_point) > 1:
                                    for value in ind_next_deliv_point:
                                        if value > current_route_point:
                                            ind_next_deliv_point = np.array([value])
                                            break
                                if all_new or len(ind_next_deliv_point) == 0 or ind_next_deliv_point.item(
                                        0) < current_route_point \
                                        or np.sum(
                                    indices_intersection[current_route_point:ind_next_deliv_point.item(0) + 1]) > 0:
                                    aux, dummy, valid = cb.getshortestpath(route_points[last_it],
                                                                           route_points[it + 1],
                                                                           self.edges_directions, self.nodes_id,
                                                                           self.nodes_states)
                                    # print(route_points[it + 1], 'new', current_route_point, last_it)
                                else:
                                    aux = route[current_route_point:ind_next_deliv_point.item(0) + 1]
                                    # print(route_points[it + 1], 'copy', current_route_point, ind_next_deliv_point.item(0))
                                    valid = True
                                if len(ind_next_deliv_point) > 0:
                                    current_route_point = ind_next_deliv_point.item(0)
                                if len(route_new) > 0 and route_new[-1] == aux[0]:
                                    aux = aux[1:]
                                # if len(route_new) > 0:
                                #     layer = self.edges_altitudes[(min(route_new[-1], aux[0]), max(route_new[-1], aux[0]))]
                                route_new = np.append(route_new, aux)
                                route_changed = True
                                if not valid:
                                    self.aircraft_waiting[aircraft] = it
                                    route_complete = False
                                    break

                                last_it = it + 1
                            else:
                                all_new = True
                                # print(route_points[it + 1], 'already present')

                        # remove aircraft waiting from the list
                        if route_complete and waiting:
                            del self.aircraft_waiting[aircraft]

                if route_changed:
                    self.aircraft_routes[aircraft] = route_new
                    change_route_aircrafts.append(aircraft)

        return change_route_aircrafts

    def set_actions(self):
        # check which edges have changed
        self.actions = bs.sim.getDDPGaction()

        if len(self.actions) == 0:
            return

        change_route_aircrafts = self.set_Maximum_Speeds()

        if len(change_route_aircrafts) > 0:
            for aircraft in change_route_aircrafts:
                aircraft_it = bs.traf.id2idx(aircraft)
                self.setAircraftRoute(aircraft_it, self.aircraft_routes[aircraft])




