import pandas as pd
import glob
import os
import ast
import re
import numpy as np
from scipy.stats import jarque_bera, kurtosis, skew, median_abs_deviation
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.cm import get_cmap
import ptitprince as pt
from functools import reduce
import operator
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))

def select_quantile(my_df, mode = "macro", quantile_value = 0.0185, op="lt"):
    '''
    Filters the DataFrame `my_df` based on a quantile threshold applied to 
    the "Error Value" column. Operates in three modes: "macro", "micro", 
    and "individual", each applying different methods of thresholding.

    Parameters:
    - my_df (pd.DataFrame): 
        The input DataFrame, which must contain at least the following columns:
        - "Error Type": Categorizes different error types.
        - "Error Value": Numeric values on which quantile thresholds are applied.

    - mode (str, default="macro"):
        Determines the mode of operation for thresholding. Options:
        - "macro": Applies a global quantile threshold to the entire DataFrame.
        - "micro": Applies the quantile threshold within each "Error Type" group.
        - "individual": Uses externally provided thresholds for each "Error Type".

    - quantile_value (float or pd.DataFrame, default=0.0185):
        Specifies the quantile for thresholding:
        - If mode is "macro" or "micro", this is a float representing the quantile (e.g., 0.0185).
        - If mode is "individual", this is a DataFrame containing thresholds for each "Error Type" group, with at least two columns: "Error Type" and the threshold value.

    - op (str, default="lt"):
        The comparison operator used to filter rows. Options are:
        - "lt": Less than.
        - "le": Less than or equal to.
        - "gt": Greater than.
        - "ge": Greater than or equal to.

    Returns:
    - pd.DataFrame: 
        A DataFrame with rows filtered based on the specified quantile threshold and comparison operator.
    '''
    operators = {
        "lt": operator.lt,  # Less than
        "le": operator.le,  # Less than or equal to
        "gt": operator.gt,  # Greater than
        "ge": operator.ge   # Greater than or equal to
    }
    if op not in operators:
        raise ValueError(f"Unknown operator '{op}'. Valid options are: {list(operators.keys())}")

    if mode == "macro" and isinstance(quantile_value, float):
        macro_threshold = my_df["Error Value"].quantile(quantile_value)
        return my_df[operators[op](my_df["Error Value"], macro_threshold)]
    elif mode == "micro" and isinstance(quantile_value, float):
        micro_threshold = my_df.groupby("Error Type")["Error Value"].transform(lambda x: x.quantile(quantile_value))
        return my_df[operators[op](my_df["Error Value"], micro_threshold)]
    elif mode == "individual":
        assert np.array_equal(np.unique(my_df["Error Type"]), np.unique(quantile_value["Error Type"])), "Could not assert no mismatches"
        threshold_column = quantile_value.drop("Error Type", axis=1).columns
        assert len(threshold_column) == 1, "Found more than two columns in threshold data frame"
        threshold_column = threshold_column[0]
        
        merged_df = my_df.merge(quantile_value, on="Error Type", how="left") # Merge on "Error Type"

        thresholds = merged_df.groupby("Error Type")["Error Value"].transform( # Calculate the quantile thresholds for each Type
            lambda x: x.quantile(merged_df.loc[x.index, threshold_column].iloc[0])
        )
        return merged_df[operators[op](merged_df["Error Value"], thresholds)]
    else:
        print(f"Unknown mode-quantile combination {mode} and {quantile_value}")

def make_half_violinplot(my_df, scatter = True, boxes = True, my_size=(10, 6), my_rotation=60):
    '''
    Generates a half violin plot for visualising the distribution of "Error Value" 
    grouped by "Error Type" in the input DataFrame. Optionally includes scatter points 
    and boxplots for added context.

    Parameters:
    - my_df (pd.DataFrame): 
        Input DataFrame that must contain at least two columns:
        - "Error Type": Categorical variable used to group the data.
        - "Error Value": Numeric values representing the data to be visualized.

    - scatter (bool, default=True): 
        Whether to overlay a scatter plot of the data points on the violin plot.

    - boxes (bool, default=True): 
        Whether to overlay boxplots that show the statistical distribution (median, quartiles) 
        for each group ("Error Type").

    - my_size (tuple, default=(10, 6)): 
        Tuple specifying the size of the figure (width, height) in inches.

    - my_rotation (int, default=60): 
        Degree of rotation for the x-axis tick labels (the "Error Type" labels). 
        Helps improve readability if the labels are long or overlapping.

    Returns:
    - None: Call plt.show() after running this function to display the plot.
    '''
    X_SHIFT = 0.5 # parameter for controlling x width of half violin plot

    fig, ax = plt.subplots(figsize=my_size)

    provinces = my_df["Error Type"].unique()
    colours = sns.color_palette("colorblind", len(provinces))
    colour_mapping = dict(zip(provinces, colours))

    my_linewidth = 0.0 if scatter or boxes else 0.5
    
    # Half violin plot
    pt.half_violinplot(x='Error Type', y='Error Value', data=my_df, palette=colour_mapping, bw=0.075, linewidth = my_linewidth, alpha = 1)

    if scatter: # data scatter 
        for i, province in enumerate(provinces):
            data = my_df[my_df["Error Type"] == province]
            x = i + np.random.uniform(high=X_SHIFT, size=len(data))
            y = data["Error Value"]
            ax.scatter(x, y, alpha=0.2, color = colour_mapping[province], s = 5)
    if boxes: # boxplot data
        boxplot_data = [
            my_df[my_df["Error Type"] == province]["Error Value"].values 
            for province in provinces
        ]
        
        SHIFT = X_SHIFT / 2
        POSITIONS = [i + SHIFT for i in range(len(provinces)) ]
        
        # For mean
        means = [np.mean(data) for data in boxplot_data]
        ax.scatter(POSITIONS, means, color='black', zorder=3, label='Mean')  # Use scatter to plot means
        
        # The style of the line that represents the median.
        medianprops = {"linewidth": 1, "color": "black", "solid_capstyle": "butt"}
        # The style of the box ... This is also used for the whiskers
        boxprops = {"linewidth": 1, "color": "black" }
        
        ax.boxplot(
            boxplot_data, 
            vert=True, 
            positions=POSITIONS, 
            manage_ticks=False,
            showfliers = False, # Do not show the outliers beyond the caps.
            showcaps = True,   # Do not show the caps
            medianprops = medianprops,
            whiskerprops = boxprops,
            boxprops = boxprops,
            widths = X_SHIFT
        )
    
    plt.title(f"Violin Plot of Error Distributions")
    plt.xlabel("Province")
    plt.ylabel("Error Value Distribution")
    plt.xticks(rotation=my_rotation)
    

