"""
Take the optimal circuit parameters from the output of 
./optimise_and_return_parameters.py and return the full wavefunctions for the
magic calculations.
"""
import pennylane as pl
import json
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm

from modules import vqe, Hamiltonian_terms, optimisation


def main():
    n_vqe_layers = 4
    n_qubits = 2
    n_reps = 1
    data_directory = f"./outputs/{n_qubits}_qubits/"

    # Check file can be written to 
    if not os.path.exists(f'{data_directory}/wavefunctions/'):
        os.mkdir(f'{data_directory}/wavefunctions/')
    

    for rep in range(n_reps):
        for mu in np.arange(0,0.0001,0.25):
            data_filename = data_directory+f"energies_and_parameters_adam-{n_vqe_layers}_layers_mu-{mu:.3f}_repeat-{rep:.0f}.out"
    
            with open(data_filename, "r") as f:
                data = json.load(f)

            # Double check the configuration details agree
            config_details = data['configuration']
            assert int(config_details['n_vqe_layers']) == n_vqe_layers
            assert int(config_details['n_qubits']) == n_qubits
            optimiser = config_details['optimiser']

            # Create the function that runs the quantum circuit
            device = pl.device("default.qubit", wires=n_qubits) 
            circuit_function = pl.QNode(vqe.vqe_ansatz, device)

            # Use the parameters to create a quantum circuit and then save the resulting
            # wavefunction
            mu_keys = data['results'].keys()
            for mu_key in tqdm(mu_keys):
                for run_key in data['results'][mu_key].keys():
                    parameters = data['results'][mu_key][run_key]['parameters']
                    wavefunction = circuit_function(
                                        parameters,
                                        None,
                                        n_qubits,
                                        n_vqe_layers,
                                        return_energy_rather_than_wavefunction=False)
                    with open(data_directory+f"wavefunctions/{optimiser}_{n_vqe_layers}-layers_{mu_key}_{run_key}.dat", "a+") as f:
                        for coef in wavefunction:
                            f.write(f"{coef.real}\t{coef.imag}\n")
    return

if __name__=="__main__":
    main()
