
import gym
from gym import spaces
import numpy as np
import tensorflow as tf
import stable_baselines
from stable_baselines import PPO2
from stable_baselines.common.callbacks import CallbackList, EvalCallback
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common.policies import MlpLstmPolicy
import math
import matplotlib.pyplot as plt
import random

# Simulation parameters which will be used to build the classes
T = 30.1            # simulation time [s]
dt = 0.1          # simulation time step [s]
n = int(T / dt)        # number of simulation steps
tauA = 0.2     # actuator lag, representing driveline dyanmics [s]

# class of the leading vehicle
class leader():
    def __init__(self, speed0, pos0):
        self.speed = [0] * n
        self.pos = [0] * n
        self.a = [0] * n   
        self.speed[0] = speed0
        self.pos[0] = pos0
    def getSpeed(self,k):
        return self.speed[k]
    def getPos(self,k):
        return self.pos[k]
    def getA(self,k):
        return self.a[k]
    def updatePos(self,k):
        self.pos[k] = self.pos[k-1] + self.speed[k-1] * dt + 0.5 * self.a[k-1] * math.pow(dt,2)
    def updateSpeed(self,k):
        self.speed[k] = self.speed[k-1] + self.a[k-1] * dt
    def setA(self,u,k):
        self.a[k] = u

# class of autonomous vehicles
class av():
    def __init__(self, lead1, speed0, pos0):
        self.speed = [0] * n
        self.pos = [0] * n
        self.a = [0] * n   
        self.speed[0] = speed0
        self.pos[0] = pos0
        self.predecessor1 = lead1
    def getSpeed(self,k):
        return self.speed[k]
    def getPos(self,k):
        return self.pos[k]
    def getA(self,k):
        return self.a[k]
    def updatePos(self,k):
        self.pos[k] = self.pos[k-1] + self.speed[k-1] * dt + 0.5 * self.a[k-1] * math.pow(dt,2)
    def updateSpeed(self,k):
        self.speed[k] = self.speed[k-1] + self.a[k-1] * dt
    def updateA(self,u,k):
        self.a[k] = self.a[k-1] + (u - self.a[k-1]) * dt / tauA

