import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

for p in sys.path:
    if 'nfs/jinkehe/.local' in p:
        sys.path.remove(p)

import wandb
api = wandb.Api()

import importlib
import itertools
import math
import random

import copy
import json
import pathlib

import gym
import numpy as np
import torch
print("torch version:", torch.__version__)
from torch.distributions.categorical import Categorical
import multiprocessing as mp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

import pickle
from torch.cuda.amp import autocast as autocast

# DIFF: efficient zero
from types import SimpleNamespace
from core.utils import *
from core.mcts import MCTS
import core.ctree.cytree as cytree
# DIFF

""" Policy Evaluation: The Core Policy Evaluation Methods """

# batched version of Monte Carlo Policy Evaluation (supposed to be faster when a GPU is used)
def BatchedMonteCarloPolicyEvaluation(
    model,
    policy,
    params
):
    n_episodes = params["n_episodes"]
    horizon = params["horizon"]

    returns = np.zeros((n_episodes, ))
    s = model.batch_reset(n_episodes)
    count = 0 
    while count <= horizon:
        
        if type(model) == BatchedGymModel:
            states = model.states(s)
            action = policy.act(states)
        else:
            action = policy.act(s)
        
        # done is not used here, make sure we do implement the done function in another place
        s, r, done, _ = model.batch_step(s, action)
        assert len(r) == len(returns)
        returns += r
        count += 1
    return returns

def MonteCarloPolicyEvaluation(
    model,
    policy,
    params
):
    with torch.inference_mode():
        np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
        augment_state_with_env = params["augment_state_with_env"]

        # assumptions: true gym model + MCTS policy
        horizon = params["horizon"]
        the_return = 0
        count = 0
        # make sure a different initial state is sampled?
        s = model.reset()
        while count <= horizon:
            # decide an action with the policy
            if augment_state_with_env:
                action = policy.act([(np.array([s]), model)])[0]
            else:
                action = policy.act(np.array([s]))[0]
            s, r, done, _ = model.step(action)
            the_return += r
            if done:
                break
            count += 1 
        return the_return

def MonteCarloPolicyEvaluationSaved(
    model,
    policy,
    params
):
    saved_data = []
    with torch.inference_mode():
        np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
        augment_state_with_env = params["augment_state_with_env"]

        # assumptions: true gym model + MCTS policy
        horizon = params["horizon"]
        the_return = 0
        count = 0
        # make sure a different initial state is sampled?
        s = model.reset()
        while count <= horizon:
            saved_data.append((copy.deepcopy(model), copy.deepcopy(s)))
            # decide an action with the policy
            if augment_state_with_env:
                action = policy.act([(np.array([s]), model)])[0]
            else:
                action = policy.act(np.array([s]))[0]
            s, r, done, _ = model.step(action)
            the_return += r
            if done:
                break
            count += 1 
        print(count, the_return)
        return the_return, saved_data

def ParallelizedMonteCarloPolicyEvaluation(
    model,
    policy,
    params
):
    n_episodes = params["n_episodes"]

    # if 'SLURM_JOB_CPUS_PER_NODE' in os.environ:
    #     with mp.Pool(int(os.environ['SLURM_JOB_CPUS_PER_NODE'])) as pool:
    #         returns = pool.starmap(MonteCarloPolicyEvaluation, [(model, policy, params) for i in range(n_episodes)])
    #         # returns = pool.starmap(MonteCarloPolicyEvaluation, [(model, policy, params) for i in range(n_episodes)])
    # else:
    #     with mp.Pool(1) as pool:
    #         returns = pool.starmap(MonteCarloPolicyEvaluation, [(model, policy, params) for i in range(n_episodes)])

    returns = []
    for i in range(n_episodes):
        print(i, "doing")
        returns.append(MonteCarloPolicyEvaluation(model, policy, params))

    return np.array(returns)

def SequentialMonteCarloPolicyEvaluation(
    model,
    policy,
    params
):
    n_episodes = params["n_episodes"]

    returns = []
    for i in range(n_episodes):
        returns.append(MonteCarloPolicyEvaluation(model, policy, params))
    
    return np.array(returns)

""" Policy Evaluation: The Models """

