import preprocess
import pandas as pd
import numpy as np
import os



def compute_fidelity(wf1: pd.Series, wf2: pd.Series) -> float:
    wf1_array = wf1.to_numpy()
    wf2_array = wf2.to_numpy()
    return compute_fidelity_from_arrays(wf1_array, wf2_array)


def compute_fidelity_from_arrays(wf1_array: np.array, wf2_array: np.array) -> float:
    return np.inner(wf1_array, wf2_array) ** 2


def basic_formatting(fig):
    try :
        fig = fig.update_traces(marker=dict(size=17))
    except ValueError:
        pass
    scaling_factor = 6
    return fig.update_layout(
            font_family='Times New Roman',
            font_size=30,
            hoverlabel_font_size=24,
            hovermode='x',
            width = 20000/scaling_factor,
            height = 8000/scaling_factor
            )


def get_energy_and_magic(filepath: str,  quantity_chosen: str, n_qubits=8):
    """
    There should only be one file in each directory containing 'magic'
    and one containing 'energies'/'energy'. For a given filepath to this 
    directory, return the data for the specific quantity_chosen (magic/energy).

    If there is another error with the same name but containing "errors" then
    this is also returned in the df.

    n_qubits currently unused as no other data is available.
    """
    shorthand_for_quantity = quantity_chosen[:-3] #accounts for energy or energies
    if len(shorthand_for_quantity) == 2: # makes magic -> ma which easily reads wrong file
        shorthand_for_quantity = f'{shorthand_for_quantity}gic'
    files = next(os.walk(filepath))[2]
    for file in files:
        if shorthand_for_quantity in file and "error" not in file:
            datafile = f'{filepath}/{file}'
            df = preprocess.read_simple_tabular_structure(datafile)
            df = df.rename(columns={'xs': 'h', 'ys': filepath})
            # fill the errors with 0 in case there are none
            df[f'{filepath}_errors'] = 0
    # Splitting into two loops ensures that df exists before error_df
    for file in files:
        if shorthand_for_quantity in file and "error" in file:
            df = df.drop(columns=[f'{filepath}_errors'])
            errorfile = f'{filepath}/{file}'
            error_df = preprocess.read_simple_tabular_structure(errorfile)
            error_df = error_df.rename(columns={'xs': 'h', 'ys': f'{filepath}_errors'})
            df = df.merge(error_df, on='h', how='left')
    return df


def get_wavefunctions(filepath: str, n_qubits=8):
    files = next(os.walk(filepath))[2]
    for file in files:
        if "wavefunction" in file:
            datafile = f'{filepath}/{file}'
            df = preprocess.read_wavefunction_into_df(datafile)
    return df


def get_common_stepsize(range_1: pd.Series, range_2: pd.Series) -> int:
    """
    Given two pandas series each going from 0-3, find the smallest common
    step size that they both share:
        e.g. 0,1,2,3 and 0,0.5,1,1.5,2,2.5,3 would result in a stepsize of 1 
        (not 0.5 clearly)
    """
    step_size_1 = float(range_1[1]) - float(range_1[0])
    step_size_2 = float(range_2[1]) - float(range_2[0])


    h1s = np.arange(0,3,step_size_1)
    h2s = np.arange(0,3,step_size_2)

    common_values =  np.intersect1d(h1s, h2s)
    stepsize = common_values[1] - common_values[0]
    return stepsize


def get_all_basis_states(n_qubits=8) -> pd.DataFrame:
    """
    The VQE is in the lexicographic order, so for two qubits the basis is ordered
     00, 01, 10, 11
    Return a df with these elements for n_qubits qubits.
    """
    combinations = []
    for i in range(2 ** int(n_qubits)):
        # Format the number as binary and pad with zeros to the left
        binary_str = format(i, '0' + str(n_qubits) + 'b')
        combinations.append(binary_str)
    df = pd.DataFrame(combinations)
    return df

 
def OLD_get_data(algorithm: str, n_qubits: int, quantity_chosen: str):
    """
    Accounts for the difference in filename for the VQE and exact, plus the
    differing number of layers in the ansatz.
    """
    try:   
        # Adam 3 layers
        filename = f"../data/{algorithm[:3]}/{n_qubits}_qubits/{algorithm[4:]}_expectations__{quantity_chosen}_adam-3_layers.dat"
        df = preprocess.read_simple_tabular_structure(filename)
    except FileNotFoundError:
        try:
            # Adam 4 layers
            filename = f"../data/{algorithm[:3]}/{n_qubits}_qubits/{algorithm[4:]}_expectations__{quantity_chosen}_adam-4_layers.dat"
            df = preprocess.read_simple_tabular_structure(filename)
        except FileNotFoundError:
            try:
                # RBM
                filename = f"../data/{algorithm[:3]}/{n_qubits}_qubits/{algorithm[4:]}_expectations__{quantity_chosen}.dat"
                df = preprocess.read_simple_tabular_structure(filename)
            except FileNotFoundError:
                # else
                filename = f"../data/{algorithm}/{n_qubits}_qubits/{algorithm}_{quantity_chosen}.dat"
                df = preprocess.read_simple_tabular_structure(filename)

    df = df.rename(columns={'xs': 'h', 'ys': algorithm})
    return df
