"""
Everything related to the variational quantum eigensolver method is contained 
here. 
"""
import pennylane as pl
import numpy as np
import itertools
import sys

import pdb

def vqe_ansatz(
        angles: pl.numpy.array(float),
        *args: tuple[pl.ops.qubit.hamiltonian.Hamiltonian, int, int],
        return_energy_rather_than_wavefunction=True
        ) -> pl.measurements.ExpectationMP:
    """
    len(angles) == 3 * n_qubits * n_layers required.
    One layer consists of:
        Parameterised rotations in the x y and z axes and then fully 
        connecting each qubit with CNOT gates.

    Return object depends on the argument return_energy_rather_than_wavefunction,
        if True then the output will be the expectation of the Hamiltonain
        if False then the output will be the full state as a vector

    Note: when you define the device upon which this will be run, if you provide
    the argument shots then the returned expectation value will be from finite 
    samples rather than the entire wavefunction.
    """
    Hamiltonian = args[0]
    n_qubits = int(args[1])
    n_layers = int(args[2])


    if len(angles) < (3 * n_qubits * n_layers):
        sys.exit(f"Not enough parameters given to vqe ansatz. Should be 3 * n_qubits * n_layers. You gave {len(angles)} rather than 3 * {n_qubits} * {n_layers}. Exiting.")

    pl.BasisState(np.zeros(n_qubits), wires=range(n_qubits))
    
    parameter_idx = 0
    while parameter_idx < len(angles):
        for wire in range(n_qubits):
            pl.RX(angles[parameter_idx], wires=wire)
            parameter_idx += 1
        for wire in range(n_qubits):
            pl.RY(angles[parameter_idx], wires=wire)
            parameter_idx += 1
        for wire in range(n_qubits):
            pl.RZ(angles[parameter_idx], wires=wire)
            parameter_idx += 1
        for pairs in itertools.combinations(np.arange(n_qubits),2):
            pl.CNOT(wires=[pairs[0],pairs[1]])

    if not return_energy_rather_than_wavefunction:
        return pl.state()
    return pl.expval(Hamiltonian)


if __name__ == '__main__':
    n_qubits = 4
    n_layers = 1
    n_shots = 10000
    pl.numpy.random.seed(0)
    test_Hamiltonian = pl.PauliZ(0)@pl.PauliY(3)

    # Draw the circuit
    angles = pl.numpy.array(np.random.rand(3*n_layers*n_qubits), requires_grad=True) 
    drawer = pl.draw(vqe_ansatz, show_all_wires=True)
    print(drawer(angles, test_Hamiltonian, n_qubits, n_layers))

    # To actually run the circuit you need the device and QNode thing
    device = pl.device("default.qubit", wires=n_qubits, shots=n_shots) 
    function = pl.QNode(vqe_ansatz, device)
    print(f"Expectation value using {n_shots} samples: {function(angles, test_Hamiltonian, n_qubits, n_layers)}")

    # Now do it with the exact wavefunction rather than sampling
    device = pl.device("default.qubit", wires=n_qubits)
    function = pl.QNode(vqe_ansatz, device)
    print(f"Expectation value using wavefunction: {function(angles, test_Hamiltonian, n_qubits, n_layers)}")