# DIFF: MuZero's model
class MuZeroModel:
    
    def __init__(self, muzero_network):
        self.muzero_network = muzero_network
        self.device = next(self.muzero_network.parameters()).device

    # given the state and action, return the next state, reward, done, and info
    def step_(self, state, action):
        action = torch.from_numpy(np.array([[action]])).to(self.device)
        num = state.size(0)
        reward_hidden = (
            torch.zeros( (1, num, self.muzero_network.lstm_hidden_size), dtype=torch.float32, device=self.device), 
            torch.zeros( (1, num, self.muzero_network.lstm_hidden_size), dtype=torch.float32, device=self.device)
        )
        next_state, _, value_prefix = self.muzero_network.dynamics(state, reward_hidden, action)
        reward = self.muzero_network.inverse_reward_transform(value_prefix).detach().cpu()[0].item()
        return next_state, reward, False, None
    
    def encode(self, z):
        return self.muzero_network.representation(torch.from_numpy(np.array([z], dtype=np.float32)).to(self.device))

    def batch_encode(self, z):
        return self.muzero_network.representation(torch.from_numpy(np.array(z, dtype=np.float32)).to(self.device))

    def batch_reset(self, n):
        raise NotImplementedError
        # s = torch.from_numpy(np.array([self.env.reset() for _ in range(n)])).to(device)
        # z = self.muzero_network.representation(s)
        # return z
    
    def batch_step(self, z, action):
        batch_size = len(z)
        assert len(z) == len(action) == batch_size
        action = torch.from_numpy(np.array(action)).to(device).unsqueeze(dim=1)
        reward_hidden = (
            torch.zeros( (1, batch_size, self.muzero_network.lstm_hidden_size), dtype=torch.float32, device=self.device), 
            torch.zeros( (1, batch_size, self.muzero_network.lstm_hidden_size), dtype=torch.float32, device=self.device)
        )
        next_state, _, value_prefix = self.muzero_network.dynamics(z, reward_hidden, action)
        reward = self.muzero_network.inverse_reward_transform(value_prefix).detach().cpu().numpy()
        assert len(reward) == batch_size
        return next_state, reward[:,0], False, None

# batched gym model 
class BatchedGymModel:

    def __init__(self, env):
        self.env = env
    
    def batch_reset(self, n):
        envs_dones = []
        for _ in range(n):
            done = False
            envs_dones.append([copy.deepcopy(self.env), done])
            envs_dones[-1][0].reset()
        return envs_dones
    
    def batch_step(self, z, action):
        rs = np.zeros((len(action), ))
        dones = []
        for i in range(len(action)):
            if z[i][1] is False:
                _, r, done, _ = z[i][0].step(action[i])
                rs[i] = r
            else:
                done = True
            dones.append(done)
            z[i][1] = done
        return z, np.array(rs), np.array(dones), None
    
    def states(self, z):
        return np.array([env[0].state for env in z], dtype=np.float32)

""" Policy Evaluation: The Policies """

class MuZeroRawPolicy:

    def __init__(self, muzero_network, params):
        self.muzero_network = muzero_network
        self.params = params

    def act(self, z):
        with torch.inference_mode():
            # if the state is from the environment, transform it into a latent state first
            if not torch.is_tensor(z):
                z = torch.from_numpy(z).to(torch.float32).to(device)
                z = self.muzero_network.representation(z)
            policy_logits = self.muzero_network.prediction(z)[0]
            softmax_outputs = torch.nn.functional.softmax(policy_logits, dim=1)
            
            if self.params["deterministic"] is True:
                actions = torch.argmax(softmax_outputs, axis=1).detach().cpu().numpy()
            else:
                policy = Categorical(softmax_outputs)
                probabilities = policy.probs.detach().cpu().numpy()
                num_actions = probabilities.shape[1]
                actions = [np.random.choice(list(range(num_actions)), p=probabilities[i]) for i in range(len(probabilities))]
            # the idea is that the output should always an numpy array
            return actions

