import preprocess, helper_functions
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import os
from dash import Dash, html, Input, Output, callback, dcc
from typing import List

stepsize = 1

app = Dash(__name__)
app.layout = html.Div([
        html.H1(children='Ground state search comparison'),
        html.H2(children='Please select the data you want to see'),
        html.H3(children='The number of qubits (not all data are available for larger than 8)'),
        dcc.RadioItems(options=['8', '12'], value='8', id='n_qubits_selector'),
        html.H3(children='The algorthm(s): select as many as you like'),
        dcc.Checklist(id='algorithm_selector'),
        html.H3(children='The quantity (choose one)'),
        dcc.RadioItems(options=['energies', 'magic'], value='energies', id='quantity_selector'),
        dcc.Graph(figure={}, id='raw_data_plot'),
        html.H2(children="Now please select the data you'd like to compare to"),
        dcc.RadioItems(id='reference_selector'),
        html.H3(children="The axis scale"),
        dcc.RadioItems(options=['linear', 'log'], value='linear', id='linear_or_log_scale_selector'),
        html.H3(children="The exact form of the y axis"),
        dcc.RadioItems(options=['difference of absolute values', 'absolute value of difference'], value='difference of absolute values', id='absolute_value_toggle'),
        html.H3(children="And toggle the errors on or off. Note that the errors shown will be the relative error of the algorithm chosen, it is assumed that the algorithm you are comparing to has no error."),
        dcc.RadioItems(options=['errors', 'hide errors'], value='hide errors', id='relative_errors_toggle'),
        dcc.Graph(figure={}, id='relative_data_plot'),
        html.H2(children="Now compare the wavefunction amplitudes (modulus taken so only real numbers)"),
        html.H3(children="Compare:"),
        dcc.RadioItems(id='wf_algorithm_selector_1'),
        html.H2(children="with:"),
        dcc.RadioItems(id='wf_algorithm_selector_2'),
        html.H2(children="and select the value of field (h) that you want to see them at:"),
        dcc.Slider(
            id='h_slider',
            min=0,
            max=3,
            value=0,
            step=stepsize,
            marks=None,
            tooltip={"placement": "bottom", "always_visible": True}
        ),
        dcc.Graph(figure={}, id='wavefunction_plot'),
        html.Br(),
        html.Div(id='fidelity', style={'fontSize': 'larger'})
        ])


@callback(
        Output(component_id='algorithm_selector', component_property='options'),
        Output(component_id='reference_selector', component_property='options'),
        Output(component_id='wf_algorithm_selector_1', component_property='options'),
        Output(component_id='wf_algorithm_selector_2', component_property='options'),
        Input(component_id='n_qubits_selector', component_property='value'),
        )
def get_available_directories_containing_data(n_qubits: int) -> List[str]:
    """
    Return a list of all directories that contain data files.
    """
    base_of_directory_tree = f"../data/{n_qubits}_qubits/"
    algorithms = [root for root, dirs, files in os.walk(base_of_directory_tree) if files != []]
    return algorithms, algorithms, algorithms, algorithms

@callback(
        Output(component_id='raw_data_plot', component_property='figure'),
        Input(component_id='quantity_selector', component_property='value'),
        Input(component_id='algorithm_selector', component_property='value'),
        Input(component_id='n_qubits_selector', component_property='value'),
        prevent_intial_call=True,
        )
def generate_raw_data_plot(quantity_chosen: str, algorithms_chosen: List[str], n_qubits: str):
    fig = go.Figure()
    for algorithm in algorithms_chosen: 
        df = helper_functions.get_energy_and_magic(algorithm, quantity_chosen)
        fig.add_trace(go.Scatter(
                    x=df['h'],
                    y=df[algorithm],
                    error_y=dict(
                        type='data',
                        array=df[f'{algorithm}_errors'],
                        ),
                    name=algorithm,
                    hovertemplate = '%{y}',
                    mode='markers')
                )
    fig = helper_functions.basic_formatting(fig)
    fig.update_layout(
            title_text=f'{n_qubits} qubits',
            xaxis_title_text='h',
            yaxis_title_text=quantity_chosen,
            )
    return fig


@callback(
        Output(component_id='relative_data_plot', component_property='figure'),
        Input(component_id='quantity_selector', component_property='value'),
        Input(component_id='algorithm_selector', component_property='value'),
        Input(component_id='n_qubits_selector', component_property='value'),
        Input(component_id='reference_selector', component_property='value'),
        Input(component_id='linear_or_log_scale_selector', component_property='value'),
        Input(component_id='absolute_value_toggle', component_property='value'),
        Input(component_id='relative_errors_toggle', component_property='value'),
        prevent_intial_call=True,
        )
