import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from typing import List, Union
# get the preprocessing code from the dashboard
import sys
sys.path.append('../VQE_RBM_comparison_dashboard/')
import preprocess, helper_functions

# If you don't have the tahoma font, click on the tahoma.tff file in this repo to install it, then uncomment this code to update matplotlib with this font
from matplotlib import font_manager
# You'll have to find this font, again on linux try: ```fc-list | grep "tah"``` in your command line to get the location
path_to_fonts = '/home/tspriggs/.local/share/fonts/'
font_file = path_to_fonts + 'tahoma.ttf'
font_manager.fontManager.addfont(font_file)


global exact_diagonalisation_colour
exact_diagonalisation_colour = 'black'

h_cutoff = 0 # Also defined in generate_plots_for_paper.py MAKE SURE THEY'RE THE SAME

# Plot formating and colours
fig_size_dim    = 8
golden_ratio    = (1+np.sqrt(5))/2
fig_size        = (fig_size_dim, fig_size_dim/golden_ratio)

def plot_style():
    font_size       = 8
    dpi             =  500

    params = {'figure.figsize': fig_size,
              'figure.dpi': dpi,
              'savefig.dpi': dpi,
              'font.size': font_size,
              'font.family': "Tahoma",
              'figure.titlesize': font_size,
              'legend.fontsize': font_size,
              'axes.labelsize': font_size,
              'axes.titlesize': font_size,
              'xtick.labelsize': font_size,
              'ytick.labelsize': font_size,
                }

    plt.rcParams.update(params)


def get_energy_and_magic_from_list_of_filenames(list_of_filenames: List[str]) -> Union[List[pd.DataFrame], List[pd.DataFrame]]:
    list_of_energy_data_as_dfs = []
    list_of_magic_data_as_dfs = []
    for filename in list_of_filenames:
        energy_df = helper_functions.get_energy_and_magic(filename, 'energy')
        energy_df = energy_df[energy_df['h'] >= h_cutoff]
        energy_df.reset_index(drop=True, inplace=True)
        list_of_energy_data_as_dfs.append(energy_df)
        magic_df = helper_functions.get_energy_and_magic(filename, 'magic')
        magic_df = magic_df[magic_df['h'] >= h_cutoff]
        magic_df.reset_index(drop=True, inplace=True)
        list_of_magic_data_as_dfs.append(magic_df)
    return list_of_energy_data_as_dfs, list_of_magic_data_as_dfs


def get_energy_magic_and_infidelity_from_list_of_filenames(list_of_filenames: List[str]) -> Union[List[pd.DataFrame], List[pd.DataFrame]]:
    """ Made after the above and didn't want to interfere with existing code."""
    energy, magic = get_energy_and_magic_from_list_of_filenames(list_of_filenames)
    list_of_infidelity_data_as_dfs = []
    for filename in list_of_filenames:
        infidelity_df = helper_functions.get_energy_and_magic(filename, 'infidelity')
        infidelity_df = infidelity_df[infidelity_df['h'] >= h_cutoff]
        infidelity_df.reset_index(drop=True, inplace=True)
        list_of_infidelity_data_as_dfs.append(infidelity_df)
    return energy, magic, list_of_infidelity_data_as_dfs


