"""
THIS RETURNS THE DATA BACK TO THE DIRECTORY IT CAME FROM.

To keep the visualisation code cleaner, I will compute all of the fidelities
here and then just plot them in the other code.
"""
import numpy as np
from typing import Union, List


def main():
    n_repeats = 1
    n_hs = 2
    hs = np.linspace(0,3,n_hs)
    n_qubits = 4
    n_amplitudes = 2**n_qubits

    analytic_GS_wf_filename = f'../ED/outputs/{n_qubits}_qubits/analytic_energies.out'

    ED_GS_wf = load_wavefunctions_for_all_h(analytic_GS_wf_filename, n_amplitudes, n_hs)


    ansatz = 'symmetricRBM'
    alphas = np.array([5])

    for alpha in alphas:
        infidelities = np.zeros((n_hs, n_repeats))
        for rep in range(n_repeats):

            input_filename = f'../RBM/cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}/wavefunction_{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI_rep{rep}.dat'
            wf = load_wavefunctions_for_all_h(input_filename, n_amplitudes, n_hs)
    
            for h_idx, h in enumerate(hs):
                infidelities[h_idx, rep] = 1 - compute_fidelity_from_arrays(ED_GS_wf[h_idx], wf[h_idx])
        mean_fidelities = np.mean(infidelities, axis=1)
        std_err_fidelities = np.std(infidelities, axis=1)/np.sqrt(n_repeats)

        output_filename = f'../RBM/cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}/infidelity_wrt_GS-{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI.dat'
        write_to_file((mean_fidelities, std_err_fidelities), output_filename, hs)
    return


def load_wavefunctions_for_all_h(
        filename: str,
        n_amplitudes: int,
        n_hs: int,
        dtype = float) -> np.ndarray:
    h_counter = 0
    amp_counter = 0
    data = np.zeros((n_hs, n_amplitudes), dtype=dtype)
    with open(filename, 'r') as f:
        for line in f:
            try: 
                data[h_counter, amp_counter] = float(line.strip())
            except ValueError:
                real, imag = line.split()
                data[h_counter, amp_counter] = np.sqrt(float(real)**2 + float(imag)**2)
            amp_counter += 1
            if amp_counter == n_amplitudes:
                amp_counter = 0
                h_counter += 1
            if h_counter == n_hs:
                break # allows to read only part of the file if required
    return data


def write_to_file(
        data: Union[List[float], List[float]],
        output_filename: str,
        h_values: List[float]) -> None:
    """
    Write the fidelities to file, both the mean and the errors.
    """
    means, errors = data
    f = open(output_filename, 'w+')
    for idx, datum in enumerate(means):
        f.write(f'{h_values[idx]:.3f}\t{datum}\n')
    f.close()
    
    f = open(f'{output_filename[:-4]}_errors.dat', 'w+')
    for idx, datum in enumerate(errors):
        f.write(f'{h_values[idx]:.3f}\t{datum}\n')
    f.close()
    return


def compute_fidelity_from_arrays(wf1_array: np.array, wf2_array: np.array) -> float:
    return np.inner(wf1_array, wf2_array) ** 2


if __name__ == "__main__":
    main()
