from .base import make_env, load_muzero_network_and_config, MuZeroModel, MuZeroPVMCTSPolicy
import sys
import numpy as np
import pathlib
import torch
import yaml
import random
import time
import multiprocessing

ENUMERATE = 0
SAMPLE = 1
ENUMERATE2 = 2
MCTS = 3

if __name__ == "__main__":

    # configurations
    config_path = sys.argv[1]
    exp_config = yaml.safe_load(open(config_path, "r"))
    env_name = exp_config["env_name"]
    run_id = exp_config.get("run_id", None)
    evaluation_horizon = exp_config["evaluation_horizon"]
    N_states = exp_config["num_states_to_sample"]
    seed = exp_config.get("seed", 0)
    
    # seeding
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("seed:", seed)

    # action sampling method
    action_sampling = exp_config["action_sampling"]
    if action_sampling["method"] == "enumerate":
        action_sampling_method = ENUMERATE
        action_space_size = make_env(env_name)[0].action_space_size
        num_envs_needed = int((action_space_size ** (evaluation_horizon+1) - 1) / (action_space_size - 1))
        print(num_envs_needed)
        N_rollouts_per_state = action_space_size ** evaluation_horizon
    elif action_sampling["method"].startswith("sample"):
        action_sampling_method = SAMPLE
        N_rollouts_per_state = action_sampling["N_rollouts_per_state"]
        noise_rate = action_sampling["noise_rate"]
        num_envs_needed = N_rollouts_per_state
        compute_action_probability = action_sampling.get("compute_action_probability", True)
        action_space_size = make_env(env_name)[0].action_space_size
    elif action_sampling["method"] == "enumerate2":
        action_sampling_method = ENUMERATE2
        action_space_size = make_env(env_name)[0].action_space_size
        num_envs_needed = 1
        N_rollouts_per_state = action_space_size ** evaluation_horizon
    elif action_sampling["method"] == "MCTS":
        action_sampling_method = MCTS
        N_rollouts_per_state = action_sampling["num_simulations"]
        num_envs_needed = N_rollouts_per_state
        action_space_size = make_env(env_name)[0].action_space_size
        assert evaluation_horizon == action_sampling["rollout_horizon"]
        compute_action_probability = action_sampling.get("compute_action_probability", True)
    def make_new_env(name):
        return make_env(name)[0]

    # env pool
    start_env_pool = time.time()
    env_names = [env_name for _ in range(num_envs_needed)]
    with multiprocessing.Pool() as pool:
        env_pool = pool.map(make_new_env, env_names)
    print(f"building {num_envs_needed} env took {time.time() - start_env_pool} seconds")

    # iterate over all evaluations
    print("===========================================")
    for evaluation in exp_config["evaluations"]:

        start_time = time.time()
        
        print(evaluation)
        model_step = evaluation["model_step"]
        policy_step = evaluation["policy_step"]
        state_distribution = evaluation["state_distribution"]

        print(f"evaluating run id {run_id} with model step {model_step} and policy step {policy_step} and state distribution {state_distribution}")
        print('-------------------------------------------')

        # load the model for the learned model 
        muzero_network_for_model, muzero_config_for_model = load_muzero_network_and_config(
            env_name, 
            run_id, 
            model_step)

        # load the model for the policy to evaluate
        muzero_network_for_policy_to_evaluate, muzero_config_for_policy_to_evaluate = load_muzero_network_and_config(
            env_name, 
            run_id, 
            policy_step)

        # load the policy for state sampling
        if state_distribution == "model":
            muzero_network_for_state_sampling = muzero_network_for_model
            muzero_config_for_state_sampling = muzero_config_for_model
            state_sampling_policy_step = model_step
        elif state_distribution == "policy":
            muzero_network_for_state_sampling = muzero_network_for_policy_to_evaluate
            muzero_config_for_state_sampling = muzero_config_for_policy_to_evaluate
            state_sampling_policy_step = policy_step
        else:
            # could also be the initial state distribution
            raise ValueError(f"Unknown state distribution {state_distribution}")

        # load the MuZero Model
        muzero_model = MuZeroModel(muzero_network_for_model)

        # load the policy to evaluate
        temperature_for_policy_to_evaluate = muzero_config_for_policy_to_evaluate.visit_softmax_temperature_fn(num_moves=0, trained_steps=int(policy_step))
        add_one_visit_to_every_action = False
        if action_sampling_method in [ENUMERATE, ENUMERATE2]:
            add_one_visit_to_every_action = True
        policy_to_evaluate_params = {
            "deterministic": False, 
            "temperature": temperature_for_policy_to_evaluate,
        }
        policy_to_evaluate = MuZeroPVMCTSPolicy(
            muzero_network_for_policy_to_evaluate, 
            muzero_config_for_policy_to_evaluate, 
            params=policy_to_evaluate_params,
            add_one_visit_to_every_action=add_one_visit_to_every_action)
        print("evaluation policy", policy_to_evaluate_params)

        if action_sampling_method == MCTS:
            policy = exp_config["action_sampling"]
            temperature = policy.get("temperature", muzero_config_for_state_sampling.visit_softmax_temperature_fn(num_moves=0, trained_steps=int(state_sampling_policy_step)))
            num_simulations = policy.get("num_simulations", muzero_config_for_state_sampling.num_simulations)
            deterministic = policy.get("deterministic", False)
            state_sampling_policy_params = {"deterministic": deterministic}
            state_sampling_policy_params["num_simulations"] = num_simulations
            state_sampling_policy_params["temperature"] = temperature
            state_sampling_policy_params["use_uniform_prior"] = policy.get("use_uniform_prior", False)
            state_sampling_policy_params["use_true_model"] = policy.get("use_true_model", False)
            state_sampling_policy_params["rollout_horizon"] = policy.get("rollout_horizon", -1)
            state_sampling_policy = MuZeroPVMCTSPolicy(
                muzero_network=muzero_network_for_state_sampling, 
                muzero_config=muzero_config_for_state_sampling, 
                params=state_sampling_policy_params
            )
            print("state sampling policy", state_sampling_policy_params)
        else:
            # load the state sampling policy
            temperature_for_state_sampling_policy = muzero_config_for_state_sampling.visit_softmax_temperature_fn(num_moves=0, trained_steps=int(state_sampling_policy_step))
            state_sampling_policy_params = {
                "deterministic": False, 
                "temperature": temperature_for_state_sampling_policy,
            }
            state_sampling_policy = MuZeroPVMCTSPolicy(
                muzero_network_for_state_sampling, 
                muzero_config_for_state_sampling, 
                params=state_sampling_policy_params)
            print("state sampling policy", state_sampling_policy_params)
        
        with torch.inference_mode():

            """ DESIGN
            1. no masking on muzero rewards for future research
            2. masking on true rewards (to 0.0)
            3. masking on action probabilities (to 0.0)
            4. assume env.step does not affect np_random
            """ 

            true_rewards = np.zeros((N_states, N_rollouts_per_state, evaluation_horizon))
            muzero_rewards = np.zeros((N_states, N_rollouts_per_state, evaluation_horizon))
            action_probabilities = np.zeros((N_states, N_rollouts_per_state, evaluation_horizon))
            nonterminal_masks = np.zeros((N_states, N_rollouts_per_state, evaluation_horizon))

            env_time = 0
            policy_time = 0
            model_time = 0

            # build the environment for state sampling
            env_seed = random.randint(0, 1000000)
            env, _ = make_env(env_name, seed=env_seed)
            obs = env.reset()

            # outer loop: sample states
            state_count = 0
            while state_count < N_states:

                """ Model Evaluation """
                # TODO: parallelize the environment step
                if action_sampling_method == ENUMERATE:
                    # enumerate all possible actions
                    # encode the state once
                    run_model_states = torch.repeat_interleave(muzero_model.encode(obs), 1, dim=0)
                    run_env_dones = np.zeros((1, ))
                    env_last_starting_counter = 0
                    env_pool[0].copy(env)
                    eval_obs = np.array([obs])
                    for eval_step in range(evaluation_horizon):
                        # print(f"eval step {eval_step}")
                        # compute action probabilities
                        start_policy_time = time.time()
                        policies = policy_to_evaluate.policy(eval_obs)
                        policy_time += time.time() - start_policy_time
                        assert policies.shape == (action_space_size**eval_step, action_space_size)
                        assert len(run_env_dones) == policies.shape[0]
                        # compute the number of evaluations needed
                        n_evals = action_space_size ** (eval_step+1)
                        run_env_dones = np.repeat(run_env_dones, action_space_size)
                        assert len(run_env_dones) == n_evals
                        # compute the actions
                        eval_actions = np.tile(np.arange(0, action_space_size), action_space_size ** eval_step)
                        assert len(eval_actions) == n_evals
                        # step the envs and compute rewards
                        rewards_this_step = np.zeros((n_evals, ))
                        done_this_step = np.zeros((n_evals, ))
                        eval_obs = np.zeros((n_evals, *obs.shape))
                        starting_index = int((action_space_size ** (eval_step+1) - 1) / (action_space_size - 1))
                        start_env = time.time()
                        for i_env in range(n_evals):
                            if run_env_dones[i_env] == 1:
                                continue
                            env_index_to_copy = int((action_space_size ** (eval_step) - 1) / (action_space_size - 1)) + (i_env // action_space_size)
                            env_pool[starting_index+i_env].copy(env_pool[env_index_to_copy])
                            # print(f"{starting_index+i_env} <- {env_index_to_copy} <- action {eval_actions[i_env]}")
                            eval_obs[i_env], rewards_this_step[i_env], done_this_step[i_env], _ = env_pool[starting_index+i_env].step(eval_actions[i_env])
                        env_time += time.time() - start_env
                        # now compute muzero rewards
                        run_model_states = torch.repeat_interleave(run_model_states, action_space_size, dim=0)
                        start_model = time.time()
                        run_model_states, muzero_rewards_this_step, _, _ = muzero_model.batch_step(run_model_states, eval_actions)
                        model_time += time.time() - start_model
                        # muzero_rewards_this_step = muzero_rewards_this_step * (1 - run_env_dones)
                        assert rewards_this_step[run_env_dones == 1].sum() == 0
                        assert (rewards_this_step == rewards_this_step * (1 - run_env_dones)).all()
                        env_last_starting_counter += n_evals
                        repeat_times = action_space_size ** (evaluation_horizon - eval_step - 1)
                        true_rewards[state_count, :, eval_step] = np.repeat(rewards_this_step, repeat_times)
                        muzero_rewards[state_count, :, eval_step] = np.repeat(muzero_rewards_this_step, repeat_times)
                        nonterminal_masks[state_count, :, eval_step] = np.repeat(1 - run_env_dones, repeat_times)
                        # note that masking happens to the true rewards implicitly
                        masked_policies = np.reshape(policies, (n_evals, )) * (1-run_env_dones)
                        assert masked_policies.shape == (n_evals, )
                        action_probabilities[state_count, :, eval_step] = np.repeat(masked_policies, repeat_times)
                        run_env_dones = np.logical_or(run_env_dones, done_this_step)
                
                elif action_sampling_method == ENUMERATE2:
                    # enumerate all possible actions
                    # encode the state once
                    run_model_states = torch.repeat_interleave(muzero_model.encode(obs), 1, dim=0)
                    run_env_dones = np.zeros((1, ))
                    env_last_starting_counter = 0
                    the_simulator_env = env_pool[0]
                    assert len(env_pool) == 1
                    env_states = [env.get_state()]
                    eval_obs = np.array([obs])
                    for eval_step in range(evaluation_horizon):
                        # print(f"eval step {eval_step}")
                        # compute action probabilities
                        start_policy_time = time.time()
                        policies = policy_to_evaluate.policy(eval_obs)
                        policy_time += time.time() - start_policy_time
                        assert policies.shape == (action_space_size**eval_step, action_space_size)
                        assert len(run_env_dones) == policies.shape[0]
                        # compute the number of evaluations needed
                        n_evals = action_space_size ** (eval_step+1)
                        run_env_dones = np.repeat(run_env_dones, action_space_size)
                        assert len(run_env_dones) == n_evals
                        # compute the actions
                        eval_actions = np.tile(np.arange(0, action_space_size), action_space_size ** eval_step)
                        assert len(eval_actions) == n_evals
                        # step the envs and compute rewards
                        rewards_this_step = np.zeros((n_evals, ))
                        done_this_step = np.zeros((n_evals, ))
                        eval_obs = np.zeros((n_evals, *obs.shape))
                        starting_index = int((action_space_size ** (eval_step+1) - 1) / (action_space_size - 1))
                        start_env = time.time()
                        for i_env in range(n_evals):
                            if run_env_dones[i_env] == 1:
                                env_states.append(None)
                            else:
                                env_index_to_copy = int((action_space_size ** (eval_step) - 1) / (action_space_size - 1)) + (i_env // action_space_size)
                                the_simulator_env.set_state(env_states[env_index_to_copy])
                                eval_obs[i_env], rewards_this_step[i_env], done_this_step[i_env], _ = the_simulator_env.step(eval_actions[i_env])
                                env_states.append(the_simulator_env.get_state())
                            assert len(env_states) == starting_index + i_env + 1
                            # print(f"{starting_index+i_env} <- {env_index_to_copy} <- action {eval_actions[i_env]}")
                            
                        env_time += time.time() - start_env
                        # now compute muzero rewards
                        run_model_states = torch.repeat_interleave(run_model_states, action_space_size, dim=0)
                        start_model = time.time()
                        run_model_states, muzero_rewards_this_step, _, _ = muzero_model.batch_step(run_model_states, eval_actions)
                        model_time += time.time() - start_model
                        # muzero_rewards_this_step = muzero_rewards_this_step * (1 - run_env_dones)
                        assert rewards_this_step[run_env_dones == 1].sum() == 0
                        assert (rewards_this_step == rewards_this_step * (1 - run_env_dones)).all()
                        env_last_starting_counter += n_evals
                        repeat_times = action_space_size ** (evaluation_horizon - eval_step - 1)
                        true_rewards[state_count, :, eval_step] = np.repeat(rewards_this_step, repeat_times)
                        muzero_rewards[state_count, :, eval_step] = np.repeat(muzero_rewards_this_step, repeat_times)
                        nonterminal_masks[state_count, :, eval_step] = np.repeat(1 - run_env_dones, repeat_times)
                        # note that masking happens to the true rewards implicitly
                        masked_policies = np.reshape(policies, (n_evals, )) * (1-run_env_dones)
                        assert masked_policies.shape == (n_evals, )
                        action_probabilities[state_count, :, eval_step] = np.repeat(masked_policies, repeat_times)
                        run_env_dones = np.logical_or(run_env_dones, done_this_step)

                elif action_sampling_method == SAMPLE:
                    # sample actions

                    for env_index in range(N_rollouts_per_state):
                        env_pool[env_index].copy(env)

                    eval_dones = np.zeros((N_rollouts_per_state, ))
                    eval_obs = np.repeat(obs[np.newaxis,:], N_rollouts_per_state, axis=0)
                    model_states = torch.repeat_interleave(muzero_model.encode(obs), N_rollouts_per_state, dim=0)
                    run_true_rewards = np.zeros((N_rollouts_per_state, ))
                    run_dones = np.zeros((N_rollouts_per_state, ))

                    for eval_step in range(evaluation_horizon):
                        # sample actions
                        start_policy_time = time.time()
                        if not compute_action_probability and noise_rate == 1.0:
                            eval_actions = np.random.randint(0, action_space_size, size=(N_rollouts_per_state, ))
                            eval_probs = np.zeros((N_rollouts_per_state, ))
                        else:
                            eval_policies = policy_to_evaluate.policy(eval_obs)
                            policy_time += time.time() - start_policy_time
                            eval_actions = np.array([np.random.choice(len(eval_policy), p=eval_policy) for eval_policy in eval_policies])
                            eval_probs = eval_policies[np.arange(0, len(eval_actions)), eval_actions]
                        assert eval_probs.shape == (N_rollouts_per_state, )
                        assert eval_probs.min() >= 0.0 and eval_probs.max() <= 1.0
                        start_policy_time = time.time()
                        for env_index in range(N_rollouts_per_state):
                            if eval_dones[env_index] == 1:
                                continue
                            eval_obs[env_index], run_true_rewards[env_index], run_dones[env_index], _ = env_pool[env_index].step(eval_actions[env_index])
                        env_time += time.time() - start_policy_time
                        not_done_yet = np.logical_not(eval_dones)
                        # true rewards
                        masked_eval_true_rewards = run_true_rewards * not_done_yet
                        assert (run_true_rewards == masked_eval_true_rewards).all()
                        true_rewards[state_count, :, eval_step] = masked_eval_true_rewards
                        # muzero rewards
                        start_model_time = time.time()
                        model_states, muzero_rewards[state_count, :, eval_step], _, _ = muzero_model.batch_step(model_states, eval_actions)
                        model_time += time.time() - start_model_time
                        # action probabilities
                        action_probabilities[state_count, :, eval_step] = eval_probs * not_done_yet
                        # handle done
                        nonterminal_masks[state_count, :, eval_step] = not_done_yet
                        eval_dones = np.logical_or(eval_dones, run_dones)
                
                elif action_sampling_method == MCTS:
                    pass
                
                else:
                    raise ValueError("invalid action sampling method")
                
                """ Model Evaluation """

                # action sampling
                if action_sampling_method == MCTS:

                    extra_info = {}
                    action = state_sampling_policy.act(obs[np.newaxis, :], extra_info=extra_info)[0]
                    trajectories = np.array(extra_info["trajectories"])
                    trajectories = trajectories[:, 0, :]
                    assert trajectories.shape == (N_rollouts_per_state, evaluation_horizon)

                    for env_index in range(N_rollouts_per_state):
                        env_pool[env_index].copy(env)

                    eval_dones = np.zeros((N_rollouts_per_state, ))
                    eval_obs = np.repeat(obs[np.newaxis,:], N_rollouts_per_state, axis=0)
                    model_states = torch.repeat_interleave(muzero_model.encode(obs), N_rollouts_per_state, dim=0)
                    run_true_rewards = np.zeros((N_rollouts_per_state, ))
                    run_dones = np.zeros((N_rollouts_per_state, ))

                    for eval_step in range(evaluation_horizon):
                        eval_actions = trajectories[:, eval_step]
                        assert eval_actions.shape == (N_rollouts_per_state, )
                        start_policy_time = time.time()
                        if compute_action_probability:
                            start_policy_time = time.time()
                            eval_policies = policy_to_evaluate.policy(eval_obs)
                            eval_probs = eval_policies[np.arange(0, len(eval_actions)), eval_actions]
                        else:
                            eval_probs = np.zeros((N_rollouts_per_state, ))
                        policy_time += time.time() - start_policy_time
                        assert eval_probs.shape == (N_rollouts_per_state, )
                        assert eval_probs.min() >= 0.0 and eval_probs.max() <= 1.0
                        start_policy_time = time.time()
                        for env_index in range(N_rollouts_per_state):
                            if eval_dones[env_index] == 1:
                                continue
                            eval_obs[env_index], run_true_rewards[env_index], run_dones[env_index], _ = env_pool[env_index].step(eval_actions[env_index])
                        env_time += time.time() - start_policy_time
                        not_done_yet = np.logical_not(eval_dones)
                        # true rewards
                        masked_eval_true_rewards = run_true_rewards * not_done_yet
                        assert (run_true_rewards == masked_eval_true_rewards).all()
                        true_rewards[state_count, :, eval_step] = masked_eval_true_rewards
                        # muzero rewards
                        start_model_time = time.time()
                        model_states, muzero_rewards[state_count, :, eval_step], _, _ = muzero_model.batch_step(model_states, eval_actions)
                        model_time += time.time() - start_model_time
                        # action probabilities
                        action_probabilities[state_count, :, eval_step] = eval_probs * not_done_yet
                        # handle done
                        nonterminal_masks[state_count, :, eval_step] = not_done_yet
                        eval_dones = np.logical_or(eval_dones, run_dones)
                else:
                    action = state_sampling_policy.act(obs[np.newaxis, :])[0]

                # transition
                print(obs.mean(), action)
                obs, _, done, _ = env.step(action)

                if done:
                    obs = env.reset()
                
                state_count += 1
                print(state_count, flush=True)

        del muzero_network_for_model
        del muzero_config_for_model
        del muzero_network_for_policy_to_evaluate
        del muzero_config_for_policy_to_evaluate

        # save result
        results_array = np.stack([true_rewards, muzero_rewards, action_probabilities, nonterminal_masks], axis=2)
        assert results_array.shape == (N_states, N_rollouts_per_state, 4, evaluation_horizon)

        print(results_array.mean(axis=(0,1,3)))

        results_folder_path = "policy_evaluation/results/test_value_prediction_error/{}".format(env_name)
        pathlib.Path(results_folder_path).mkdir(parents=True, exist_ok=True)
        
        results_name = "runid_{}_modelstep_{}_policystep_{}_state_{}_Nstates_{}_Nrollouts_{}_horizon_{}".format(
            run_id,
            model_step,
            policy_step,
            state_distribution,
            N_states,
            N_rollouts_per_state,
            evaluation_horizon
        )
        if action_sampling_method == SAMPLE:
            results_name += "_sample_noiseratio_{}".format(noise_rate)
        elif action_sampling_method == ENUMERATE:
            results_name += "_enumerate"
        elif action_sampling_method == ENUMERATE2:
            results_name += "_enumerate2"
        elif action_sampling_method == MCTS:
            results_name += "_MCTS_numsim_{}_horizon_{}_uniformprior_{}".format(action_sampling["num_simulations"], action_sampling["rollout_horizon"], action_sampling["use_uniform_prior"])
        else:
            raise ValueError("invalid action sampling method")

        results_path = results_folder_path+"/{}.npy".format(results_name)
        np.save(file=results_path, arr=results_array)
        print(f"env time: {env_time} | model time: {model_time} | policy time: {policy_time}")
        print(f"results saved to {results_path}")
        print(f"time elapsed: {time.time() - start_time}")

        print("===========================================")