def create_four_pane_plots(
    exact_diagonalisation_energy_as_df: pd.DataFrame,
    exact_diagonalisation_magic_as_df: pd.DataFrame,
    list_of_data_to_load: List[str],
    labels: List[str],
    colours: "list of[strings or matplotlib colours]",
    markers: List[str],
    marker_size = 20,
    marker_thickness = 0.8,
    ED_linewidth = 0.7,
    filename_stem ='./figures/test',
    legend_x = 0,
    legend_y = 0.92,
    legend_width = 1,
    legend_height = 0.1,
    energy_ylim_lower = 1e-7,
    energy_ylim_upper = 1e-0,
    magic_ylim_lower = 1e-7,
    magic_ylim_upper = 1e-0,
    n_yticks = 4,
    save_data=False,
):
    """
    Generate a four plane plot. From top left to bottom right it goes: energy, magic, difference in energy,
    difference in magic.
    """

    # Read data
    list_of_energy_data, list_of_magic_data = get_energy_and_magic_from_list_of_filenames(list_of_data_to_load)
    """
    PLOT 1
    Basic formatting of the whole figure.
    """
    fig, axs = plt.subplots(ncols=2, nrows=2, dpi=300)
    axs = axs.flatten()
    
    fig.set_size_inches(3.40457, 1.8)
    font = {'size' : 8}
    mpl.rc('font', **font)
    fig.subplots_adjust(bottom = 0.2, top=0.90, left=0.16, right=0.98, wspace=0.35)
    
    for ax in axs:
        ax.grid(True, which='both',linewidth=0.1)
        for axx in [ax]:
            for axis in ['top','bottom','left','right']:
                axx.spines[axis].set_linewidth(0.5)
            axx.tick_params(width=0.5, which='both')
            axx.tick_params(length=2.5, which='major')
            axx.tick_params(length=1.5, which='minor')
    
    """
    Energy pane, axs[0].
    """
    # Exact diagonalisation result
    keys = exact_diagonalisation_energy_as_df.keys()
    axs[0].plot(
        exact_diagonalisation_energy_as_df[keys[0]], 
        exact_diagonalisation_energy_as_df[keys[1]], 
        '--',
        color = exact_diagonalisation_colour,
        linewidth = ED_linewidth,
        label = 'ED',
    )
    # Other algorithms
    for plot_idx, data in enumerate(list_of_energy_data):
        keys = data.keys()
        axs[0].scatter(
            data[keys[0]], 
            data[keys[1]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
            label = labels[plot_idx],
        )
    # Formatting
    axs[0].set_ylabel('$E$')
    axs[0].set_ylim(-26,-4)
    axs[0].set_yticks(np.linspace(-25, -5, 3))
    axs[0].set_xlim(-0.05+h_cutoff,3.05)
    
    
    
    """
    Error in energy pane, axs[2].
    """
    ed_keys = exact_diagonalisation_energy_as_df.keys()
    for plot_idx, data in enumerate(list_of_energy_data):
        # Join ED's and each algorithm's dataframe as some data are defined with different h spacings
        temp_df = data.copy(deep=True)
        temp_df = temp_df.merge(exact_diagonalisation_energy_as_df, on='h', how='right')
        keys = temp_df.keys()
        axs[2].scatter(
            temp_df[keys[0]], 
            abs(exact_diagonalisation_energy_as_df[ed_keys[1]]) - abs(temp_df[keys[1]]),
            color= colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[2].set_ylabel(r'|$E_{ED}$| - |$E$|')
    axs[2].set_yscale('log')
    axs[2].set_xlim(-0.05+h_cutoff,3.05)
    axs[2].set_ylim(energy_ylim_lower, energy_ylim_upper)
    axs[2].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    
    """
    Actual value pane, magic, axs[1].
    """
    # Exact diagonalisation result
    keys = exact_diagonalisation_magic_as_df.keys()
    axs[1].plot(
        exact_diagonalisation_magic_as_df[keys[0]], 
        exact_diagonalisation_magic_as_df[keys[1]], 
        '--',
        color = exact_diagonalisation_colour,
        linewidth = ED_linewidth,
    )
    # Other algorithms
    for plot_idx, data in enumerate(list_of_magic_data):
        keys = data.keys()
        axs[1].scatter(
            data[keys[0]], 
            data[keys[1]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[1].set_ylabel('$M$')
    axs[1].set_ylim(-0.01,0.41)
    axs[1].set_yticks(np.linspace(0, 0.4, 5))
    axs[1].set_xlim(-0.05+h_cutoff,3.05)
    
    
    
    """
    Error in magic pane, axs[3].
    """
    ed_keys = exact_diagonalisation_magic_as_df.keys()
    for plot_idx, data in enumerate(list_of_magic_data):
        # Join ED's and each algorithm's dataframe as some data are defined with different h spacings
        temp_df = data.copy(deep=True)
        temp_df = temp_df.merge(exact_diagonalisation_magic_as_df, on='h', how='right')
        keys = temp_df.keys()
        axs[3].scatter(
            temp_df[keys[0]], 
            abs(exact_diagonalisation_magic_as_df[ed_keys[1]] - temp_df[keys[1]]),
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[3].set_ylabel(r'|$M_{ED}$ - $M$|')
    axs[3].set_yscale('log')
    axs[3].set_xlim(-0.05+h_cutoff,3.05)
    axs[3].set_ylim(magic_ylim_lower,magic_ylim_upper)
    axs[3].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    
    
    """
    Final bit of formatting.
    """
    for ax in axs:
        ax.set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
    fig.legend(frameon=False, bbox_to_anchor=(legend_x, legend_y, legend_width, legend_height), loc='center', ncol=4)
    axs[2].set_xlabel('$h$')
    axs[3].set_xlabel('$h$')
    #plt.tight_layout()
    plt.subplots_adjust(wspace=0.5, hspace=0.5)
    if save_data:
        plt.savefig(f'{filename_stem}.png')
    plt.show()
    return


def create_two_pane_plot_differences_only(
    exact_diagonalisation_energy_as_df: pd.DataFrame,
    exact_diagonalisation_magic_as_df: pd.DataFrame,
    list_of_data_to_load: List[str],
    labels: List[str],
    colours: "list of[strings or matplotlib colours]",
    markers: List[str],
    marker_size = 20,
    marker_thickness = 0.8,
    filename_stem ='./figures/test',
    legend_x = 0,
    legend_y = 0.89,
    legend_width = 1,
    legend_height = 0.1,
    energy_ylim_lower = 1e-7,
    energy_ylim_upper = 1e-0,
    magic_ylim_lower = 1e-7,
    magic_ylim_upper = 1e-0,
    n_yticks = 4,
    turn_off_minor_ticks = True,
    save_data=False,
):
    """
    Generate one two pane plot, top pane being the difference between the ED energy and the supplied data,
    and the bottom plot is the analogous plot for magic.
    """

    # Read data
    list_of_energy_data, list_of_magic_data = get_energy_and_magic_from_list_of_filenames(list_of_data_to_load)
    """
    Basic formatting of the whole figure.
    """
    fig, axs = plt.subplots(ncols=1, nrows=2, dpi=300)
    
    fig.set_size_inches(3.40457, 1.8)
    font = {'size' : 8}
    mpl.rc('font', **font)
    fig.subplots_adjust(bottom = 0.2, top=0.9, left=0.16, right=0.98, wspace=0.35)
    
    for ax in axs:
        ax.grid(True, which='both',linewidth=0.1)
        for axx in [ax]:
            for axis in ['top','bottom','left','right']:
                axx.spines[axis].set_linewidth(0.5)
            axx.tick_params(width=0.5, which='both')
            axx.tick_params(length=2.5, which='major')
            axx.tick_params(length=1.5, which='minor')

    
    """
    Error in energy pane, axs[0].
    """
    ed_keys = exact_diagonalisation_energy_as_df.keys()
    for plot_idx, data in enumerate(list_of_energy_data):
        # Join ED's and each algorithm's dataframe as some data are defined with different h spacings
        temp_df = data.copy(deep=True)
        temp_df = temp_df.merge(exact_diagonalisation_energy_as_df, on='h', how='right')
        keys = temp_df.keys()
        axs[0].scatter(
            temp_df[keys[0]], 
            abs(exact_diagonalisation_energy_as_df[ed_keys[1]]) - abs(temp_df[keys[1]]),
            color= colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
            label = labels[plot_idx]
        )
    # Formatting
    axs[0].set_ylabel(r'|$E_{ED}$| - |$E$|')
    axs[0].set_yscale('log')
    axs[0].set_xlim(-0.05+h_cutoff,3.05)
    axs[0].set_ylim(energy_ylim_lower, energy_ylim_upper)
    axs[0].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    
    """
    Error in magic pane, axs[1].
    """
    ed_keys = exact_diagonalisation_magic_as_df.keys()
    for plot_idx, data in enumerate(list_of_magic_data):
        # Join ED's and each algorithm's dataframe as some data are defined with different h spacings
        temp_df = data.copy(deep=True)
        temp_df = temp_df.merge(exact_diagonalisation_magic_as_df, on='h', how='right')
        keys = temp_df.keys()
        axs[1].scatter(
            temp_df[keys[0]], 
            abs(exact_diagonalisation_magic_as_df[ed_keys[1]] - temp_df[keys[1]]),
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[1].set_ylabel(r'|$M_{ED}$ - $M$|')
    axs[1].set_yscale('log')
    axs[1].set_xlim(-0.05+h_cutoff,3.05)
    axs[1].set_ylim(magic_ylim_lower,magic_ylim_upper)
    axs[1].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    
    
    """
    Final bit of formatting.
    """

    for ax in axs:
        ax.set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
        if turn_off_minor_ticks:
            ax.yaxis.set_minor_locator(ticker.NullLocator())
    fig.legend(frameon=False, bbox_to_anchor=(legend_x, legend_y, legend_width, legend_height), loc='center', ncol=4)
    axs[1].set_xlabel('$h$')
    #plt.tight_layout()
    plt.subplots_adjust(hspace=0.4)
    if save_data:
        plt.savefig(f'{filename_stem}.png')
    plt.show()
    return


def create_statistical_error_plot(
    list_of_data_to_load: List[str],
    labels: List[str],
    colours: "list of[strings or matplotlib colours]",
    markers: List[str],
    marker_size = 20,
    marker_thickness = 0.8,
    filename_stem ='./figures/test',
    legend_x = 0,
    legend_y = 0.89,
    legend_width = 1,
    legend_height = 0.1,
    energy_ylim_lower = 1e-5,
    energy_ylim_upper = 1e-2,
    magic_ylim_lower = 1e-5,
    magic_ylim_upper = 1e-1,
    n_yticks = 4,
    turn_off_minor_ticks = True,
    save_data=False,
):
    """
    Create two horizontal subplots, both the statistical errors against h. 
    Left for energy, right for magic
    """
    # Read data
    list_of_energy_data, list_of_magic_data = get_energy_and_magic_from_list_of_filenames(list_of_data_to_load)
    """
    Basic formatting of the whole figure.
    """
    fig, axs = plt.subplots(ncols=2, nrows=1, dpi=300)
    
    fig.set_size_inches(3.40457, 1.8)
    font = {'size' : 8}
    mpl.rc('font', **font)
    fig.subplots_adjust(bottom = 0.2, top=0.88, left=0.16, right=0.98, wspace=0.35)
    
    for ax in axs:
        ax.grid(True, which='both',linewidth=0.1)
        for axx in [ax]:
            for axis in ['top','bottom','left','right']:
                axx.spines[axis].set_linewidth(0.5)
            axx.tick_params(width=0.5, which='both')
            axx.tick_params(length=2.5, which='major')
            axx.tick_params(length=1.5, which='minor')
    """
    Error in energy pane, axs[0].
    """
    for plot_idx, data in enumerate(list_of_energy_data):
        keys = data.keys()
        axs[0].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
            label = labels[plot_idx],
        )
    # Formatting
    axs[0].set_ylabel(r'$\delta E$')
    axs[0].set_yscale('log')
    axs[0].set_xlim(-0.05+h_cutoff,3.05)
    axs[0].set_ylim(energy_ylim_lower, energy_ylim_upper)
    axs[0].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Error in magic pane, axs[1].
    """
    """"""
    for plot_idx, data in enumerate(list_of_magic_data):
        keys = data.keys()
        axs[1].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[1].set_ylabel(r'$\delta M$')
    axs[1].set_yscale('log')
    axs[1].set_xlim(-0.05+h_cutoff,3.05)
    axs[1].set_ylim(magic_ylim_lower, magic_ylim_upper)
    axs[1].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))
    """
    Final bit of formatting.
    """
    for ax in axs:
        if turn_off_minor_ticks:
            ax.yaxis.set_minor_locator(ticker.NullLocator())
        ax.set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
    fig.legend(frameon=False, bbox_to_anchor=(legend_x, legend_y, legend_width, legend_height), loc='center', ncol=4)
    axs[0].set_xlabel('$h$')
    axs[1].set_xlabel('$h$')
    #plt.tight_layout()
    plt.subplots_adjust(hspace=0.3, wspace=0.5)
    if save_data:
        plt.savefig(f'{filename_stem}.png')
    plt.show()
    return


def create_statistical_error_plot_8_and_12(
    list_of_data_to_load_for_top_pane: List[str],
    list_of_data_to_load_for_bottom_pane: List[str],
    labels: List[str],
    colours: "list of[strings or matplotlib colours]",
    markers: List[str],
    marker_size = 20,
    marker_thickness = 0.8,
    filename_stem ='./figures/test',
    legend_x = 0,
    legend_y = 0.89,
    legend_width = 1,
    legend_height = 0.1,
    energy_ylim_lower = 1e-5,
    energy_ylim_upper = 1e-2,
    magic_ylim_lower = 1e-5,
    magic_ylim_upper = 1e-1,
    n_yticks = 4,
    turn_off_minor_ticks = True,
    save_data=False,
):
    """
    Create two horizontal subplots, both the statistical errors against h. 
    Left for energy, right for magic. Top for 8 qubit, bottom for 12
    """
    """
    Basic formatting of the whole figure.
    """
    fig, axs = plt.subplots(ncols=2, nrows=2, dpi=300)
    axs = axs.flatten()
    
    fig.set_size_inches(3.40457, 1.8)
    font = {'size' : 8}
    mpl.rc('font', **font)
    fig.subplots_adjust(bottom = 0.2, top=0.90, left=0.16, right=0.98, wspace=0.35)
    
    for ax in axs:
        ax.grid(True, which='both',linewidth=0.1)
        for axx in [ax]:
            for axis in ['top','bottom','left','right']:
                axx.spines[axis].set_linewidth(0.5)
            axx.tick_params(width=0.5, which='both')
            axx.tick_params(length=2.5, which='major')
            axx.tick_params(length=1.5, which='minor')

    """ 8 qubits """

    # Read data
    list_of_energy_data_top, list_of_magic_data_top = get_energy_and_magic_from_list_of_filenames(list_of_data_to_load_for_top_pane)
    
    """
    Error in energy pane, axs[0].
    """
    for plot_idx, data in enumerate(list_of_energy_data_top):
        keys = data.keys()
        axs[0].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
            label = labels[plot_idx],
        )
    # Formatting
    axs[0].set_ylabel(r'$\delta E_{N=8}$')
    axs[0].set_yscale('log')
    axs[0].set_xlim(-0.05+h_cutoff,3.05)
    axs[0].set_ylim(energy_ylim_lower, energy_ylim_upper)
    axs[0].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Error in magic pane, axs[1].
    """
    for plot_idx, data in enumerate(list_of_magic_data_top):
        keys = data.keys()
        axs[1].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[1].set_ylabel(r'$\delta M_{N=8}$')
    axs[1].set_yscale('log')
    axs[1].set_xlim(-0.05+h_cutoff,3.05)
    axs[1].set_ylim(magic_ylim_lower, magic_ylim_upper)
    axs[1].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """ 12 qubits """

    # Read data
    list_of_energy_data_bottom, list_of_magic_data_bottom = get_energy_and_magic_from_list_of_filenames(list_of_data_to_load_for_bottom_pane)
    
    """
    Error in energy pane, axs[2].
    """
    for plot_idx, data in enumerate(list_of_energy_data_bottom):
        keys = data.keys()
        axs[2].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[2].set_ylabel(r'$\delta E_{N=12}$')
    axs[2].set_yscale('log')
    axs[2].set_xlim(-0.05+h_cutoff,3.05)
    axs[2].set_ylim(energy_ylim_lower, energy_ylim_upper)
    axs[2].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Error in magic pane, axs[1].
    """
    for plot_idx, data in enumerate(list_of_magic_data_bottom):
        keys = data.keys()
        axs[3].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[3].set_ylabel(r'$\delta M_{N=12}$')
    axs[3].set_yscale('log')
    axs[3].set_xlim(-0.05+h_cutoff,3.05)
    axs[3].set_ylim(magic_ylim_lower, magic_ylim_upper)
    axs[3].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Final bit of formatting.
    """
    for ax in axs:
        if turn_off_minor_ticks:
            ax.yaxis.set_minor_locator(ticker.NullLocator())
        ax.set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
    fig.legend(frameon=False, bbox_to_anchor=(legend_x, legend_y, legend_width, legend_height), loc='center', ncol=4)
    axs[2].set_xlabel('$h$')
    axs[3].set_xlabel('$h$')
    #plt.tight_layout()
    plt.subplots_adjust(wspace=0.5, hspace=0.5)
    if save_data:
        plt.savefig(f'{filename_stem}.png')
    plt.show()
    return


def create_three_pane_plot(
    list_of_data_to_load: List[str],
    labels: List[str],
    colours: "list of[strings or matplotlib colours]",
    markers: List[str],
    marker_size = 20,
    marker_thickness = 0.8,
    filename_stem ='./figures/test',
    legend_x = 0,
    legend_y = 0.89,
    legend_width = 1,
    legend_height = 0.1,
    energy_ylim_lower = 1e-5,
    energy_ylim_upper = 1e-2,
    magic_ylim_lower = 1e-5,
    magic_ylim_upper = 1e-1,
    infidelity_ylim_lower = 1e-8,
    infidelity_ylim_upper = 1e-4,
    n_yticks = 4,
    turn_off_minor_ticks = True,
    save_data=False,
):
    """
    Energy accuracy, magic accuracy, and infidelity.
    """
    list_of_energy_data, list_of_magic_data, list_of_infidelity_data = \
        get_energy_magic_and_infidelity_from_list_of_filenames(
            list_of_data_to_load
        )
    
    fig, axs = plt.subplots(ncols=1, nrows=3, dpi=300)
    
    fig.set_size_inches(3.40457, 1.8)
    font = {'size' : 8}
    mpl.rc('font', **font)
    fig.subplots_adjust(bottom = 0.2, top=0.88, left=0.16, right=0.98, wspace=0.35)
    
    for ax in axs:
        ax.grid(True, which='both',linewidth=0.1)
        for axx in [ax]:
            for axis in ['top','bottom','left','right']:
                axx.spines[axis].set_linewidth(0.5)
            axx.tick_params(width=0.5, which='both')
            axx.tick_params(length=2.5, which='major')
            axx.tick_params(length=1.5, which='minor')
    """
    Error in energy pane, axs[0].
    """
    for plot_idx, data in enumerate(list_of_energy_data):
        keys = data.keys()
        axs[0].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
            label = labels[plot_idx],
        )
    # Formatting
    axs[0].set_ylabel(r'|$E_{ED}$|-|$E$|')
    axs[0].set_yscale('log')
    axs[0].set_xlim(-0.05+h_cutoff,3.05)
    axs[0].set_ylim(energy_ylim_lower, energy_ylim_upper)
    axs[0].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Error in magic pane, axs[1].
    """
    for plot_idx, data in enumerate(list_of_magic_data):
        keys = data.keys()
        axs[1].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[1].set_ylabel(r'|$M_{ED}$-$M$|')
    axs[1].set_yscale('log')
    axs[1].set_xlim(-0.05+h_cutoff,3.05)
    axs[1].set_ylim(magic_ylim_lower, magic_ylim_upper)
    axs[1].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))


    """
    Infidelity pane, axs[2].
    """
    for plot_idx, data in enumerate(list_of_infidelity_data):
        keys = data.keys()
        axs[2].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[2].set_ylabel(r'$I$')
    axs[2].set_yscale('log')
    axs[2].set_xlim(-0.05+h_cutoff,3.05)
    axs[2].set_ylim(infidelity_ylim_lower, infidelity_ylim_upper)
    axs[2].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Final bit of formatting.
    """
    for ax in axs:
        if turn_off_minor_ticks:
            ax.yaxis.set_minor_locator(ticker.NullLocator())
        ax.set_xticks([0+h_cutoff,1,2,3], labels = [' ',' ',' ',' '])
    axs[2].set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
    fig.legend(frameon=False, bbox_to_anchor=(legend_x, legend_y, legend_width, legend_height), loc='center', ncol=5)
    axs[2].set_xlabel('$h$')
    #plt.tight_layout()
    plt.subplots_adjust(hspace=0.3, wspace=0.5)
    if save_data:
        plt.savefig(f'{filename_stem}.png')
    plt.show()
    return


def create_six_panel_full_width_plot(
    list_of_data_to_load: List[str],
    labels: List[str],
    colours: "list of[strings or matplotlib colours]",
    markers: List[str],
    marker_size = 20,
    marker_thickness = 0.8,
    filename_stem ='./figures/test',
    legend_x = 0,
    legend_y = 0.89,
    legend_width = 1,
    legend_height = 0.1,
    energy_left_ylim_lower = 1e-5,
    energy_left_ylim_upper = 1e-2,
    magic_left_ylim_lower = 1e-5,
    magic_left_ylim_upper = 1e-1,
    infidelity_left_ylim_lower = 1e-8,
    infidelity_left_ylim_upper = 1e-4,
    energy_right_ylim_lower = 1e-5,
    energy_right_ylim_upper = 1e-2,
    magic_right_ylim_lower = 1e-5,
    magic_right_ylim_upper = 1e-1,
    infidelity_right_ylim_lower = 1e-8,
    infidelity_right_ylim_upper = 1e-4,
    n_yticks = 4,
    turn_off_minor_ticks = True,
    save_data=False,
):
    """
    Energy accuracy, magic accuracy, and infidelity. Once on left and then 
    again on the right for the 0-2nd and 3-5th elements of the list respectively.
    Only the lefthand data are put in the legend, so MAKE SURE the markers and 
    labels are correctly duplicated.
    """
    list_of_energy_data, list_of_magic_data, list_of_infidelity_data = \
        get_energy_magic_and_infidelity_from_list_of_filenames(
            list_of_data_to_load
        )
    
    fig, axs = plt.subplots(ncols=2, nrows=3, dpi=300)
    axs = axs.flatten(order='F')
    print(len(axs))
    
    fig.set_size_inches(3.40457 * 2, 3)
    font = {'size' : 8}
    mpl.rc('font', **font)
    fig.subplots_adjust(bottom = 0.2, top=0.88, left=0.08, right=0.98, wspace=0.35)
    
    for ax in axs:
        ax.grid(True, which='both',linewidth=0.1)
        for axx in [ax]:
            for axis in ['top','bottom','left','right']:
                axx.spines[axis].set_linewidth(0.5)
            axx.tick_params(width=0.5, which='both')
            axx.tick_params(length=2.5, which='major')
            axx.tick_params(length=1.5, which='minor')
    
    """ LEFT HAND PANES """

    """
    Error in energy pane, axs[0].
    """
    for plot_idx, data in enumerate(list_of_energy_data[:3]):
        keys = data.keys()
        axs[0].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
            label = labels[plot_idx],
        )
    # Formatting
    axs[0].set_ylabel(r'|$E_{ED}$| - |$E$|')
    axs[0].set_yscale('log')
    axs[0].set_xlim(-0.05+h_cutoff,3.05)
    axs[0].set_ylim(energy_left_ylim_lower, energy_left_ylim_upper)
    axs[0].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Error in magic pane, axs[1].
    """
    for plot_idx, data in enumerate(list_of_magic_data[:3]):
        keys = data.keys()
        axs[1].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[1].set_ylabel(r'|$M_{ED}$ - $M$|')
    axs[1].set_yscale('log')
    axs[1].set_xlim(-0.05+h_cutoff,3.05)
    axs[1].set_ylim(magic_left_ylim_lower, magic_left_ylim_upper)
    axs[1].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))


    """
    Infidelity pane, axs[2].
    """
    for plot_idx, data in enumerate(list_of_infidelity_data[:3]):
        keys = data.keys()
        axs[2].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[2].set_ylabel(r'$I$')
    axs[2].set_yscale('log')
    axs[2].set_xlim(-0.05+h_cutoff,3.05)
    axs[2].set_ylim(infidelity_left_ylim_lower, infidelity_left_ylim_upper)
    axs[2].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))


    """ RIGHT HAND PANES """

    """
    Error in energy pane, axs[3].
    """
    for plot_idx, data in enumerate(list_of_energy_data[3:]):
        keys = data.keys()
        axs[3].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[3].set_ylabel(r'|$E_{ED}$| - |$E$|')
    axs[3].set_yscale('log')
    axs[3].set_xlim(-0.05+h_cutoff,3.05)
    axs[3].set_ylim(energy_right_ylim_lower, energy_right_ylim_upper)
    axs[3].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Error in magic pane, axs[4].
    """
    for plot_idx, data in enumerate(list_of_magic_data[3:]):
        keys = data.keys()
        axs[4].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[4].set_ylabel(r'|$M_{ED}$ - $M$|')
    axs[4].set_yscale('log')
    axs[4].set_xlim(-0.05+h_cutoff,3.05)
    axs[4].set_ylim(magic_right_ylim_lower, magic_right_ylim_upper)
    axs[4].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))


    """
    Infidelity pane, axs[5].
    """
    for plot_idx, data in enumerate(list_of_infidelity_data[3:]):
        keys = data.keys()
        axs[5].scatter(
            data[keys[0]], 
            data[keys[2]],
            color = colours[plot_idx],
            s = marker_size,
            linewidth = marker_thickness, 
            marker = markers[plot_idx],
        )
    # Formatting
    axs[5].set_ylabel(r'$I$')
    axs[5].set_yscale('log')
    axs[5].set_xlim(-0.05+h_cutoff,3.05)
    axs[5].set_ylim(infidelity_right_ylim_lower, infidelity_right_ylim_upper)
    axs[5].yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=n_yticks))

    """
    Final bit of formatting.
    """
    for ax in axs:
        if turn_off_minor_ticks:
            ax.yaxis.set_minor_locator(ticker.NullLocator())
        ax.set_xticks([0+h_cutoff,1,2,3], labels = [' ',' ',' ',' '])
    axs[2].set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
    axs[5].set_xticks([0+h_cutoff,1,2,3], labels = [f'{h_cutoff}', '1', '2', '3'])
    fig.legend(frameon=False, bbox_to_anchor=(legend_x, legend_y, legend_width, legend_height), loc='center', ncol=3)
    axs[2].set_xlabel('$h$')
    axs[5].set_xlabel('$h$')
    axs[0].set_title('RBM')
    axs[3].set_title('symmetric RBM')
    #plt.tight_layout()
    plt.subplots_adjust(hspace=0.3, wspace=0.3)
    if save_data:
        plt.savefig(f'{filename_stem}.png')
    plt.show()
    return
