# -*- coding: utf-8 -*-
# Created by mjribeiro at 28-1-2021
from bluesky.tools.simtime import timed_function
import bluesky as bs
import numpy as np
from bluesky.tools import geo
from bluesky.tools.aero import nm, ft
import copy

NUMBER_LAYERS = 6
ACTION_DIM = 1
STATE_DIM = 8

LAYER_SEP = 5  # 5 ft
ASAS_PZH = bs.settings.asas_pzh  # ft
ALT_MIN = 65  # ft

ALTS_all = np.arange(ALT_MIN, ALT_MIN + (LAYER_SEP + ASAS_PZH) * 17, LAYER_SEP + ASAS_PZH)  # in ft
ALTS_traffic = np.arange(ALT_MIN, 400, ASAS_PZH * 3 + LAYER_SEP * 3)

ALTS_all = ALTS_all * ft  # pass to meters
ALTS_traffic = ALTS_traffic * ft  # pass to meters

PENALTY_LOS = 20
PENALTY_NEAR_LOS = 2  # 4
PENALTY_CONFLICT = 1  # 2

LOOK_AHEAD_TIME = 4000  # seconds - 20 minutes
STATIC_STRUCT = True
FINAL_TIME = 1200
# FINAL_TIME = 3600
TAS_AVG = 10.83  # kts - 10.83m/s

TIME_ACTION_DURATION = 10 # time from start the start of the action until the reward is calculated

def init_plugin():
    # Addtional initilisation code
    global merging_RL
    merging_RL = merging_RL()

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'merging_RL',

        # The type of this plugin. For now, only simulation plug8ins are possible.
        'plugin_type': 'sim',

        # The update function is called after traffic is updated.
        'update': merging_RL.update,

        # The reset function
        'reset': merging_RL.reset
    }
    stackfunctions = {}

    # init_plugin() should always return these two dicts.
    return config, stackfunctions