def make_violinplot(my_df):
    '''
    Generates a full violin plot for visualising the distribution of "Error Value" 
    grouped by "Error Type" in the input DataFrame.

    Parameters:
    - my_df (pd.DataFrame): 
        The input DataFrame, which must contain at least two columns:
        - "Error Type": Categorical variable used to group the data.
        - "Error Value": Numeric values representing the data to be visualized.

    Returns: 
    - None: Call plt.show() after running this function to display the plot.
    '''
    plt.figure(figsize=(10, 6))

    provinces = my_df["Error Type"].unique()
    colours = sns.color_palette("colorblind", len(provinces))
    colour_mapping = dict(zip(provinces, colours))

    sns.violinplot(x='Error Type', y='Error Value', data=my_df, palette=colour_mapping, bw=0.075, linewidth = 0.1, alpha = 0.75)
    plt.title(f"Violin Plot of Error Distributions")
    plt.xlabel("Province")
    plt.ylabel("Error Value Distribution")
    plt.xticks(rotation=60)
    

def make_stripplot(my_df, my_jitter = 0.4):
    '''
    Generates a strip plot for visualising the distribution of "Error Value" 
    grouped by "Error Type" in the input DataFrame. Adds jitter to the points 
    to avoid overlap, showing the distribution of individual data points.

    Parameters:
    - my_df (pd.DataFrame): 
        The input DataFrame, which must contain at least two columns:
        - "Error Type": Categorical variable used to group the data.
        - "Error Value": Numeric values representing the data to be visualized.
    
    - my_jitter (float, default=0.4): 
        Controls the amount of horizontal jitter applied to the points to reduce overlap. 
        A higher value increases the spread of points.

    Returns: 
    - None: Call plt.show() after running this function to display the plot.
    '''
    plt.figure(figsize=(10, 6))

    provinces = my_df["Error Type"].unique()
    colours = sns.color_palette("colorblind", len(provinces))
    colour_mapping = dict(zip(provinces, colours))

    sns.stripplot(x='Error Type', y='Error Value', data=my_df, palette=colour_mapping, jitter=my_jitter, alpha = 0.2)    
    plt.title(f"Box Plot of Error Distributions")
    plt.xlabel("Province")
    plt.ylabel("Error Value Distribution")
    plt.xticks(rotation=60)

