"""
All wavefunctions generated by code in this repo are in the same format.
Therefore, you can point the 'compute_magic.py' code at the output wavefunction
from any of the methods and compute the magic. 

This is an example script for the 2 qubit symmetric RBM test results.
"""

import numpy as np
import os
from tqdm import tqdm

import magic_module
import sys
sys.path.append('../utils/')
import compute_infidelity


def main():
    n_qubits = 2

    all_pauli_strings, qubits = magic_module.return_all_Pauli_strings(n_qubits)

    alpha = 1
    rep = 0

    ansatz = 'symmetricRBM'
    
    n_hs = 2
    n_amplitudes = 2**2
    hs = np.linspace(0,3,n_hs)

    input_filename = f'../RBM/cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}/wavefunction_{ansatz}_alpha{alpha}_sweep6_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI_rep{rep}.dat'
    print(input_filename)
    wf = compute_infidelity.load_wavefunctions_for_all_h(
            input_filename,
            n_amplitudes,
            n_hs,
            dtype=complex
        )

    output_directory = f'./outputs/{n_qubits}_qubits/magic/'
    if not os.path.exists(f'./outputs/{n_qubits}_qubits/'):
        os.mkdir(f'./outputs/{n_qubits}_qubits/')
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)
    output_filename = output_directory + f'magic_{ansatz}_alpha{alpha}_sweep6_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI_rep{rep}.dat'


    magics = np.zeros(n_hs)
    for h_idx, h in enumerate(hs):
        magics[h_idx] = magic_module.compute_stabiliser_Renyi_entropy_order_2(
                        qubits, 
                        all_pauli_strings,
                        wf[h_idx]
                        )
        if h_idx % 10 == 0: # Checkpointing, can often be a long computation
            with open(output_filename, 'w+') as f:
                for h_idx in range(len(magics)):
                    f.write(f'{hs[h_idx]}\t{magics[h_idx]}\n')

    
    with open(output_filename, 'w+') as f:
        for h_idx in range(len(magics)):
            f.write(f'{hs[h_idx]}\t{magics[h_idx]}\n')
    return


if __name__ == "__main__":
    main()
