# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Generate plots for the VQE/RBM/DMRG/ED magic work
# Note that all the data are stored at '../data/' and not in this directory

save_data = False # avoid unnecessary overwrites, applies to all plots
global h_cutoff
h_cutoff = 0 # Also defined in plot_templates.py MAKE SURE THEY'RE THE SAME

# + [markdown] jp-MarkdownHeadingCollapsed=true
# ## Imports 
# -

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from typing import List, Union
from plot_templates import (
    plot_style,
    get_energy_and_magic_from_list_of_filenames,
    get_energy_magic_and_infidelity_from_list_of_filenames,
    create_four_pane_plots,
    create_two_pane_plot_differences_only,
    create_statistical_error_plot,
    create_statistical_error_plot_8_and_12,
    create_three_pane_plot,
    create_six_panel_full_width_plot,
)
# get the preprocessing code from the dashboard
import sys
sys.path.append('../VQE_RBM_comparison_dashboard/')
import preprocess, helper_functions

plot_style()

# For the exact data, take them out of the list
exact_diagonalisation_energy_as_df, exact_diagonalisation_magic_as_df = get_energy_and_magic_from_list_of_filenames(list(['../data/8_qubits/exact_diagonalisation/']))
exact_diagonalisation_energy_as_df = exact_diagonalisation_energy_as_df[0]
exact_diagonalisation_magic_as_df = exact_diagonalisation_magic_as_df[0]
q12exact_diagonalisation_energy_as_df, q12exact_diagonalisation_magic_as_df = get_energy_and_magic_from_list_of_filenames(list(['../data/12_qubits/exact_diagonalisation/']))
q12exact_diagonalisation_energy_as_df = q12exact_diagonalisation_energy_as_df[0]
q12exact_diagonalisation_magic_as_df = q12exact_diagonalisation_magic_as_df[0]

# # Paper plots below here

# + [markdown] jp-MarkdownHeadingCollapsed=true
# ## Four panel plot of the energy, magic, delta_error, and delta_magic for the best representatives of each algorithm

# +
colours = [ 
    plt.cm.cool(0.1), 
    plt.cm.cool(0.5), 
    plt.cm.cool(1.0),
] 

markers = [
    '1', 
    '2', 
    '3', 
]

list_of_data = [
    '../data/8_qubits/DMRG/',
    '../data/8_qubits/RBM/alpha_5/RBM',
    '../data/8_qubits/VQE/4_layers/', 
]

list_of_labels = [
    'DMRG',
    'RBM',
    'VQE',
]

create_four_pane_plots(
    exact_diagonalisation_energy_as_df,
    exact_diagonalisation_magic_as_df,
    list_of_data,
    list_of_labels,
    colours,
    markers,
    marker_size = 16,
    marker_thickness = 0.6,
    filename_stem='./figures/best_of_each-four_panel_8q',
    legend_x = 0,
    legend_y = 0.9,
    legend_width = 1,
    legend_height = 0.1,
    energy_ylim_lower = 1e-16,
    energy_ylim_upper = 1e-0,
    magic_ylim_lower = 1e-16,
    magic_ylim_upper = 1e-0,
    n_yticks = 4,
    save_data=save_data
)

# -

# ## Two panel plots looking only at differences of energy and magic

# ### 12 qubit version for main text

# +
q12exact_diagonalisation_energy_as_df = helper_functions.get_energy_and_magic('../data/12_qubits/exact_diagonalisation/', 'energy')
q12exact_diagonalisation_magic_as_df = helper_functions.get_energy_and_magic('../data/12_qubits/exact_diagonalisation/', 'magic')


colours = [ 
    plt.cm.cool(0.1), 
    plt.cm.cool(0.5), 
    plt.cm.cool(1.0),
] 

markers = [
    '1', 
    '2', 
    '3', 
]

list_of_data = [
    '../data/12_qubits/DMRG/',
    '../data/12_qubits/RBM/alpha_5/',
    '../data/12_qubits/VQE/4_layers/',
]

list_of_labels = [
    'DMRG',
    'RBM',
    'VQE',
]

create_two_pane_plot_differences_only(
    q12exact_diagonalisation_energy_as_df,
    q12exact_diagonalisation_magic_as_df,
    list_of_data,
    list_of_labels,
    colours,
    markers,
    marker_size = 18,
    marker_thickness = 0.8,
    filename_stem='./figures/12_qubit_best_of_plot',
    energy_ylim_lower = 1e-16,
    energy_ylim_upper = 1e1,
    magic_ylim_lower = 1e-10,
    magic_ylim_upper = 1e0,
    n_yticks = 5,
    save_data=save_data
)

# -

# ## Statistical errors of RBM vs VQE

# ### Both 

# +

colours = [ 
    plt.cm.cool(0.1),
    plt.cm.cool(1.0), 
]

markers = [
    '1', 
    '2', 
]

list_of_data_bottom = [
    '../data/12_qubits/RBM/alpha_5/',
    '../data/12_qubits/VQE/4_layers/',
]

list_of_data_top = [
    '../data/8_qubits/RBM/alpha_5/RBM',
    '../data/8_qubits/VQE/4_layers/', 
]

list_of_labels = [
    'RBM',
    'VQE'
]

create_statistical_error_plot_8_and_12(
    list_of_data_top,
    list_of_data_bottom,
    list_of_labels,
    colours,
    markers,
    marker_size = 18,
    marker_thickness = 1.,
    filename_stem='./figures/statistical_errors_of_RBM_and_VQE_8_and_12q',
    energy_ylim_lower = 1e-6,
    energy_ylim_upper = 1e-1,
    magic_ylim_lower = 1e-6,
    magic_ylim_upper = 1e-1,
    legend_y=0.9,
    n_yticks = 5,
    save_data=save_data
)
# -

# ## Plotting infidelity (six panel plot)

# +

list_of_data = [
    '../data/8_qubits/RBM/alpha_1/RBM',
    '../data/8_qubits/RBM/alpha_3/RBM',
    '../data/8_qubits/RBM/alpha_5/RBM',
    '../data/8_qubits/RBM/alpha_1/symmetricRBM',
    '../data/8_qubits/RBM/alpha_3/symmetricRBM',
    '../data/8_qubits/RBM/alpha_5/symmetricRBM',
]
colours = [
    plt.cm.cool(0.0),
    plt.cm.cool(0.5),
    plt.cm.cool(1.0),
]
markers = [
    '1',
    '2',
    '3',
]
labels = [
    r'$\alpha = 1$',
    r'$\alpha = 3$',
    r'$\alpha = 5$',
]

create_six_panel_full_width_plot(
    list_of_data,
    labels,
    colours,
    markers,
    marker_size = 18,
    marker_thickness = 1.,
    filename_stem='./figures/six_panel_plot_for_varying_alpha-both-RBMs',
    energy_left_ylim_lower = 1e-5,
    energy_left_ylim_upper = 1e-3,
    magic_left_ylim_upper = 1e-2,
    magic_left_ylim_lower = 1e-6,
    infidelity_left_ylim_lower = 1e-8,
    infidelity_left_ylim_upper = 1e-0,
    energy_right_ylim_lower = 1e-5,
    energy_right_ylim_upper = 1e-3,
    magic_right_ylim_upper = 1e-2,
    magic_right_ylim_lower = 1e-6,
    infidelity_right_ylim_lower = 1e-8,
    infidelity_right_ylim_upper = 1e-0,
    legend_y=0.92,
    legend_x=0.0,
    n_yticks = 3,
    save_data=save_data,
)