def plot_nan_contribution(my_df, my_provinces, stack=None, my_rotation=0):
    '''
    Plots the contribution of missing values (NaNs) for a set of provinces, showing 
    how often different numbers of NaNs occur in rows and which provinces contribute 
    to those NaNs. Allows optional grouping of x-values for stacked plotting and 
    label rotation.

    Parameters:
    - my_df (pd.DataFrame): 
        Input DataFrame containing data for the provinces to be analysed.
    - my_provinces (list of str): 
        List of column names corresponding to the provinces whose NaN contributions 
        will be analysed and visualized.
    - stack (list or list of lists, optional):
        Optional parameter to specify grouping of x-values. When a list of integers 
        is provided, all rows with those counts of NaNs will be summed together.
        If a list of lists is provided, each sublist indicates x-values to be grouped 
        into a single stacked bar.
    - my_rotation (int, optional):
        The rotation angle for the x-axis labels. Defaults to 0.

    Returns:
    - None: Call plt.show() after running this function to display the plot.
    '''
    # Step 1: Create a mask to check where NaNs are present
    nan_mask = my_df[my_provinces].isna()

    # Step 2: Calculate the number of NaNs in each row
    nan_counts_per_row = nan_mask.sum(axis=1)

    # Step 3: Compute the probability distribution of NaN counts
    nan_count_distribution = nan_counts_per_row.value_counts(normalize=True).sort_index()

    # Step 4: Calculate the contribution of each province to the NaN count in each row
    nan_contribution_per_row = nan_mask.groupby(nan_counts_per_row).sum()

    # Step 5: Normalize the contributions
    nan_contribution_normalized = nan_contribution_per_row.div(nan_contribution_per_row.sum(axis=1), axis=0)

    # Step 6: Scale the normalized contributions by the overall NaN count distribution
    nan_contribution_scaled = nan_contribution_normalized.multiply(nan_count_distribution, axis=0)

    # Step 7: Add the case where there are zero NaNs
    rows_with_no_nan = my_df[nan_mask.sum(axis=1) == 0]
    zero_nan_prob = len(rows_with_no_nan) / len(my_df)
    zero_nan_contribution = pd.Series([1/len(my_provinces)] * len(my_provinces), index=my_provinces) * zero_nan_prob
    nan_contribution_scaled.loc[0] = zero_nan_contribution

    # Sort the DataFrame so that 0 NaNs are at the beginning
    nan_contribution_scaled = nan_contribution_scaled.sort_index()

    # Step 8: Handle stacking based on the `stack` parameter
    if stack:
        # If stack is a list of integers, wrap it in a list of lists for consistent handling
        if all(isinstance(x, int) for x in stack):
            stack = [stack]

        # Prepare to store the grouped and individual rows
        grouped_data = pd.DataFrame(0, index=[], columns=nan_contribution_scaled.columns)

        # Track the indices that have been grouped already
        grouped_indices = set()

        for group in stack:
            # Sum up the rows specified in each group
            if isinstance(group, list):
                group_name = str(group)  # Use the list directly as label
                # Sum only if all elements in group are valid indices
                valid_indices = [idx for idx in group if idx in nan_contribution_scaled.index]
                if valid_indices:
                    group_data = nan_contribution_scaled.loc[valid_indices].sum()
                    grouped_data.loc[group_name] = group_data
                    grouped_indices.update(valid_indices)

        # Append rows not in any stack group to the final DataFrame
        individual_data = nan_contribution_scaled.drop(index=grouped_indices, errors='ignore')
        
        # Convert all indices to strings for consistency
        individual_data.index = individual_data.index.astype(str)
        
        nan_contribution_scaled = pd.concat([grouped_data, individual_data]).sort_index()

    # Step 9: Define the custom colour mapping using the colourblind palette
    num_provinces = len(my_provinces)
    colours = sns.color_palette("colorblind", num_provinces)
    colour_mapping = dict(zip(my_provinces, colours))

    # Step 10: Plot the stacked bar chart with the custom colour mapping and label rotation
    ax = nan_contribution_scaled.plot(
        kind='bar',
        stacked=True,
        figsize=(10, 7),
        color=[colour_mapping[col] for col in nan_contribution_scaled.columns]
    )

    # plt.title('NaN-count Distribution per province')
    if stack:
        plt.xlabel('Number of misfits (stacked)')
    else:
        plt.xlabel('Number of misfits')  
    plt.ylabel('Probability (Proportion)')
    plt.xticks(rotation=my_rotation)  # Apply the rotation here
    plt.legend(title='Province')

    plt.tight_layout()
    plt.show()

# def plot_nan_contribution(my_df, my_provinces):
#     '''
#     Plots the contribution of missing values (NaNs) for a set of provinces, showing 
#     how often different numbers of NaNs occur in rows and which provinces contribute 
#     to those NaNs. The result is a stacked bar chart, where each bar represents a 
#     specific number of NaNs in a row, and the stacked sections represent the 
#     contribution of each province.

#     Parameters:
#     - my_df (pd.DataFrame): 
#         Input DataFrame containing data for the provinces to be analysed.
    
#     - my_provinces (list of str): 
#         List of column names corresponding to the provinces whose NaN contributions 
#         will be analysed and visualized.

#     Returns:
#     - None: Call plt.show() after running this function to display the plot.
#     '''
#     # Step 1: Create a mask to check where NaNs are present
#     nan_mask = my_df[my_provinces].isna()

#     # Step 2: Calculate the number of NaNs in each row
#     nan_counts_per_row = nan_mask.sum(axis=1)

#     # Step 3: Compute the probability distribution of NaN counts
#     nan_count_distribution = nan_counts_per_row.value_counts(normalize=True).sort_index()

#     # Step 4: Calculate the contribution of each province to the NaN count in each row
#     nan_contribution_per_row = nan_mask.groupby(nan_counts_per_row).sum()

#     # Step 5: Normalize the contributions
#     nan_contribution_normalized = nan_contribution_per_row.div(nan_contribution_per_row.sum(axis=1), axis=0)

#     # Step 6: Scale the normalized contributions by the overall NaN count distribution
#     nan_contribution_scaled = nan_contribution_normalized.multiply(nan_count_distribution, axis=0)

#     # Step 7: Add the case where there are zero NaNs
#     rows_with_no_nan = my_df[nan_mask.sum(axis=1) == 0]
#     zero_nan_prob = len(rows_with_no_nan) / len(my_df)
#     zero_nan_contribution = pd.Series([1/len(my_provinces)] * len(my_provinces), index=my_provinces) * zero_nan_prob
#     nan_contribution_scaled.loc[0] = zero_nan_contribution

