"""
Find the ground state of the TFIM for many field strengths using RBMS of 
different alphas to check their fidelities.
"""

import netket as nk
import numpy as np
import os
import sys
import pdb
import json
from tqdm import tqdm
from models import symmetricRBM


def timer_to_dict(timer, _root=True):
    d = {}
    if _root:
        d["total_overview"] = str(timer)
    for k, v in timer.sub_timers.items():
        if len(v.sub_timers) == 0:
            v = v.total
        else:
            v = timer_to_dict(v, _root=False)
        d[k] = v
    return d


def timer_to_json(timer, path, indent=4):
    dir_name = os.path.dirname(path)
    if dir_name != "":
        os.makedirs(dir_name, exist_ok=True)
    timer_dict = timer_to_dict(timer)
    with open(path, "w", encoding='utf-8') as f:
        json.dump(timer_dict, f, indent=indent, ensure_ascii=False)


def optimise_and_output(run_args: tuple[float, int, int], model_name):
    h = run_args[0]
    alpha = run_args[1]
    rep = run_args[2]
    
    if model_name == 'RBM':
        model = nk.models.RBM(alpha=alpha) 
    if model_name == 'symmetricRBM':
        model = symmetricRBM(alpha=alpha)
    print(f'{alpha=}, {h=}, {rep=}')
    # Physics parameters
    n_qubits = 4
    J = -1

    # NQS parameters
    n_chains = 8
    n_samples = 5000
    n_discard_per_chain = 200
    sweep_size = n_qubits * 3
    n_time_steps = 5000


    run_specific_filename_suffix = f'{model_name}_alpha{alpha}_sweep{sweep_size}_chains{n_chains}_samples{n_samples}_discards{n_discard_per_chain}_max-timesteps{n_time_steps}_J{J}_h{h:.3f}_rep{rep}_MetropolisLocal_adam_SR_TFI'
    output_directory = f'./outputs/{n_qubits}_qubits/'
    if not os.path.exists(output_directory): # outputs/ must exist
        os.mkdir(output_directory)
    output_filename = f'{output_directory}{run_specific_filename_suffix}'
    wavefunction_output_directory = f'./outputs/{n_qubits}_qubits/wavefunctions/'
    if not os.path.exists(wavefunction_output_directory): # outputs/ must exist
        os.mkdir(wavefunction_output_directory)
    wavefunction_output_filename = f'{wavefunction_output_directory}{run_specific_filename_suffix}'


    hi = nk.hilbert.Spin(s=1/2, N=n_qubits)
    graph = nk.graph.Chain(n_qubits, pbc=True)
    H = nk.operator.IsingJax(hi, graph, h, J)
    sampler = nk.sampler.MetropolisLocal(
            hi,
            sweep_size=sweep_size,
            n_chains=n_chains
            ) 
    optimiser = nk.optimizer.Adam()
    sr = nk.optimizer.SR()

    vs = nk.vqs.MCState(
            sampler, 
            model, 
            n_samples=n_samples, 
            n_discard_per_chain=n_discard_per_chain, 
        )


    def callback(step, logged_data, driver):
        state = driver.state
        logged_data["acceptance"] = float(state.sampler_state.acceptance)
        return True

    early_stopping_mean = nk.callbacks.EarlyStopping(min_reldelta=1e-7, patience=500, monitor='mean')
    invalid_loss_stopping_var = nk.callbacks.InvalidLossStopping(monitor='variance')
    invalid_loss_stopping_sig = nk.callbacks.InvalidLossStopping(monitor='Sigma')

    gs = nk.driver.VMC(
            H,
            optimiser,
            preconditioner=sr,
            variational_state=vs
        )

    with nk.utils.timing.Timer() as t:
        gs.run(
            n_time_steps,
            out=f'{output_filename}.nk',
            callback=[
                early_stopping_mean,
                #invalid_loss_stopping_var,
                #invalid_loss_stopping_sig,
                callback
            ]
        )
    timer_to_json(t, f'{output_filename}.timer')
    wavefunction = vs.to_array()
    with open(wavefunction_output_filename, "w+") as f:
        for amplitude in wavefunction:
            f.write(f'{amplitude}\n')
    return

if __name__ == "__main__":

    ## Local
    reps = np.arange(1)
    n_hs = 2
    hs = np.linspace(0,3,n_hs)
    alphas = np.array([1,5])
    run_args = [(h, alpha, rep) for h in hs for alpha in alphas for rep in reps]

    ## Cluster
    #n_hs = 76
    #hs = np.linspace(0,3,n_hs)
    #alphas = np.array([1,2,3,4,5])
    #rep = int(sys.argv[1])
    #run_args = [(h, alpha, rep) for h in hs for alpha in alphas]

    for arg in tqdm(run_args):
        optimise_and_output(arg, 'symmetricRBM')





