
import gym
from gym import spaces
import numpy as np
import stable_baselines
from stable_baselines import PPO2
from stable_baselines.common.vec_env import DummyVecEnv
import math
import random
import matplotlib.pyplot as plt

# Simulation parameters which will be used to build the classes
T = 50.1            # simulation time [s]
dt = 0.1          # simulation time step [s]
n = int(T / dt)        # number of simulation steps
tauA = 0.2     # actuator lag, representing driveline dyanmics [s]

# class of the leading vehicle
# Position, speed, and acceleration profiles are recorded in the vehicle object it self.
class leader():
    def __init__(self, speed0, pos0):
        self.speed = [0] * n
        self.pos = [0] * n
        self.a = [0] * n   
        self.speed[0] = speed0
        self.pos[0] = pos0
    def getSpeed(self,k):
        return self.speed[k]
    def getPos(self,k):
        return self.pos[k]
    def getA(self,k):
        return self.a[k]
    def updatePos(self,k):
        self.pos[k] = self.pos[k-1] + self.speed[k-1] * dt + 0.5 * self.a[k-1] * math.pow(dt,2)
    def updateSpeed(self,k):
        self.speed[k] = self.speed[k-1] + self.a[k-1] * dt
    def setA(self,u,k):
        self.a[k] = u

# class of autonomous vehicles
class av():
    def __init__(self, lead1, speed0, pos0):
        self.speed = [0] * n
        self.pos = [0] * n
        self.a = [0] * n   
        self.speed[0] = speed0
        self.pos[0] = pos0
        self.predecessor1 = lead1
    def getSpeed(self,k):
        return self.speed[k]
    def getPos(self,k):
        return self.pos[k]
    def getA(self,k):
        return self.a[k]
    def updatePos(self,k):
        self.pos[k] = self.pos[k-1] + self.speed[k-1] * dt + 0.5 * self.a[k-1] * math.pow(dt,2)
    def updateSpeed(self,k):
        self.speed[k] = self.speed[k-1] + self.a[k-1] * dt
    def updateA(self,u,k):
        self.a[k] = self.a[k-1] + (u - self.a[k-1]) * dt / tauA