#     # Sort the DataFrame so that 0 NaNs are at the beginning
#     nan_contribution_scaled = nan_contribution_scaled.sort_index()

#     # Step 8: Define the custom colour mapping using the colourblind palette
#     num_provinces = len(my_provinces)
#     colours = sns.color_palette("colorblind", num_provinces)
#     colour_mapping = dict(zip(my_provinces, colours))

#     # Step 9: Plot the stacked bar chart with the custom colour mapping
#     ax = nan_contribution_scaled.plot(
#         kind='bar',
#         stacked=True,
#         figsize=(10, 7),
#         color=[colour_mapping[col] for col in nan_contribution_scaled.columns]
#     )

#     plt.title('NaN-count Distribution per province')
#     plt.xlabel('Number of NaNs in Row')
#     plt.ylabel('Probability (Proportion)')
#     plt.legend(title='Province')

#     plt.tight_layout()
    
def make_histogram(my_df, my_bins=100, separate_plots=False):
    '''
    Creates histograms to visualize the distribution of "Error Value" grouped by "Error Type" 
    in the input DataFrame. Supports either overlaying all groups in a single plot or 
    generating separate subplots for each group.

    Parameters:
    - my_df (pd.DataFrame): 
        Input DataFrame that must contain at least two columns:
        - "Error Type": Categorical variable used to group the data.
        - "Error Value": Numeric values representing the data to be visualized.

    - my_bins (int, default=100): 
        Number of bins to use in the histogram.

    - separate_plots (bool, default=False): 
        If True, generates separate subplots for each "Error Type". 
        If False, overlays all groups in a single plot.

    Returns: 
    - None: Call plt.show() after running this function to display the plot.
    '''
    provinces = my_df["Error Type"].unique()
    num_provinces = len(provinces)
    colours = sns.color_palette("colorblind", num_provinces)
    colour_mapping = dict(zip(provinces, colours))
    
    # Common plot options
    plot_opts = {
        "bins": my_bins,
        "kde": True,
        "kde_kws": {"bw_adjust": 0.5},
        "alpha": 0.25,
        "linewidth": 0.05
    }

    # Internal function to plot a single histogram
    def plot_single_histogram(data, province=None, ax=None):
        sns.histplot(
            data=data,
            x='Error Value',
            color=colour_mapping[province] if province else None,
            hue=None if province else 'Error Type',
            ax=ax,
            **plot_opts
        )
        if ax:
            ax.set_xlabel("Error Value")
            ax.set_ylabel("Frequency")
        else:
            plt.xlabel("Error Value")
            plt.ylabel("Frequency")

    if separate_plots:
        # Calculate subplot grid size
        ncols = int(np.ceil(np.sqrt(num_provinces)))
        nrows = int(np.ceil(num_provinces / ncols))
        
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
        axes = axes.flatten()  # Flatten to easily iterate
        
        for i, province in enumerate(provinces):
            plot_single_histogram(my_df[my_df["Error Type"] == province], province, ax=axes[i])
        
        # Hide any extra axes
        for ax in axes[num_provinces:]:
            ax.remove()
        
        plt.tight_layout()

    else:
        plt.figure(figsize=(10, 6))
        plot_single_histogram(my_df)

        plt.title("Overlayed Distributions of Error Columns")
        
        # Create custom legend
        legend_handles = [Patch(color=colour_mapping[province], label=province) for province in provinces]
        plt.legend(handles=legend_handles, title='Error Type')

def calculate_plateau_points(error_values, quantile_value=0.03, fixed_threshold=100, direction="left", tick_count=500):
    '''
    Computes the quantiles, their gradients (derivatives), and identifies plateau points 
    where the gradient exceeds a fixed threshold. Plateau points represent regions 
    where the error distribution is relatively flat.

    Parameters:
    - error_values (pd.Series): 
        The numeric values representing the error data for which quantiles are computed.
    - quantile_value (float, default=0.03): 
        The maximum quantile value used for identifying plateau points. Defines the quantile range to analyse.
    - fixed_threshold (float, default=100): 
        Threshold for the derivative to identify plateaus. If the gradient exceeds this value, it indicates a plateau.
    - direction (str, default="left"): 
        The direction of quantile calculation. "left" focuses on lower quantiles, and "right" focuses on higher quantiles.
    - tick_count (int, default=500): 
        Number of ticks (quantile steps) to use for calculating quantiles and derivatives.

    Returns:
    - plateau_points (list of float): 
        List of quantile values where plateaus are detected.
    - quantile_ticks (np.ndarray): 
        The array of quantile ticks used for computing quantiles.
    - quantiles (np.ndarray): 
        The quantile values corresponding to the ticks.
    - derivatives (np.ndarray): 
        The derivatives (gradients) of the quantiles.
    '''
    quantile_ticks = np.linspace(0.0, quantile_value, tick_count) if direction == "left" else np.linspace(quantile_value, 1.0, tick_count)
    quantiles = error_values.quantile(quantile_ticks)
    derivatives = np.gradient(quantiles, quantile_ticks)
    plateau_above_threshold = np.where(derivatives > fixed_threshold)[0]
    intervals = np.split(plateau_above_threshold, np.where(np.diff(plateau_above_threshold) > 1)[0] + 1)
    
    if direction == "left":
        plateau_points = [quantile_ticks[interval[0]] for interval in intervals if len(interval) > 0]
        plateau_points.pop(0)  # Remove the first point, which is always 0.0
    elif direction == "right":
        plateau_points = [quantile_ticks[interval[-1]] for interval in intervals if len(interval) > 0]
        plateau_points.pop(-1)  # Remove the last point, which is always 1.0
    else:
        raise ValueError("Invalid direction. Choose either 'left' or 'right'.")
    
    return plateau_points if plateau_points else [np.nan], quantile_ticks, quantiles, derivatives

