"""
Compute equation (1) from https://arxiv.org/pdf/2204.11236.pdf

Computes the stabliser Renyi entropy order 2 PER QUBIT
"""
import cirq
import itertools
import numpy as np
import time
from tqdm import tqdm
from typing import List, Union


def return_all_Pauli_strings(n_qubits: int) -> \
            Union[List[cirq.ops.linear_combinations.PauliSum], 
                  List[cirq.devices.line_qubit.LineQubit]]:
    """
    Given four Pauli operators, I,X,Y,Z, this returns all possible combinations
    of strings each with length n_qubits. 

    Returns the list of Pauli strings alongside the qubits acting on.
    """
    qubits = [cirq.LineQubit(qubit) for qubit in range(n_qubits)]
    circuit = cirq.Circuit()
    sigmas = [cirq.I, cirq.X, cirq.Y, cirq.Z]

    # Creates all 4^n_qubits possible combinations of n_qubits taking values 0-3
    combinations = np.array([p for p in itertools.product([0,1,2,3], repeat=n_qubits)])

    # Convert these into all possible pauli strings
    all_pauli_strings = []
    for combination in combinations:
        pauli_string = 1 # placeholder to allow for *= in loop
        for qubit in range(n_qubits):
            pauli_string *= cirq.PauliString(sigmas[combination[qubit]](qubits[qubit]))
        all_pauli_strings.append(cirq.PauliSum.from_pauli_strings(pauli_string))
    return all_pauli_strings, qubits


def compute_stabiliser_Renyi_entropy_order_2(
        qubits: List[cirq.devices.line_qubit.LineQubit],
        all_pauli_strings: List[cirq.ops.linear_combinations.PauliSum],
        wavefunction: np.array) -> float:
    """
    Returns this magic measurement PER QUBIT.
    """
    n_qubits = len(qubits)
    qubit_map = {qubits[qubit]: qubit for qubit in range(n_qubits)}
    normalised_wavefunction = wavefunction/np.linalg.norm(wavefunction)

    expectation_value = 0
    for pauli_string in all_pauli_strings:
        single_matrix_element = pauli_string.expectation_from_state_vector(
                                    normalised_wavefunction,
                                    qubit_map=qubit_map)
        expectation_value += 2**(-n_qubits) * np.abs(single_matrix_element)**4
    return -np.log2(expectation_value) / n_qubits