# class of the customized simulation environment for the controller 1 that follows gym interface
class two_ACC1_LSTM(gym.Env):
    
    metadata = {'render.modes': ['human']}
    
    def __init__(self):
        super(two_ACC1_LSTM, self).__init__()
        # Define the action space
        self.u_max = 3    # maximum acceleration [m/s^2]
        self.u_min = -6     # minimum acceleration [m/s^2]
        self.action_space = spaces.Box(low=self.u_min, high=self.u_max, shape=(1,), dtype='float32')
        
        # Define the observation space: [distance gap, ego-vehicle speed, relative speed, jerk]
        vlow = np.array([0,0,-40,0,-6],dtype = 'f')
        vhigh = np.array([500,40,40,10,3],dtype = 'f')
        self.observation_space = spaces.Box(low=vlow, high=vhigh, shape=(5,))
        
        self.done = False
        
        # Define simulation parameters
        self.T = 50.1            # simulation time [s]
        self.dt = 0.1          # simulation time step [s]
        self.n = int(self.T / self.dt)        # number of simulation steps
        # Define system parameters
        self.min_gap = 2            # minimum distance gap at standstill [m]
        self.vlength = 4        # vehicle length [m]
        self.td = 1          # desired time gap [s]
        self.tauA = 0.2     # actuator lag, representing driveline dyanmics [s]
        self.tauS = 0.2     # delay in sensor [s]
    
    # Reset all the required variables and initialize every setting to start optimization again
    def reset_test_multi(self):
        self.done = False
        self.current_step = 0
        
        # Initialize the leader
        self.n_l = 1
        self.l = [None] * self.n_l
        v0_leader = 33
        pos0_leader = 800
        self.l[0] = leader(v0_leader,pos0_leader)
                 
        # Intialize the follower and states
        self.n_f = 19
        gap0 = 1
        v0_follower = 33
        s0 = v0_follower * gap0 + self.min_gap
        self.f = [None] * self.n_f
        self.u = [0] * self.n_f
        self.f[0] = av(self.l[0],v0_follower,self.l[0].getPos(0)-s0-self.vlength)
        s = [v0_follower * gap0] * self.n_f
        jerk = [0] * self.n_f
        vf = [v0_follower] * self.n_f
        rv = [v0_leader-v0_follower] * self.n_f
        pa = [0] * self.n_f
        self.observedState = np.empty((0, 5), float)
        
        for v in range(1,self.n_f):
            self.f[v] = av(self.f[v-1],v0_follower,self.f[v-1].getPos(0)-s0-self.vlength)            
        
        for v in range(self.n_f):
            self.observedState = np.append(self.observedState, np.array([[s[v],vf[v],rv[v],jerk[v],pa[v]]]), axis=0)        
        
        return(self.observedState)
    
    # run one time step and make a decision
    def step_test_multi(self, action):
        
        self.done = False
        
        # Receive the actions given by the agents
        # Assert that they are valid actions 
        for a in action:
            assert self.action_space.contains(a), "Invalid Action"
        
        # Calculate the num of time steps of sensor delay
        sensor_delay = int(self.tauS / self.dt)
        
        # Update leader acceleration
        if self.current_step > 30 and self.current_step <= 70:
            self.l[0].setA(-3,self.current_step)
        elif self.current_step > 120 and self.current_step <= 200:
            self.l[0].setA(1.5,self.current_step)
        else:
            self.l[0].setA(0,self.current_step)
        
        # Apply leader behavior
        self.l[0].updateSpeed(self.current_step+1)
        self.l[0].updatePos(self.current_step+1)
        
        self.u = [0] * self.n_f
        reward = [0] * self.n_f
        g_previous = [0] * self.n_f
        g_now = [0] * self.n_f
        d_previous = [0] * self.n_f
        d_now = [0] * self.n_f
        lspeed_now = [0] * self.n_f
        
        # Apply followers actions
        for v in range(self.n_f):
            if self.current_step < sensor_delay:
                self.u[v] = 0
            else:
                self.u[v] = action[v][0]
            
            # Update follower acceleration
            self.f[v].updateA(self.u[v],self.current_step) 
            
            self.f[v].updatePos(self.current_step+1) 
            self.f[v].updateSpeed(self.current_step+1)
            
            std_level_d = 0.2
            std_level_s = 0.2
            d_error = np.random.normal(0, std_level_d, None)
            s_error = np.random.normal(0, std_level_s, None)
            
            if self.current_step < sensor_delay:
                self.observedState[v] = self.observedState[v]
            else:
                d_previous[v] = (self.f[v].predecessor1.getPos(self.current_step-sensor_delay) - self.f[v].getPos(self.current_step-sensor_delay) - self.vlength - self.min_gap)
                d_now[v] = (self.f[v].predecessor1.getPos(self.current_step+1-sensor_delay) - self.f[v].getPos(self.current_step+1-sensor_delay) - self.vlength - self.min_gap)
                lspeed_now[v] = self.f[v].predecessor1.getSpeed(self.current_step+1-sensor_delay)
                self.observedState[v] = np.array([d_now[v] + d_error,
                                        self.f[v].getSpeed(self.current_step+1),
                                        (lspeed_now[v]+s_error)-self.f[v].getSpeed(self.current_step+1-sensor_delay),
                                        self.f[v].getA(self.current_step) - self.f[v].getA(self.current_step - 1),
                                        self.f[v].getA(self.current_step - 1)])
            
            max_es = self.td * 0.5
            max_jerk = (self.u_max - self.u_min) / 3
            w1 = 0.9
            w2 = 0.1
            reward[v] = w1 * (-1) * abs(g_now[v] - self.td) / max_es + w2 * (-1) * abs(self.observedState[v][3]) / max_jerk
            
            # wrong behavior, negative reward
            if abs(g_previous[v]-self.td) < abs(g_now[v]-self.td):
                reward[v] = reward[v] + (abs(g_previous[v]-self.td) - abs(g_now[v]-self.td)) / max_es
        
        self.current_step += 1
        
        if self.current_step == 500:
            self.done = True
        
        return(self.observedState, reward, self.done, {})
    
    # Function to plot the simulation result of an episode
    def render(self, mode='human', close=False):
        render_num = 5  # Number of followers to show in the plot
        time = [(i*dt) for i in range(self.n)]
        for i in range (self.n):
            time[i] = round(time[i],1)        
        # plot results (acceleration, speed, time gap profiles)
        fig, axs = plt.subplots(4,figsize=(10,8),sharex=True)
        linecolor = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink']
        for j in range(self.n_l):
            axs[0].plot(time,[self.l[j].getA(i) for i in range(self.n)],'--',color=linecolor[j])
            axs[1].plot(time,[self.l[j].getSpeed(i) for i in range(self.n)],'--',color=linecolor[j])
        axs[0].plot(time,[self.f[0].getA(i) for i in range(self.n)],color=linecolor[1])
        axs[1].plot(time,[self.f[0].getSpeed(i) for i in range(self.n)],color=linecolor[1])
        axs[2].plot(time,[(self.f[0].predecessor1.getPos(i) - self.f[0].getPos(i) - self.vlength - self.min_gap) / self.f[0].getSpeed(i) for i in range(self.n)],color=linecolor[1])
        for j in range(1,render_num):
            axs[0].plot(time,[self.f[j].getA(i) for i in range(self.n)],color=linecolor[j+1])
            axs[1].plot(time,[self.f[j].getSpeed(i) for i in range(self.n)],color=linecolor[j+1])
            axs[2].plot(time,[(self.f[j].predecessor1.getPos(i) - self.f[j].getPos(i) - self.vlength - self.min_gap) / self.f[j].getSpeed(i) for i in range(self.n)],color=linecolor[j+1])
            axs[3].plot(time,[(self.f[j].predecessor1.predecessor1.getPos(i) - self.f[j].getPos(i) - self.vlength*2 - self.min_gap*2) / self.f[j].getSpeed(i) for i in range(self.n)],color=linecolor[j+1])
        axs[0].set(xlabel='time (s)',ylabel='acceleration (m/s\u00b2)')
        axs[0].set_ylim([-6, 3])
        axs[1].set(xlabel='time (s)',ylabel='velocity (m/s)')
        axs[1].set_ylim([15, 40])
        axs[2].set(xlabel='time (s)',ylabel='gap 1 (s)')
        axs[2].set_ylim([0.85, 1.15])
        axs[3].set(xlabel='time (s)',ylabel='gap 2 (s)')
        axs[3].set_ylim([1.8, 2.2])
        axs[0].legend(['leader1','follower1','follower2','follower3','follower4','follower5','follower6'],loc='lower right',fontsize='medium', ncol=2)
        
        fig.tight_layout()
        
        plt.show()

