from bluesky.tools.simtime import timed_function
import bluesky as bs
import numpy as np
from bluesky.tools import geo
from shapely.ops import unary_union
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon, LineString
from bluesky.tools.aero import nm
from bluesky.traffic.asas.mvp import MVP


#STATE (closest # intruders)
# distance to CPA closest intruders 
# time to CPA closest intruders 
# relative bearing closest intruders
# current speed
# relative bearing to next waypoint
# 5*3 + 2 = 17

# action
# combinations for MVP

#reward (per cluster)
# Number of LoSs
# Drift from bearing to next waypoing 
# If aircraft is in secondary altitude level

USE_MVP = False
MAX_INTRUDERS = 4
ACTION_DIM = 2
STATE_DIM = 3 + 4*MAX_INTRUDERS
#from sklearn.cluster import KMeans
#from pandas import DataFrame

def init_plugin():
    # Addtional initilisation code
    global ConflictResolution
    ConflictResolution = ConflictResolution()

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'ConflictResolution',

        # The type of this plugin. For now, only simulation plug8ins are possible.
        'plugin_type': 'sim',

        # The update function is called after traffic is updated.
        'update': ConflictResolution.update,

        # The reset function
        'reset': ConflictResolution.reset
    }
    stackfunctions = {}

    # init_plugin() should always return these two dicts.s
    return config, stackfunctions


