import sys
import numpy as np
import pathlib
import yaml
import time
import random
import gym
import torch
import copy

# import sys
# # add the current file path to sys.path
# import os
# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from .base import make_env, load_muzero_network_and_config, MonteCarloPolicyEvaluation, MuZeroRawPolicy, MuZeroPVMCTSPolicy

def load_policy(policy, model_step=None, muzero_config=None, muzero_network=None):
    muzero_config = copy.deepcopy(muzero_config)
    deterministic = policy.get("deterministic", False)
    use_MCTS = policy.get("use_MCTS", False)
    params = {"deterministic": deterministic}
    if not use_MCTS:
        assert muzero_network is not None
        policy = MuZeroRawPolicy(
            muzero_network=muzero_network, 
            params={"deterministic": deterministic}
        )
        print(f"MuZero policy with config {params} is constructed.")
    else:
        temperature = policy.get("temperature", muzero_config.visit_softmax_temperature_fn(num_moves=0, trained_steps=int(model_step)))
        num_simulations = policy.get("num_simulations", muzero_config.num_simulations)
        params["num_simulations"] = num_simulations
        params["temperature"] = temperature
        params["use_uniform_prior"] = policy.get("use_uniform_prior", False)
        params["use_true_model"] = policy.get("use_true_model", False)
        params["rollout_horizon"] = policy.get("rollout_horizon", -1)
        params["with_exploration_noise"] = policy.get("with_exploration_noise", True)
        params["log_prob_error"] = policy.get("log_prob_error", False)
        if "pb_c_init" in policy:
            muzero_config.pb_c_init = policy["pb_c_init"]
            print("pb_c_init is set to {}".format(muzero_config.pb_c_init))
        policy = MuZeroPVMCTSPolicy(
            muzero_network=muzero_network, 
            muzero_config=muzero_config, 
            params=params
        )
        print(f"MCTS policy with config {params} is constructed.")
    return policy

if __name__ == "__main__":

    # configurations
    config_path = sys.argv[1]
    exp_config = yaml.safe_load(open(config_path, "r"))
    env_name = exp_config["env_name"]
    run_id = exp_config.get("run_id", None)
    model_steps = exp_config.get("model_steps", None)
    policies = exp_config["policies"]
    num_episodes = exp_config["num_episodes"]
    seed = exp_config.get("seed", 0)
    log_prob_error = exp_config.get("log_prob_error", False)

    # seeding
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("seed:", seed)

    # build env pool
    require_env_pool = False
    for policy_key in policies.keys():
        if policies[policy_key].get("use_true_model", False):
            require_env_pool = True
            break
    if require_env_pool:
        env_pool = [make_env(env_name)[0] for _ in range(num_episodes)]
        print(f"env pool with {num_episodes} envs is constructed.")

    print("===========================================")
    for model_step in model_steps:
        
        if run_id is None:
            pass
        else:
            muzero_network, muzero_config = load_muzero_network_and_config(env_name, run_id, model_step)
        
        for policy_key in policies.keys():

            print(f"evaluating run_id: {run_id}, model_step: {model_step}, policy: {policy_key}")
            print("-------------------------------------------")

            start_time = time.time()
            _, horizon = make_env(env_name)
            policy_config = policies[policy_key]
            policy = load_policy(policy_config, model_step, muzero_config, muzero_network)

            # if use true model and use mcts
            if (policy_config.get("use_true_model", False) and policy_config.get("use_MCTS", False)) or log_prob_error is True:
                augment_state_with_env = True
            else:
                augment_state_with_env = False
            
            evaluation_config = {"horizon": horizon, "n_episodes": num_episodes, "augment_state_with_env": augment_state_with_env}
            print(f"evaluation config: {evaluation_config}")
            print("-------------------------------------------")

            results = [None for _ in range(num_episodes)]
            episodic_rewards = np.zeros((num_episodes, ))
            results_count = 0

            env_time = 0
            policy_time = 0
            with torch.inference_mode():
                env_seed = random.randint(0, 1000000)
                print("environment seed:", env_seed)
                if augment_state_with_env:
                    envs = gym.vector.SyncVectorEnv([lambda x=ep_i: make_env(env_name, seed=env_seed+x)[0] for ep_i in range(num_episodes)])
                else:
                    envs = gym.vector.AsyncVectorEnv([lambda x=ep_i: make_env(env_name, seed=env_seed+x)[0] for ep_i in range(num_episodes)])
                obs = envs.reset()

                if log_prob_error is False:
                
                    count = 0
                    while results_count < num_episodes:
                        start_policy = time.time()
                        if augment_state_with_env:
                            env_states = []
                            for env_i, env in enumerate(envs.envs):
                                if results[env_i] is None:
                                    env_states.append((env.get_state(), False))
                                else:
                                    env_states.append((None, True))
                            actions = policy.act((obs, env_states), env_pool=env_pool)
                        else:
                            actions = policy.act(obs)
                        policy_time += time.time() - start_policy
                        start_env = time.time()
                        obs, rewards, dones, infos = envs.step(actions)
                        env_time += time.time() - start_env
                        episodic_rewards += rewards

                        for i in range(len(dones)):
                            if dones[i] and results[i] is None:
                                results[i] = episodic_rewards[i]
                                results_count += 1
                                print("num", results_count, "index", i, "rewards", episodic_rewards[i], "length", count+1, flush=True)
                        count += 1
                    
                    envs.close()

                else:
                    env_seed = random.randint(0, 1000000)
                print("environment seed:", env_seed)
                if augment_state_with_env:
                    envs = gym.vector.SyncVectorEnv([lambda x=ep_i: make_env(env_name, seed=env_seed+x)[0] for ep_i in range(num_episodes)])
                else:
                    envs = gym.vector.AsyncVectorEnv([lambda x=ep_i: make_env(env_name, seed=env_seed+x)[0] for ep_i in range(num_episodes)])
                obs = envs.reset()

                if log_prob_error is False:
                
                    count = 0
                    while results_count < num_episodes:
                        start_policy = time.time()
                        if augment_state_with_env:
                            env_states = []
                            for env_i, env in enumerate(envs.envs):
                                if results[env_i] is None:
                                    env_states.append((env.get_state(), False))
                                else:
                                    env_states.append((None, True))
                            actions = policy.act((obs, env_states), env_pool=env_pool)
                        else:
                            actions = policy.act(obs)
                        policy_time += time.time() - start_policy
                        start_env = time.time()
                        obs, rewards, dones, infos = envs.step(actions)
                        env_time += time.time() - start_env
                        episodic_rewards += rewards

                        for i in range(len(dones)):
                            if dones[i] and results[i] is None:
                                results[i] = episodic_rewards[i]
                                results_count += 1
                                print("num", results_count, "index", i, "rewards", episodic_rewards[i], "length", count+1, flush=True)
                        count += 1
                    
                    envs.close()


            print(f"env time: {env_time}, policy time: {policy_time}")
            print(f"average: {np.average(results)}, ste: {np.std(results)/np.sqrt(len(results))}")

            results_name = policy_key
            if run_id is not None:
                results_name += "_runid_{}_modelstep_{}".format(run_id, model_step)

            results_folder_path = "policy_evaluation/results/test_policies/{}".format(env_name)
            pathlib.Path(results_folder_path).mkdir(parents=True, exist_ok=True)
            results_path = results_folder_path+"/{}.npy".format(results_name)
            np.save(results_path, results)
            print(f"results saved to {results_path}")
            print(f"time elapsed: {time.time() - start_time}")

            print("===========================================")