# class of the customized simulation environment for the controller 2 that follows gym interface
class two_ACC2_LSTM(gym.Env):
    
    metadata = {'render.modes': ['human']}
    
    def __init__(self):
        super(two_ACC2_LSTM, self).__init__()
        # Define the action space
        self.u_max = 3    # maximum acceleration [m/s^2]
        self.u_min = -6     # minimum acceleration [m/s^2]
        self.action_space = spaces.Box(low=self.u_min, high=self.u_max, shape=(1,), dtype='float32')
        
        # Define the observation space: [distance gap, ego-vehicle speed, relative speed, jerk]
        vlow = np.array([0,0,-40,0,-6],dtype = 'f')
        vhigh = np.array([500,40,40,10,3],dtype = 'f')
        self.observation_space = spaces.Box(low=vlow, high=vhigh, shape=(5,))
        
        self.done = False
        
        # Define simulation parameters
        self.T = 50.1            # simulation time [s]
        self.dt = 0.1          # simulation time step [s]
        self.n = int(self.T / self.dt)        # number of simulation steps
        # Define system parameters
        self.min_gap = 2            # minimum distance gap at standstill [m]
        self.vlength = 4        # vehicle length [m]
        self.td = 2          # desired time gap [s]
        self.tauA = 0.2     # actuator lag, representing driveline dyanmics [s]
        self.tauS = 0.2     # delay in sensor [s]
    
    # Reset all the required variables and initialize every setting to start optimization again
    def reset_test_multi(self):
        self.done = False
        self.current_step = 0
        
        # Initialize the virtual leader
        v0_leaderv = 40
        pos0_leaderv = 900
        self.lv = leader(v0_leaderv,pos0_leaderv)
        
        # Initialize the leader
        self.n_l = 1    # Number of the leaders
        self.l = [None] * self.n_l
        v0_leader = 33  # initial speed of the leader
        pos0_leader = 800  # initial position of the leader
        self.l[0] = leader(v0_leader,pos0_leader)
        
        # Intialize the follower
        self.n_f = 19
        gap0 = 1
        v0_follower = 33
        s0 = v0_follower * gap0 + self.min_gap
        self.f = [None] * self.n_f
        self.u = [0] * self.n_f
        
        s = [v0_follower * (gap0*2)] * self.n_f
        jerk = [0] * self.n_f
        vf = [v0_follower] * self.n_f
        rv = [v0_leader-v0_follower] * self.n_f
        pa = [0] * self.n_f
        self.observedState = np.empty((0, 5), float)
        
        self.f[0] = av(self.lv,v0_follower,self.l[0].getPos(0)-s0-self.vlength)
        self.f[1] = av(self.l[0],v0_follower,self.f[0].getPos(0)-s0-self.vlength)

        for v in range(2,self.n_f):
            self.f[v] = av(self.f[v-2],v0_follower,self.f[v-1].getPos(0)-s0-self.vlength)            
        
        for v in range(self.n_f):
            self.observedState = np.append(self.observedState, np.array([[s[v],vf[v],rv[v],jerk[v],pa[v]]]), axis=0)        
        
        return(self.observedState)
    
    # run one time step and make a decision
    def step_test_multi(self, action):
        
        self.done = False
        
        # Receive the actions given by the agents
        # Assert that they are valid actions 
        for a in action:
            assert self.action_space.contains(a), "Invalid Action"
        
        # Calculate the num of time steps of sensor delay
        sensor_delay = int(self.tauS / self.dt)
        
        # Update virtual leader acceleration
        self.lv.setA(0,self.current_step)
        
        # Apply virtual leader behavior
        self.lv.updateSpeed(self.current_step+1)
        self.lv.updatePos(self.current_step+1)
        
        # Update leader acceleration
        if self.current_step > 30 and self.current_step <= 70:
            self.l[0].setA(-3,self.current_step)
        elif self.current_step > 120 and self.current_step <= 200:
            self.l[0].setA(1.5,self.current_step)
        else:
            self.l[0].setA(0,self.current_step)
        
        # Apply leader behavior
        self.l[0].updateSpeed(self.current_step+1)
        self.l[0].updatePos(self.current_step+1)
        
        # Initialize the variables which will be used
        self.u = [0] * self.n_f
        reward = [0] * self.n_f
        d_previous = [0] * self.n_f
        d_now = [0] * self.n_f
        lspeed_now = [0] * self.n_f
        g_previous = [0] * self.n_f
        g_now = [0] * self.n_f
        
        # Apply followers actions
        # Read the actions
        for v in range(self.n_f):
            if self.current_step < sensor_delay:
                self.u[v] = 0
            else:
                self.u[v] = action[v][0]
            
            # Update follower acceleration
            self.f[v].updateA(self.u[v],self.current_step) 
            
            self.f[v].updatePos(self.current_step+1) 
            self.f[v].updateSpeed(self.current_step+1)
            
            if v == 0:
                std_level_d = 0
                std_level_s = 0
            else:
                std_level_d = 0.5
                std_level_s = 0.5
            d_error = np.random.normal(0, std_level_d, None)
            s_error = np.random.normal(0, std_level_s, None)
            
            if self.current_step < sensor_delay:
                self.observedState[v] = self.observedState[v]
            else:
                d_previous[v] = (self.f[v].predecessor1.getPos(self.current_step-sensor_delay) - self.f[v].getPos(self.current_step-sensor_delay) - self.vlength*2 - self.min_gap*2)
                d_now[v] = (self.f[v].predecessor1.getPos(self.current_step+1-sensor_delay) - self.f[v].getPos(self.current_step+1-sensor_delay) - self.vlength*2 - self.min_gap*2)
                lspeed_now[v] = self.f[v].predecessor1.getSpeed(self.current_step+1-sensor_delay)
                self.observedState[v] = np.array([d_now[v] + d_error,
                                        self.f[v].getSpeed(self.current_step+1),
                                        (lspeed_now[v] + s_error)-self.f[v].getSpeed(self.current_step+1-sensor_delay),
                                        self.f[v].getA(self.current_step) - self.f[v].getA(self.current_step - 1),
                                        self.f[v].getA(self.current_step - 1)])
            
            # Calculate rewards (not important for simulation testing)
            g_previous[v] = d_previous[v] / self.f[v].getSpeed(self.current_step)
            g_now[v] = d_now[v] / self.f[v].getSpeed(self.current_step+1)
            
            max_es = self.td * 0.5
            max_jerk = (self.u_max - self.u_min) / 3
            w1 = 0.9
            w2 = 0.1
            reward[v] = w1 * (-1) * abs(g_now[v] - self.td) / max_es + w2 * (-1) * abs(self.observedState[v][3]) / max_jerk
            
            # wrong behavior, negative reward
            if abs(g_previous[v]-self.td) < abs(g_now[v]-self.td):
                reward[v] = reward[v] + (abs(g_previous[v]-self.td) - abs(g_now[v]-self.td)) / max_es
        
        self.current_step += 1
        
        if self.current_step == 500:
            self.done = True
        
        return(self.observedState, reward, self.done, {})

