# -*- coding: utf-8 -*-
# Created by mjribeiro at 28-1-2021
from numpy.distutils.system_info import accelerate_info

from bluesky.tools.simtime import timed_function
import bluesky as bs
import numpy as np
from bluesky.tools import geo
from bluesky.tools.aero import nm, ft
import copy

NUMBER_LAYERS = 6
ACTION_DIM = 3
STATE_DIM = 12
ACTION_DIM2 = 4
STATE_DIM2 = 8
NUM_INTRUDERS_STATE = 8

LAYER_SEP = 5  # 5 ft
ASAS_PZH = bs.settings.asas_pzh  # ft
ALT_MIN = 65  # ftp

ALTS_all = np.arange(ALT_MIN, ALT_MIN + (LAYER_SEP + ASAS_PZH) * 17, LAYER_SEP + ASAS_PZH)  # in ft
# structure [4, 1, 2, 3, 4, 1]
# heading_transform = [-np.inf, 45, 135, 225, 315]
ALTS_hdgs = np.array([  86.8, 86.8, 86.8,  356.8, 356.8, 356.8,   266.8,  266.8,  266.8,   176.8, 176.8, 176.8,86.8, 86.8, 86.8, 356.8, 356.8])
ALTS_traffic = np.arange(ALT_MIN, 400, ASAS_PZH * 3 + LAYER_SEP * 3)

ALTS_all = ALTS_all * ft  # pass to meters
ALTS_traffic = ALTS_traffic * ft  # pass to meters

PENALTY_LOS = 20
PENALTY_NEAR_LOS = 2  # 4
PENALTY_CONFLICT = 1  # 2

LOOK_AHEAD_TIME = 4000  # seconds - 20 minutes
STATIC_STRUCT = True
FINAL_TIME = 1200
# FINAL_TIME = 3600
TAS_AVG = 21  # kts - 10.83m/s

TIME_MAXIMUM_ACTION_DURATION = 100 # m/2time from start the start of the action until the reward is calculated
TIME_MINIMUM_ACTION_DURATION = 5

TRAFFIC_LAYER_ID = 0
FAST_LAYER_ID = 1
SLOW_LAYER_ID = 2

def init_plugin():
    # Addtional initilisation code
    global merging_RL
    merging_RL = merging_RL()

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'merging_RL',

        # The type of this plugin. For now, only simulation plug8ins are possible.
        'plugin_type': 'sim',

        # The update function is called after traffic is updated.
        'update': merging_RL.update,

        # The reset function
        'reset': merging_RL.reset
    }
    stackfunctions = {}

    # init_plugin() should always return these two dicts.
    return config, stackfunctions


