''' This module provides the Conflict Detection base class. '''
import numpy as np

import bluesky as bs
from bluesky.tools.aero import ft, nm
from bluesky.tools.replaceable import ReplaceableSingleton
from bluesky.tools.trafficarrays import TrafficArrays, RegisterElementParameters
from bluesky import stack

bs.settings.set_variable_defaults(asas_pzr=0.027, asas_pzh=30.0,
                                  asas_dtlookahead=60.0)


class ConflictDetection(ReplaceableSingleton, TrafficArrays):
    ''' Base class for Conflict Detection imintrusionsspeedplemenasas_pzhtations. '''

    def __init__(self):
        TrafficArrays.__init__(self)
        # [m] Horizontal separation minimum for detection
        self.rpz = bs.settings.asas_pzr * nm
        # [m] Vertical separation minimum for detection
        self.hpz = bs.settings.asas_pzh * ft
        # [s] lookahead time
        self.dtlookahead = bs.settings.asas_dtlookahead
        self.dtnolook = 0.0

        # Conflicts and LoS detected in the current timestep (used for resolving)
        self.confpairs = list()
        self.lospairs = list()
        self.qdr = np.array([])
        self.dist = np.array([])
        self.dcpa = np.array([])
        self.tcpa = np.array([])
        self.tLOS = np.array([])
        self.tLOS_all = np.array([])
        # Unique conflicts and LoS in the current timestep (a, b) = (b, a)
        self.confpairs_unique = set()
        self.lospairs_unique = set()

        # All conflicts and LoS since simt=0
        self.confpairs_all_cruising = list()
        self.lospairs_all_cruising = list()
        self.confpairs_all = list()
        self.lospairs_all = list()

        # breaching the geofence limits
        self.geo_breach = np.array([])
        self.geo_breach_exit = np.array([])
        self.geo_breach_reentrance = np.array([])
        self.geo_breach_dist = np.array([])
        self.geo_breach_total = 0
        self.geo_breach_max_dist = np.array([])

        self.nodes_routes = dict()

        self.intrusions = np.array([])  # Keeps track of the severity of all intrusions [nm]
        self.intrusionstime = np.array([])
        self.intrusionsalt = np.array([])
        self.intrusionsspeed = np.array([])
        self.intrusionslat = np.array([])
        self.intrusionslon = np.array([])
        self.intrusionsnodes1 = np.array([])
        self.intrusionsnodes0 = np.array([])
        self.confnodes1 = np.array([])
        self.confnodes0 = np.array([])
        self.conftimeLoS = np.array([])
        self.confalt = np.array([])

        # Per-aircraft conflict data
        with RegisterElementParameters(self):
            self.inconf = np.array([], dtype=bool)  # In-conflict flag
            self.tcpamax = np.array([])  # Maximum time to CPA for aircraft in conflict

    def clearconfdb(self):
        ''' Clear conflict database. '''
        self.confpairs_unique.clear()
        self.lospairs_unique.clear()
        self.confpairs.clear()
        self.lospairs.clear()
        self.qdr = np.array([])
        self.dist = np.array([])
        self.dcpa = np.array([])
        self.tcpa = np.array([])
        self.tLOS = np.array([])
        self.tLOS_all = np.array([])
        self.inconf = np.zeros(bs.traf.ntraf)
        self.tcpamax = np.zeros(bs.traf.ntraf)
        self.geo_breach = np.zeros(bs.traf.ntraf)
        self.geo_breach_exit = np.array([])
        self.geo_breach_reentrance = np.array([])
        self.geo_breach_dist = np.array([])
        self.geo_breach_max_dist = np.zeros(bs.traf.ntraf)
        self.geo_breach_total = 0

    def reset(self):
        super().reset()
        self.clearconfdb()
        print('RESET', len(self.confpairs_all), len(self.lospairs_all), bs.stack.get_scenname())
        self.confpairs_all.clear()
        self.lospairs_all.clear()
        self.confpairs_all_cruising.clear()
        self.lospairs_all_cruising.clear()
        self.rpz = bs.settings.asas_pzr * nm
        self.hpz = bs.settings.asas_pzh * ft
        self.dtlookahead = bs.settings.asas_dtlookahead
        self.dtnolook = 0.0
        self.geo_breach = np.zeros(bs.traf.ntraf)
        self.geo_breach_exit = np.zeros(bs.traf.ntraf)
        self.geo_breach_reentrance = np.array([])
        self.geo_breach_dist = np.array([])
        self.geo_breach_total = 0
        self.geo_breach_max_dist = np.zeros(bs.traf.ntraf)

        self.nodes_routes = dict()

        self.intrusions = np.array([])  # Keeps track of the severity of all intrusions [nm]
        self.intrusionstime = np.array([])  # logs in the time instance when a new LOS is detected [s]
        self.intrusionsalt = np.array([])
        self.intrusionsspeed = np.array([])
        self.intrusionslat = np.array([])
        self.intrusionslon = np.array([])
        self.intrusionsnodes1 = np.array([])
        self.intrusionsnodes0 = np.array([])
        self.confnodes1 = np.array([])
        self.confnodes0 = np.array([])
        self.conftimeLoS = np.array([])
        self.confalt = np.array([])

    @classmethod
    def setmethod(cls, name=''):
        ''' Select a CD method. '''
        # Get a dict of all registered CD methods
        methods = cls.derived()
        names = ['OFF' if n == 'CONFLICTDETECTION' else n for n in methods]
        if not name:
            curname = 'OFF' if cls.selected() is ConflictDetection else cls.selected().__name__
            return True, f'Current CD method: {curname}' + \
                   f'\nAvailable CD methods: {", ".join(names)}'
        # Check if the requested method exists
        if name == 'OFF':
            # Select the base method and clear the conflict database
            ConflictDetection.select()
            ConflictDetection.instance().clearconfdb()
            return True, 'Conflict Detection turned off.'
        if name == 'ON':
            # Just select the first CD method in the list
            name = next(n for n in names if n != 'OFF')
        method = methods.get(name, None)
        if method is None:
            return False, f'{name} doesn\'t exist.\n' + \
                   f'Available CD methods: {", ".join(names)}'

        # Select the requested method
        method.select()
        ConflictDetection.instance().clearconfdb()
        return True, f'Selected {method.__name__} as CD method.'

    def setrpz(self, value=None):
        ''' Set the horizontal separation distance. '''
        if value is None:
            return True, ("ZONER [radius (nm)]\nCurrent PZ radius: %.2f NM" % (self.rpz / nm))
        self.rpz = value * nm

    def sethpz(self, value=None):
        ''' Set the vertical separation distance. '''
        if value is None:
            return True, ("ZONEDH [height (ft)]\nCurrent PZ height: %.2f ft" % (self.hpz / ft))
        self.hpz = value * ft

    def setdtlook(self, value=None):
        ''' Set the lookahead time for conflict detection. '''
        if value is None:
            return True, ("DTLOOK [time]\nCurrent value: %.1f sec" % self.dtlookahead)
        self.dtlookahead = valuelist(confpairs_unique)

    def setdtnolook(self, value=None):
        ''' Set the interval in which conflict detection is skipped after a
            conflict resolution. '''
        if value is None:
            return True, ("DTNOLOOK [time]\nCurrent value: %.1f sec" % self.dtasas)
        self.dtnolook = value

    def setnotesroute(self, nodesroutes):
        self.nodes_routes = nodesroutes

    def update(self, ownship, intruder):
        ''' Perform an update step of the Conflict Detection implementation. '''
        self.confpairs, self.lospairs, self.inconf, self.tcpamax, self.qdr, self.dist, self.dcpa, self.tcpa,\
        self.tLOS, self.tcpa_all, self.tLOS_all, self.dist_all,  self.dcpa_all, self.mdalt, self.tinhor,\
        self.touthor, self.swhorconf, self.qdr_all, self.dalt = self.detect(ownship, intruder, self.rpz, self.hpz, self.dtlookahead)

        # confpairs has conflicts observed from both sides (a, b) and (b, a)p
        # confpairs_unique keeps only one of these
        confpairs_unique = {frozenset(pair) for pair in self.confpairs}
        lospairs_unique = {frozenset(pair) for pair in self.lospairs}

        new_conf_unique = confpairs_unique - self.confpairs_unique
        new_conf_pairs = [tuple(x) for x in new_conf_unique]
        self.confpairs_all.extend(new_conf_pairs)
        # if len(new_conf_pairs) > 0:
        #     new_array_alt = []
        #     for pair in new_conf_pairs:
        #         id0 = bs.traf.id2idx(pair[0])
        #         id1 = bs.traf.id2idx(pair[1])
        #         if len(bs.traf.ap.route[id0].wptnodes) > bs.traf.ap.route[id0].iactwp and len(bs.traf.ap.route[id1].wptnodes) > bs.traf.ap.route[id1].iactwp:
        #             self.confnodes0 = np.append(self.confnodes0,
        #                                         bs.traf.ap.route[id0].wptnodes[bs.traf.ap.route[id0].iactwp])
        #             self.confnodes1 = np.append(self.confnodes1,
        #                                         bs.traf.ap.route[id1].wptnodes[bs.traf.ap.route[id1].iactwp])
        #             new_array_alt.append((bs.traf.alt[id0], bs.traf.alt[id1]))
        #     self.confalt = np.append(self.confalt, new_array_alt)
        #     self.conftimeLoS = np.append(self.conftimeLoS, self.tLOS[0::2])


        aux = list()
        for pair in new_conf_pairs:
            id0 = bs.traf.id2idx(pair[0])
            id1 = bs.traf.id2idx(pair[1])
            if bs.traf.vs[id0] == 0 and bs.traf.vs[id1] == 0:
                aux.append(pair)
        self.confpairs_all_cruising.extend(aux)

        new_los_unique = lospairs_unique - self.lospairs_unique
        new_los_pairs = [tuple(x) for x in new_los_unique]
        self.lospairs_all.extend(new_los_unique)

        aux = list()
        for pair in new_los_pairs:
            id0 = bs.traf.id2idx(pair[0])
            id1 = bs.traf.id2idx(pair[1])
            if bs.traf.vs[id0] == 0 and bs.traf.vs[id1] == 0:
                aux.append(pair)
        self.lospairs_all_cruising.extend(aux)


        if len(new_los_pairs) > 0:
            idx = [self.confpairs.index(pair) for pair in new_los_pairs]
            # get intrusion severity
            new_array = [(self.rpz - self.dcpa[index]) / nm for index in idx]
            self.intrusions = np.append(self.intrusions, new_array)
            self.intrusionstime = np.append(self.intrusionstime, [bs.sim.simt] * len(new_los_pairs))
            new_array_alt = []
            new_array_speed = []
            new_array_lat = []
            new_array_lon = []
            for pair in new_los_pairs:
                id0 = bs.traf.id2idx(pair[0])
                id1 = bs.traf.id2idx(pair[1])
                new_array_alt.append((bs.traf.alt[id0], bs.traf.alt[id1]))
                new_array_speed.append((bs.traf.gs[id0], bs.traf.gs[id1]))
                new_array_lat.append((bs.traf.lat[id0], bs.traf.lat[id1]))
                new_array_lon.append((bs.traf.lon[id0], bs.traf.lon[id1]))
            self.intrusionsalt = np.append(self.intrusionsalt, new_array_alt)
            self.intrusionsspeed = np.append(self.intrusionsspeed, new_array_speed)
            self.intrusionslat = np.append(self.intrusionslat, new_array_lat)
            self.intrusionslon = np.append(self.intrusionslon, new_array_lon)
            # if len(bs.traf.ap.route[id0].wptnodes) > bs.traf.ap.route[id0].iactwp and len(
            #         bs.traf.ap.route[id1].wptnodes) > bs.traf.ap.route[id1].iactwp:
            #     self.intrusionsnodes0 = np.append(self.intrusionsnodes0,
            #                                       bs.traf.ap.route[id0].wptnodes[bs.traf.ap.route[id0].iactwp])
            #     self.intrusionsnodes1 = np.append(self.intrusionsnodes1,
            #                                       bs.traf.ap.route[id1].wptnodes[bs.traf.ap.route[id1].iactwp])

        # Update confpairs_unique and lospairs_unique
        self.confpairs_unique = confpairs_unique
        self.lospairs_unique = lospairs_unique

    def detect(self, ownship, intruder, rpz, hpz, dtlookahead):
        ''' 
            Detect any conflicts between ownship and intruder.
            This function should be reimplemented in a subclass for actual
            detection of conflicts. See for instance
            bluesky.traffic.asas.statebased.
        '''
        confpairs = []
        lospairs = []
        inconf = np.zeros(ownship.ntraf)
        tcpamax = np.zeros(ownship.ntraf)
        qdr = np.array([])
        dist = np.array([])
        dcpa = np.array([])
        tcpa = np.array([])
        tLOS = np.array([])
        return confpairs, lospairs, inconf, tcpamax, qdr, dist, dcpa, tcpa, tLOS, tcpa, tLOS, dist, dcpa, dalt, tcpa, tcpa, tcpa, tcpa


