import pennylane as pl
import numpy as np
import json
from typing import Union, List

from . import vqe, Hamiltonian_terms


# dev
from pennylane.devices import ExecutionConfig


def get_minimum_energy_and_parameters_vqe_consecutive_agreement(
        J: float,
        mu: float,
        n_qubits: int,
        n_vqe_layers: int,
        consecutive_agreement_tolerance: float,
        track_magic_over_training=False,
        n_shots_if_doing_sampling_for_expectation_value=False,
        y_field_mu=0) -> Union[np.ndarray, float]:
    """
    Keep running the optimisation routine until two consecutive runs agree
    within the given consecutive_agreement_tolerance.

    If track_magic_over_training=True this will write to file the model
    parameters after each 100 iterations only if agreement was met.
        NOTE: this function does not know which run this is, so only do it once
        per value of mu.

    If y_field_mu is non-zero then a different Hamiltonian is used. One with
    an external field in the x and y field. 

    Returns: 
        parameters: np.array of angles in the quantum circuit of length
                    3*n_vqe_layers*n_qubits.
        energy: float.
    """
    energy = 1e5
    previous_energy = 1e3
    while abs((energy - previous_energy)/energy) > consecutive_agreement_tolerance:
        previous_energy = energy
        parameters, energy, parameters_over_training =\
                            get_minimum_energy_and_parameters_vqe(
                                J, 
                                mu,
                                n_qubits,
                                n_vqe_layers,
                                track_magic_over_training,
                                n_shots_if_doing_sampling_for_expectation_value,
                                y_field_mu)
    if track_magic_over_training:
        write_parameters_to_file(
                                    parameters_over_training,
                                    J, 
                                    mu,
                                    n_qubits,
                                    n_vqe_layers, 
                                    consecutive_agreement_tolerance,
                                    n_shots_if_doing_sampling_for_expectation_value)
    return parameters, energy


def get_minimum_energy_and_parameters_vqe(
        J: float,
        mu: float,
        n_qubits: int,
        n_vqe_layers: int,
        track_magic_over_training=False,
        n_shots_if_doing_sampling_for_expectation_value=False,
        y_field_mu=0) \
        -> Union[np.ndarray, float, List[np.ndarray]]:
    """
    Using Pennylane's built in Adam optimiser, find the quantum circuit that
    prepares a wavefunction that minimises the expectation value of the 
    Hamiltonian. 

    if n_shots_if_doing_sampling_for_expectation_value is an integer then the 
    expectation value is computed using finite sampling, rather than the full
    wavefunction. It should be an integer if is is not False, in that case the
    value will be the number of samples used. 

    if n_shots_if_doing_sampling_for_expectation_value is False then Adam will
    be used as the optimiser. 
    else SPSA will be used.

    Returns: 
        parameters: np.array of angles in the quantum circuit of length
                    3*n_vqe_layers*n_qubits.
        energy: float.
        parameters_over_training: 
            if track_magic_over_training=False this is an empty array.
            else:
                np.ndarray, one higher dimension than 'parameters' as it is the
                parameters every 100 iterations
    """
    # Can always use gradient, just will change type of gradient calculation 
    # based on the way the expectation value is computed: exact or via samples
    optimiser = pl.AdamOptimizer()

    vqe_ansatz = vqe.vqe_ansatz

    # Define the device that the circuit is to be run on, choose if the 
    #   expectation value is to be computed using sampling or exactly
    if n_shots_if_doing_sampling_for_expectation_value != False:
        device = pl.device(
                    "default.qubit", 
                    wires=n_qubits, 
                    shots=n_shots_if_doing_sampling_for_expectation_value) 
        cost_function = pl.QNode(vqe_ansatz, device, diff_method="parameter-shift")
    else:
        device = pl.device(
                    "default.qubit", 
                    wires=n_qubits) 
        cost_function = pl.QNode(vqe_ansatz, device, diff_method="backprop")

    # Define some attributes of the problem
    max_iterations = 10000
    tolerance = 1e-6
    converging = 0

    if y_field_mu == 0:
        Hamiltonian = Hamiltonian_terms.transverse_Ising_Hamiltonian(
                                        n_qubits, 
                                        J, 
                                        mu)
    else:
        Hamiltonian = Hamiltonian_terms.double_transverse_Ising_Hamiltonian(
                                        n_qubits,
                                        J, 
                                        mu,
                                        y_field_mu)

    n_angles = 3*n_vqe_layers*n_qubits
    angles = pl.numpy.array(np.random.rand(n_angles), requires_grad=True) 
    energies = [cost_function(angles, Hamiltonian, n_qubits, n_vqe_layers)]
    angles_used = [angles]
    
    # Begin optimising
    for iteration in range(max_iterations):
        """
        Nothing to ensure that it actually converges.
        """
        updates, previous_energy = optimiser.step_and_cost(
                                                cost_function,
                                                angles,
                                                Hamiltonian, 
                                                n_qubits,
                                                n_vqe_layers)
        angles = updates[0]
        energies.append(cost_function(angles, Hamiltonian, n_qubits, n_vqe_layers))
        angles_used.append(angles)
    
        energy_change = np.abs(energies[-1] - previous_energy)
    
    
        if energy_change <= tolerance:
            converging += 1
            if converging == 3:
                break
        else:
            converging = 0 
    if track_magic_over_training:
        return angles_used[-1].numpy().ravel(), energies[-1], angles_used[::100]
    else:
        return angles_used[-1].numpy().ravel(), energies[-1], np.array([])


def write_parameters_to_file(
        parameters_over_training: np.array,
        J: float,
        mu: float,
        n_qubits: int,
        n_vqe_layers: int,
        consecutive_agreement_tolerance: float) -> None:
    """
    Output is hardcoded and not, therefore, affected by changing it elsewhere.
    """
    output_filename = f"./outputs/{n_qubits}_qubits/magic_over_training/energies_and_parameters_adam-{n_vqe_layers}_layers_mu{mu:.3f}.out"
    # Constructs for logging the outputs
    configuration_details = {}
    configuration_details['optimiser'] = 'adam'
    configuration_details['n_vqe_layers'] = n_vqe_layers
    configuration_details['J'] = J
    configuration_details['n_qubits'] = n_qubits
    configuration_details['consecutive_agreement_tolerance'] = consecutive_agreement_tolerance
    output_data_as_json = {}
    output_data_as_json['configuration'] = configuration_details
    output_data_as_json['results'] = {}
    iteration = 0
    for parameters_at_one_iteration in parameters_over_training:
        parameters_at_one_iteration = parameters_at_one_iteration.numpy().ravel()
        output_data_as_json['results'][f'iteration_{iteration}_parameters'] = \
                                            parameters_at_one_iteration.tolist()
        iteration += 100
    with open(output_filename, "w+") as f:
        json.dump(output_data_as_json, f, indent=4)

    
    

if __name__ == "__main__":
    print("When using exact computation for expectation value:\n")
    parameters, energy, parameters_over_training = get_minimum_energy_and_parameters_vqe(-1, 0, 3, 3)    
    print(parameters, type(parameters))
    print(energy)
    print("When using finite sampling for expectation value:\n")
    n_shots = 1000
    parameters, energy, parameters_over_training = get_minimum_energy_and_parameters_vqe(-1, 0, 3, 3, n_shots_if_doing_sampling_for_expectation_value=n_shots)    
    print(parameters, type(parameters))
    print(energy)