def generate_relative_data_plot(
        quantity_chosen: str,
        algorithms_chosen: List[str],
        n_qubits: str,
        reference_algorithm: str,
        axis_scale: str,
        absolute_value_toggle: str,
        relative_errors_toggle: str):
    """
    Currently only compare to VQE or exact diagonalisation because RBM has many
    more data in the x axis.
    """
    fig = go.Figure()
    # Load the reference data
    reference_df = helper_functions.get_energy_and_magic(reference_algorithm, quantity_chosen)

    # Load the desired_data but don't repeat the reference data
    try:
        algorithms_chosen.remove(reference_algorithm)
    except ValueError:
        pass
    for algorithm in algorithms_chosen: 
        df = helper_functions.get_energy_and_magic(algorithm, quantity_chosen)
        df = df.merge(reference_df, on='h', how='outer') 
        if relative_errors_toggle == 'errors':
            if absolute_value_toggle == 'difference of absolute values':
                fig.add_trace(go.Scatter(
                                x=df['h'],
                                y=abs(df[reference_algorithm]) - abs(df[algorithm]),
                                error_y=dict(
                                    type='data',
                                    array=df[f'{algorithm}_errors']/df[algorithm],
                                ),
                                hovertemplate = '%{y}',
                                name=algorithm,
                                mode='markers')
                            )
            else:
                fig.add_trace(go.Scatter(
                                x=df['h'],
                                y=abs(df[reference_algorithm] - df[algorithm]),
                                error_y=dict(
                                    type='data',
                                    array=df[f'{algorithm}_errors']/df[algorithm],
                                ),
                                hovertemplate = '%{y}',
                                name=algorithm,
                                mode='markers')
                            )
        else:
            if absolute_value_toggle == 'difference of absolute values':
                fig.add_trace(go.Scatter(
                                x=df['h'],
                                y=abs(df[reference_algorithm]) - abs(df[algorithm]),
                                hovertemplate = '%{y}',
                                name=algorithm,
                                mode='markers')
                            )
            else:
                fig.add_trace(go.Scatter(
                                x=df['h'],
                                y=abs(df[reference_algorithm] - df[algorithm]),
                                hovertemplate = '%{y}',
                                name=algorithm,
                                mode='markers')
                            )
    fig = helper_functions.basic_formatting(fig)
    fig.update_layout(
            title_text=f'{n_qubits} qubits',
            xaxis_title_text='h',
            yaxis_title_text=f'{absolute_value_toggle} of {quantity_chosen}',
            )
    fig.update_yaxes(type=axis_scale, showexponent='all', exponentformat='e')
    return fig


@callback(
        Output(component_id='h_slider', component_property='step'),
        Input(component_id='wf_algorithm_selector_1', component_property='value'),
        Input(component_id='wf_algorithm_selector_2', component_property='value'),
        Input(component_id='n_qubits_selector', component_property='value'),
        prevent_intial_call=True,
        )
def get_common_h_values_for_slider(
        algorithm_1: str,
        algorithm_2: str, 
        n_qubits: str):
    """
    Finds the shared range of h by both algorithms.
    """
    print(algorithm_1, algorithm_2)
    wf_1 = helper_functions.get_wavefunctions(algorithm_1)
    wf_1_h_range = wf_1.columns
    wf_2 = helper_functions.get_wavefunctions(algorithm_2)
    wf_2_h_range = wf_2.columns
    print(wf_2_h_range)
    global stepsize 
    stepsize = helper_functions.get_common_stepsize(wf_1_h_range, wf_2_h_range)
    print(f'{stepsize=}')
    return stepsize




@callback(
        Output(component_id='wavefunction_plot', component_property='figure'),
        Input(component_id='wf_algorithm_selector_1', component_property='value'),
        Input(component_id='wf_algorithm_selector_2', component_property='value'),
        Input(component_id='n_qubits_selector', component_property='value'),
        Input(component_id='h_slider', component_property='value'),
        prevent_intial_call=True,
        )
def wavefunction_plot(
        algorithm_1: str,
        algorithm_2: str, 
        n_qubits: str,
        h: float):
    fig = go.Figure()
    wf_1 = helper_functions.get_wavefunctions(algorithm_1)
    wf_2 = helper_functions.get_wavefunctions(algorithm_2)
    all_basis_states_df = helper_functions.get_all_basis_states(n_qubits)
    fig = go.Figure(data = [
        go.Bar(name=f'{algorithm_1}', y=wf_1[f'{h:.3f}'], hovertext=all_basis_states_df[0]),
        go.Bar(name=f'{algorithm_2}', y=wf_2[f'{h:.3f}'], hovertext=all_basis_states_df[0]),
        ])
    fig = helper_functions.basic_formatting(fig)
    fig.update_layout(
            title_text=f'h = {h}',
            xaxis_title_text='index',
            yaxis_title_text=f'|amplitude| of wavefunction',
            )
    return fig


@callback(
        Output(component_id='fidelity', component_property='children'),
        Input(component_id='wf_algorithm_selector_1', component_property='value'),
        Input(component_id='wf_algorithm_selector_2', component_property='value'),
        Input(component_id='n_qubits_selector', component_property='value'),
        Input(component_id='h_slider', component_property='value'),
        prevent_intial_call=True,
        )
def get_fidelity(
        algorithm_1: str,
        algorithm_2: str, 
        n_qubits: str,
        h: float):
    fig = go.Figure()
    wf_1 = helper_functions.get_wavefunctions(algorithm_1)
    wf_2 = helper_functions.get_wavefunctions(algorithm_2)
    all_basis_states_df = helper_functions.get_all_basis_states(n_qubits)
    fidelity = helper_functions.compute_fidelity(wf_1[f'{h:.3f}'], wf_2[f'{h:.3f}'])
    return f'Fidelity = {fidelity}'


if __name__ == '__main__':
    app.run(debug=True)
    



"""
@callback(
        Output(component_id='reference_selector', component_property='options'),
        Input(component_id='n_qubits_selector', component_property='value'),
        prevent_intial_call=True,
        )
def call_get_available_directories_containing_data(n_qubits: int) -> List[str]:
    return get_available_directories_containing_data(n_qubits)

@callback(
        Output(component_id='algorithm_selector_dupe1', component_property='options'),
        Input(component_id='n_qubits_selector', component_property='value'),
        prevent_intial_call=True,
        )
def call_get_available_directories_containing_data(n_qubits: int) -> List[str]:
    return get_available_directories_containing_data(n_qubits)


@callback(
        Output(component_id='algorithm_selector_dupe2', component_property='options'),
        Input(component_id='n_qubits_selector', component_property='value'),
        prevent_intial_call=True,
        )
def call_get_available_directories_containing_data(n_qubits: int) -> List[str]:
    return get_available_directories_containing_data(n_qubits)
"""