# class of the customed simulation environment that follows gym interface
class LSTM_ACC2(gym.Env):
    
    metadata = {'render.modes': ['human']}
    
    def __init__(self):
        super(LSTM_ACC2, self).__init__()
        # Define the action space
        self.u_max = 3    # maximum acceleration [m/s^2]
        self.u_min = -6     # minimum acceleration [m/s^2]
        self.action_space = spaces.Box(low=self.u_min, high=self.u_max, shape=(1,), dtype='float32')
        
        # Define the observation space: [distance gap, ego-vehicle speed, relative speed, jerk,p previous action]
        vlow = np.array([0,0,-40,0,-6],dtype = 'f')
        vhigh = np.array([500,40,40,10,3],dtype = 'f')
        self.observation_space = spaces.Box(low=vlow, high=vhigh, shape=(5,))
        
        self.done = False
        
        # Define simulation parameters
        self.T = 30.1            # simulation time [s]
        self.dt = 0.1          # simulation time step [s]
        self.n = int(self.T / self.dt)        # number of simulation steps
        # Define system parameters
        self.min_gap = 2            # minimum distance gap at standstill [m]
        self.vlength = 4        # vehicle length [m]
        self.td = 2          # desired time gap [s]
        self.tauA = 0.2     # actuator lag, representing driveline dyanmics [s]
        self.tauS = 0     # delay in sensor [s]
    
    # Reset all the required variables and initialize every setting to start optimization again
    def reset(self):
        self.done = False
        self.current_step = 0
        self.u = 0  # control input (acceleration)
        
        # Initialize the leader
        self.n_l = 1    # Number of the leaders
        self.l = [None] * self.n_l
        v0_leader1 = random.randint(15,35)  # initial speed of the leader
        pos0_leader1 = 500  # initial position of the leader
        self.l[0] = leader(v0_leader1,pos0_leader1)
        
        # Intialize the follower
        self.n_f = 1      # Number of followers
        v0_follower_error = 0
        v0_follower1 = v0_leader1 + v0_follower_error   # initial speed of the follower
        gap0 = 2
        s0 = v0_follower1 * gap0 + self.min_gap*2
        self.f = [None] * self.n_f
        self.f[0] = av(self.l[0],v0_follower1,self.l[0].getPos(0)-s0-self.vlength*2)  
       
        # Initialize and reset the state (s,vf,rv,j)       
        s = (self.l[0].getPos(0) - self.f[0].getPos(0) - self.vlength*2 - self.min_gap*2)  
        vf = self.f[0].getSpeed(0)
        rv = self.l[0].getSpeed(0) - self.f[0].getSpeed(0)
        jerk = 0
        pa = 0  # action in the previous time step
        self.observedState = np.array([s,vf,rv,jerk,pa])
        
        # define disturbance case
        case = random.uniform(-1,1)
        if case < -0.25:
            self.dis_length = random.randint(0,50)
            self.dis_rate = random.uniform(0,2)
            speed_change = self.dis_rate * (self.dis_length / 10)
            if v0_leader1 + speed_change > 39:
                speed_change = 39 - v0_leader1
                self.dis_rate = speed_change / (self.dis_length / 10)
            self.fix_rate = random.uniform((-1) *  self.dis_rate, (-1) * self.dis_rate / 3)
        else:
            self.dis_length = random.randint(0,50)
            self.dis_rate = random.uniform(-4,0)
            speed_change = self.dis_rate * (self.dis_length / 10)
            if v0_leader1 + speed_change < 11:
                speed_change = 11 - v0_leader1
                self.dis_rate = speed_change / (self.dis_length / 10)
            
            self.fix_rate = random.uniform((-1) * self.dis_rate / 3, (-1) * self.dis_rate)
            if self.fix_rate > 3:
                self.fix_rate = 3
        
        self.dis_point = random.randint(20,40)
        self.low_speed_length = random.randint(5,80)
        self.fix_length = math.floor(10 * abs(speed_change / self.fix_rate))
        
        return(self.observedState)
        
    # Check if the state is out of the operation range
    def disengaged(self):
        end = False
        # crash
        if (self.l[0].getPos(self.current_step) - self.f[0].getPos(self.current_step) - self.vlength*2 - self.min_gap*2) / self.f[0].getSpeed(self.current_step) <= 0:
            end = True
            print(repr(self.current_step)+" collision")
        # cruising mode
        if (self.l[0].getPos(self.current_step) - self.f[0].getPos(self.current_step) - self.vlength*2 - self.min_gap*2) / self.f[0].getSpeed(self.current_step) > (self.td+5.0):
            end = True
            print(repr(self.current_step)+" cruise")
        # negative speed
        if self.f[0].getSpeed(self.current_step) < 0:
            end = True
            print(repr(self.current_step)+" negative speed")
        return end
    
    # run one time step and make a decision
    def step(self, action):
        
        self.done = False
        
        # Receive the action given by the agent
        # Assert that it is a valid action 
        assert self.action_space.contains(action), "Invalid Action"
        
        # Calculate the num of time steps of sensor delay
        sensor_delay = int(self.tauS / self.dt)
        
        # Update leader acceleration
        if self.current_step > self.dis_point and self.current_step <= (self.dis_point + self.dis_length):
            self.l[0].setA(self.dis_rate,self.current_step)
        elif self.current_step > (self.dis_point + self.dis_length + self.low_speed_length) and self.current_step <= (self.dis_point + self.dis_length + self.low_speed_length + self.fix_length):
            self.l[0].setA(self.fix_rate,self.current_step)
        else:
            self.l[0].setA(0,self.current_step)
        
        # Apply leader behavior
        self.l[0].updateSpeed(self.current_step+1)
        self.l[0].updatePos(self.current_step+1)
        
        if self.current_step < sensor_delay:
            self.u = 0
        else:
            self.u = action[0]
        
        # Update follower acceleration
        self.f[0].updateA(self.u,self.current_step) 
        
        # Apply follower action
        # Update behavior of the followers
        self.f[0].updatePos(self.current_step+1) 
        self.f[0].updateSpeed(self.current_step+1)
        
        # Generate random error (noise)
        std_level = 0.5     # or 1.0, 1.5, 2.0
        d_error = np.random.normal(0, std_level, None)
        s_error = np.random.normal(0, std_level, None)
        
        d_previous = (self.l[0].getPos(self.current_step-sensor_delay) - self.f[0].getPos(self.current_step-sensor_delay) - self.vlength*2 - self.min_gap*2)
        d_now = (self.l[0].getPos(self.current_step+1-sensor_delay) - self.f[0].getPos(self.current_step+1-sensor_delay) - self.vlength*2 - self.min_gap*2)
        lspeed_now = self.l[0].getSpeed(self.current_step+1-sensor_delay)
        
        if self.current_step < sensor_delay:
            self.observedState = self.observedState
        else:
            self.observedState = np.array([d_now + d_error,
                                 self.f[0].getSpeed(self.current_step+1),
                                 (lspeed_now + s_error)-self.f[0].getSpeed(self.current_step+1-sensor_delay),
                                 self.f[0].getA(self.current_step) - self.f[0].getA(self.current_step - 1),
                                 self.f[0].getA(self.current_step - 1)])
        
        g_previous = d_previous / self.f[0].getSpeed(self.current_step)
        g_now = d_now / self.f[0].getSpeed(self.current_step+1)
        
        # parameters for normalization in the reward function
        max_es = 0.5 * self.td
        max_jerk = (self.u_max - self.u_min) / 3
        # weights
        w1 = 0.9
        w2 = 0.1
        reward = w1 * (-1) * abs(g_now - self.td) / max_es + w2 * (-1) * abs(self.observedState[3]) / max_jerk
        
        # Add a negative reward value for a wrong behavior
        if abs(g_previous-self.td) < abs(g_now-self.td):
            reward = reward + (abs(g_previous-self.td) - abs(g_now-self.td)) / max_es
        
        self.current_step += 1
        
        # Add a penalty if the state is out of the operation range due to the action
        penalty = -100
        failed = self.disengaged()
        if failed == True:
            reward = reward + penalty
            self.done = True
        
        if self.current_step == 300:
            self.done = True
        
        return(self.observedState, reward, self.done, {})
    
    # Reset all the required variables and initialize every setting for the testing case
    def reset_test(self):
        self.done = False
        self.tauS = 0.2     # delay in sensor and estimator [s]
        self.current_step = 0
        self.u = 0
        self.n_l = 1
        self.l = [None] * self.n_l
        v0_leader1 = 33     # designated leader speed
        pos0_leader1 = 500
        self.l[0] = leader(v0_leader1,pos0_leader1)
        
        self.n_f = 1
        v0_follower_error = 0   # no relative speed
        v0_follower1 = v0_leader1 + v0_follower_error
        gap0 = 2    # no gap error
        s0 = v0_follower1 * gap0 + self.min_gap*2
        self.f = [None] * self.n_f
        self.f[0] = av(self.l[0],v0_follower1,self.l[0].getPos(0)-s0-self.vlength*2)
                
        s = (self.l[0].getPos(0) - self.f[0].getPos(0) - self.vlength*2 - self.min_gap*2)
        vf = self.f[0].getSpeed(0)
        rv = self.l[0].getSpeed(0) - self.f[0].getSpeed(0)
        jerk = 0
        pa = 0
        self.observedState = np.array([s,vf,rv,jerk,pa])                
        return(self.observedState)
    
    # run one time step and make a decision for the testing case
    def step_test(self, action):
        self.done = False
        assert self.action_space.contains(action), "Invalid Action"
        sensor_delay = int(self.tauS / self.dt)
        
        # Update leader acceleration
        if self.current_step > 30 and self.current_step <= 70:
            self.l[0].setA(-3,self.current_step)
        elif self.current_step > 120 and self.current_step <= 200:
            self.l[0].setA(1.5,self.current_step)
        else:
            self.l[0].setA(0,self.current_step)
        
        # Apply leader behavior
        self.l[0].updateSpeed(self.current_step+1)
        self.l[0].updatePos(self.current_step+1)
        
        # Determine follower acceleration
        if self.current_step < sensor_delay:
            self.u = 0
        else:
            self.u = action[0]
        
        # Update follower acceleration
        self.f[0].updateA(self.u,self.current_step) 
        
        # Apply follower action and update behavior of the followers
        self.f[0].updatePos(self.current_step+1) 
        self.f[0].updateSpeed(self.current_step+1)
        
        std_level = 0.5     # or 1.0, 1.5, 2.0
        d_error = np.random.normal(0, std_level, None)
        s_error = np.random.normal(0, std_level, None)
        
        d_previous = (self.l[0].getPos(self.current_step-sensor_delay) - self.f[0].getPos(self.current_step-sensor_delay) - self.vlength*2 - self.min_gap*2)
        d_now = (self.l[0].getPos(self.current_step+1-sensor_delay) - self.f[0].getPos(self.current_step+1-sensor_delay) - self.vlength*2 - self.min_gap*2)
        lspeed_now = self.l[0].getSpeed(self.current_step+1-sensor_delay)
        
        if self.current_step < sensor_delay:
            self.observedState = self.observedState
        else:
            self.observedState = np.array([d_now + d_error,
                                 self.f[0].getSpeed(self.current_step+1),
                                 (lspeed_now + s_error)-self.f[0].getSpeed(self.current_step+1-sensor_delay),
                                 self.f[0].getA(self.current_step) - self.f[0].getA(self.current_step - 1),
                                 self.f[0].getA(self.current_step - 1)])
        
        g_previous = d_previous / self.f[0].getSpeed(self.current_step)
        g_now = d_now / self.f[0].getSpeed(self.current_step+1)
        
        max_es = 0.5 * self.td
        max_jerk = (self.u_max - self.u_min) / 3
        w1 = 0.9
        w2 = 0.1
        reward = w1 * (-1) * abs(g_now - self.td) / max_es + w2 * (-1) * abs(self.observedState[3]) / max_jerk
        
        if abs(g_previous-self.td) < abs(g_now-self.td):
            reward = reward + (abs(g_previous-self.td) - abs(g_now-self.td)) / max_es
        
        self.current_step += 1
        
        penalty = -100
        if self.disengaged() == True:
            reward = reward + penalty
            self.done = True
        
        if self.current_step == 300:
            self.done = True
        
        return(self.observedState, reward, self.done, {})
    
    # Function to plot the simulation result of an episode
    def render(self, mode='human', close=False):
        time = [(i*dt) for i in range(self.n)]
        for i in range (self.n):
            time[i] = round(time[i],1)        
        # plot results (acceleration, speed, time gap profiles)
        fig, axs = plt.subplots(3,figsize=(10,6),sharex=True)
        linecolor = ['tab:blue','tab:orange']
        for j in range(self.n_l):
            axs[0].plot(time,[self.l[j].getA(i) for i in range(self.n)],'--',color=linecolor[j])
            axs[1].plot(time,[self.l[j].getSpeed(i) for i in range(self.n)],'--',color=linecolor[j])
        for j in range(self.n_f):
            axs[0].plot(time,[self.f[j].getA(i) for i in range(self.n)],color=linecolor[j+1])
            axs[1].plot(time,[self.f[j].getSpeed(i) for i in range(self.n)],color=linecolor[j+1])
            axs[2].plot(time,[(self.f[j].predecessor1.getPos(i) - self.f[j].getPos(i) - self.vlength*2 - self.min_gap*2) / self.f[j].getSpeed(i) for i in range(self.n)],color=linecolor[j+1])
        axs[0].set(xlabel='time (s)',ylabel='acceleration (m/s\u00b2)')
        axs[0].set_ylim([-6, 3])
        axs[1].set(xlabel='time (s)',ylabel='velocity (m/s)')
        axs[1].set_ylim([15, 40])
        axs[2].set(xlabel='time (s)',ylabel='gap 2 (s)')
        axs[2].set_ylim([1.8, 2.2])
        axs[0].legend(['leader','follower'],loc='lower right',fontsize='medium', ncol=2)
        
        fig.tight_layout()

        plt.show()

