"""
Preprocess the log files and wavefunctions into something a little easier to plot.
The wavefunction preprocessing is basically just concatenating the files together.

You need to preprocess the wavefunction before you can compute the magic, so
you can't leave this all uncommented at first. 
    Run energies and wavefunctions first
    Compute the magic
    Then preprocess the magic
"""
import numpy as np
import json
import pdb
import os
from typing import List, Union


def main():
    ## Energies

    n_repeats = 1
    n_hs = 2
    hs = np.linspace(0,3,n_hs)
    alphas = np.array([1,5])

    n_qubits = 4
    ansatz = 'symmetricRBM'

    energy_data_dir = './outputs/{n_qubits}_qubits/'
    wavefunction_data_dir = energy_data_dir+'wavefunctions/'
    magic_data_dir = energy_data_dir+'magic/'
    for alpha in alphas:
        if not os.path.exists(f'./cleaned_outputs/{n_qubits}_qubits'):
            os.mkdir(f'./cleaned_outputs/{n_qubits}_qubits')
        if not os.path.exists(f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}'):
            os.mkdir(f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}')
        if not os.path.exists(f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}'):
            os.mkdir(f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}')
    
    ## Energies
    unformatted_input_filename = energy_data_dir+'{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_h{h:.3f}_rep{rep}_MetropolisLocal_adam_SR_TFI.nk.log'
    for alpha in alphas:

        data = read_data_from_many_jsons(
            unformatted_input_filename,
            n_qubits,
            ansatz,
            alpha,
            hs,
            n_repeats)

        output_filename = f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}/energies_{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI.dat'
        write_to_file(data, output_filename, hs)

    ## Wavefunctions
    for alpha in alphas:
        for rep in range(n_repeats):
            unformatted_input_filename = wavefunction_data_dir+'{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_h{h:.3f}_rep{rep}_MetropolisLocal_adam_SR_TFI'
            output_filename = f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}/wavefunction_{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI_rep{rep}.dat'
            wavefunctions = process_files_for_all_h(
                unformatted_input_filename,
                hs, 
                ansatz, 
                alpha, 
                rep,
                n_qubits)
            write_files_for_all_h(wavefunctions, output_filename)
    """

    ## Magic
    repeats_of_magics = np.zeros((n_repeats, n_hs))
    for alpha in alphas:
        output_filename = f'./cleaned_outputs/{n_qubits}_qubits/alpha_{alpha}/{ansatz}/magic_{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI.dat'
        for rep in range(n_repeats):
            input_filename = magic_data_dir.format(n_qubits=n_qubits)+f'magic_{ansatz}_alpha{alpha}_sweep12_chains8_samples5000_discards200_max-timesteps5000_J-1_MetropolisLocal_adam_SR_TFI_rep{rep}.dat'
            hs, repeats_of_magics[rep] = read_magic_file_for_all_h(input_filename)
        mean_magics = np.mean(repeats_of_magics, axis=0)
        error_in_magics = np.std(repeats_of_magics, axis=0)/np.sqrt(n_repeats)
        write_to_file((mean_magics, error_in_magics), output_filename, hs)
    """

    return


## Energy utils
def read_data_from_many_jsons(
        filename_unformatted: str,
        n_qubits: int,
        ansatz: str,
        alpha: int,
        h_values: List[float],
        n_repeats = 10) -> Union[List[float], List[float]]:
    """
    Json is the logger of the energy minimisation, average over the final 10 
    iterations of the minimisation but also over all 10 different optimisations.
    """
    data = []
    errors = []
    for h in h_values:
        running_average = []
        for rep in range(int(n_repeats)):
            data_as_dict = json.load(
                open(
                    filename_unformatted.format(
                        n_qubits=n_qubits,
                        ansatz=ansatz,
                        alpha=alpha,
                        h=h,
                        rep=rep
                    )
                )
            )
            running_average.append(np.mean(data_as_dict['Energy']['Mean'][-10:]))
        data.append(np.mean(running_average))
        errors.append(np.std(running_average)/np.sqrt(len(running_average)))
    return data, errors


def write_to_file(
        data: Union[List[float], List[float]],
        output_filename: str,
        h_values: List[float]) -> None:
    """
    Write the ENERGIES to file, both the mean and the errors.
    """
    means, errors = data
    f = open(output_filename, 'w+')
    for idx, datum in enumerate(means):
        f.write(f'{h_values[idx]:.3f}\t{datum}\n')
    f.close()
    
    f = open(f'{output_filename[:-4]}_errors.dat', 'w+')
    for idx, datum in enumerate(errors):
        f.write(f'{h_values[idx]:.3f}\t{datum}\n')
    f.close()
    return

## Wavefunction utils
def process_single_file(filename: str) -> np.array:
    wavefunction = []
    with open(filename, "r") as f:
        for idx, line in enumerate(f):
            wavefunction.append(float(line.strip()))
    return np.array(wavefunction)


def process_files_for_all_h(
        filename_unformatted: str,
        hs_for_filenames: np.array,
        ansatz: str,
        alpha: int,
        rep: int,
        n_qubits: int) -> np.ndarray:
    all_wavefunctions = []
    for h in hs_for_filenames:
        filename = filename_unformatted.format(
                    n_qubits=n_qubits,
                    ansatz=ansatz,
                    alpha=alpha,
                    h=h,
                    rep=rep
                )
        all_wavefunctions.append(process_single_file(filename))
    return np.array(all_wavefunctions)


def write_files_for_all_h(
    data: np.ndarray,
    output_filename: str) -> None:
    n_hs, n_amplitudes = np.shape(data)
    with open(output_filename, "w+") as f:
        for h_idx in range(n_hs):
            for amp_idx in range(n_amplitudes):
                f.write(f'{data[h_idx, amp_idx]}\n')
    return

## Magic utils
def read_magic_file_for_all_h(filename: str) -> Union[np.array, np.array]:
    magics = []
    hs = []
    with open(filename, 'r') as f:
        for line in f:
            h, magic = line.split()
            magics.append(float(magic))
            hs.append(float(h))
    return np.array(hs), np.array(magics)
            


if __name__ == "__main__":
    main()
