from math import gamma
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from bluesky.traffic.SAC.buffer import ReplayBuffer
from bluesky.traffic.SAC.mactor_critic import Actor, CriticQ, CriticV
from torch.nn.utils.clip_grad import clip_grad_norm_
import bluesky.traffic.MADDDPG.TempConfig as tc
import os


GAMMMA = 0.99
TAU =5e-3
INITIAL_RANDOM_STEPS = 100
POLICY_UPDATE_FREQUENCE = 2
#NUM_AGENTS = 5

BUFFER_SIZE = 1000000
BATCH_SIZE = 256

ACTION_DIM = 4
MAX_INTRUDERS = 4
STATE_DIM = 4 + 5*MAX_INTRUDERS

class SacAgent:
    def __init__(self):        
        self.actions = dict()
        self.reward_per_action = dict()
        self.state = dict()
        
        self.memory = ReplayBuffer(STATE_DIM,ACTION_DIM, BUFFER_SIZE, BATCH_SIZE)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.target_alpha = -np.prod((ACTION_DIM,)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)

        self.actor = Actor(STATE_DIM, ACTION_DIM).to(self.device)

        self.vf = CriticV(STATE_DIM).to(self.device)
        self.vf_target = CriticV(STATE_DIM).to(self.device)
        self.vf_target.load_state_dict(self.vf.state_dict())

        self.qf1 = CriticQ(STATE_DIM + ACTION_DIM).to(self.device)
        self.qf2 = CriticQ(STATE_DIM + ACTION_DIM).to(self.device)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.vf_optimizer = optim.Adam(self.vf.parameters(), lr=3e-4)
        self.qf1_optimizer = optim.Adam(self.qf1.parameters(), lr=3e-4)
        self.qf2_optimizer = optim.Adam(self.qf2.parameters(), lr=3e-4)        

        self.total_step = 0

        self.is_test = False
        # self.load_models()

        #print('DEVICE USED', torch.cuda.device(torch.cuda.current_device()), torch.cuda.get_device_name(0))
    
    def normalizeState(self, s_t):  
        # bearing
        s_t[0] = s_t[0] / 180 # current bearing
        s_t[1] = s_t[1] / 180 # target bearing

        # speed
        s_t[2] = s_t[2] / 18

        #alt
        s_t[3] = s_t[3] / 10

        for it in range(MAX_INTRUDERS):
            # normalized distance
            s_t[it+4] = min(s_t[it+4] / 3000,1)

            # normalized distancetoCPA
            s_t[it+4+MAX_INTRUDERS] = min(s_t[it+4+MAX_INTRUDERS] / 3000,1)

            # normalized timetoCPA
            s_t[it+4+2*MAX_INTRUDERS] = min(s_t[it+4+2*MAX_INTRUDERS] / 600,1)

            # normalized bearing
            s_t[it+4+3*MAX_INTRUDERS] = s_t[it+4+3*MAX_INTRUDERS] / 180    

            # normalized altitude
            s_t[it+4+4*MAX_INTRUDERS] = s_t[it+4+4*MAX_INTRUDERS] / 10 

        return s_t

    def do_step(self, state, scename):
        for it in range(len(state)):
            state[it] = self.normalizeState(state[it])

        if self.total_step < INITIAL_RANDOM_STEPS and not self.is_test:
            selected_action = np.random.uniform(-1, 1, (len(state), ACTION_DIM))
        else:
            selected_action = []
            for i in range(len(state)):
                action = self.actor(torch.FloatTensor(state[i]).to(self.device))[0].detach().cpu().numpy()
                selected_action.append(action)
            selected_action = np.array(selected_action)
            selected_action = np.clip(selected_action, -1, 1)

        self.total_step += 1
        return selected_action.tolist()
    
    def setResult(self, scenname, state, nextstate, reward, actions):
        if self.actions.get(scenname) is None:
            self.actions[scenname] = np.array([])
        if self.reward_per_action.get(scenname) is None:
            self.reward_per_action[scenname] = np.array([])
        if self.state.get(scenname) is None:
            self.state[scenname] = np.array([])
        
        self.reward_per_action[scenname] = np.append(self.reward_per_action[scenname], reward)
        self.actions[scenname] = np.append(self.actions[scenname], actions)
        self.state[scenname] = np.append(self.state[scenname], state)

        for it in range(len(state)):
            state[it] = self.normalizeState(np.asarray(state[it]))
            nextstate[it] = self.normalizeState(np.asarray(nextstate[it]))

        self.transition = [[] for i in range(len(state))]
        if not self.is_test:
            for i in range(len(state)):
                self.transition[i] = [state[i], np.asarray(actions[i]), reward, nextstate[i], True]
                self.memory.store(*self.transition[i])

        if (len(self.memory) >  BATCH_SIZE and self.total_step > INITIAL_RANDOM_STEPS):
            self.update_model()
    
    def update_model(self):
        device = self.device

        samples = self.memory.sample_batch()
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.FloatTensor(samples["acts"].reshape(-1, ACTION_DIM)).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1,1)).to(device)
        done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
        new_action, log_prob = self.actor(state)

        alpha_loss = ( -self.log_alpha.exp() * (log_prob + self.target_alpha).detach()).mean()

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        alpha = self.log_alpha.exp()

        mask = 1 - done
        q1_pred = self.qf1(state, action)
        q2_pred = self.qf2(state, action)
        vf_target = self.vf_target(next_state)
        q_target = reward + GAMMMA * vf_target * mask
        qf1_loss = F.mse_loss(q_target.detach(), q1_pred)
        qf2_loss = F.mse_loss(q_target.detach(), q2_pred)

        v_pred = self.vf(state)
        q_pred = torch.min(
            self.qf1(state, new_action), self.qf2(state, new_action)
        )
        v_target = q_pred - alpha * log_prob
        v_loss = F.mse_loss(v_pred, v_target.detach())

        if self.total_step % POLICY_UPDATE_FREQUENCE == 0:
            advantage = q_pred - v_pred.detach()
            actor_loss = (alpha * log_prob - advantage).mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            self._target_soft_update()
        else:
            actor_loss = torch.zeros(1)
        
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        qf_loss = qf1_loss + qf2_loss

        self.vf_optimizer.zero_grad()
        v_loss.backward()
        self.vf_optimizer.step()

        return actor_loss.data, qf_loss.data, v_loss.data, alpha_loss.data
    
    def episode_end(self, scenarioName):
        print('episode end', scenarioName)    
        if len(scenarioName)== 0:
            return

        repetition = int(scenarioName.split('Rep')[1])
        if repetition % 16 == 0 and not self.is_test:  
            torch.save(self.actor.state_dict(), "results/mactor.pt")
            torch.save(self.qf1.state_dict(), "results/mqf1.pt")
            torch.save(self.qf2.state_dict(), "results/mqf2.pt")
            torch.save(self.vf.state_dict(), "results/mvf.pt")

        tc.dump_pickle(self.state[scenarioName], 'results/save/state_' + scenarioName)
        tc.dump_pickle(self.reward_per_action[scenarioName], 'results/save/reward_' + scenarioName)
        tc.dump_pickle(self.actions[scenarioName], 'results/save/actions_' + scenarioName)        

    def load_models(self):
        # The models were trained on a CUDA device
        # If you are running on a CPU-only machine, use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
        self.actor.load_state_dict(torch.load("results/mactor.pt"))
        self.qf1.load_state_dict(torch.load("results/mqf1.pt"))
        self.qf2.load_state_dict(torch.load("results/mqf2.pt"))
        self.vf.load_state_dict(torch.load("results/mvf.pt")) 
    
    def _target_soft_update(self):
        for t_param, l_param in zip(
            self.vf_target.parameters(), self.vf.parameters()
        ):
            t_param.data.copy_(TAU * l_param.data + (1.0 - TAU) * t_param.data)