class merging_RL():

    def __init__(self):
        super(merging_RL, self).__init__()

    def reset(self):
        # all layers start as inactive. state, action
        self.actions_DDPG = [[] for i in range(3)]
        # ac waiting for action
        self.acs_waiting_action = np.array([])
        self.acs_waiting_action2 = np.array([])
        self.acs_waiting_action_time = np.array([])
        self.acs_waiting_action_layer = np.array([])
        self.acs_waiting_action_wpt_layer = np.array([])
        self.acs_performing_action = np.array([])
        self.performing_action = []
        self.acs_performing_action_time = np.array([])
        self.acs_reaching_objective_layer_time = np.array([])
        self.performing_action_lat = np.array([])
        self.performing_action_lon = np.array([])
        self.acs_performing_action_initial_layer = np.array([])
        self.acs_performing_action_wpt_layer = np.array([])
        self.conf_total_until_point = np.array([])
        self.los_total_until_point = np.array([])
        self.objective_layer = np.array([])
        self.ac_has_reached_objective = np.array([])
        self.current_layer = np.zeros(bs.traf.ntraf)
        self.wpt_layer = np.zeros(bs.traf.ntraf)
        self.RL_actions = [[] for i in range(2)]
        self.ac_already_set = np.array([])
        self.speed_during_action = np.array([])

        bs.sim.RLsave(bs.stack.get_scenname())
        bs.sim.DDPG2save(bs.stack.get_scenname())

        self.ac_waiting_merging = np.array([])
        self.ac_waiting_merging_until_wpt = np.array([])

    def getActiveWaypoint(self, aircraft):
        iactwp = bs.traf.ap.route[aircraft].iactwp
        return iactwp

    def getTransitLayers(self, layer):
        if layer in [0, 3, 6, 9, 12, 15]:  # aircraft is in a traffic layer
            current_layer_id = 0

        elif layer in [1, 4, 7, 10, 13, 16]:  # aircraft is in a fast layer
            current_layer_id = 1

        elif layer in [-1, 2, 5, 8, 11, 14]:  # aircraft is in a slow layer
            current_layer_id = 2

        return current_layer_id

    def getTLosFromArray(self, array, pos):
        if len(array) > pos:
            return min(array[pos], 600)

        return 600

    def getDistanceToAC(self, ac_id,current_layer, intended_layer ):
        tlos_frontac = 160
        tlos_backac = 160

        ac_intruders = np.where(current_layer == intended_layer)[0]
        #print('ac_intruders', intended_layer)
        if len(ac_intruders) > 0:
            qdr, dist = geo.kwikqdrdist_matrix(bs.traf.lat[ac_intruders], bs.traf.lon[ac_intruders],
                                               np.ones(len(ac_intruders)) * bs.traf.lat[ac_id],
                                               np.ones(len(ac_intruders)) * bs.traf.lon[ac_id])
            dist *= nm
            if len(dist) > 1:
                qdr_idx = np.where(np.logical_or(qdr >= 340, qdr <= 20))[0]
                if len(qdr_idx) > 0 and np.any(dist[qdr_idx]) > 0:
                    indx = np.where(dist[qdr_idx] > 0)[0]
                    if len(indx) > 0:
                        dist_frontac = min(dist[qdr_idx][indx])
                        frontac = int(np.where(dist[qdr_idx] == dist_frontac)[0])
                        dvs = bs.traf.gs[ac_id] - bs.traf.gs[frontac]
                        tlos_frontac = (dist_frontac - bs.traf.cd.hpz) / -dvs
                        if tlos_frontac < 0:  # it happened in the past
                            tlos_frontac = 3000
                qdr_idx = np.where(np.logical_and(qdr > 160, qdr < 190))[0]
                if len(qdr_idx) > 0 and np.any(dist[qdr_idx]) > 0:
                    dist_backac = min(dist[qdr_idx])
                    backac = int(np.argmin(dist[qdr_idx]))
                    dvs = bs.traf.gs[ac_id] - bs.traf.gs[backac]
                    tlos_backac = (dist_backac - bs.traf.cd.hpz) / -dvs
                    if tlos_backac < 0:  # it happened in the past
                        tlos_backac = 3000

        return min(tlos_frontac, 3000), min(tlos_backac, 3000), len(ac_intruders)

    def getLosLayers(self,current_layer, ac_id):
        relative_position = np.array([])
        relative_speed = np.array([])
        relative_layer = np.array([])
        intruders = np.array([])

        dummy, dist = geo.kwikqdrdist_matrix(bs.traf.lat, bs.traf.lon, np.ones(bs.traf.ntraf) * bs.traf.lat[ac_id],
                                             np.ones(bs.traf.ntraf) * bs.traf.lon[ac_id])
        dist*= nm

        # Ownship track angle and speed
        owntrkrad = np.radians(bs.traf.trk)
        ownu = bs.traf.gs * np.sin(owntrkrad).reshape((1, bs.traf.ntraf))  # m/s
        ownv = bs.traf.gs * np.cos(owntrkrad).reshape((1, bs.traf.ntraf))  # m/s

        # Intruder track angle and speed
        inttrkrad = np.radians(bs.traf.trk)
        intu = bs.traf.gs * np.sin(inttrkrad).reshape((1, bs.traf.ntraf))  # m/s
        intv = bs.traf.gs * np.cos(inttrkrad).reshape((1, bs.traf.ntraf))  # m/s

        du = ownu - intu.T  # Speed du[i,j] is perceived eastern speed of i to j
        dv = ownv - intv.T  # Speed dv[i,j] is perceived northern speed of i to j

        dv2 = du * du + dv * dv
        dv2 = np.where(np.abs(dv2) < 1e-6, 1e-6, dv2)  # limit lower absolute value
        vrel = np.sqrt(dv2)[0]

        dist_sorted = np.sort(dist)
        argsort_dist = np.argsort(dist)
        for it in range(len(dist_sorted)):
            if dist_sorted[it] == 0:
                continue
            intruders = np.append(intruders, argsort_dist[it])
            relative_position = np.append(relative_position, min(dist_sorted[it], 1000))
            relative_speed = np.append(relative_speed, vrel[argsort_dist[it]] )
            layer_dif = (current_layer[argsort_dist[it]] - current_layer[ac_id])/4
            layer_dif = min(layer_dif, np.sign(layer_dif))
            relative_layer = np.append(relative_layer, layer_dif)

        return intruders[:NUM_INTRUDERS_STATE], relative_position[:NUM_INTRUDERS_STATE], relative_speed[:NUM_INTRUDERS_STATE], relative_layer[:NUM_INTRUDERS_STATE]

    @timed_function('merging_RL', dt=1.0)
    def update(self, dt):
        ac_per_layer = np.zeros(len(ALTS_all))  # there are 6 traffic layers
        self.current_layer = np.zeros(bs.traf.ntraf)
        dist_next_waypoints = np.zeros(bs.traf.ntraf)
        time_next_waypoints = np.zeros(bs.traf.ntraf)
        self.wpt_layer = np.zeros(bs.traf.ntraf)
        self.wpts_to_turn= np.zeros(bs.traf.ntraf)

        # go through every aircraft
        for ac_id in range(bs.traf.ntraf):
            # get the next node that will result in a change
            iactwp = self.getActiveWaypoint(ac_id)

            previous_waypoint = iactwp-1
            if iactwp == 0:
                previous_waypoint = 0
            layer_now = min(range(len(ALTS_all)),key=lambda i: abs(ALTS_all[i] - bs.traf.ap.route[ac_id].wpalt[previous_waypoint]))

            if bs.traf.vs[ac_id] > 0 and bs.traf.alt[ac_id] > ALTS_all[layer_now]:
                layer_now += 1
            elif bs.traf.vs[ac_id] < 0 and  bs.traf.alt[ac_id] < ALTS_all[layer_now]:
                layer_now -= 1

            layer_now = min(16, layer_now)
            self.current_layer[ac_id] = layer_now
            ac_per_layer[layer_now] += 1

            next_wpt = iactwp
            layer_wpt = min(range(len(ALTS_all)),
                                key=lambda i: abs(ALTS_all[i] - bs.traf.ap.route[ac_id].wpalt[next_wpt]))
            layer_wpt2 = layer_wpt
            while layer_wpt2 == layer_now and next_wpt < len(bs.traf.ap.route[ac_id].wpalt):
                layer_wpt2 = min(range(len(ALTS_all)),
                                key=lambda i: abs(ALTS_all[i] - bs.traf.ap.route[ac_id].wpalt[next_wpt]))
                next_wpt += 1

            self.wpts_to_turn[ac_id] = layer_wpt2 - layer_now
            self.wpt_layer[ac_id] = layer_wpt
            dist_next_waypoints[ac_id] = geo.kwikdist(bs.traf.lat[ac_id], bs.traf.lon[ac_id],
                                                      bs.traf.ap.route[ac_id].wplat[iactwp],
                                                      bs.traf.ap.route[ac_id].wplon[iactwp])
            # if dist_next_waypoints[ac_id]*nm/bs.traf.gs[ac_id] > 100000:
            #     print('time_next_waypoints', dist_next_waypoints[ac_id]*nm, bs.traf.gs[ac_id])
            time_next_waypoints[ac_id] = dist_next_waypoints[ac_id] * nm / bs.traf.gs[ac_id]
            if np.isinf(time_next_waypoints[ac_id]):
                time_next_waypoints[ac_id] = 600

        completed_actions = np.array([], dtype=np.int8)
        for it in range(len(self.acs_performing_action)):
            ac_name = self.acs_performing_action[it]
            ac_id = bs.traf.id2idx(ac_name)

            if self.current_layer[ac_id] == self.objective_layer[it] and self.acs_reaching_objective_layer_time[it] == np.inf:
                self.acs_reaching_objective_layer_time[it] = bs.sim.simt

            if bs.sim.simt >= self.acs_reaching_objective_layer_time[it] + TIME_MINIMUM_ACTION_DURATION or \
                    self.wpt_layer[ac_id] != self.objective_layer[it]:
                completed_actions = np.append(completed_actions, int(it))

        if len(completed_actions) > 0:
            self.acs_performing_action = np.delete(self.acs_performing_action, completed_actions)
            self.speed_during_action = np.delete(self.speed_during_action, completed_actions)
            self.acs_reaching_objective_layer_time = np.delete(self.acs_reaching_objective_layer_time,completed_actions)
            self.objective_layer = np.delete(self.objective_layer, completed_actions)
            for it_completed_actions in completed_actions[::-1]:
                self.performing_action.pop(it_completed_actions)

        for ac_id in range(len(time_next_waypoints)):
            ac_name = bs.traf.id[ac_id]
            if ac_name not in self.acs_waiting_action  and ac_name not in self.acs_performing_action and ac_name not in self.acs_waiting_action2 \
                    and self.wpt_layer[ac_id]!= self.current_layer[ac_id] and iactwp < len(bs.traf.ap.route[ac_id].wpalt)-1:

                current_layer = self.current_layer[ac_id]
                current_layer_initial = current_layer
                if self.wpt_layer[ac_id] != self.current_layer[ac_id]: # we are gonna move into another area
                    current_layer = self.wpt_layer[ac_id]

                current_layer_id = self.getTransitLayers(current_layer)

                if current_layer_id == TRAFFIC_LAYER_ID:
                    tlos_frontac_traffic_layer,tlos_backac_traffic_layer, number_ac_trafficlayer = self.getDistanceToAC( ac_id, self.current_layer, current_layer)
                    tlos_frontac_fastlayer, tlos_backac_fastlayer,number_ac_fastlayer = self.getDistanceToAC( ac_id, self.current_layer, current_layer + 1)
                    tlos_frontac_slowLayer, tlos_backac_slowlayer, number_ac_slowlayer = self.getDistanceToAC( ac_id, self.current_layer, current_layer + 2)
                elif current_layer_id == FAST_LAYER_ID:
                    tlos_frontac_traffic_layer, tlos_backac_traffic_layer, number_ac_trafficlayer = self.getDistanceToAC(ac_id, self.current_layer, current_layer - 1)
                    tlos_frontac_fastlayer, tlos_backac_fastlayer, number_ac_fastlayer = self.getDistanceToAC(ac_id, self.current_layer, current_layer)
                    tlos_frontac_slowLayer, tlos_backac_slowlayer,number_ac_slowlayer = self.getDistanceToAC( ac_id, self.current_layer, current_layer + 1)
                else:
                    tlos_frontac_traffic_layer,tlos_backac_traffic_layer, number_ac_trafficlayer = self.getDistanceToAC(ac_id, self.current_layer, current_layer - 2)
                    tlos_frontac_fastlayer, tlos_backac_fastlayer, number_ac_fastlayer = self.getDistanceToAC(ac_id, self.current_layer, current_layer - 1)
                    tlos_frontac_slowLayer, tlos_backac_slowlayer, number_ac_slowlayer = self.getDistanceToAC(ac_id, self.current_layer, current_layer)

                # idx = -1
                # if ac_name in self.ac_waiting_merging:
                #     idx2 = np.where(self.ac_waiting_merging == ac_name)[0]
                #     active_wpt = self.ac_waiting_merging_until_wpt[idx2]
                #     idx = active_wpt
                #     if iactwp > active_wpt:
                #         idx = -1

                if ac_name not in self.ac_waiting_merging:
                    state = [bs.traf.gs[ac_id],  current_layer_id,
                            number_ac_trafficlayer, tlos_frontac_traffic_layer, tlos_backac_traffic_layer,
                            number_ac_fastlayer, tlos_frontac_fastlayer, tlos_backac_fastlayer,
                            number_ac_slowlayer, tlos_frontac_slowLayer, tlos_backac_slowlayer,
                            self.wpts_to_turn[ac_id]]

                    self.acs_waiting_action = np.append(self.acs_waiting_action, ac_name)
                    if bs.sim.isRLcreated():
                        bs.sim.RLaction(bs.stack.get_scenname(), state)
                    else:
                        bs.sim.RLcreate(bs.stack.get_scenname(), state)

                else:
                    if self.wpt_layer[ac_id] > self.current_layer[ac_id]:
                        next_layer_aux = self.current_layer[ac_id] + 1
                    else:
                        next_layer_aux = self.current_layer[ac_id] - 1
                    tlos_frontac_next_layer, tlos_backac_next_layer, number_ac_nextlayer = self.getDistanceToAC( ac_id, self.current_layer, next_layer_aux)

                    self.acs_waiting_action2 = np.append(self.acs_waiting_action2, ac_name)
                    state = [bs.traf.gs[ac_id], abs(ALTS_hdgs[int(self.current_layer[ac_id])] - bs.traf.trk[ac_id]),
                             tlos_frontac_traffic_layer, tlos_backac_traffic_layer,
                             abs(ALTS_hdgs[int(next_layer_aux)] - bs.traf.trk[ac_id]),
                             tlos_frontac_next_layer, tlos_backac_next_layer,self.wpts_to_turn[ac_id]]

                    if bs.sim.isDDPG2created():
                        bs.sim.DDPG2action(bs.stack.get_scenname(), state)
                    else:
                        bs.sim.DDPG2create(bs.stack.get_scenname(), state)


    def applyPossibleAction(self, ac_idx, actions):
        # 0 - move to next traffic layer
        # 1 - move to fast layer
        # 2 - move to slow layer
        # 4 - stay in same layer
        objective_layer = None
        action_idx = np.argmax(actions)
        if action_idx < 3:  # if 3, do nothing
            traffic_layer, fast_layer, slow_layer, current_layer_id, layer_dif = \
                self.getTransitLayers(self.wpt_layer, self.current_layer, ac_idx)
            # print(self.wpt_layer, self.current_layer, ac_idx, '-',  traffic_layer, fast_layer, slow_layer, current_layer_id, layer_dif)
            self.RL_actions[0].append(ac_idx)
            if action_idx == 0 and self.isSafeMove(ac_idx, traffic_layer):  # move to traffic layer
                self.RL_actions[1].append(ALTS_all[int(traffic_layer)])
                objective_layer = int(traffic_layer)
            elif action_idx == 1 and self.isSafeMove(ac_idx, fast_layer):  # move to fast layer
                self.RL_actions[1].append(ALTS_all[int(fast_layer)])
                objective_layer = int(fast_layer)
            elif self.isSafeMove(ac_idx, slow_layer):  # move to slow layer
                self.RL_actions[1].append(ALTS_all[int(slow_layer)])
                objective_layer = int(slow_layer)
        else:
            objective_layer = self.current_layer[ac_idx]

        # which layer is the objective
        return objective_layer

    # is the moving to Layer in between the current layer and the objective?
    def actionIsValid(self, currentLayer, wpt_layer, moving_to_layer):
        # is ac is to continue in the same layer, there's no problem
        if currentLayer != moving_to_layer:
            # ac is going down
            if wpt_layer > currentLayer and moving_to_layer < wpt_layer:
                return False
            # ac is going down
            if wpt_layer < currentLayer and moving_to_layer > wpt_layer:
                return False
        return True

    def setDDPGaction(self, action):
        # state (dim = 5) + actions(sim=1)
        self.waiting_response = False
        ac_name = self.acs_waiting_action[0]
        ac_id = bs.traf.id2idx(ac_name)

        self.acs_waiting_action = self.acs_waiting_action[self.acs_waiting_action != ac_name]

        actions_prefered = np.argmax(action[STATE_DIM:])
        objective_layer = self.wpt_layer[ac_id]

        if self.wpt_layer[ac_id] != self.current_layer[ac_id]:  # we are gonna move into another area
            objective_layer = self.wpt_layer[ac_id]

        if actions_prefered == FAST_LAYER_ID:  # we want to move to the fast layer
            objective_layer += 1 # move up to traffic layer

        elif actions_prefered == SLOW_LAYER_ID:  # we want to move to the slow layer
            objective_layer += 2 # move up to slow layer

        objective_layer = min(16,objective_layer ) # there are only 16 layers
        it_wpt = self.updateWaypoints(ac_id, ALTS_all[int(objective_layer)])

        if ac_name not in self.ac_waiting_merging:
            self.ac_waiting_merging = np.append(self.ac_waiting_merging, ac_name)
            self.ac_waiting_merging_until_wpt = np.append(self.ac_waiting_merging_until_wpt, it_wpt -1)
        else:
            idx = np.where(self.ac_waiting_merging == ac_name)[0]
            self.ac_waiting_merging_until_wpt[idx] = it_wpt -1

    def setDDPG2action(self, action):
        ac_name = self.acs_waiting_action2[0]
        ac_id = bs.traf.id2idx(ac_name)
        self.acs_waiting_action2 = self.acs_waiting_action2[self.acs_waiting_action2 != ac_name]
        self.acs_performing_action = np.append(self.acs_performing_action, ac_name)
        self.performing_action.append(action[STATE_DIM2:])
        self.acs_reaching_objective_layer_time = np.append(self.acs_reaching_objective_layer_time, np.inf)
        self.speed_during_action = np.append(self.speed_during_action, bs.traf.gs[ac_id])

        self.ac_waiting_merging = self.ac_waiting_merging[self.ac_waiting_merging != ac_name]

        actions_prefered = np.argmax(action[STATE_DIM2:])
        if np.argmax(actions_prefered) == 3:
            if self.wpt_layer[ac_id] > self.current_layer[ac_id]:
                objective_layer = self.current_layer[ac_id] + 1
            else:
                objective_layer = self.current_layer[ac_id] - 1
        else:
            objective_layer = self.current_layer[ac_id]

        objective_layer = min(16, objective_layer)  # there are only 16 layers
        #self.updateWaypoints2(ac_id, ALTS_all[int(objective_layer)])
        self.objective_layer = np.append(self.objective_layer, objective_layer)

    def updateWaypoints2(self, ac, movingTo_alt):
        iactwp = bs.traf.ap.route[ac].iactwp
        current_alt_traffic = bs.traf.ap.route[ac].wpalt[iactwp]
        bs.traf.ap.route[ac].wpalt[iactwp] = movingTo_alt
        return


    def updateWaypoints(self, ac, movingTo_alt):
        iactwp = bs.traf.ap.route[ac].iactwp
        current_alt_traffic = bs.traf.ap.route[ac].wpalt[iactwp]
        bs.traf.ap.route[ac].wpalt[iactwp] = movingTo_alt
        it_wpt = iactwp + 1
        while it_wpt < len(bs.traf.ap.route[ac].wpalt) and \
                ( bs.traf.ap.route[ac].wpalt[it_wpt]==current_alt_traffic or bs.traf.ap.route[ac].wpalt[it_wpt] not in ALTS_traffic):
            bs.traf.ap.route[ac].wpalt[it_wpt] = movingTo_alt
            it_wpt += 1

        return it_wpt

    def getActions(self):
        vs_needed = np.array([])
        gs_needed = np.array([])
        alts = np.array([])
        vmin, vmax, vsmin, vsmax = bs.traf.perf.currentlimits()
        vmin_allowed = 5

        #actions:
        # 0 - Stay in the current layer, keep current speed
        # 1 - Stay in the current layer, accelerate with -2m/s
        # 2 - Stay in the current layer, accelerate with -5m/s
        # 3 - Stay in the current layer, accelerate with 2m/s
        # 1/4 - Move one layer up, keep current speed#
        # 2/5 - Move one layer down, keep current speed

        # aircraft must follow their route independlty
        qdr, distinnm = geo.qdrdist(bs.traf.lat, bs.traf.lon, bs.traf.actwp.lat,bs.traf.actwp.lon)  # [deg][nm])
        dist2wp = distinnm * nm
        reached = bs.traf.actwp.Reached(qdr, dist2wp, bs.traf.actwp.flyby, bs.traf.actwp.flyturn, bs.traf.actwp.turnrad)

        for it in range(len(self.acs_performing_action)):
            ac_idx = bs.traf.id2idx(self.acs_performing_action[it])
            action = np.argmax(self.performing_action[it])
            speed_during_action = self.speed_during_action[it]

            if action > 2:
                vs_needed = np.append(vs_needed, vsmax[ac_idx])
                gs_needed = np.append(gs_needed, bs.traf.tas[ac_idx])
                alts = np.append(alts, ALTS_all[int(self.objective_layer[it])])
            else:
                vs_needed = np.append(vs_needed, 0)
                if self.current_layer[ac_idx] != self.objective_layer[it]:
                    if action == 0:
                        gs_needed = np.append(gs_needed, bs.traf.tas[ac_idx])
                    elif action == 1:
                        gs_needed = np.append(gs_needed, min(vmax[ac_idx], speed_during_action + 2))
                    elif action == 2:
                        gs_needed = np.append(gs_needed, max(1, speed_during_action - 2))
                else:
                    gs_needed = np.append(gs_needed, bs.traf.tas[ac_idx])

                alts = np.append(alts, ALTS_all[int(self.current_layer[ac_idx])])

        return self.acs_performing_action, vs_needed, alts, gs_needed



    def getNearMisses(self, rewards, conf_begin_action, layer_now):
        new_conf = bs.traf.cd.confpairs_all_cruising[int(conf_begin_action):]
        confpairs_unique = {frozenset(pair) for pair in new_conf}
        new_conf_pairs = [tuple(x) for x in confpairs_unique]
        for pair in new_conf_pairs:
            idx0 = bs.traf.id2idx(pair[0])
            idx1 = bs.traf.id2idx(pair[1])
            if min(range(len(ALTS_all)), key=lambda i: abs(ALTS_all[i] - bs.traf.alt[idx0])) == layer_now or \
                    min(range(len(ALTS_all)), key=lambda i: abs(ALTS_all[i] - bs.traf.alt[idx1])) == layer_now:
                rewards -= 1
        return rewards

    def getLoss(self, rewards, los_begin_action, layer_now):
        new_los = bs.traf.cd.lospairs_all_cruising[int(los_begin_action):]
        lospairs_unique = {frozenset(pair) for pair in new_los}
        new_los_pairs = [tuple(x) for x in lospairs_unique]
        for pair in new_los_pairs:
            idx0 = bs.traf.id2idx(pair[0])
            idx1 = bs.traf.id2idx(pair[1])
            if min(range(len(ALTS_all)), key=lambda i: abs(ALTS_all[i] - bs.traf.alt[idx0])) == layer_now or \
                    min(range(len(ALTS_all)), key=lambda i: abs(ALTS_all[i] - bs.traf.alt[idx1])) == layer_now:
                rewards -= 10
        return rewards
