# -*- coding: utf-8 -*-
# Created by mjribeiro at 19-6-2021

# -*- coding: utf-8 -*-
# Created by mjribeiro at 18-10-2019
from bluesky.traffic.DDPG import TempConfig
from bluesky.traffic.DDPG.ActorNetwork import ActorNetwork
from bluesky.traffic.DDPG.CriticNetwork import CriticNetwork
from bluesky.traffic.DDPG.OU import OrnsteinUhlenbeckActionNoise
from bluesky.traffic.DDPG.ReplayBuffer import ReplayBuffer
import tensorflow as tf
import numpy as np
import glob
import os

# from keras.models import load_model

BUFFER_SIZE = 1000000
BATCH_SIZE = 256
TAU = 0.001  # Target Network HyperParameters, for soft update of target parameters
LRA = 0.0001  # Learning rate for Actor
LRC = 0.001  # Lerning rate for Critic
EPSILON = 0.1  # no exploration
ALPHA = 0.9  # learning rate

ACTION_DIM = 4
STATE_DIM = 16

class DDPG(object):

    def __init__(self, testingMode=False):
        self.testingMode = testingMode

        self.episodeStarted = False
        self.waiting_reward = False
        self.episodecount = -1

        self.totalLoSs = np.array([])
        self.L_per_Episode = np.array([])

        # save performed actions
        self.reward_per_action = dict()
        self.states = dict()
        self.next_states = dict()
        self.actions = dict()
        self.numberSteps_all = dict()
        self.layer_per_action = dict()

        tf.compat.v1.disable_eager_execution()
        # in TF 2.0 you should use tf.compat.v1.Session() instead of tf.Session()
        sess = tf.compat.v1.Session()
        tf.compat.v1.keras.backend.set_session(sess)

        self.action_dim = ACTION_DIM
        self.state_dim = STATE_DIM
        #
        self.actors = []
        self.critics = []

        # if self.testingMode:
        # self.allWeights_actor = sorted(glob.glob(os.path.join("results", "DDPG_actor*.h5")))
        # self.allWeights_critic = sorted(glob.glob(os.path.join("results", "DDPG_critic*.h5")))
        # for i in range(len(self.allWeights_actor)):
        self.actor = ActorNetwork(sess, self.state_dim, self.action_dim, BATCH_SIZE, TAU, LRA)
        self.critic = CriticNetwork(sess, self.state_dim, self.action_dim, BATCH_SIZE, TAU, LRC)
        #self.actor.model.
        #self.actor.model.load_weights(self.allWeights_actor[i])
        #self.critic.model.load_weights(self.allWeights_critic[i])
        # else:
        # self.actor = ActorNetwork(sess, self.state_dim, self.action_dim, BATCH_SIZE, TAU, LRA)
        # self.critic = CriticNetwork(sess, self.state_dim, self.action_dim, BATCH_SIZE, TAU, LRC)

        # Initialize replay memorybatch
        self.buff = ReplayBuffer(BUFFER_SIZE)  # Create replay buffer
        # exploration noise
        # The authors of the original DDPG paper recommended time-correlated OU.py noise
        # However,  more recent results suggest that uncorrelated, mean-zero Gaussian noise works perfectly well ?
        # TODO: try mean-zero Gaussian noise later!
        # mu - long - running mean
        # theta - the speed of mean  reversion
        # sigma -  the volatility parameter
        self.noise1 = OrnsteinUhlenbeckActionNoise(ACTION_DIM)

        self.L = np.array([])
        self.L_per_step = np.array([])

        self.numberSteps = 0

    def start_episode(self):
        self.episodeStarted = True
        self.episodecount += 1
        self.reward_in_this_episode = 0

        # save losses
        self.L = np.array([])
        self.L_per_step = np.array([])

    def episode_Started(self):
        return self.episodeStarted

    def getEpisodeCount(self):
        return self.episodecount

    def getStepsCount(self):
        return self.numberSteps

    def normalizeState(self, s_t): 
        # normalized distance
        s_t[0:5] = s_t[0:5] / 3000

        # normalized timetoCPA
        s_t[5:10] = s_t[5:10] / 120

        # normalized bearing
        s_t[10:15] = s_t[10:15] / 180

        # # bearing to next waypoint
        # s_t[15] = s_t[15] / 360

        # own speed
        s_t[15] = s_t[15] / 18

        return s_t

    def do_step(self, s_t, scenname):
        s_t_initial = s_t
        # print('--- state', s_t)
        s_t = self.normalizeState(np.asarray(s_t))
        if np.any(s_t) > 1:
            print(s_t)

        #with tf.Graph().as_default():
        a_t_original = self.actor.model.predict(s_t.reshape(1, s_t.shape[0]))

        actions = a_t_original[0]
        # if not self.testingMode:  # add noise
        #     # repetition = scenname.split('-')[-1]
        #     # repetition = repetition.split('_')[0]
        #     repetition = scenname.split('Rep')[1]
        #     repetition = int(repetition.split('.')[0])
        #     # if repetition < 5000:
        #     #actions += self.noise1.sample()  # (5000 - repetition) /5000 * (self.noise1.sample())
        #     # actions += self.noise1.sample()
        #     actions = np.clip(actions, 0, 1)

        if np.isnan(np.asarray(actions)).any():
            print('ACTION', s_t_initial, a_t_original[0])

        return actions

    def batch_update(self):
        # Do the batch update
        batch = self.buff.getBatch(BATCH_SIZE)
        states = np.asarray([e[0] for e in batch])
        actions = np.asarray([e[1] for e in batch])
        rewards = np.asarray([e[2] for e in batch])
        new_states = np.asarray([e[3] for e in batch])
        y_t = np.asarray([e[1] for e in batch])

        # calculate targets
        if len(states) > 0:
            new_states_predicted = self.actor.target_model.predict(new_states)
            target_q_values = self.critic.target_model.predict([new_states, new_states_predicted])

            for k in range(len(batch)):
                y_t[k] = rewards[k] + ALPHA * target_q_values[k]

            # update the critic given the targets
            loss = self.critic.model.train_on_batch([states, actions], y_t)
            self.L = np.append(self.L, loss)

            # update the actor policy using the sampled gradient
            a_for_grad = self.actor.model.predict(states)
            # states = states.reshape(len(states), self.state_dim)

            grads = self.critic.gradients(states, a_for_grad)
            self.actor.train(states, grads)

            self.actor.target_train()
            self.critic.target_train()

    def setResult(self, scenname, state, nextstate, rewards, actions):

        if self.states.get(scenname) is None:
            self.states[scenname] = np.array([])
        if self.next_states.get(scenname) is None:
            self.next_states[scenname] = np.array([])
        if self.reward_per_action.get(scenname) is None:
            self.reward_per_action[scenname] = np.array([])
        if self.actions.get(scenname) is None:
            self.actions[scenname] = np.array([])
        if self.layer_per_action.get(scenname) is None:
            self.layer_per_action[scenname] = np.array([])

        self.states[scenname] = np.append(self.states[scenname], state)
        self.next_states[scenname] = np.append(self.next_states[scenname], nextstate)
        self.reward_per_action[scenname] = np.append(self.reward_per_action[scenname], rewards)
        self.layer_per_action[scenname] = np.append(self.layer_per_action[scenname], np.argmax(actions))
        self.actions[scenname] = np.append(self.actions[scenname], actions)

        # states = [state / NORMALIZATION_FACTOR for state in states]
        # nextstates = [nextstate / NORMALIZATION_FACTOR for nextstate in nextstates]
        state = self.normalizeState(np.asarray(state))
        nextstate = self.normalizeState(np.asarray(nextstate))

        rewards = rewards / 10

        # print('setResult', scenname, state, nextstate, rewards/100, actions)
        # print('rewards', rewards)

        self.buff.add(state, actions, rewards, nextstate)  # Add replay buffer

        self.batch_update()
        self.L_per_step = np.append(self.L_per_step, np.mean(self.L))

        self.waiting_reward = False

    def episode_end(self, scenarioName):
        print('episode end', scenarioName)
        self.episodeStarted = False
        self.waiting_reward = False

        self.numberSteps_all[scenarioName] = self.numberSteps
        if len(scenarioName) == 0:
            return
        repetition = scenarioName.split('Rep')[1]
        repetition = int(repetition.split('.')[0])
        if repetition % 500 == 0:
            TempConfig.save_DDQL('results', "DDPG_critic_" + scenarioName + ".h5", self.critic)
            TempConfig.save_DDQL('results', "DDPG_actor_" + scenarioName + ".h5", self.actor)
            TempConfig.dump_pickle(self.L, 'results/save/loss_' + scenarioName)

        if scenarioName not in self.reward_per_action:
            return

        TempConfig.dump_pickle(self.reward_per_action[scenarioName], 'results/save/reward_' + scenarioName)
        print('max reward', min(self.reward_per_action[scenarioName]))
        # print(scenarioName, 'REWARDS', self.reward_per_action[scenarioName], 'ACTIONS', self.actions[scenarioName])
        # TempConfig.dump_pickle(self.states[scenarioName], 'results/save/states_' + scenarioName)
        TempConfig.dump_pickle(self.actions[scenarioName], 'results/save/actions_' + scenarioName)
        # TempConfig.dump_pickle(self.next_states[scenarioName], 'results/save/next_states_' + scenarioName)