class MuZeroPVMCTSPolicy:

    def __init__(self, muzero_network, muzero_config, params, add_one_visit_to_every_action=False):
        self.muzero_network = muzero_network
        self.muzero_config = muzero_config
        self.params = params
        self.add_one_visit_to_every_action = add_one_visit_to_every_action

    def policy(self, z, env_pool=None, extra_info=None):
        with torch.inference_mode():
            config = copy.deepcopy(self.muzero_config)

            num_simulations = self.params.get("num_simulations", self.muzero_config.num_simulations)
            use_uniform_prior = self.params.get("use_uniform_prior", False)
            use_true_model = self.params.get("use_true_model", False)
            rollout_horizon = self.params.get("rollout_horizon", -1)
            with_exploration_noise = self.params.get("with_exploration_noise", True)
            if use_true_model:
                assert env_pool is not None
                z, env_states = z

            """ construct the root node """
            config.num_simulations = num_simulations
            x = torch.from_numpy(z).to(torch.float32).to(next(self.muzero_network.parameters()).device)
            with autocast():
                network_output = self.muzero_network.initial_inference(x)
            hidden_state_roots = network_output.hidden_state
            reward_hidden_roots = network_output.reward_hidden
            value_prefix_pool = network_output.value_prefix
            if use_uniform_prior:
                policy_logits_pool = np.zeros_like(network_output.policy_logits).tolist()
            else:
                policy_logits_pool = network_output.policy_logits.tolist()
            MCTS_seed = np.random.randint(0, 2**32 - 1)
            roots = cytree.Roots(x.shape[0], config.action_space_size, config.num_simulations, seed=MCTS_seed)
            if with_exploration_noise:
                # print("with exploration noise")
                noises = [np.random.dirichlet([config.root_dirichlet_alpha] * config.action_space_size).astype(np.float32).tolist() for _ in range(x.shape[0])]
                roots.prepare(config.root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool)
            else:
                # print("without exporation noise")
                roots.prepare_no_noise(value_prefix_pool, policy_logits_pool)
            if use_true_model:
                MCTS(config).search(roots, self.muzero_network, hidden_state_roots, reward_hidden_roots, num_simulations=num_simulations, use_uniform_prior=use_uniform_prior, env_pool=env_pool, env_states=env_states, use_true_model=True, rollout_horizon=rollout_horizon, extra_info=extra_info)
            else:
                MCTS(config).search(roots, self.muzero_network, hidden_state_roots, reward_hidden_roots, num_simulations=num_simulations, use_uniform_prior=use_uniform_prior, rollout_horizon=rollout_horizon, use_true_model=False, extra_info=extra_info)
            visit_counts = np.array(roots.get_distributions())
            if self.add_one_visit_to_every_action is True:
                for i in range(visit_counts.shape[1]):
                    visit_counts[:, i] += 1
                assert (np.sum(visit_counts, axis=1) == config.num_simulations + visit_counts.shape[1]).all()
            # print(visit_counts)
            assert visit_counts.shape == (x.shape[0], config.action_space_size)
            action_probs = visit_counts ** (1 / self.params["temperature"])
            total_count = np.sum(action_probs, axis=1)
            assert total_count.shape == (x.shape[0],)
            action_probs = action_probs / total_count[:, None]
            # print(action_probs)
            # print("root value", roots.get_values())
            # print(roots.get_trajectories())
            assert action_probs.shape == (x.shape[0], config.action_space_size)
            assert np.isclose(action_probs[0].sum(), 1.0)
            del roots
            return action_probs
        
    def act(self, z, env_pool=None, extra_info=None):
        policy = self.policy(z, env_pool=env_pool, extra_info=extra_info)
        deterministic = self.params["deterministic"]
        if deterministic:
            action = np.argmax(policy, axis=1)
        else:
            # batch sample action
            action = np.array([np.random.choice(self.muzero_config.action_space_size, p=policy_i) for policy_i in policy])
        if type(z) == tuple:
            assert action.shape == (z[0].shape[0],)
        else:
            assert action.shape == (z.shape[0],)
        return action