def sanitise_filename(title):
    '''
    Helper function to sanitise a string to make it a valid file name by removing or replacing 
    invalid characters and spaces.

    Parameters:
    - title (str): 
        The input string to be sanitised (usually a plot title).

    Returns:
    - sanitised_title (str): 
        The sanitised file name string.
    '''
    # Remove or replace invalid characters for a file name
    return re.sub(r'[^\w\s-]', '', title).strip().replace(' ', '_')

def plot_error_distribution(quantile_ticks, quantiles, derivatives, plateau_points, province, colour_mapping, show=True, save_graphic=False, separate_plots=False, ax=None):
    '''
    Plots the error distribution of a province, including quantile values, their derivatives, 
    and plateau points. Optionally, the plot can be saved to a file or displayed.

    Parameters:
    - quantile_ticks (np.ndarray): 
        The array of quantile ticks used for plotting.
    - quantiles (np.ndarray): 
        The array of quantile values corresponding to the ticks.
    - derivatives (np.ndarray): 
        The array of derivatives (gradients) of the quantile values.
    - plateau_points (list of float): 
        List of quantile values where plateaus were detected.
    - province (str): 
        The name of the province (used for plot labeling).
    - colour_mapping (dict): 
        A dictionary that maps the elements (quantiles, plateaus) to specific colours.
    - show (bool, default=True): 
        Whether to display the plot after generating it.
    - save_graphic (str or bool, default=False): 
        If specified, the plot will be saved in the given file format (e.g., 'pdf', 'png').
    - separate_plots (bool, default=False): 
        If True, generates separate plots for each province, with each saved under its own file name.
    - ax (matplotlib.axes.Axes, optional): 
        Axis to plot on (if None, a new axis is created).

    Returns:
    - ax or fig (matplotlib.axes.Axes or matplotlib.figure.Figure): 
        Returns the axis or figure, depending on whether an axis was passed.
    '''
    if ax is None:
        fig, ax1 = plt.subplots(figsize=(10, 6))
    else:
        ax1 = ax

    ax1.plot(quantile_ticks, derivatives, color=colour_mapping["gradient"], label='Derivative of Quantiles')
    ax1.axhline(0, color='gray', linestyle='--', linewidth=0.5)
    
    ax1.set_xlabel('Quantile Ticks')
    ax1.set_ylabel('Derivative Value', color=colour_mapping["gradient"])
    ax1.set_title(f'Derivative and Quantile Plot for {province}')
    ax1.tick_params(axis='y', labelcolor=colour_mapping["gradient"])
    
    ax2 = ax1.twinx()
    ax2.plot(quantile_ticks, quantiles, color=colour_mapping["normal"])
    ax2.set_ylabel('Quantile Values', color=colour_mapping["normal"])
    ax2.tick_params(axis='y', labelcolor=colour_mapping["normal"])
    
    for plateau in plateau_points:
        color = colour_mapping[plateau]
        ax1.axvline(plateau, linestyle='--', color=color)
        ax2.axvline(plateau, linestyle='--', color=color, label=f'Cut-off at x = {plateau:.3f}')
    
    handles, labels = ax2.get_legend_handles_labels()
    if labels:
        ax2.legend(handles, labels)

    # Save the figure if save_pdf is True
    if save_graphic is not False:
        # Use the title as the file name if separate_plots is True
        if separate_plots:
            # Extract the title and sanitise it for file name usage
            title = ax1.get_title()
            file_name = f"{sanitise_filename(title)}.{save_graphic}"
        else:
            # Default file name for combined plots
            file_name = f"combined_error_distribution_plots.{save_graphic}"
        
        plt.savefig(file_name, format=save_graphic)

    # Show the plot if show is True
    if show and ax is None:
        plt.show()

    return ax if ax is not None else fig