# create the environment
env = LSTM_ACC2()
obs = env.reset()

# Add the callback function to evaluate and save the best model during the training process
eval_env = LSTM_ACC2()    # Separate evaluation env
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs_LSTM_ACC_td2/',
                             log_path='./logs_LSTM_ACC_td2/', n_eval_episodes=100, eval_freq=100000,
                             deterministic=True, verbose=1, render=False)
call_back = CallbackList([eval_callback])

# customize the network structure and hyperparameters
class CustomPolicy(MlpLstmPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=128, reuse=False, **_kwargs):
        super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
                                            net_arch=['lstm', dict(vf=[64, 64], pi=[64,64])], act_fun=tf.tanh,
                                              **_kwargs)

model = PPO2(CustomPolicy, env, learning_rate=0.0003, n_steps=2048, ent_coef=0.0, cliprange_vf=-1, nminibatches=1, noptepochs=4, verbose=0)

# Train the agent
model.learn(total_timesteps=6000000, callback=call_back)
# Save the agent
model.save("final_model_LSTM_ACC2")
# Delete the trained model to demonstrate loading
del model  

# Below scripts use the trained model to show the results
# To use the best model, extract it from the "logs_ACC_td1" folder
def make_env():
    def _init():
        env = LSTM_ACC2()
        return env
    return _init

env = DummyVecEnv([make_env()])

model = PPO2.load("best_model_LSTM_ACC2", env=env)

# Test trained agent
obs = env.env_method('reset_test')
state = None
done = [False for _ in range(env.num_envs)]

for i in range(300):
    action, state = model.predict(obs, state=state, mask=done, deterministic=True)
    obs, rewards, done, info = env.env_method('step_test',action[0])[0]
    obs = [obs]
    rewards = [rewards]
    done = [done]
    info = [info]

env.render()