class ConflictResolution():

    def __init__(self):       
        super(ConflictResolution, self).__init__()

    def reset(self):
        self.DDPG_actions = []
        self.DDPG_states = []
        self.DDPG_aircraft = []
        self.DDPG_ac_done = []
        self.conf_los_ac = np.array([])
        self.conf_total_until_point = np.array([])
        self.los_total_until_point = np.array([])

        bs.sim.RLsave(bs.stack.get_scenname())

    def getNormalized(self, bearing):
        if bearing > 180:
            bearing = -(360 - bearing)
        elif bearing < -180:
            bearing = 360 + bearing

        return bearing

    def getStateArray(self, ac_id):
        # # only positive tcpa
        # idx = np.arange(bs.traf.cd.tcpa_all[ac_id].size)[bs.traf.cd.tcpa_all[ac_id] > 0]
        # closest_intruders_tcpa = np.argsort(bs.traf.cd.tcpa_all[ac_id][idx])[:MAX_INTRUDERS]     
        # closest_intruders_dist = np.argsort(bs.traf.cd.dist_all[ac_id])[:MAX_INTRUDERS*2]

        # closest_intruders = np.array([])
        # for int_idx in np.append(closest_intruders_tcpa, closest_intruders_dist):
        #     if int_idx not in closest_intruders:
        #         closest_intruders = np.append(closest_intruders, int(int_idx))
        #         if len(closest_intruders) >= MAX_INTRUDERS*2:
        #             break
                
        # closest_intruders = closest_intruders.astype(int)
        # dist = bs.traf.cd.dcpa_all[ac_id][closest_intruders]        
        # time_cpa = bs.traf.cd.tcpa_all[ac_id][closest_intruders] 
        # relative_bearing = bs.traf.cd.qdr_all[ac_id][closest_intruders] - bs.traf.trk[ac_id]
        # relative_bearing = np.where(relative_bearing > 180, -(360 - relative_bearing), relative_bearing)
        # relative_bearing = np.where(relative_bearing < -180, 360 + relative_bearing, relative_bearing)

        # state = np.asarray(self.getNormalized(bs.traf.trk[ac_id]))      
        # state = np.append(state, self.getNormalized(bs.traf.ap.trk[ac_id]))   
        # state = np.append(state, bs.traf.gs[ac_id])

        # if len(dist) < MAX_INTRUDERS*2:
        #     dist = np.append(dist, 3000)

        # if len(time_cpa) < MAX_INTRUDERS*2:
        #     time_cpa = np.append(time_cpa, 600)

        # if len(relative_bearing) < MAX_INTRUDERS*2:
        #     relative_bearing = np.append(relative_bearing, 180)

        # state = np.append(state, dist)
        # state = np.append(state, time_cpa)
        # state = np.append(state, relative_bearing)      

        #idx = np.arange(bs.traf.cd.tcpa_all[ac_id].size)[bs.traf.cd.dist_all[ac_id] > 0]
        closest_intruders = np.argsort(bs.traf.cd.dist_all[ac_id])[:MAX_INTRUDERS]       
        dist = bs.traf.cd.dist_all[ac_id][closest_intruders]        
        dist_cpa = bs.traf.cd.dcpa_all[ac_id][closest_intruders] 
        time_cpa = bs.traf.cd.tcpa_all[ac_id][closest_intruders] 
        relative_bearing = bs.traf.cd.qdr_all[ac_id][closest_intruders] - bs.traf.trk[ac_id]
        relative_bearing = np.where(relative_bearing > 180, -(360 - relative_bearing), relative_bearing)
        relative_bearing = np.where(relative_bearing < -180, 360 + relative_bearing, relative_bearing)

        state = np.asarray(self.getNormalized(bs.traf.trk[ac_id]))      
        state = np.append(state, self.getNormalized(bs.traf.ap.trk[ac_id]))   
        state = np.append(state, bs.traf.gs[ac_id])

        if len(dist) < MAX_INTRUDERS:
            dist = np.append(dist, 3000)

        if len(time_cpa) < MAX_INTRUDERS:
            dist_cpa = np.append(dist_cpa, 3000)

        if len(time_cpa) < MAX_INTRUDERS:
            time_cpa = np.append(time_cpa, 600)

        if len(relative_bearing) < MAX_INTRUDERS:
            relative_bearing = np.append(relative_bearing, 180)

        state = np.append(state, dist)
        state = np.append(state, dist_cpa)
        state = np.append(state, time_cpa)
        state = np.append(state, relative_bearing)      
                
        return state

    def getClusters(self):        
        n_cluster = int(np.ceil(bs.traf.ntraf/NUMBER_ACTORS_MARL))
        ids = []
        lat = []
        lon = []
        for flight_idx in range(bs.traf.ntraf):
            ids.append(flight_idx)
            lat.append(bs.traf.lat[flight_idx])
            lon.append(bs.traf.lon[flight_idx])
        df = DataFrame( {'id': ids, 'lat': lat, 'lon': lon})
        kmeans = KMeans(n_clusters = n_cluster).fit(df[['lat', 'lon']])    
        cluster_indexes = [ [] for _ in range(n_cluster)]
        # kmeans does not limit the number of points per cluster, so we have to do it ourselfs
        for flight_idx in range(bs.traf.ntraf):
            distance_to_centers = [0] * n_cluster
            for idx_center in range(n_cluster):
                distance_to_centers[idx_center] = geo.kwikdist(bs.traf.lat[flight_idx], bs.traf.lon[flight_idx], kmeans.cluster_centers_[idx_center][0], kmeans.cluster_centers_[idx_center][1])
            picked_center = 0
            while len(cluster_indexes[np.argsort(distance_to_centers)[picked_center]]) >=NUMBER_ACTORS_MARL:
                picked_center +=1
            cluster_indexes[np.argsort(distance_to_centers)[picked_center]].append(flight_idx)

        return n_cluster, cluster_indexes  

    @timed_function('ConflictResolution', dt=1.0)
    def update(self, dt):           
        inconf = np.where(bs.traf.cd.inconf == True)[0]
        if len(inconf) > 0:
            #if bs.traf.ntraf > NUMBER_ACTORS_MARL: # we need to divide all aircraft into groups based on their current position
            #    n_cluster, cluster_indexes  = self.getClusters()
            #else:
            #cluster_indexes = [np.arange(bs.traf.ntraf)]       
            states = []                 
            self.DDPG_aircraft.append(np.asarray(bs.traf.id)[inconf].tolist())
            for ac_id in inconf:
                states.append(self.getStateArray(ac_id).tolist()  )                  
            bs.sim.RLaction(bs.stack.get_scenname(), states)

        exclude = np.array([], dtype=np.uint32)
        for it in range(len(self.DDPG_ac_done)):
            acs = self.DDPG_ac_done[it]
            self.setDDPGreward(it, acs)            
            exclude = np.append(exclude, it)

        for it in exclude[::-1]:        
            self.DDPG_ac_done.pop(it)
            self.DDPG_actions.pop(it)
            self.DDPG_states.pop(it)

    def getTLosAll(self):
        I = np.eye(bs.traf.ntraf)

        # Horizontal conflict ------------------------------------------------------

        # qdlst is for [i,j] qdr from i to j, from perception of ADSB and own coordinates
        qdr, dist = geo.kwikqdrdist_matrix(np.asmatrix(bs.traf.lat), np.asmatrix(bs.traf.lon),
                                    np.asmatrix(bs.traf.lat), np.asmatrix(bs.traf.lon))

        # Convert back to array to allow element-wise array multiplications later on
        # Convert to meters and add large value to own/own pairs
        qdr = np.asarray(qdr)
        dist = np.asarray(dist) * nm + 1e9 * I

        # Calculate horizontal closest point of approach (CPA)
        qdrrad = np.radians(qdr)
        dx = dist * np.sin(qdrrad)  # is pos j rel to i
        dy = dist * np.cos(qdrrad)  # is pos j rel to i

        # Ownship track angle and speed
        owntrkrad = np.radians(bs.traf.trk)
        ownu = bs.traf.gs * np.sin(owntrkrad).reshape((1, bs.traf.ntraf))  # m/s
        ownv = bs.traf.gs * np.cos(owntrkrad).reshape((1, bs.traf.ntraf))  # m/s

        # Intruder track angle and speed
        inttrkrad = np.radians(bs.traf.trk)
        intu = bs.traf.gs * np.sin(inttrkrad).reshape((1, bs.traf.ntraf))  # m/s
        intv = bs.traf.gs * np.cos(inttrkrad).reshape((1, bs.traf.ntraf))  # m/s

        du = ownu - intu.T  # Speed du[i,j] is perceived eastern speed of i to j
        dv = ownv - intv.T  # Speed dv[i,j] is perceived northern speed of i to j

        dv2 = du * du + dv * dv
        dv2 = np.where(np.abs(dv2) < 1e-6, 1e-6, dv2)  # limit lower absolute value
        vrel = np.sqrt(dv2)

        tcpa = -(du * dx + dv * dy) / dv2 + 1e9 * I

        # Calculate distance^2 at CPA (minimum distance^2)
        dcpa2 = np.abs(dist * dist - tcpa * tcpa * dv2)

        # Check for horizontal conflict
        R2 = bs.traf.cd.rpz * bs.traf.cd.rpz
        swhorconf = dcpa2 < R2  # conflict or not

        # Calculate times of entering and leaving horizontal conflict
        dxinhor = np.sqrt(np.maximum(0., R2 - dcpa2))  # half the distance travelled inzide zone
        dtinhor = dxinhor / vrel

        tinhor = np.where(swhorconf, tcpa - dtinhor, 1e8)  # Set very large if no conf
        touthor = np.where(swhorconf, tcpa + dtinhor, -1e8)  # set very large if no conf

        # Vertical conflict --------------------------------------------------------

        # Vertical crossing of disk (-dh,+dh)
        dalt = bs.traf.alt.reshape((1, bs.traf.ntraf)) - \
            bs.traf.alt.reshape((1, bs.traf.ntraf)).T  + 1e9 * I

        dvs = bs.traf.vs.reshape(1, bs.traf.ntraf) - \
            bs.traf.vs.reshape(1, bs.traf.ntraf).T
        dvs = np.where(np.abs(dvs) < 1e-6, 1e-6, dvs)  # prevent division by zero

        # Check for passing through each others zone
        tcrosshi = (dalt + bs.traf.cd.hpz) / -dvs
        tcrosslo = (dalt - bs.traf.cd.hpz) / -dvs
        tinver = np.minimum(tcrosshi, tcrosslo)
        toutver = np.maximum(tcrosshi, tcrosslo)

        # Combine vertical and horizontal conflict----------------------------------
        tinconf = np.maximum(tinver, tinhor)

        return tinconf

    def setDDPGaction(self, action):
        self.DDPG_states.append(action[:int(len(action)/2)])
        self.DDPG_actions.append(action[int(len(action)/2):])
        acs = self.DDPG_aircraft.pop(0)
        self.DDPG_ac_done.append(acs)

        vmin, vmax, vsmin, vsmax = bs.traf.perf.currentlimits()

        for ac_name_it in range(len(acs)):
            ac_name = acs[ac_name_it]
            ac_idx = bs.traf.id2idx(ac_name)   
            self.conf_los_ac = np.append(self.conf_los_ac, ac_name)
            self.conf_total_until_point = np.append(self.conf_total_until_point, len(bs.traf.cd.confpairs_all))
            self.los_total_until_point = np.append(self.los_total_until_point, len(bs.traf.cd.lospairs_all))
            tLOS_all = self.getTLosAll()

            if USE_MVP:
                # if is only speed, it means that this needs to be decoupled
                look_ahead_time = min((self.DDPG_actions[-1][ac_name_it][0]+1) * bs.traf.cd.dtlookahead/2, bs.traf.cd.dtlookahead)
                heading_solution = self.DDPG_actions[-1][ac_name_it][1] >= 0
                speed_solution = self.DDPG_actions[-1][ac_name_it][2] >= 0
                intruders_idx = np.where(tLOS_all[ac_idx] < look_ahead_time)[0]
                
                if len(intruders_idx) > 0 and np.any(np.array([heading_solution, speed_solution])):  
                    # if (not heading_solution and speed_solution) or ( heading_solution and not speed_solution):                                
                    #     polygons_VOs = []
                    #     asase = []
                    #     asasn = []
                    #     ind1, ind2 = self.qdrdist_matrix_indices(bs.traf.ntraf)                        
                    #     [qdr, dist] = geo.qdrdist_matrix(bs.traf.lat[ind1], bs.traf.lon[ind1], bs.traf.lat[ind2], bs.traf.lon[ind2])
                    #     # Put result of function from matrix to ndarray
                    #     qdr = np.reshape(np.array(qdr), np.shape(ind1))
                    #     dist = np.reshape(np.array(dist), np.shape(ind1))            
                    #     # SI-units from [deg] to [rad]
                    #     qdr = np.deg2rad(qdr)
                    #     # Get distance from [nm] to [m]
                    #     dist = dist * nm
                    #     hsepm = bs.traf.cd.rpz * 1.05
                    #     dist[dist < hsepm] = hsepm
                        
                    #     # Put result of function from matrix to ndarray
                    #     qdr = np.reshape(np.array(qdr), np.shape(ind1))
                    #     dist = np.reshape(np.array(dist), np.shape(ind1))

                    #     # Relevant sin/cos/tan
                    #     alpha = np.arcsin(hsepm / dist)
                    #     sinqdr = np.sin(qdr)
                    #     cosqdr = np.cos(qdr)
                    #     tanalpha = np.tan(alpha)
                    #     cosqdrtanalpha = cosqdr * tanalpha
                    #     sinqdrtanalpha = sinqdr * tanalpha

                    #     # Relevant x1,y1,x2,y2 (x0 and y0 are zero in relative velocity space)
                    #     x1 = (sinqdr + cosqdrtanalpha) * 2 * np.max(vmax)
                    #     x2 = (sinqdr - cosqdrtanalpha) * 2 * np.max(vmax)
                    #     y1 = (cosqdr - sinqdrtanalpha) * 2 * np.max(vmax)
                    #     y2 = (cosqdr + sinqdrtanalpha) * 2 * np.max(vmax)

                    #     # Get indices that belong to aircraft i
                    #     ind = np.where(np.logical_or(ind1 == ac_idx, ind2 == ac_idx))[0]
                    #     # The i's of the other aircraft
                    #     i_other = np.delete(np.arange(0, bs.traf.ntraf), ac_idx)

                    #     # VO from 2 to 1 is mirror of 1 to 2. Only 1 to 2 can be constructed in
                    #     # this manner, so need a correction vector that will mirror the VO
                    #     fix = np.ones(np.shape(i_other))
                    #     fix[i_other < ac_idx] = -1
                    #     # Get vertices in an x- and y-array of size (ntraf-1)*3x1
                    #     x = np.concatenate((bs.traf.gseast[i_other], x1[ind] * fix + bs.traf.gseast[i_other], x2[ind] * fix + bs.traf.gseast[i_other]))
                    #     y = np.concatenate((bs.traf.gsnorth[i_other], y1[ind] * fix + bs.traf.gsnorth[i_other], y2[ind] * fix + bs.traf.gsnorth[i_other]))
                    #     # Reshape [(ntraf-1)x3] and put arrays in one array [(ntraf-1)x3x2]
                    #     x = np.transpose(x.reshape(3, np.shape(i_other)[0]))
                    #     y = np.transpose(y.reshape(3, np.shape(i_other)[0]))
                    #     xy = np.dstack((x, y))

                    #     # build vos
                    #     for j in intruders_idx:
                    #         if j >  ac_idx:
                    #             polygon_points = xy[j-1, :, :]
                    #         else:
                    #             polygon_points = xy[j, :, :]
                    #         polygon_points = np.vstack((polygon_points, polygon_points[0]))
                    #         polygons_VOs.append(Polygon(polygon_points))
                
                    #     if (heading_solution >0) or (speed_solution >0):
                    #         all_VOs = unary_union(polygons_VOs)
                    #         asase = None
                    #         asasn = None

                    #         # calculate resolution
                    #         for it_heading in range(0, 180):
                    #             asase, asasn, speed_dif = self.findLineSolution(np.deg2rad(bs.traf.trk[ac_idx] - it_heading), vmax[ac_idx], all_VOs, bs.traf.gs[ac_idx]) # turn left
                    #             if self.checkCRsolutionValid(asase, heading_solution, speed_solution, speed_dif):
                    #                 break
                    #             asase, asasn, speed_dif = self.findLineSolution(np.deg2rad(bs.traf.trk[ac_idx] + it_heading), vmax[ac_idx], all_VOs, bs.traf.gs[ac_idx]) # turn right
                    #             if self.checkCRsolutionValid(asase, heading_solution, speed_solution, speed_dif):
                    #                 break
                        
                    #         if asase is not None and asasn is not None:                                
                    #             bs.traf.cr.trk[ac_idx] = np.arctan2(asase, asasn) * 180 / np.pi                                        
                    #             bs.traf.cr.trk[ac_idx]  = bs.traf.cr.trk[ac_idx] %360
                    #             bs.traf.cr.tas[ac_idx] = np.sqrt(asase ** 2 + asasn ** 2)                                
                    #         else: # no solution was found
                    #             bs.traf.cr.trk[ac_idx] = bs.traf.ap.trk[ac_idx]
                    #             bs.traf.cr.tas[ac_idx] = bs.traf.ap.tas[ac_idx]   
                    #     else:
                    #         bs.traf.cr.trk[ac_idx] = bs.traf.ap.trk[ac_idx]
                    #         bs.traf.cr.tas[ac_idx] = bs.traf.ap.tas[ac_idx]     
                    # else:
                    # solution 
                    dv = np.array([0,0])
                    asas_mar = 1.05

                    # Relative position vector between id1 and id2
                    for idx2 in intruders_idx:
                        qdr, dist = geo.kwikqdrdist(bs.traf.lat[ac_idx], bs.traf.lon[ac_idx], bs.traf.lat[idx2], bs.traf.lon[idx2])
                        qdr_radians = np.deg2rad(qdr)
                        dist *= nm
                        drel = np.array([np.sin(qdr_radians)*dist, np.cos(qdr_radians)*dist, bs.traf.alt[idx2]-bs.traf.alt[ac_idx]])
                        # Write velocities as vectors and find relative velocity vector
                        v1 = np.array([bs.traf.gseast[ac_idx], bs.traf.gsnorth[ac_idx], bs.traf.vs[ac_idx]])
                        v2 = np.array([bs.traf.gseast[idx2], bs.traf.gsnorth[idx2], bs.traf.vs[idx2]])
                        vrel = np.array(v2-v1)
                        # Find horizontal disbattance at the tcpa (min horizontal distance)
                        tcpa = tLOS_all[ac_idx][idx2]
                        dcpa  = drel + vrel*tcpa
                        dabsH = np.sqrt(dcpa[0]*dcpa[0]+dcpa[1]*dcpa[1])
                        # Compute horizontal intrusion
                        iH = (bs.traf.cd.rpz * asas_mar) - dabsH
                        # Exception handlers for head-on conflicts
                        # This is done to prevent division by zero in the next step
                        if dabsH <= 10.:
                            dabsH = 10.
                            dcpa[0] = drel[1] / dist * dabsH
                            dcpa[1] = -drel[0] / dist * dabsH
                        # If intruder is outside the ownship PZ, then apply extra factor
                        # to make sure that resolution does not graze IPZ                            
                        if (bs.traf.cd.rpz * asas_mar) < dist and dabsH < dist:
                            # Compute the resolution velocity vector in horizontal direction.
                            # abs(tcpa) because it bcomes negative during intrusion.
                            erratum=np.cos(np.arcsin((bs.traf.cd.rpz * asas_mar)/dist)-np.arcsin(dabsH/dist))
                            dv1 = (((bs.traf.cd.rpz * asas_mar)/erratum - dabsH)*dcpa[0])/(abs(tcpa)*dabsH)
                            dv2 = (((bs.traf.cd.rpz * asas_mar)/erratum - dabsH)*dcpa[1])/(abs(tcpa)*dabsH)
                        else:
                            dv1 = (iH * dcpa[0]) / (abs(tcpa) * dabsH)
                            dv2 = (iH * dcpa[1]) / (abs(tcpa) * dabsH)
                        dv_mvp = np.array([dv1,dv2])     
                        dv  = dv - dv_mvp
                    dv = np.transpose(dv)
                    v = np.array([bs.traf.gseast[ac_idx], bs.traf.gsnorth[ac_idx]])
                    newv = v + dv
                    if heading_solution and speed_solution:   # heading + speed solution
                        bs.traf.cr.trk[ac_idx] = (np.arctan2(newv[0],newv[1])*180/np.pi) %360
                        bs.traf.cr.tas[ac_idx] = np.sqrt(newv[0] ** 2 + newv[1] ** 2)   
                        bs.traf.cr.active[ac_idx] = True  
                    elif heading_solution and not speed_solution: # heading solution only
                        bs.traf.cr.trk[ac_idx] = (np.arctan2(newv[0],newv[1])*180/np.pi) %360
                        bs.traf.cr.tas[ac_idx] = bs.traf.tas[ac_idx]
                        bs.traf.cr.active[ac_idx] = True  
                    else: # speed solution only
                        bs.traf.cr.trk[ac_idx] = bs.traf.trk[ac_idx]
                        bs.traf.cr.tas[ac_idx] = np.sqrt(newv[0] ** 2 + newv[1] ** 2)  
                        bs.traf.cr.active[ac_idx] = True  
                    #newtrack, newgscapped, vscapped, alt, ownship.ap.trk, alt_no_vertical_deviation = bs.traf.cr.resolve(bs.traf.cd, bs.traf, bs.traf, np.zeros(bs.traf.ntraf)*1)

            else:
                bs.traf.cr.trk[ac_idx] += self.DDPG_actions[-1][ac_name_it][0] * 15
                bs.traf.cr.trk[ac_idx] = bs.traf.cr.trk[ac_idx] %360
                bs.traf.cr.tas[ac_idx] += self.DDPG_actions[-1][ac_name_it][1] * 5    
                bs.traf.cr.tas[ac_idx] = min(max(bs.traf.cr.tas[ac_idx] , vmin[ac_idx]), vmax[ac_idx])
                bs.traf.cr.active[ac_idx] = True     

    def checkCRsolutionValid(self, asase,heading_solution, speed_solution, speed_dif):
        if not heading_solution: # speed only 
            return True
        if asase is not None: # we have a solution
            if not speed_solution: # heading only
                if speed_dif < 1: # 2m/s
                    return True
            else:
                return True
        return False

    def findLineSolution(self, hdg, vmax, all_VOs, speed_now):
        found_solution = False
        solution_speed_dif = np.array([])
        solutions_e = np.array([])
        solution_n = np.array([])
        line = LineString([[0, 0], [np.sin(hdg) * vmax, np.cos(hdg) * vmax]])

        try:
            differences = line.difference(all_VOs)
            if type(differences) is LineString: # when there is only one line difference, so the next code doenst fail
                differences = [differences]
            for line_it in range(len(differences)):
                #for it in range(len(differences[line_it].xy)):
                if not differences[line_it].is_empty:
                    speed_begin_line = np.sqrt(differences[line_it].xy[0][0] ** 2 + differences[line_it].xy[1][0] ** 2)
                    speed_end_line = np.sqrt(differences[line_it].xy[0][1] ** 2 + differences[line_it].xy[1][1] ** 2)
                    if speed_begin_line > speed_now:
                        solution_speed_dif = np.append(solution_speed_dif, abs(speed_now - speed_begin_line))
                        solutions_e = np.append(solutions_e, differences[line_it].xy[0][0])
                        solution_n = np.append(solution_n, differences[line_it].xy[1][0])
                    elif speed_end_line < speed_now:
                        solution_speed_dif = np.append(solution_speed_dif, abs(speed_end_line - speed_now))
                        solutions_e = np.append(solutions_e, differences[line_it].xy[0][1])
                        solution_n = np.append(solution_n, differences[line_it].xy[1][1])
                    else:
                        solution_speed_dif = np.append(solution_speed_dif, 0)
                        solutions_e = np.append(solutions_e, speed_now*np.sin(hdg))
                        solution_n = np.append(solution_n, speed_now*np.cos(hdg))
                    found_solution = True          
        except:
            print("shapelly could not find difference", all_VOs)
            
        if found_solution:
            ind = np.argmin(solution_speed_dif)
            return solutions_e[ind], solution_n[ind], solution_speed_dif[ind]
        else:
            return None, None, None

    def setDDPGreward(self, it, acs):
        rewards = []
        for ac_name in acs:
            ind = np.where(self.conf_los_ac == ac_name)[0]            
            ac_id = bs.traf.id2idx(ac_name)
            reward = 0
            reward = self.getLoss(reward, ac_id, self.los_total_until_point[ind])
            #reward = self.getDriftTarget(reward, ac_id)
            rewards.append(reward)

            # remove elements at position 'ind'
            self.los_total_until_point = np.delete(self.los_total_until_point, ind)
            self.conf_total_until_point = np.delete(self.conf_total_until_point, ind)
            self.conf_los_ac = np.delete(self.conf_los_ac, ind)

        state = self.DDPG_states[it]
        action = self.DDPG_actions[it]

        next_state = []
        for ac_name in acs:
            next_state.append(self.getStateArray(ac_id).tolist())

        bs.sim.RLresult(bs.stack.get_scenname(), state, next_state, sum(rewards), action)
           
    def qdrdist_matrix_indices(self, ntraf):
        tmp_range = np.arange(ntraf - 1, dtype=np.int32)
        ind1 = np.repeat(tmp_range, (tmp_range + 1)[::-1])
        ind2 = np.ones(ind1.shape[0], dtype=np.int32)
        inds = np.cumsum(tmp_range[1:][::-1] + 1)
        np.put(ind2, inds, np.arange(ntraf * -1 + 3, 1))
        ind2 = np.cumsum(ind2, out=ind2)
        return ind1, ind2

    def getDriftTarget(self, rewards, ac_idx): 
        drift = abs(bs.traf.ap.trk[ac_idx] -bs.traf.trk[ac_idx])/360
        rewards -= drift
        return rewards

    def getNearMisses(self, rewards, ac_idx, conf_begin_action):
        new_conf = bs.traf.cd.confpairs_all_cruising[int(conf_begin_action):]
        confpairs_unique = {frozenset(pair) for pair in new_conf}
        new_conf_pairs = [tuple(x) for x in confpairs_unique]
        for pair in new_conf_pairs:
            idx0 = bs.traf.id2idx(pair[0])
            idx1 = bs.traf.id2idx(pair[1])
            if idx0 == ac_idx or idx1 == ac_idx:
                rewards -= 1
        return rewards

    def getLoss(self, rewards, ac_idx, los_begin_action):
        new_los = bs.traf.cd.lospairs_all_cruising[int(los_begin_action):]
        lospairs_unique = {frozenset(pair) for pair in new_los}
        #new_los_pairs = [tuple(x) for x in lospairs_unique]
        rewards -= len(lospairs_unique)
        # for pair in new_los_pairs:
        #     idx0 = bs.traf.id2idx(pair[0])
        #     idx1 = bs.traf.id2idx(pair[1])
        #     if idx0 == ac_idx or idx1 == ac_idx:
        #         rewards -= 1
        #print(rewards, lospairs_unique)
        return rewards