import numpy as np
from bluesky.traffic.MADDDPG.super_agent import SuperAgent
import bluesky.traffic.MADDDPG.TempConfig as tc

NUMBER_INTRUDERS_STATE = 2
ACTION_DIM = 2
STATE_DIM = 9

class MADDPG(object):
    def __init__(self):      
        # save performed actions
        self.actions = dict()
        self.reward_per_action = dict()
        self.super_agent = SuperAgent(STATE_DIM, ACTION_DIM)

    def normalizeState(self, s_t):  
       # bearing
        s_t[0] = s_t[0] / 180
        s_t[1] = s_t[1] / 180

        # speed
        s_t[2] = s_t[2] / 18

        for it in range(2):
            # normalized distance
            s_t[it+3] = min(s_t[it+3] / 3000,1)

            # normalized timetoCPA
            s_t[it+5] = min(s_t[it+5] / 600,1)

            # normalized bearing
            s_t[it+7] = s_t[it+7] / 180   

        return s_t

    def do_step(self, s_t, episode_name):
        for it in range(len(s_t)):
            s_t[it] = self.normalizeState(s_t[it])
        actions = self.super_agent.get_actions(s_t)
        return actions

    def setResult(self, scenname, state, nextstate, reward, actions):
        if self.actions.get(scenname) is None:
            self.actions[scenname] = np.array([])
        if self.reward_per_action.get(scenname) is None:
            self.reward_per_action[scenname] = np.array([])
  
        self.reward_per_action[scenname] = np.append(self.reward_per_action[scenname], reward)
        self.actions[scenname] = np.append(self.actions[scenname], actions)

        for it in range(len(state)):
            state[it] = self.normalizeState(np.asarray(state[it]))
            nextstate[it] = self.normalizeState(np.asarray(nextstate[it]))
        #reward = reward/10

        all_state = np.concatenate(state)
        all_nextstate = np.concatenate(nextstate)

        #update
        self.super_agent.update_batch(state, nextstate, actions, all_state, all_nextstate, reward, 0)

    def episode_end(self, scenarioName):
        if len(scenarioName)>0:
            print('episode end', scenarioName)      
            print(scenarioName, 'min reward', min(self.reward_per_action[scenarioName]), 'max reward', max(self.reward_per_action[scenarioName]))
            tc.dump_pickle(self.reward_per_action[scenarioName], 'results/save/reward_' + scenarioName)
            tc.dump_pickle(self.actions[scenarioName], 'results/save/actions_' + scenarioName)
            self.super_agent.episode_end(scenarioName)