def make_MuZeroPVMCTSPolicyDist(muzero_network, muzero_config, params, addOneVisitToEveryAction=False):
    print(params)

    def policy(z):

        use_true_model = params["use_true_model"]
        rollout_horizon = params["rollout_horizon"]
        
        def compute_action_with_MCTS(x):

            if use_true_model:
                x, env = x

            """ construct the root node """
            root = Node(0)

            # encode the state
            if not torch.is_tensor(x):
                observation = (torch.tensor(x).float().unsqueeze(0).to(next(muzero_network.parameters()).device))
                encoded_state = muzero_network.representation(observation)
            else:
                encoded_state = x

            policy_logits, value = muzero_network.prediction(encoded_state)

            reward = torch.log(
                (
                    torch.zeros(1, muzero_network.full_support_size)
                    .scatter(1, torch.tensor([[muzero_network.full_support_size // 2]]).long(), 1.0)
                    .repeat(len(encoded_state), 1)
                    .to(encoded_state.device)
                )
            )
            reward = models.support_to_scalar(reward, muzero_config.support_size).item()
            if use_true_model:
                reward = 0.0
            
            root_state = encoded_state
            if use_true_model:
                root_state = (encoded_state, (copy.deepcopy(env), False))
            else:
                root_state = encoded_state
            
            root.expand(
                muzero_config.action_space,
                0,
                reward,
                policy_logits,
                root_state,
                fraction_of_uniform_noise_in_policy_prior=params["uniform_noise_ratio_in_policy_prior"]
            )

            """ run the MCTS algorithm """
            root, mcts_info = MCTS(muzero_config).run(
                model=muzero_network,
                observation=None,
                legal_actions=muzero_config.action_space,
                to_play=0,
                add_exploration_noise=True,
                override_root_with=root,
                fraction_of_uniform_noise_in_policy_prior=params["uniform_noise_ratio_in_policy_prior"],
                use_true_model=use_true_model,
                rollout_horizon=rollout_horizon
            )
            
            """ sample the actions """
            visit_counts = np.array(
                [child.visit_count for child in root.children.values()], dtype="int32"
            )

            if addOneVisitToEveryAction is True:
                for i in range(len(muzero_config.action_space)):
                    visit_counts[i] += 1
                assert np.sum(visit_counts) == muzero_config.num_simulations + len(muzero_config.action_space)

            temperature = params["temperature"]
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(
                visit_count_distribution
            )

            del root
            return visit_count_distribution
        
        batch_size = len(z)

        dists = []
        for i in range(batch_size):
            dists.append(compute_action_with_MCTS(z[i]))

        return dists

    return policy

def make_UCTPolicy(num_actions, model, params):

    def policy(all_z):
        
        with torch.inference_mode():
            batch_size = len(all_z)
            actions = []

            for j in range(batch_size):
                z = all_z[j]

                n_sims = params["n_simulations"]
                effective_horizon = params["effective_horizon"]
                discount_factor = params["discount_factor"]
                exploration_constant = params["exploration_constant"]
                max_horizon = effective_horizon

                class Node:
                    
                    def __init__(self):
                        self.children = None
                        self.children_N = None
                        self.children_Q = None
                        self.N = 0.0
                        self.state = None
                        self.terminal = None
                        self.n_steps = None
                        self.untried_actions = list(range(num_actions))
                        self.rewards = [None for _ in range(num_actions)]

                    def update(self, action, Q):
                        self.children_Q[action] = (self.children_Q[action] * self.children_N[action] + Q) / (self.children_N[action] + 1)
                        self.N += 1
                        self.children_N[action] += 1
                        if action in self.untried_actions:
                            self.untried_actions.remove(action)

                    def pick_action(self):
                        if len(self.untried_actions) != 0:
                            action = self.untried_actions[random.randint(0, len(self.untried_actions)-1)]
                        else:
                            assert self.N == sum(self.children_N)
                            max_val = None
                            max_action = None
                            for action in range(num_actions):
                                val = self.children_Q[action] + exploration_constant * np.sqrt(math.log(self.N)/self.children_N[action])
                                if max_action is None or max_val <= val:
                                    max_action = action
                                    max_val = val
                            action = max_action
                        return action
                    
                    def expand(self):
                        self.children = [Node() for _ in range(num_actions)]
                        self.children_N = [0 for _ in range(num_actions)]
                        self.children_Q = [0.0 for _ in range(num_actions)]

                def one_step_simulate(state, action):
                    if type(model) == MuZeroModel:
                        return model.step_(state, action)
                    else:
                        # gym model, use the environment
                        env = state
                        _, r, done, _ = env.step(action)
                        return env, r, done, None

                def rollout(node: Node):
                    state = copy.deepcopy(node.state)
                    done = False
                    count = node.n_steps
                    the_return = 0.0
                    factor = 1.0
                    while not done:
                        action = random.randint(0, num_actions-1)
                        state, r, done, _ = one_step_simulate(state, action)
                        if count + 1 >= max_horizon:
                            done = True
                        count += 1
                        the_return += factor * r
                        factor *= discount_factor
                    return the_return

                def simulate(node: Node):

                    # if terminal node
                    if node.terminal is True:
                        return 0.0
                    
                    # if leaf node
                    if node.children is None:
                        # expand the node
                        node.expand()
                        return rollout(node)
                    # if not leaf node
                    else:
                        # pick an action
                        action = node.pick_action()
                        # if we going to explore a new action
                        if node.children_N[action] == 0:
                            # terminal, state, reward, n_steps
                            # let's perform the simulation
                            state = copy.deepcopy(node.state)
                            next_state, r, done, _ = one_step_simulate(state, action)
                            if done is True or node.n_steps + 1 >= max_horizon:
                                done = True
                                node.children[action].terminal = True
                            node.rewards[action] = r
                            node.children[action].n_steps = node.n_steps + 1
                            node.children[action].state = next_state
                        else:
                            assert len(node.untried_actions) == 0
                        # then do the next simulation
                        r = node.rewards[action]
                        the_return = r + discount_factor * simulate(node.children[action])
                        node.update(action, the_return)
                        return the_return

                # encode the state
                if type(model) == MuZeroModel:
                    state = model.encode(z)
                else:
                    state = z[1]

                root = Node()
                root.expand()
                root.state = copy.deepcopy(state)
                root.terminal = False
                root.n_steps = 0

                for i in range(n_sims):
                    simulate(root)
                # eventually pick the greedy action
                actions.append(np.argmax(root.children_Q))

                # print(root.children_N, root.children_Q)
        
            return actions
    
    return policy

def make_MuZeroOneStepPolicy(num_actions, muzero_network):

    def policy(z):

        # encode the state
        if not torch.is_tensor(z):
            observation = (torch.tensor(z).float().unsqueeze(0).to(next(muzero_network.parameters()).device))
            encoded_state = muzero_network.representation(observation)
        else:
            encoded_state = z

        # one step simulation
        rewards = []
        for action in range(num_actions):
            a = torch.from_numpy(np.array([[action]])).to(device)
            _, r = muzero_network.dynamics(encoded_state, a)
            r = support_to_scalar(r, muzero_network.support_size).item()
            rewards.append(r)
        
        return [np.argmax(rewards)]

    return policy

def make_NStepPolicy(num_actions, model, n):

    def policy(z):
        
        if type(model) == MuZeroModel:
            # encode the state
            encoded_state = model.encode(z)
        else:
            pass

        cached_rewards = {}
        cached_dones = {}
        cached_states = {}
        # let's assume rewards are deterministic which is true for both frozen lake and cartpole
        actions = list(range(num_actions))
        returns = []
        all_seq_actions = list(itertools.product(*([actions]*n)))

        for seq_actions in all_seq_actions:
            # print(seq_actions)
            the_return = 0
            # do the simulation
            done = False
            s = encoded_state
            for step in range(n):
                action = seq_actions[step]
                if seq_actions[:step+1] in cached_rewards:
                    r = cached_rewards[seq_actions[:step+1]]
                    done = cached_dones[seq_actions[:step+1]]
                    s = cached_states[seq_actions[:step+1]]
                else:
                    
                    # one step simulation 
                    if type(model) == MuZeroModel:
                        s, r, done, _ = model.step_(s, action)
                    else:
                        # gym model, use the environment
                        env = copy.deepcopy(z)
                        s, r, done, _ = env.step(action)
                        # the stepped env becomes the new state
                        s = env
                    
                    cached_rewards[seq_actions[:step+1]] = r
                    cached_dones[seq_actions[:step+1]] = done
                    cached_states[seq_actions[:step+1]] = s
                the_return += r
                if done:
                    break
            # print(the_return)
            returns.append(the_return)
        # print(returns)
        
        # pick the action - break the tie randomly
        returns = np.array(returns)
        action = all_seq_actions[np.random.choice(np.flatnonzero(returns == returns.max()))][0]
        return [action]

    return policy
        

""" Policy Evaluation: The Domains """

class FrozenLakeEnv():
    
    def __init__(self):
        self.env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False)
        self.env.render_mode = "ansi"
        
    def step(self, action):
        s, r, done, _ = self.env.step(action)
        return np.array([[np.eye(16, dtype=np.float32)[s]]]), r, done, _
    
    def reset(self):
        return np.array([[np.eye(16, dtype=np.float32)[self.env.reset()]]])
    
    @property
    def action_space(self):
        return self.env.action_space
    
    @property
    def state(self):
        return np.array([[np.eye(16, dtype=np.float32)[self.env.s]]])

    def render(self):
        return self.env.render()

""" Policy Evaluation: Utils """

def make_env(env_name, seed=None):
    if "terminal-on-loss-of-life" in env_name:
        env_name = env_name.replace("terminal-on-loss-of-life-", "")
        terminal_on_loss_of_life = True
    else:
        terminal_on_loss_of_life = False
    module_name = "config." + "atari_muzero"
    config = importlib.import_module(module_name).game_config
    config.env_name = env_name
    if config.gray_scale:
        config.image_channel = 1
    obs_shape = (config.image_channel, 96, 96)
    config.obs_shape = (obs_shape[0] * config.stacked_observations, obs_shape[1], obs_shape[2])
    config.cvt_string = False
    if terminal_on_loss_of_life:
        env = config.new_test_game(seed=seed)
    else:
        env = config.new_test_game(seed=seed, test=True)
        assert terminal_on_loss_of_life is False
    horizon = config.max_moves
    return env, horizon

def make_time_unlimited_env(env_name):
    return make_env(env_name)[0]

def load_muzero_network_and_config(env_name, run_id, model_id):

    api = wandb.Api(overrides={"entity": "jinkehe"})
    run = api.run("MuZero/{}".format(run_id))
    case = run.config["case"]
    module_name = "config." + case
    muzero_config = importlib.import_module(module_name).game_config
    print(muzero_config)
    run_config_dict = dict(run.config)
    run_config_namespace = SimpleNamespace(**run_config_dict)
    run_config_namespace.env = run_config_namespace.env_name
    run_config_namespace.result_dir = ""
    run_config_namespace.info = ""
    muzero_config.set_config(run_config_namespace)
    muzero_config.device = device
    muzero_config.amp_type = "none"

    # download the model
    model = "model_{}".format(model_id)
    agent = "agent_{}".format(model_id)
    model_path = "muzero_models/{}/{}/model_{}.pt".format(env_name, run_id, model_id)
    agent_path = "muzero_models/{}/{}/agent_{}.pt".format(env_name, run_id, model_id)
    if not os.path.exists(agent_path):
        if not os.path.exists(model_path):
            pathlib.Path("muzero_models/{}/{}/".format(env_name, run_id)).mkdir(parents=True, exist_ok=True)
            run.file(model+".pt").download(replace=True, root="muzero_models/{}/{}/".format(env_name, run_id))
        model_weights = torch.load(model_path, map_location=device)
        muzero_network = muzero_config.get_uniform_network().to(device)
        muzero_network.set_weights(model_weights)
        torch.save(muzero_network, agent_path)
        del muzero_network
        del model_weights
    muzero_network = torch.load(agent_path, map_location=device)
    muzero_network.eval()
    if torch.__version__.startswith('2'):
        muzero_network = torch.compile(muzero_network, mode="max-autotune")
        print("model compiled.")

    return muzero_network, muzero_config

def get_run_ids_by_exp_name(exp_name):
    api = wandb.Api()
    runs = api.runs(
        path="jinkehe/MuZero",
        filters={"display_name": exp_name}
    )
    runs = [run for run in runs]
    runs = sorted(runs, key=lambda run: run.config["seed"])
    run_ids = [run.id for run in runs]
    return run_ids

def get_muzero_model_from_run_id(env_name, run_id, model_id):
    muzero_network, muzero_config = load_muzero_network_and_config(env_name, run_id, model_id)
    muzero_model = MuZeroModel(muzero_network, make_env(env_name)[0])
    return muzero_model