def analyse_error_distribution(df, quantile_value=0.03, fixed_threshold=100, direction="left", show=False, save_graphic=False, separate_plots=True, tick_count=500):
    '''
    Analyses the error distribution for each province in the input DataFrame, calculating 
    quantiles, derivatives, and plateau points. It generates and optionally saves or shows 
    plots for each province.

    Parameters:
    - df (pd.DataFrame): 
        Input DataFrame that must contain at least two columns: 
        - "Error Type": The categorical column representing provinces or regions.
        - "Error Value": The numeric values for which quantiles are calculated.
    - quantile_value (float, default=0.03): 
        The maximum quantile value used for plateau point detection.
    - fixed_threshold (float, default=100): 
        Threshold for the gradient (derivative) used to identify plateau points.
    - direction (str, default="left"): 
        Direction for calculating quantiles. "left" focuses on lower quantiles, while "right" focuses on higher quantiles.
    - show (bool, default=False): 
        Whether to display the plots.
    - save_graphic (str or bool, default=False): 
        If specified, saves the plots in the given format (e.g., 'png', 'pdf').
    - separate_plots (bool, default=True): 
        Whether to create separate plots for each province or combine them in one figure.
    - tick_count (int, default=500): 
        Number of ticks (quantile steps) to use for calculating quantiles.

    Returns:
    - plateau_ticks (dict): 
        A dictionary where the keys are province names and the values are lists of plateau points for each province.
    '''
    provinces = df["Error Type"].unique()
    plateau_ticks = {}

    # Modify the logic here to skip plotting entirely if neither show nor save_graphic is not False (can also be string)
    if not show:
        if save_graphic is False:
            return plateau_ticks  # Skip the rest if both flags are False

    fig, axes = None, None

    if separate_plots:
        # Loop through each province and generate separate plots
        for province in provinces:
            error_values = df[df["Error Type"] == province]["Error Value"]
            plateau_points, quantile_ticks, quantiles, derivatives = calculate_plateau_points(
                error_values, quantile_value, fixed_threshold, direction, tick_count
            )
            plateau_ticks[province] = plateau_points

            colour_mapping_prov = dict(zip(["gradient", "normal"] + plateau_points, sns.color_palette("colorblind", len(plateau_points) + 2)))

            plot_error_distribution(quantile_ticks, quantiles, derivatives, plateau_points, province, colour_mapping_prov, show=show, save_graphic=save_graphic, separate_plots=separate_plots)

    else:
        # Generate a combined figure with subplots
        ncols = int(np.ceil(np.sqrt(len(provinces))))
        nrows = int(np.ceil(len(provinces) / ncols))
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
        axes = axes.flatten()

        for i, province in enumerate(provinces):
            error_values = df[df["Error Type"] == province]["Error Value"]
            plateau_points, quantile_ticks, quantiles, derivatives = calculate_plateau_points(
                error_values, quantile_value, fixed_threshold, direction, tick_count
            )
            plateau_ticks[province] = plateau_points

            colour_mapping_prov = dict(zip(["gradient", "normal"] + plateau_points, sns.color_palette("colorblind", len(plateau_points) + 2)))

            plot_error_distribution(quantile_ticks, quantiles, derivatives, plateau_points, province, colour_mapping_prov, show=False, save_graphic=False, ax=axes[i])

        plt.tight_layout()
        if save_graphic is not False:
            plt.savefig(f"combined_error_distribution_plots.{save_graphic}", format={save_graphic})
        if show:
            plt.show()

    return plateau_ticks

def plot_quantiles(df, quantile_ranges=[(0.0, 1.0)], ncols=2):
    '''
    Plots the quantile distribution of "Error Value" for different "Error Type" categories 
    over specified quantile ranges. Multiple subplots are created based on the number 
    of quantile ranges, with the option to display them in a grid.

    Parameters:
    - df (pd.DataFrame): 
        Input DataFrame that must contain at least two columns:
        - "Error Type": Categorical variable used to group the data.
        - "Error Value": Numeric values for which quantiles will be calculated.

    - quantile_ranges (list of tuples, default=[(0.0, 1.0)]): 
        A list of quantile ranges to be plotted. Each tuple specifies the lower and 
        upper bounds of the quantiles (e.g., (0.0, 1.0) represents the full range of quantiles).

    - ncols (int, default=2): 
        Number of columns in the subplot grid. The number of rows is automatically calculated 
        based on the number of quantile ranges and columns.

    Returns:
    - None: Call plt.show() after running this function to display the plot.
    '''
    error_types = df['Error Type'].unique()
    num_plots = len(quantile_ranges)

    colours = sns.color_palette("colorblind", len(error_types))
    colour_mapping = dict(zip(error_types, colours))
    
    nrows = (num_plots + ncols - 1) // ncols  # Ceiling division
    
    fig, axs = plt.subplots(nrows, ncols, figsize=(ncols * 7.5, nrows * 6))
    
    # If axs is a single Axes object, convert it to a list
    if not isinstance(axs, np.ndarray):
        axs = [axs]
    else:
        axs = axs.flatten()  # Flatten the 2D array of axes to 1D for easy indexing
    
    # Collect the handles and labels from the first subplot to use for a single legend
    handles, labels = None, None
    
    for idx, quantile_range in enumerate(quantile_ranges):
        if idx >= len(axs):
            break  # Prevent overflow if more ranges than subplots
        
        quantile_ticks = np.linspace(quantile_range[0], quantile_range[1], 20)
        ax = axs[idx]
        
        for error_type in error_types:
            error_values = df[df['Error Type'] == error_type]['Error Value']
            quantiles = error_values.quantile(quantile_ticks)
            colour = colour_mapping.get(error_type, 'black')
            line, = ax.plot(quantile_ticks, quantiles, linestyle='-', color=colour, label=error_type)
            
            # Collect the handles and labels only from the first plot for the global legend
            if idx == 0:
                handles, labels = ax.get_legend_handles_labels()
        
        ax.set_title(f'Quantile Plot for Range {quantile_range}')
        ax.grid(False)
    
    # Hide any unused subplots if the number of quantile_ranges < nrows*ncols
    for ax in axs[num_plots:]:
        ax.set_visible(False)
    
    # Add a single legend for the whole figure
    if handles and labels:
        fig.legend(handles, labels, loc='upper center', ncol=max(1, len(error_types)//2), title='Error Type')
    
    # Add global x and y axis labels
    fig.text(0.5, 0.04, 'Quantile', ha='center', va='center')
    fig.text(0.04, 0.5, 'Error Value', ha='center', va='center', rotation='vertical')
    
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.95])  # Adjust layout to make space for the legend and labels