# Create and load the models
def make_env1():
    def _init():
        env = two_ACC1_LSTM()
        return env
    return _init

def make_env2():
    def _init():
        env = two_ACC2_LSTM()
        return env
    return _init

env1 = DummyVecEnv([make_env1()])
model1 = PPO2.load("best_model_LSTM_ACC1", env=env1)

env2 = DummyVecEnv([make_env2()])
model2 = PPO2.load("best_model_LSTM_ACC2", env=env2)

# Test trained agent
obs1 = env1.env_method('reset_test_multi')
obs2 = env2.env_method('reset_test_multi')
done1 = [False for _ in range(env1.num_envs)]
done2 = [False for _ in range(env2.num_envs)]

action = [None] * env2.get_attr('n_f')[0]
action1 = [None] * env1.get_attr('n_f')[0]
action2 = [None] * env2.get_attr('n_f')[0]
state1 = [None] * env1.get_attr('n_f')[0]
state2 = [None] * env2.get_attr('n_f')[0]

for i in range(500):
    for j in range(env2.get_attr('n_f')[0]):
        action1[j], state1[j] = model1.predict([obs1[0][j]], state=state1[j], mask=done1, deterministic=True)
        action2[j], state2[j] = model2.predict([obs2[0][j]], state=state2[j], mask=done2, deterministic=True)
        # First follower always follow the direct leader, while other vehicles pick the minimal acceleration
        if j == 0:
            action[j] = action1[j]
        else:
            action[j] = min(action1[j],action2[j])
    obs1, rewards1, done1, info1 = env1.env_method('step_test_multi',[action[j][0] for j in range(env1.get_attr('n_f')[0])])[0]
    obs2, rewards2, done2, info2 = env2.env_method('step_test_multi',[action[j][0] for j in range(env2.get_attr('n_f')[0])])[0]
    obs1 = [obs1]
    obs2 = [obs2]
    rewards1 = [rewards1]
    rewards2 = [rewards2]
    done1 = [done1]
    done2 = [done2]
    info1 = [info1]
    info2 = [info2]

env1.render()