from bluesky.tools.simtime import timed_function
import bluesky as bs
import numpy as np
from bluesky.tools import geo
from shapely.geometry.polygon import LineString
from bluesky.tools.aero import nm
from bluesky.tools.trafficarrays import TrafficArrays, RegisterElementParameters
import pickle

USE_MVP = True
MAX_INTRUDERS = 4
ACTION_DIM = 4
STATE_DIM = 4 + 5*MAX_INTRUDERS

ft  = 0.3048 # m    of 1 foot

class ConflictResoRL(TrafficArrays):

    def __init__(self):      
        TrafficArrays.__init__(self)  
        with RegisterElementParameters(self):
            self.trk = np.array([]) 
            self.tas = np.array([])  
            self.vs = np.array([]) 
            self.alt = np.array([])  

        #to Printactions
        #self.print_ac =  np.array([])
        #self.print_state =  np.array([])
        #self.print_current_heading =  np.array([])
        #self.print_current_speed =  np.array([])
        #self.print_heading_change =  np.array([])
        #self.print_speed_change =  np.array([])
        #self.print_MVP_heading =  np.array([])
        #self.print_MVP_speed =  np.array([])

    def dump_pickle(self, obj, name):
        with open(name, 'wb') as handle:
            pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def reset(self):
        self.RL_actions = []
        self.RL_states = []
        self.RL_aircraft = []
        self.RL_ac_done = []
        self.conf_los_ac = np.array([])
        self.conf_total_until_point = np.array([])
        self.los_total_until_point = np.array([])
        bs.sim.RLsave(bs.stack.get_scenname())

        scenarioName = bs.stack.get_scenname()
        if False: #len(scenarioName)> 0:
            print("RESET PICKLE", scenarioName, len(self.print_ac))
            self.dump_pickle(self.print_ac, 'results/save/print_ac _' + scenarioName)
            self.dump_pickle(self.print_state, 'results/save/print_state_' + scenarioName)
            self.dump_pickle(self.print_current_heading, 'results/save/print_current_heading_' + scenarioName)
            self.dump_pickle(self.print_current_speed, 'results/save/print_current_speed_' + scenarioName)
            self.dump_pickle(self.print_heading_change, 'results/save/print_heading_change_' + scenarioName)
            self.dump_pickle(self.print_speed_change, 'results/save/print_speed_change_' + scenarioName)
            self.dump_pickle(self.print_MVP_heading, 'results/save/print_MVP_heading_' + scenarioName)
            self.dump_pickle(self.print_MVP_speed, 'results/save/print_MVP_speed_' + scenarioName)       
      
    def getNormalized(self, bearing):
        if bearing > 180:
            bearing = -(360 - bearing)
        elif bearing < -180:
            bearing = 360 + bearing

        return bearing

    def getStateArray(self, ac_id):      
        #idx = np.arange(bs.traf.cd.tcpa_all[ac_id].size)[bs.traf.cd.dist_all[ac_id] > 0]
        closest_intruders = np.argsort(bs.traf.cd.dist_all[ac_id])[:MAX_INTRUDERS]       
        dist = bs.traf.cd.dist_all[ac_id][closest_intruders]        
        dist_cpa = bs.traf.cd.dcpa_all[ac_id][closest_intruders] 
        time_cpa = bs.traf.cd.tcpa_all[ac_id][closest_intruders] 
        #time_los = bs.traf.cd.tLOS_all[ac_id][closest_intruders]
        relative_bearing = bs.traf.cd.qdr_all[ac_id][closest_intruders] - bs.traf.trk[ac_id]
        relative_bearing = np.where(relative_bearing > 180, -(360 - relative_bearing), relative_bearing)
        relative_bearing = np.where(relative_bearing < -180, 360 + relative_bearing, relative_bearing)
        relative_altitude = bs.traf.cd.dalt[ac_id][closest_intruders] 

        state = np.asarray(self.getNormalized(bs.traf.trk[ac_id]))      
        state = np.append(state, self.getNormalized(bs.traf.ap.trk[ac_id]))   
        state = np.append(state, bs.traf.gs[ac_id])
        state = np.append(state, bs.traf.alt[ac_id]- 100*ft)

        while len(dist) < MAX_INTRUDERS:
            dist = np.append(dist, 3000)

        while len(dist_cpa) < MAX_INTRUDERS:
            dist_cpa = np.append(dist_cpa, 3000)

        while len(time_cpa) < MAX_INTRUDERS:
            time_cpa = np.append(time_cpa, 600)

        while len(relative_bearing) < MAX_INTRUDERS:
            relative_bearing = np.append(relative_bearing, 180)

        while len(relative_altitude) < MAX_INTRUDERS:
            relative_altitude = np.append(relative_altitude, 100*ft)

        state = np.append(state, dist)
        state = np.append(state, dist_cpa)
        state = np.append(state, time_cpa)
        state = np.append(state, relative_bearing) 
        state = np.append(state, relative_altitude)   
        return state

    def update(self):           
        inconf = np.where(bs.traf.cd.inconf == True)[0]
        if len(inconf) > 0:
            #if bs.traf.ntraf > NUMBER_ACTORS_MARL: # we need to divide all aircraft into groups based on their current position
            #    n_cluster, cluster_indexes  = self.getClusters()
            #else:
            #cluster_indexes = [np.arange(bs.traf.ntraf)]       
            states = []                 
            self.RL_aircraft.append(np.asarray(bs.traf.id)[inconf].tolist())
            for ac_id in inconf:
                states.append(self.getStateArray(ac_id).tolist()  )       
            bs.sim.RLaction(bs.stack.get_scenname(), states)

        exclude = np.array([], dtype=np.uint32)
        for it in range(len(self.RL_ac_done)):
            acs = self.RL_ac_done[it]
            self.setRLreward(it, acs)            
            exclude = np.append(exclude, it)

        for it in exclude[::-1]:        
            self.RL_ac_done.pop(it)
            self.RL_actions.pop(it)
            self.RL_states.pop(it)

    def getTLosAll(self):
        I = np.eye(bs.traf.ntraf)

        # Horizontal conflict ------------------------------------------------------

        # qdlst is for [i,j] qdr from i to j, from perception of ADSB and own coordinates
        qdr, dist = geo.kwikqdrdist_matrix(np.asmatrix(bs.traf.lat), np.asmatrix(bs.traf.lon),
                                    np.asmatrix(bs.traf.lat), np.asmatrix(bs.traf.lon))

        # Convert back to array to allow element-wise array multiplications later on
        # Convert to meters and add large value to own/own pairs
        qdr = np.asarray(qdr)
        dist = np.asarray(dist) * nm + 1e9 * I

        # Calculate horizontal closest point of approach (CPA)
        qdrrad = np.radians(qdr)
        dx = dist * np.sin(qdrrad)  # is pos j rel to i
        dy = dist * np.cos(qdrrad)  # is pos j rel to i

        # Ownship track angle and speed
        owntrkrad = np.radians(bs.traf.trk)
        ownu = bs.traf.gs * np.sin(owntrkrad).reshape((1, bs.traf.ntraf))  # m/s
        ownv = bs.traf.gs * np.cos(owntrkrad).reshape((1, bs.traf.ntraf))  # m/s

        # Intruder track angle and speed
        inttrkrad = np.radians(bs.traf.trk)
        intu = bs.traf.gs * np.sin(inttrkrad).reshape((1, bs.traf.ntraf))  # m/s
        intv = bs.traf.gs * np.cos(inttrkrad).reshape((1, bs.traf.ntraf))  # m/s

        du = ownu - intu.T  # Speed du[i,j] is perceived eastern speed of i to j
        dv = ownv - intv.T  # Speed dv[i,j] is perceived northern speed of i to j

        dv2 = du * du + dv * dv
        dv2 = np.where(np.abs(dv2) < 1e-6, 1e-6, dv2)  # limit lower absolute value
        vrel = np.sqrt(dv2)

        tcpa = -(du * dx + dv * dy) / dv2 + 1e9 * I

        # Calculate distance^2 at CPA (minimum distance^2)
        dcpa2 = np.abs(dist * dist - tcpa * tcpa * dv2)

        # Check for horizontal conflict
        R2 = bs.traf.cd.rpz * bs.traf.cd.rpz
        swhorconf = dcpa2 < R2  # conflict or not

        # Calculate times of entering and leaving horizontal conflict
        dxinhor = np.sqrt(np.maximum(0., R2 - dcpa2))  # half the distance travelled inzide zone
        dtinhor = dxinhor / vrel

        tinhor = np.where(swhorconf, tcpa - dtinhor, 1e8)  # Set very large if no conf
        touthor = np.where(swhorconf, tcpa + dtinhor, -1e8)  # set very large if no conf

        # Vertical conflict --------------------------------------------------------

        # Vertical crossing of disk (-dh,+dh)
        dalt = bs.traf.alt.reshape((1, bs.traf.ntraf)) - \
            bs.traf.alt.reshape((1, bs.traf.ntraf)).T  + 1e9 * I

        dvs = bs.traf.vs.reshape(1, bs.traf.ntraf) - \
            bs.traf.vs.reshape(1, bs.traf.ntraf).T
        dvs = np.where(np.abs(dvs) < 1e-6, 1e-6, dvs)  # prevent division by zero

        # Check for passing through each others zone
        tcrosshi = (dalt + bs.traf.cd.hpz) / -dvs
        tcrosslo = (dalt - bs.traf.cd.hpz) / -dvs
        tinver = np.minimum(tcrosshi, tcrosslo)
        toutver = np.maximum(tcrosshi, tcrosslo)

        # Combine vertical and horizontal conflict----------------------------------
        tinconf = np.maximum(tinver, tinhor)

        return tinconf

    def getValues(self):       
        if len(self.trk) != bs.traf.ntraf:
            print('resetting arrays in conflictResoRL', len(self.trk) , bs.traf.ntraf, bs.sim.simt, bs.stack.get_scenname())
            self.trk = np.zeros(bs.traf.ntraf)
            self.tas =  np.zeros(bs.traf.ntraf)
            self.vs = np.zeros(bs.traf.ntraf)
            self.alt = np.zeros(bs.traf.ntraf)

        return self.trk, self.tas, self.vs, self.alt, np.copy(bs.traf.ap.trk)

    def setRLaction(self, action):
        self.RL_states.append(action[:int(len(action)/2)])
        self.RL_actions.append(action[int(len(action)/2):])
        acs = self.RL_aircraft.pop(0)
        self.RL_ac_done.append(acs)

        vmin, vmax, vsmin, vsmax = bs.traf.perf.currentlimits()

        for ac_name_it in range(len(acs)):
            ac_name = acs[ac_name_it]
            ac_idx = bs.traf.id2idx(ac_name)   
            self.conf_los_ac = np.append(self.conf_los_ac, ac_name)            
            self.conf_total_until_point = np.append(self.conf_total_until_point, len(bs.traf.cd.confpairs_all))
            self.los_total_until_point = np.append(self.los_total_until_point, len(bs.traf.cd.lospairs_all))
            tLOS_all = self.getTLosAll()

            # just in case the RL model decides to do nothing, then continue with the same sate
            self.trk[ac_idx] = bs.traf.trk[ac_idx]
            self.tas[ac_idx] = bs.traf.tas[ac_idx]
            self.vs[ac_idx] = bs.traf.vs[ac_idx] 
            self.alt[ac_idx] = bs.traf.alt[ac_idx] 

            #self.print_ac = np.append(self.print_ac, ac_name)
            #self.print_state = np.append(self.print_state, action[:int(len(action)/2)])
            #self.print_current_heading = np.append(self.print_current_heading, self.trk[ac_idx])
            #self.print_current_speed = np.append(self.print_current_speed, self.tas[ac_idx] )

            if USE_MVP:
                # if is only speed, it means that this needs to be decoupled
                look_ahead_time = min((self.RL_actions[-1][ac_name_it][0]+1) * bs.traf.cd.dtlookahead/2, bs.traf.cd.dtlookahead)
                heading_solution = self.RL_actions[-1][ac_name_it][1] >= 0
                speed_solution = self.RL_actions[-1][ac_name_it][2] >= 0
                vertical_solution =  self.RL_actions[-1][ac_name_it][3] >= 0
                #look_ahead_time2 = 300
                #heading_solution =  True
                #speed_solution = True
                intruders_idx = np.where(tLOS_all[ac_idx] < look_ahead_time)[0]
                #intruders_idx2 = np.where(tLOS_all[ac_idx] < look_ahead_time2)[0]
                
                if len(intruders_idx) > 0 and np.any(np.array([heading_solution, speed_solution, vertical_solution])):  
                    dv = np.array([0,0])
                    asas_mar = 1.05

                    # Relative position vector between id1 and id2
                    for idx2 in intruders_idx:
                        qdr, dist = geo.kwikqdrdist(bs.traf.lat[ac_idx], bs.traf.lon[ac_idx], bs.traf.lat[idx2], bs.traf.lon[idx2])
                        qdr_radians = np.deg2rad(qdr)
                        dist *= nm
                        drel = np.array([np.sin(qdr_radians)*dist, np.cos(qdr_radians)*dist, bs.traf.alt[idx2]-bs.traf.alt[ac_idx]])
                        # Write velocities as vectors and find relative velocity vector
                        v1 = np.array([bs.traf.gseast[ac_idx], bs.traf.gsnorth[ac_idx], bs.traf.vs[ac_idx]])
                        v2 = np.array([bs.traf.gseast[idx2], bs.traf.gsnorth[idx2], bs.traf.vs[idx2]])
                        vrel = np.array(v2-v1)
                        # Find horizontal disbattance at the tcpa (min horizontal distance)
                        tcpa = tLOS_all[ac_idx][idx2]
                        dcpa  = drel + vrel*tcpa
                        dabsH = np.sqrt(dcpa[0]*dcpa[0]+dcpa[1]*dcpa[1])
                        # Compute horizontal intrusion
                        iH = (bs.traf.cd.rpz * asas_mar) - dabsH
                        # Exception handlers for head-on conflicts
                        # This is done to prevent division by zero in the next step
                        if dabsH <= 10.:
                            dabsH = 10.
                            dcpa[0] = drel[1] / dist * dabsH
                            dcpa[1] = -drel[0] / dist * dabsH
                        # If intruder is outside the ownship PZ, then apply extra factor
                        # to make sure that resolution does not graze IPZ                            
                        if (bs.traf.cd.rpz * asas_mar) < dist and dabsH < dist:
                            # Compute the resolution velocity vector in horizontal direction.
                            # abs(tcpa) because it bcomes negative during intrusion.
                            erratum=np.cos(np.arcsin((bs.traf.cd.rpz * asas_mar)/dist)-np.arcsin(dabsH/dist))
                            dv1 = (((bs.traf.cd.rpz * asas_mar)/erratum - dabsH)*dcpa[0])/(abs(tcpa)*dabsH)
                            dv2 = (((bs.traf.cd.rpz * asas_mar)/erratum - dabsH)*dcpa[1])/(abs(tcpa)*dabsH)
                        else:
                            dv1 = (iH * dcpa[0]) / (abs(tcpa) * dabsH)
                            dv2 = (iH * dcpa[1]) / (abs(tcpa) * dabsH)
                        dv_mvp = np.array([dv1,dv2])     
                        dv  = dv - dv_mvp
                    dv = np.transpose(dv)
                    v = np.array([bs.traf.gseast[ac_idx], bs.traf.gsnorth[ac_idx]])
                    newv = v + dv

                    if vertical_solution:
                        iV = (bs.traf.cd.hpz * bs.traf.cr.resofacv) if abs(vrel[2])>0.0 else (bs.traf.cd.hpz * bs.traf.cr.resofacv)-abs(drel[2])
                        tsolV = abs(drel[2]/vrel[2]) if abs(vrel[2])>0.0 else tLOS_all[ac_idx][idx2]
                        if tsolV>look_ahead_time:
                            tsolV = tLOS_all[ac_idx][idx2]
                            iV    = ( bs.traf.cd.hpz * bs.traf.cr.resofacv)
                        newvs = np.where(abs(vrel[2])>0.0,  (iV/tsolV)*(-vrel[2]/abs(vrel[2])), (iV/tsolV))
                        vscapped = np.maximum(vsmin[ac_idx],np.minimum(vsmax[ac_idx],newvs))
                        asasalttemp = vscapped * tsolV + bs.traf.alt[ac_idx]
                        signdvs = np.sign(vscapped - bs.traf.ap.vs[ac_idx] * np.sign(bs.traf.selalt[ac_idx] - bs.traf.alt[ac_idx]))
                        signalt = np.sign(asasalttemp - bs.traf.selalt[ac_idx])
                        alt = np.where(np.logical_or(signdvs == 0, signdvs == signalt), asasalttemp, bs.traf.selalt[ac_idx])
                        #print('vs',vertical_solution,  vsmin[ac_idx], vsmax[ac_idx], alt)
                        self.vs[ac_idx] = vscapped
                        self.alt[ac_idx] = alt

                    #previous_heading = self.trk[ac_idx]
                    #previous_speed = self.tas[ac_idx]
                    if heading_solution and speed_solution:   # heading + speed solution
                        self.trk[ac_idx] = (np.arctan2(newv[0],newv[1])*180/np.pi) %360
                        self.tas[ac_idx] = np.sqrt(newv[0] ** 2 + newv[1] ** 2)                          
                        #self.active[ac_idx] = True  
                    elif heading_solution and not speed_solution: # heading solution only
                        self.trk[ac_idx] = (np.arctan2(newv[0],newv[1])*180/np.pi) %360
                        self.tas[ac_idx] = bs.traf.tas[ac_idx]
                        #self.active[ac_idx] = True  
                    else: # speed solution only
                        self.trk[ac_idx] = bs.traf.trk[ac_idx]
                        self.tas[ac_idx] = np.sqrt(newv[0] ** 2 + newv[1] ** 2)  
                        #self.active[ac_idx] = True  

                    #self.print_heading_change =  np.append(self.print_heading_change, self.trk[ac_idx] )
                    #self.print_speed_change =  np.append(self.print_speed_change, self.tas[ac_idx] )

                if False: #len(intruders_idx2) > 0:
                    dv = np.array([0,0])
                    asas_mar = 1.05

                    for idx2 in intruders_idx2:
                        qdr, dist = geo.kwikqdrdist(bs.traf.lat[ac_idx], bs.traf.lon[ac_idx], bs.traf.lat[idx2], bs.traf.lon[idx2])
                        qdr_radians = np.deg2rad(qdr)
                        dist *= nm
                        drel = np.array([np.sin(qdr_radians)*dist, np.cos(qdr_radians)*dist, bs.traf.alt[idx2]-bs.traf.alt[ac_idx]])
                        # Write velocities as vectors and find relative velocity vector
                        v1 = np.array([bs.traf.gseast[ac_idx], bs.traf.gsnorth[ac_idx], bs.traf.vs[ac_idx]])
                        v2 = np.array([bs.traf.gseast[idx2], bs.traf.gsnorth[idx2], bs.traf.vs[idx2]])
                        vrel = np.array(v2-v1)
                        # Find horizontal disbattance at the tcpa (min horizontal distance)
                        tcpa = tLOS_all[ac_idx][idx2]
                        dcpa  = drel + vrel*tcpa
                        dabsH = np.sqrt(dcpa[0]*dcpa[0]+dcpa[1]*dcpa[1])
                        # Compute horizontal intrusion
                        iH = (bs.traf.cd.rpz * asas_mar) - dabsH
                        # Exception handlers for head-on conflicts
                        # This is done to prevent division by zero in the next step
                        if dabsH <= 10.:
                            dabsH = 10.
                            dcpa[0] = drel[1] / dist * dabsH
                            dcpa[1] = -drel[0] / dist * dabsH
                        # If intruder is outside the ownship PZ, then apply extra factor
                        # to make sure that resolution does not graze IPZ                            
                        if (bs.traf.cd.rpz * asas_mar) < dist and dabsH < dist:
                            # Compute the resolution velocity vector in horizontal direction.
                            # abs(tcpa) because it bcomes negative during intrusion.
                            erratum=np.cos(np.arcsin((bs.traf.cd.rpz * asas_mar)/dist)-np.arcsin(dabsH/dist))
                            dv1 = (((bs.traf.cd.rpz * asas_mar)/erratum - dabsH)*dcpa[0])/(abs(tcpa)*dabsH)
                            dv2 = (((bs.traf.cd.rpz * asas_mar)/erratum - dabsH)*dcpa[1])/(abs(tcpa)*dabsH)
                        else:
                            dv1 = (iH * dcpa[0]) / (abs(tcpa) * dabsH)
                            dv2 = (iH * dcpa[1]) / (abs(tcpa) * dabsH)
                        dv_mvp = np.array([dv1,dv2])     
                        dv  = dv - dv_mvp
                    dv = np.transpose(dv)
                    v = np.array([bs.traf.gseast[ac_idx], bs.traf.gsnorth[ac_idx]])
                    newv = v + dv
                    #self.print_MVP_heading =   np.append(self.print_MVP_heading, (np.arctan2(newv[0],newv[1])*180/np.pi) %360)
                    #self.print_MVP_speed =  np.append(self.print_MVP_speed, np.sqrt(newv[0] ** 2 + newv[1] ** 2))

                    # if abs(self.trk[ac_idx] - previous_heading) > 15:
                    #     a = self.trk[ac_idx]  - previous_heading
                    #     if a > 180:
                    #         self.trk[ac_idx] = previous_heading - 15
                    #     else:
                    #         self.trk[ac_idx] = previous_heading + 15
                    # if abs(self.tas[ac_idx] - previous_speed) > 5:
                    #     if self.tas[ac_idx] < previous_speed:
                    #         self.tas[ac_idx]  = max(previous_speed - 5, 0)
                    #     else:
                    #         self.tas[ac_idx]  = min(previous_speed + 5, vmax[ac_idx])
                        
            # else:
            # self.trk[ac_idx] = bs.traf.trk[ac_idx] + self.DDPG_actions[-1][ac_name_it][0] * 15
            # self.trk[ac_idx] = self.trk[ac_idx] %360
            # self.tas[ac_idx] = bs.traf.tas[ac_idx] + self.DDPG_actions[-1][ac_name_it][1] * 5    
            # self.tas[ac_idx] = min(max(self.tas[ac_idx] , vmin[ac_idx]), vmax[ac_idx])
            # self.vs[ac_idx] = bs.traf.vs[ac_idx] + self.DDPG_actions[-1][ac_name_it][2] * 2    
            # self.vs[ac_idx] = min(max(self.vs[ac_idx] , vsmin[ac_idx]), vsmax[ac_idx])
            # tsolV = 1
            # asasalttemp = self.vs[ac_idx] * tsolV + bs.traf.alt[ac_idx]
            # signdvs = np.sign(self.vs[ac_idx] - bs.traf.ap.vs[ac_idx] * np.sign(bs.traf.selalt[ac_idx] - bs.traf.alt[ac_idx]))
            # signalt = np.sign(asasalttemp - bs.traf.selalt[ac_idx])
            # alt = np.where(np.logical_or(signdvs == 0, signdvs == signalt), asasalttemp, bs.traf.selalt[ac_idx])
            # self.alt[ac_idx] = alt              

    def checkCRsolutionValid(self, asase,heading_solution, speed_solution, speed_dif):
        if not heading_solution: # speed only 
            return True
        if asase is not None: # we have a solution
            if not speed_solution: # heading only
                if speed_dif < 1: # 2m/s
                    return True
            else:
                return True
        return False

    def findLineSolution(self, hdg, vmax, all_VOs, speed_now):
        found_solution = False
        solution_speed_dif = np.array([])
        solutions_e = np.array([])
        solution_n = np.array([])
        line = LineString([[0, 0], [np.sin(hdg) * vmax, np.cos(hdg) * vmax]])

        try:
            differences = line.difference(all_VOs)
            if type(differences) is LineString: # when there is only one line difference, so the next code doenst fail
                differences = [differences]
            for line_it in range(len(differences)):
                #for it in range(len(differences[line_it].xy)):
                if not differences[line_it].is_empty:
                    speed_begin_line = np.sqrt(differences[line_it].xy[0][0] ** 2 + differences[line_it].xy[1][0] ** 2)
                    speed_end_line = np.sqrt(differences[line_it].xy[0][1] ** 2 + differences[line_it].xy[1][1] ** 2)
                    if speed_begin_line > speed_now:
                        solution_speed_dif = np.append(solution_speed_dif, abs(speed_now - speed_begin_line))
                        solutions_e = np.append(solutions_e, differences[line_it].xy[0][0])
                        solution_n = np.append(solution_n, differences[line_it].xy[1][0])
                    elif speed_end_line < speed_now:
                        solution_speed_dif = np.append(solution_speed_dif, abs(speed_end_line - speed_now))
                        solutions_e = np.append(solutions_e, differences[line_it].xy[0][1])
                        solution_n = np.append(solution_n, differences[line_it].xy[1][1])
                    else:
                        solution_speed_dif = np.append(solution_speed_dif, 0)
                        solutions_e = np.append(solutions_e, speed_now*np.sin(hdg))
                        solution_n = np.append(solution_n, speed_now*np.cos(hdg))
                    found_solution = True          
        except:
            print("shapelly could not find difference", all_VOs)
            
        if found_solution:
            ind = np.argmin(solution_speed_dif)
            return solutions_e[ind], solution_n[ind], solution_speed_dif[ind]
        else:
            return None, None, None

    def setRLreward(self, it, acs):
        rewards = []
        for ac_name in acs:
            ind = np.where(self.conf_los_ac == ac_name)[0]            
            ac_id = bs.traf.id2idx(ac_name)
            reward = 0
            reward = self.getLoss(reward, ac_id, self.los_total_until_point[ind])
            #reward = self.getDriftTarget(reward, ac_id)
            rewards.append(reward)

            # remove elements at position 'ind'
            self.los_total_until_point = np.delete(self.los_total_until_point, ind)
            self.conf_total_until_point = np.delete(self.conf_total_until_point, ind)
            self.conf_los_ac = np.delete(self.conf_los_ac, ind)            

        state = self.RL_states[it]
        action = self.RL_actions[it]

        next_state = []
        for ac_name in acs:
            next_state.append(self.getStateArray(ac_id).tolist())

        #print('setResult',state, next_state, sum(rewards), action )
        bs.sim.RLresult(bs.stack.get_scenname(), state, next_state, sum(rewards), action)
           
    def qdrdist_matrix_indices(self, ntraf):
        tmp_range = np.arange(ntraf - 1, dtype=np.int32)
        ind1 = np.repeat(tmp_range, (tmp_range + 1)[::-1])
        ind2 = np.ones(ind1.shape[0], dtype=np.int32)
        inds = np.cumsum(tmp_range[1:][::-1] + 1)
        np.put(ind2, inds, np.arange(ntraf * -1 + 3, 1))
        ind2 = np.cumsum(ind2, out=ind2)
        return ind1, ind2

    def getDriftTarget(self, rewards, ac_idx): 
        drift = abs(bs.traf.ap.trk[ac_idx] -bs.traf.trk[ac_idx])/360
        rewards -= drift
        return rewards

    def getNearMisses(self, rewards, ac_idx, conf_begin_action):
        new_conf = bs.traf.cd.confpairs_all_cruising[int(conf_begin_action):]
        confpairs_unique = {frozenset(pair) for pair in new_conf}
        new_conf_pairs = [tuple(x) for x in confpairs_unique]
        for pair in new_conf_pairs:
            idx0 = bs.traf.id2idx(pair[0])
            idx1 = bs.traf.id2idx(pair[1])
            if idx0 == ac_idx or idx1 == ac_idx:
                rewards -= 1
        return rewards

    def getLoss(self, rewards, ac_idx, los_begin_action):
        if int(los_begin_action) < len( bs.traf.cd.lospairs_all):
            new_los = bs.traf.cd.lospairs_all[int(los_begin_action):]
            lospairs_unique = {frozenset(pair) for pair in new_los}
            new_los_pairs = [tuple(x) for x in lospairs_unique]
            rewards -= len(lospairs_unique)
            # for pair in new_los_pairs:
            #     idx0 = bs.traf.id2idx(pair[0])
            #     idx1 = bs.traf.id2idx(pair[1])
            #     if idx0 == ac_idx or idx1 == ac_idx:
            #         rewards -= 1
            #print(rewards, lospairs_unique)
        return rewards