def generate_stats(df):
    '''
    Computes and returns various summary statistics for the "Error Value" column in the input DataFrame, 
    grouped by the "Error Type". The statistics include central tendency, dispersion, skewness, kurtosis, 
    and other relevant metrics.

    Parameters:
    - df (pd.DataFrame): 
        Input DataFrame that must contain at least two columns:
        - "Error Type": Categorical variable used to group the data.
        - "Error Value": Numeric values on which statistics will be computed.

    Returns:
    - stats_df (pd.DataFrame): 
        A DataFrame where the index is the unique values of "Error Type", 
        and the columns contain the following statistics for each group:
        - "amount": Count of data points.
        - "mean": Mean of "Error Value".
        - "mad": Median absolute deviation.
        - "median": Median of "Error Value".
        - "std": Standard deviation.
        - "var": Variance.
        - "min": Minimum value.
        - "max": Maximum value.
        - "skew": Skewness (asymmetry of the distribution).
        - "kurt": Kurtosis (tailedness of the distribution).
        - "range": Range (difference between max and min).
        - "iqr": Interquartile range (difference between 75th and 25th percentiles).
        - "cv": Coefficient of variation (std/mean).
        - "10th_%ile": 10th percentile of "Error Value".
        - "90th_%ile": 90th percentile of "Error Value".
        - "jarque_bera": Jarque-Bera test statistic for normality.

    Notes:
    - The output is formatted to limit floating point numbers to 4 significant digits.
    '''
    error_types = df['Error Type'].unique()
    stats_df = pd.DataFrame(index=error_types)

    stats_df["amount"] = df.groupby('Error Type')['Error Value'].size()
    stats_df["mean"] = df.groupby('Error Type')['Error Value'].mean()
    stats_df["median"] = df.groupby('Error Type')['Error Value'].median()
    stats_df["mad"] = df.groupby('Error Type')['Error Value'].apply(median_abs_deviation)
    stats_df["std"] = df.groupby('Error Type')['Error Value'].std()
    stats_df["var"] = df.groupby('Error Type')['Error Value'].var()
    stats_df["min"] = df.groupby('Error Type')['Error Value'].min()
    stats_df["max"] = df.groupby('Error Type')['Error Value'].max()
    stats_df["skew"] = df.groupby('Error Type')['Error Value'].apply(skew)
    stats_df["kurt"] = df.groupby('Error Type')['Error Value'].apply(kurtosis)
    stats_df["range"] = stats_df["max"] - stats_df["min"]
    stats_df["iqr"] = (df.groupby('Error Type')['Error Value'].quantile(0.75) - df.groupby('Error Type')['Error Value'].quantile(0.25))
    stats_df["cv"] = stats_df["std"] / stats_df["mean"]
    stats_df["10th_%ile"] = df.groupby('Error Type')['Error Value'].quantile(0.1)
    stats_df["90th_%ile"] = df.groupby('Error Type')['Error Value'].quantile(0.9)
    stats_df["jarque_bera"] = [ jarque_bera(df[df['Error Type'] == error_type]['Error Value'].values)[0] for error_type in error_types ]

    pd.set_option('display.float_format', lambda x: f"{x:.4g}") # scientific notation
    
    return stats_df

def check_for_problems(my_df):
    '''
    Checks the input DataFrame for rows containing problematic values, such as NaN (missing) values 
    or infinite values (+inf, -inf). If any such rows are found, they are printed; otherwise, 
    a message indicating no problems is displayed.

    Parameters:
    - my_df (pd.DataFrame): 
        Input DataFrame that will be checked for NaN or infinite values in any column.

    Returns:
    - None: Prints the rows that contain problematic values (NaN or inf) or confirms if no issues are found.
    '''
    # Check for NaN values
    nan_rows = my_df[my_df.isna().any(axis=1)]
    
    # Check for positive or negative infinite values
    inf_rows = my_df[(my_df == np.inf).any(axis=1) | (my_df == -np.inf).any(axis=1)]
    
    # Combine both checks to find rows with either NaN or inf values
    problematic_rows = my_df[my_df.isna().any(axis=1) | (my_df == np.inf).any(axis=1) | (my_df == -np.inf).any(axis=1)]
    
    # Check if any rows have issues and print them
    if not problematic_rows.empty:
        print("Rows containing NaN or inf values:")
        print(problematic_rows)
    else:
        print("No rows contain NaN or inf values.")

from fastdtw import fastdtw