class merging_RL():

    def __init__(self):
        super(merging_RL, self).__init__()

    def reset(self):
        # all layers start as inactive. state, action
        self.actions_DDPG = [[] for i in range(3)]
        # ac waiting for action
        self.acs_waiting_action = np.array([])
        self.acs_waiting_action_time = np.array([])
        self.acs_waiting_action_layer = np.array([])
        self.acs_waiting_action_wpt_layer = np.array([])
        self.acs_performing_action = np.array([])
        self.performing_action = np.array([])
        self.acs_performing_action_time = np.array([])
        self.acs_performing_action_initial_layer = np.array([])
        self.acs_performing_action_wpt_layer = np.array([])
        self.conf_total_until_point = np.array([])
        self.los_total_until_point = np.array([])
        self.objective_layer = np.array([])>
        self.ac_has_reached_objective = np.array([])
        self.current_layer = np.zeros(bs.traf.ntraf)
        self.wpt_layer = np.zeros(bs.traf.ntraf)
        self.RL_actions = [[] for i in range(2)]

        bs.sim.RLsave(bs.stack.get_scenname())

    def getActiveWaypoint(self, aircraft):
        iactwp = bs.traf.ap.route[aircraft].iactwp
        return iactwp

    def getTransitLayers(self, wpt_layer, current_layer, ac_id):
        # do we want to go up or down?
        layer_dif = wpt_layer[ac_id] - current_layer[ac_id]
        # print('aux',  wpt_layer[ac_id] , current_layer[ac_id])

        if current_layer[ac_id] in [0, 3, 6, 9, 12, 15]:  # aircraft is in a traffic layer
            current_layer_id = 0
            if layer_dif > 0:  # ac is going up to fast layer
                next_layer_id = 1
                next_layer =  current_layer[ac_id] + 1
            else:  # ac is going down to slow layer
                next_layer_id = 2
                next_layer =  current_layer[ac_id] - 1

        elif current_layer[ac_id] in [1, 4, 7, 10, 13, 16]:  # aircraft is in a fast layer
            current_layer_id = 1
            if layer_dif > 0:  # ac is going up to slow layer
                next_layer_id = 2
                next_layer =  current_layer[ac_id] + 1
            else: # ac is going down to a traffic layer
                next_layer_id = 0
                next_layer =  current_layer[ac_id] - 1

        elif current_layer[ac_id] in [-1, 2, 5, 8, 11, 14]:  # aircraft is in a slow layer
            current_layer_id = 2
            if layer_dif > 0:  # ac is going up to traffic layer
                next_layer_id = 0
                next_layer = current_layer[ac_id] + 1
            else:  # ac is going down to fast layer
                next_layer_id = 1
                next_layer =  current_layer[ac_id] - 1

        #print(current_layer[ac_id], next_layer, current_layer_id, next_layer_id, layer_dif)
        return current_layer[ac_id], next_layer, current_layer_id, next_layer_id, layer_dif

    def getLosLayers(self,current_layer, next_layer, ac_id):
        default_value = 600
        tlos_currentlayer = default_value
        tlos_nextlayer = default_value

        # traffic situation in current layer
        ac_intruders = np.where(self.current_layer == current_layer)[0]
        if len(ac_intruders) > 0 and len(bs.traf.cd.tLOS_all) > ac_id and len(bs.traf.cd.tLOS_all[ac_id]) > max(
                ac_intruders):
            tlos_idx = np.where(bs.traf.cd.tLOS_all[ac_id][ac_intruders] > 0)[0]
            if len(tlos_idx) > 0:
                tlos_currentlayer = min(bs.traf.cd.tLOS_all[ac_id][ac_intruders][tlos_idx])

        # traffic situation in next layer
        ac_intruders = np.where(self.current_layer == next_layer)[0]
        if len(ac_intruders) > 0 and len(bs.traf.cd.tLOS_all) > ac_id and len(bs.traf.cd.tLOS_all[ac_id]) > max(
                ac_intruders):
            tlos_idx = np.where(bs.traf.cd.tLOS_all[ac_id][ac_intruders] > 0)[0]
            if len(tlos_idx) > 0:
                tlos_nextlayer = min(bs.traf.cd.tLOS_all[ac_id][ac_intruders][tlos_idx])


        return min(tlos_currentlayer, default_value), min(tlos_nextlayer, default_value)

    @timed_function('merging_RL', dt=1.0)
    def update(self, dt):
        ac_per_layer = np.zeros(len(ALTS_all))  # there are 6 traffic layers
        self.current_layer = np.zeros(bs.traf.ntraf)
        dist_next_waypoints = np.zeros(bs.traf.ntraf)
        time_next_waypoints = np.zeros(bs.traf.ntraf)
        self.wpt_layer = np.zeros(bs.traf.ntraf)

        # go through every aircraft
        for ac_id in range(bs.traf.ntraf):

            # get aircraft at each layer
            layer_now = min(range(len(ALTS_all)), key=lambda i: abs(ALTS_all[i] - bs.traf.alt[ac_id]))
            # if the ac is in the middle of two rows, we need to decide which one is correct based on vs direction
            if bs.traf.alt[ac_id] > ALTS_all[layer_now] and bs.traf.vs[ac_id] > 0:
                layer_now += 1
            elif bs.traf.alt[ac_id] < ALTS_all[layer_now] and bs.traf.vs[ac_id] < 0:
                layer_now -= 1

            layer_now = min(16, layer_now)
            self.current_layer[ac_id] = layer_now
            ac_per_layer[layer_now] += 1

            # get the next node that will result in a change
            layer_wpt = layer_now
            iactwp = self.getActiveWaypoint(ac_id)
            next_wpt = iactwp
            isTrafficLayer = False
            while (layer_wpt == layer_now or not isTrafficLayer) and next_wpt < len(bs.traf.ap.route[ac_id].wpalt):
                layer_wpt = min(range(len(ALTS_all)),
                                key=lambda i: abs(ALTS_all[i] - bs.traf.ap.route[ac_id].wpalt[next_wpt]))
                if layer_wpt in [0, 3, 6, 9, 12, 15]:
                    isTrafficLayer = True
                else:
                    isTrafficLayer = False
                next_wpt += 1

            if layer_wpt != layer_now and isTrafficLayer:
                self.wpt_layer[ac_id] = layer_wpt
                dist_next_waypoints[ac_id] = geo.kwikdist(bs.traf.lat[ac_id], bs.traf.lon[ac_id],
                                                          bs.traf.ap.route[ac_id].wplat[iactwp],
                                                          bs.traf.ap.route[ac_id].wplon[iactwp])
                # if dist_next_waypoints[ac_id]*nm/bs.traf.gs[ac_id] > 100000:
                #     print('time_next_waypoints', dist_next_waypoints[ac_id]*nm, bs.traf.gs[ac_id])
                time_next_waypoints[ac_id] = dist_next_waypoints[ac_id] * nm / bs.traf.gs[ac_id]
                if np.isinf(time_next_waypoints[ac_id]):
                    time_next_waypoints[ac_id] = 600

        # check if previous set actions are completed
        completed_actions = np.array([], dtype=np.int8)
        for it in range(len(self.acs_performing_action)):
            ac_name = self.acs_performing_action[it]
            ac_id = bs.traf.id2idx(ac_name)
            # we use the 5 seconds to get a good idea of the effect of an action. but what if during that time an ac moves to the objective layer and moves out. this needs to be counted as well
            if self.current_layer[ac_id] == self.objective_layer[it] or (ac_name in self.ac_has_reached_objective):
                if ac_name not in self.ac_has_reached_objective:
                    self.ac_has_reached_objective = np.append(self.ac_has_reached_objective, ac_name)
                if bs.sim.simt >= self.acs_performing_action_time[it] + TIME_ACTION_DURATION:  # give each action at least 5 seconds
                    # give back reward
                    self.setDDPGreward(ac_name, self.current_layer, time_next_waypoints, self.wpt_layer,
                                       self.los_total_until_point[it], self.conf_total_until_point[it],
                                       abs(self.current_layer[ac_id] - self.acs_performing_action_wpt_layer[it]),
                                       abs(self.acs_performing_action_initial_layer[it] - self.acs_performing_action_wpt_layer[it]),
                                       TAS_AVG, self.performing_action[it])
                    #self.performing_action[it][1])
                    completed_actions = np.append(completed_actions, int(it))

                    index = np.where(self.ac_has_reached_objective == ac_name)[0]
                    if len(index) > 0:
                        self.ac_has_reached_objective = np.delete(self.ac_has_reached_objective, index[0])

        if len(completed_actions) > 0:
            # print(bs.stack.get_scenname(), 'remove acs', self.acs_performing_action[completed_actions])
            self.acs_performing_action = np.delete(self.acs_performing_action, completed_actions)
            #for it_completed_actions in completed_actions[::-1]:
            #    self.performing_action.pop(it_completed_actions)
            self.performing_action = np.delete(self.performing_action, completed_actions)
            self.acs_performing_action_time = np.delete(self.acs_performing_action_time, completed_actions)
            self.acs_performing_action_initial_layer = np.delete(self.acs_performing_action_initial_layer, completed_actions)
            self.acs_performing_action_wpt_layer = np.delete(self.acs_performing_action_wpt_layer,completed_actions)
            self.objective_layer = np.delete(self.objective_layer, completed_actions)
            self.los_total_until_point = np.delete(self.los_total_until_point, completed_actions)
            self.conf_total_until_point = np.delete(self.conf_total_until_point, completed_actions)

        for ac_id in range(len(time_next_waypoints)):
            ac_name = bs.traf.id[ac_id]
            if self.wpt_layer[ac_id] != self.current_layer[ac_id] and ac_name not in self.acs_performing_action and \
                    ac_name not in self.acs_waiting_action and ac_name not in self.acs_waiting_action \
                    and ac_name not in self.actions_DDPG[0]:

                # if we are about to turn, dont bother calling the action
                if time_next_waypoints[ac_id] >= 5 and abs(self.wpt_layer[ac_id] - self.current_layer[ac_id]) > 1:

                    current_layer, next_layer, current_layer_id, next_layer_id, layer_dif = self.getTransitLayers(self.wpt_layer, self.current_layer, ac_id)
                    tlos_currentlayer, tlos_nextlayer = self.getLosLayers(current_layer, next_layer, ac_id)

                    # state = [tlos_trafficlayer, self.getACNumberNearBy(traffic_layer, ac_id),
                    #          tlos_fastlayer, self.getACNumberNearBy(fast_layer, ac_id),
                    #          tlos_slowlayer,self.getACNumberNearBy(slow_layer, ac_id),
                    #          time_next_waypoints[ac_id], abs(layer_dif)]

                    return

                    state = [current_layer_id, self.getACNumberNearBy(current_layer, ac_id),
                             tlos_currentlayer, next_layer_id, self.getACNumberNearBy(next_layer, ac_id), tlos_nextlayer,
                              time_next_waypoints[ac_id], abs(layer_dif)]

                    self.acs_waiting_action = np.append(self.acs_waiting_action, ac_name)
                    self.acs_waiting_action_time = np.append(self.acs_waiting_action_time, bs.sim.simt)
                    self.acs_waiting_action_layer = np.append(self.acs_waiting_action_layer, self.current_layer[ac_id])
                    self.acs_waiting_action_wpt_layer =  np.append(self.acs_waiting_action_wpt_layer, self.wpt_layer[ac_id])

                    if bs.sim.isRLcreated():
                        bs.sim.RLaction(bs.stack.get_scenname(), state)
                    else:
                        bs.sim.RLcreate(bs.stack.get_scenname(), state)

    def getACNumberNearBy(self, layer, ac_id):
        ac_inLayer = np.where(self.current_layer == int(layer))[0]
        if len(ac_inLayer) == 0:
            return 0

        dist = geo.kwikdist_matrix(bs.traf.lat[ac_inLayer], bs.traf.lon[ac_inLayer],
                            np.ones(len(ac_inLayer)) * bs.traf.lat[ac_id], np.ones(len(ac_inLayer)) * bs.traf.lon[ac_id])
        dist *= nm
        dist = dist[dist>0] # remove the ownship

        return len(dist < 2*bs.traf.cd.rpz)

    def isSafeMove(self, ac_idx, moveTo_layer):
        return True
        # where is the aircraft now
        current_layer = self.current_layer[ac_idx]

        if current_layer == moveTo_layer:
            intruders = np.where(self.current_layer == current_layer)[0]
            if ac_idx < len(bs.traf.cd.tLOS_all) and intruders[-1] < len(bs.traf.cd.tLOS_all[ac_idx]):
                tLOS_all = bs.traf.cd.tLOS_all[ac_idx][intruders]
            else:
                confpairs, lospairs, inconf, tcpamax, qdr, dist, dcpa, tcpa, tLOS, tcpa_all, tLOS_all, \
                dist_all, dcpa_all, dalt, tinhor, touthor, swhorconf = \
                    bs.traf.cd.detect(bs.traf, bs.traf, bs.traf.cd.rpz, bs.traf.cd.hpz, bs.traf.cd.dtlookahead)
                tLOS_all = tLOS_all[ac_idx][intruders]
        else:

            vsspeed_all = copy.deepcopy(bs.traf.vs)
            vsspeed_all[ac_idx] = 7.62
            confpairs, lospairs, inconf, tcpamax, qdr, dist, dcpa, tcpa, tLOS, tcpa_all, tLOS_all, \
            dist_all, dcpa_all, dalt, tinhor, touthor, swhorconf = \
                bs.traf.cd.detect(bs.traf, bs.traf, bs.traf.cd.rpz, bs.traf.cd.hpz, bs.traf.cd.dtlookahead)

        return np.all(tLOS_all > bs.traf.cd.dtlookahead)

    def applyPossibleAction(self, ac_idx, actions):
        # 0 - move to next traffic layer
        # 1 - move to fast layer
        # 2 - move to slow layer
        # 4 - stay in same layer
        objective_layer = None
        action_idx = np.argmax(actions)
        if action_idx < 3:  # if 3, do nothing
            traffic_layer, fast_layer, slow_layer, current_layer_id, layer_dif = \
                self.getTransitLayers(self.wpt_layer, self.current_layer, ac_idx)
            # print(self.wpt_layer, self.current_layer, ac_idx, '-',  traffic_layer, fast_layer, slow_layer, current_layer_id, layer_dif)
            self.RL_actions[0].append(ac_idx)
            if action_idx == 0 and self.isSafeMove(ac_idx, traffic_layer):  # move to traffic layer
                self.RL_actions[1].append(ALTS_all[int(traffic_layer)])
                objective_layer = int(traffic_layer)
            elif action_idx == 1 and self.isSafeMove(ac_idx, fast_layer):  # move to fast layer
                self.RL_actions[1].append(ALTS_all[int(fast_layer)])
                objective_layer = int(fast_layer)
            elif self.isSafeMove(ac_idx, slow_layer):  # move to slow layer
                self.RL_actions[1].append(ALTS_all[int(slow_layer)])
                objective_layer = int(slow_layer)
        else:
            objective_layer = self.current_layer[ac_idx]

        # which layer is the objective
        return objective_layer

    # is the moving to Layer in between the current layer and the objective?
    def actionIsValid(self, currentLayer, wpt_layer, moving_to_layer):
        # is ac is to continue in the same layer, there's no problem
        if currentLayer != moving_to_layer:
            # ac is going down
            if wpt_layer > currentLayer and moving_to_layer < wpt_layer:
                return False
            # ac is going down
            if wpt_layer < currentLayer and moving_to_layer > wpt_layer:
                return False
        return True

    def setDDPGaction(self, action):
        # state (dim = 5) + actions(sim=1)
        self.waiting_response = False
        ac_name = self.acs_waiting_action[0]
        ac_id = bs.traf.id2idx(ac_name)

        layer_at_time_action_call = self.acs_waiting_action_layer[self.acs_waiting_action == ac_name][0]
        time_action_was_requested = self.acs_waiting_action_time[self.acs_waiting_action == ac_name][0]
        wpt_layer_call = self.acs_waiting_action_wpt_layer[self.acs_waiting_action == ac_name][0]
        self.acs_waiting_action_layer = self.acs_waiting_action_layer[self.acs_waiting_action != ac_name]
        self.acs_waiting_action_time = self.acs_waiting_action_time[self.acs_waiting_action != ac_name]
        self.acs_waiting_action_wpt_layer  = self.acs_waiting_action_wpt_layer[self.acs_waiting_action != ac_name]  #
        self.acs_waiting_action = self.acs_waiting_action[self.acs_waiting_action != ac_name]  # no longer waiting for action, we got

        # perhaps the RL model took too long and the ac is already at the intended layer, then move on
        if self.current_layer[ac_id] != layer_at_time_action_call or bs.sim.simt > (time_action_was_requested + action[:STATE_DIM][-2] + 2):
            return

        current_layer, next_layer, current_layer_id, next_layer_id, layer_dif = self.getTransitLayers(self.wpt_layer, self.current_layer, ac_id)
        objective_layer = next_layer

        # save info that we performed this action so we can return the reward afterwards
        #if objective_layer is not None:
            # check if this action is still valid
        #isValid = self.actionIsValid(self.current_layer[ac_id], self.wpt_layer[ac_id], objective_layer)
        if True:
            # print(bs.stack.get_scenname(), 'adding DDPG', ac_name)
            self.actions_DDPG[0].append(ac_name)
            self.actions_DDPG[1].append(action[:STATE_DIM])
            self.actions_DDPG[2].append(action[STATE_DIM:])

            self.conf_total_until_point = np.append(self.conf_total_until_point, len(bs.traf.cd.confpairs_all))
            self.los_total_until_point = np.append(self.los_total_until_point, len(bs.traf.cd.lospairs_all))
            # print(bs.stack.get_scenname(), 'adding acs', bs.traf.id[ac_id])
            self.acs_performing_action = np.append(self.acs_performing_action, ac_name)
            #self.performing_action.append(action[STATE_DIM:])
            self.performing_action = np.append(self.performing_action, action[STATE_DIM:])
            self.acs_performing_action_time = np.append(self.acs_performing_action_time, bs.sim.simt)
            self.acs_performing_action_initial_layer = np.append(self.acs_performing_action_initial_layer, layer_at_time_action_call)
            self.acs_performing_action_wpt_layer = np.append(self.acs_performing_action_wpt_layer, wpt_layer_call)
            self.objective_layer = np.append(self.objective_layer, objective_layer)
            # print('ACTION', ac_name, self.current_layer[ac_idx], layer_at_time_action_call, objective_layer, action[:STATE_DIM])
            if action[STATE_DIM:][0] > 0.01:  # only if the RL model decided to move to a differnt layer
                self.updateWaypoints(ac_id, ALTS_all[int(objective_layer)])

    def updateWaypoints(self, ac, movingTo_alt):
        iactwp = bs.traf.ap.route[ac].iactwp
        it_wpt = iactwp
        current_alt = bs.traf.ap.route[ac].wpalt[max(it_wpt - 1, 0)]
        # print( final_objective)
        # print(bs.traf.ap.route[ac].wpalt[iactwp:])
        while it_wpt < len(bs.traf.ap.route[ac].wpalt) and bs.traf.ap.route[ac].wpalt[it_wpt] == current_alt and bs.traf.ap.route[ac].wpalt[it_wpt] != movingTo_alt:
            bs.traf.ap.route[ac].wpalt[it_wpt] = movingTo_alt
            it_wpt += 1
        # check the next element just to see if it's in a transition layer that is no longer npeded
        if it_wpt + 1 < len(bs.traf.ap.route[ac].wpalt) and bs.traf.ap.route[ac].wpalt[it_wpt] not in ALTS_traffic:
            if bs.traf.ap.route[ac].wpalt[it_wpt] < bs.traf.ap.route[ac].wpalt[it_wpt + 1] and \
                    bs.traf.ap.route[ac].wpalt[it_wpt - 1] > bs.traf.ap.route[ac].wpalt[it_wpt]:
                bs.traf.ap.route[ac].wpalt[it_wpt] = movingTo_alt
            elif bs.traf.ap.route[ac].wpalt[it_wpt] > bs.traf.ap.route[ac].wpalt[it_wpt + 1] and \
                    bs.traf.ap.route[ac].wpalt[it_wpt - 1] < bs.traf.ap.route[ac].wpalt[it_wpt]:
                bs.traf.ap.route[ac].wpalt[it_wpt] = movingTo_alt

        # print(bs.traf.ap.route[ac].wpalt[iactwp:])
        return

    def getActions(self):
        vs_needed = np.array([])
        alts = np.array([])
        layer_nums =np.array([])

        # lets tell pilot to move ac toward the desired action/layer
        for it in range(len(self.acs_performing_action)):
            ac_idx = bs.traf.id2idx(self.acs_performing_action[it])
            #if self.objective_layer[it] == self.current_layer[ac_idx]:
            #    continue
            layer_nums = np.append(layer_nums, self.current_layer[ac_idx])
            diff = np.sign(self.objective_layer[it] - self.acs_performing_action_initial_layer[it])
            if self.performing_action[it] > 0.01:
                alts = np.append(alts, ALTS_all[int(self.current_layer[ac_idx])])
            else:
                alts = np.append(alts, ALTS_all[int(self.objective_layer[it])])
            vs_needed = np.append(vs_needed, bs.traf.perf.vsmax[ac_idx] * diff * self.performing_action[it])
            #vs_needed = np.append(vs_needed, bs.traf.perf.vsmax[ac_idx] * diff * ((self.performing_action[it][0] + 1) / 2))
            #gs_needed = np.append(gs_needed, max(1, bs.traf.perf.vmax[ac_idx] *  ((self.performing_action[it][1] + 1) / 2)))
            #else:
            #    gs_needed[it] = bs.traf.perf.vmax[ac_idx]  *  ((self.performing_action[it][1] + 1) / 2)

        return self.acs_performing_action, vs_needed, alts
            # ac is already at the correct layer

    def detectConflicts(self, vs, hpz, dalt, tinhor, ntraf):
        I = np.eye(ntraf)

        dvs = vs.reshape(1, ntraf) - vs.reshape(1, ntraf).T
        dvs = np.where(np.abs(dvs) < 1e-6, 1e-6, dvs)  # prevent division by zero

        # Check for passing through each others zone
        tcrosshi = (dalt + hpz) / -dvs
        tcrosslo = (dalt - hpz) / -dvs
        tinver = np.minimum(tcrosshi, tcrosslo)

        # Combine vertical and horizontal conflict----------------------------------
        tinconf = np.maximum(tinver, tinhor)
        return tinconf

    def setDDPGreward(self, ac_name, current_layer, time_next_waypoint, wpt_layer, los_begin_action,
                      conf_begin_action, layer_dif, before_layer_dif, desired_speed, action):

        # print(ac_name)
        # print(bs.traf.cd.lospairs_all)
        # print(bs.traf.cd.confpairs_all)

        ac_id = bs.traf.id2idx(ac_name)
        reward = 0
        reward = self.getLoss(reward, ac_name, los_begin_action)
        #print('los', reward)
        reward = self.getNearMisses(reward, ac_name, conf_begin_action)
        #print('before', reward)
        if action > 0.01:
            reward = self.getDistanceToTarget(reward, layer_dif, before_layer_dif)
        #print(action, 'after', reward)
        reward = self.getSpeedDif(reward, ac_id, desired_speed)
        #print('desired_speed', reward)

        # print(bs.stack.get_scenname(), 'remove DDPG', ac_name)
        state_idx = np.where(np.asarray(self.actions_DDPG[0]) == ac_name)[0]

        state = self.actions_DDPG[1][state_idx[0]]
        action = self.actions_DDPG[2][state_idx[0]]

        # current_layer_id, tlos_trafficlayer, tlos_fastlayer, tlos_slowlayer, layer_dif, traffic_layer, \
        # fast_layer, slow_layer = self.getLosLayers(wpt_layer, current_layer, ac_id)
        #
        # next_state = [tlos_trafficlayer, self.getACNumberNearBy(traffic_layer, ac_id),
        #                      tlos_fastlayer, self.getACNumberNearBy(fast_layer, ac_id),
        #                      tlos_slowlayer,self.getACNumberNearBy(slow_layer, ac_id),
        #               time_next_waypoint[ac_id], abs(layer_dif)]

        current_layer, next_layer, current_layer_id, next_layer_id, dummy = self.getTransitLayers(self.wpt_layer,self.current_layer, ac_id)
        tlos_currentlayer, tlos_nextlayer = self.getLosLayers(current_layer, next_layer, ac_id)

        next_state = [current_layer_id, self.getACNumberNearBy(current_layer, ac_id),tlos_currentlayer, next_layer_id,
                      self.getACNumberNearBy(next_layer, ac_id), tlos_nextlayer, time_next_waypoint[ac_id], abs(layer_dif)]

        bs.sim.RLresult(bs.stack.get_scenname(), state, next_state, reward, action)

        for it in range(3):
            self.actions_DDPG[it].remove(self.actions_DDPG[it][state_idx[0]])
            # self.actions_DDPG[it] = np.delete(self.actions_DDPG[it], state_idx)

    def getSpeedDif(self, rewards, ac_id, desired_speed):
        #desired_speed = bs.traf.perf.vmax[ac_id] * ((desired_speed + 1) / 2)
        rewards -= abs(desired_speed-bs.traf.tas[ac_id])/5
        return rewards

    def getNearMisses(self, rewards, ac_name, conf_begin_action):
        new_conf = bs.traf.cd.lospairs_all[-int(conf_begin_action):]
        confpairs_unique = {frozenset(pair) for pair in new_conf}
        new_conf_pairs = [tuple(x) for x in confpairs_unique]
        for pair in new_conf_pairs:
            if pair[0] == ac_name or pair[1] == ac_name:
                rewards -= 1
        # if len(new_conf_pairs)> 0:
        #     print('conflicts', rewards, ac_name, new_conf_pairs)
        return rewards

    def getLoss(self, rewards, ac_name, los_begin_action):
        new_los = bs.traf.cd.lospairs_all[-int(los_begin_action):]
        lospairs_unique = {frozenset(pair) for pair in new_los}
        new_los_pairs = [tuple(x) for x in lospairs_unique]
        for pair in new_los_pairs:
            if pair[0] == ac_name or pair[1] == ac_name:
                rewards -= 10
        return rewards

    def getDistanceToTarget(self, rewards, layer_dif, previous_layer_dif):
        #print(rewards, layer_dif, previous_layer_dif, layer_dif - previous_layer_dif)
        rewards += abs(layer_dif - previous_layer_dif)*2
        return rewards