def my_distance_function(x,y):
    distance, _ = fastdtw(list(x), list(y))  # fastdtw returns both distance and path
    return distance

def get_province_data():
    '''
    Reads normalized data files for various provinces from a specified directory, 
    processes the contents, and returns a DataFrame where each column corresponds 
    to a province and the rows contain the associated numeric data.

    The function looks for files that match the pattern `*_normalised.txt`, where 
    the prefix before `_normalised.txt` is assumed to be the province name. Each file 
    contains a list of comma-separated values representing some metric for that province.

    Parameters:
    - None: 
        The function assumes the input data files are located in a specific directory 
        relative to the current working directory and match the file naming convention 
        of `*_normalised.txt`.

    Returns:
    - province_data (pd.DataFrame): 
        A DataFrame where each column corresponds to a province (derived from the file names) 
        and the rows contain numeric data loaded from the corresponding files. 
        The column names are derived from the file name prefixes.

    Notes:
    - The function handles any errors that occur during file reading by printing an error message 
      but continues processing other files.
    '''
    input_directory = os.path.abspath(os.path.join(os.getcwd(), "..", "..", "data"))

    # Define the file pattern
    file_pattern = os.path.join(input_directory, '*_normalised.txt')

    # Find matching files
    matching_files = glob.glob(file_pattern)

    # Initialize an empty dictionary to store the data
    province_data = {}

    def read_txt_as_list(filename):
        with open(filename, 'r') as file:
            data = file.read().split(',')
            # Convert each string in the list to a float
            float_data = [float(i) for i in data]
        return float_data

    # Iterate over the list of matching files and read each into the dictionary
    for file_path in matching_files:
        file_name = os.path.basename(file_path)
        try:
            # Read the file data into a list of floats
            file_data = read_txt_as_list(file_path)
            # Store the data in the dictionary
            province_data[file_name.split("_")[0]] = file_data
        except Exception as e:
            print(f"Error reading {file_name}: {e}")

    province_data = pd.DataFrame(province_data)

    return province_data

def make_df_from_filenames(filtered_files, sep, reporters):
    '''
    Reads a list of files (either CSV or Excel), combines the data into a single DataFrame, 
    and formats specified columns by converting them to appropriate types and lists. 
    Handles file reading errors gracefully by skipping problematic files.

    Parameters:
    - filtered_files (list of str): 
        A list of file paths to be read. Files can be either in CSV or Excel format.
    
    - sep (str): 
        The delimiter to be used when reading CSV files (ignored for Excel files).
    
    - reporters (list of str): 
        A list of column names in the DataFrame that should be converted into lists using a safe evaluation method.

    Returns:
    - combined_df (pd.DataFrame): 
        A single DataFrame containing the combined data from all successfully read files, with the specified columns 
        converted to lists and other columns reformatted (e.g., type changes for "action", "second_index", 
        "predicted_fitness_change").
    
    Notes:
    - If reading a file fails (either due to format or parsing issues), the file is skipped and an error message is printed.
    '''
    # Step 2: Read the files into pandas data frames
    dataframes = []
    for iter_file in filtered_files:
        try: iter_df = pd.read_csv(iter_file, sep=sep)
        except:
            try: iter_df = pd.read_excel(iter_file)
            except Exception as e:
                print(f"Failed to read {iter_file}: {e}")
                continue  # Skip to the next file if reading fails
        dataframes.append(iter_df)

    combined_df = reduce(lambda df1, df2: df1.combine_first(df2), dataframes) if dataframes else pd.DataFrame()

    try:
        combined_df = combined_df.drop_duplicates()
    except TypeError as e:
        print(f'Operation failed with TypeError "{e}"')

    print("Concatenated all data frames. Formatting...")

    # Step 4: change column types
    # # column action as string but leave None untouched
    combined_df["action"] = combined_df["action"].apply(lambda x: str(x) if pd.notna(x) else None)
    # # column second_index as int else None
    combined_df["second_index"] = combined_df["second_index"].apply(lambda x: int(x) if pd.notna(x) else None)
    # # column predicted_fitness_change as float but leave None untouched
    combined_df["predicted_fitness_change"] = combined_df["predicted_fitness_change"].apply(lambda x: float(x) if pd.notna(x) else None)

    def convert_to_list(string):
        try:
            # Attempt to safely parse the string into a Python object (list)
            parsed_list = ast.literal_eval(string)
        except (ValueError, SyntaxError):
            try:
                # Fallback to eval() if literal_eval() fails
                parsed_list = eval(string)
            except Exception:
                # If both fail, raise an error with the column name
                raise ValueError(f"Failed to parse the content: {string}")

        # Ensure output is always a list, even if a single float or other object
        if not isinstance(parsed_list, list):
            parsed_list = [parsed_list]
        
        # Return the parsed list with numeric conversions
        return [
            float(item) if isinstance(item, (int, float, np.float64)) or 
                        (isinstance(item, str) and item.replace('.', '', 1).isdigit()) 
            else item 
            for item in parsed_list
        ]

    for column in reporters + ["actual_fitness_change", "fitness_before", "fitness_after", "actions_after", "actions_before"]:
        combined_df[column] = combined_df[column].apply(lambda x: convert_to_list(x))

    return combined_df