import argparse
from collections import defaultdict
from copy import copy, deepcopy
import csv
from dataclasses import dataclass, field

import itertools
import logging
from pathlib import Path
from types import ModuleType
from typing import Callable, Optional, TypeAlias
import sys

logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import Divider
import mpl_toolkits.axes_grid1.axes_size as Size
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import numpy.typing as npt

from scipy.optimize import curve_fit

from complexity.constants import *
from complexity.supporting_functions import *
from complexity.computational_complexity import EXP_RESULT_IND, EXP_RESULT_TAGS, ExperimentSet, ResultsFile

from complexity.plot_config import *
from complexity.plot_config.social_force_fit import *

Opt_Axes: TypeAlias = Optional[Axes]
Opt_Axes_List: TypeAlias = Optional[list[Axes]]

#=======================================================
#
#=======================================================

logger = logging.getLogger(__name__)

@dataclass
class DividerConfig():
    left: float = 0.5
    right: float = 0.05
    top: float = 0.30
    bottom: float = 0.70
    gap_hor: float = 0.5
    gap_ver: float = 0.5

    left_label: Optional[float] = None
    v_legend: Optional[float] = None

    legend_width: Optional[float] = None
    legend_gap: Optional[float] = None

    share_y: bool = False
    share_x: bool = False

    color_bar_width: float = 0.2
    color_label_width: float = 0.2
    color_bar_gap: float = 0.05

    share_y_list: list = field(default_factory=list)
    share_x_list: list = field(default_factory=list)

    hor_axes_locs: list = field(default_factory=list)
    ver_axes_locs: list = field(default_factory=list)

    def __post_init__(self) -> None:
        if len(self.share_y_list) > 0:
            self.share_y = True
        if len(self.share_x_list) > 0:
            self.share_x = True

    @property
    def has_left_label(self) -> bool:
        return self.left_label is not None

    @property
    def has_v_legend(self) -> bool:
        return self.v_legend is not None

    @property
    def uses_custom_locs(self) -> bool:
        return len(self.hor_axes_locs) > 0


@dataclass(kw_only=True)
class PlotConfig():
    data_type: str
    plot_type: str
    plot_fcn: Callable

    time_steps_per_sec_per_model: dict
    model_tag: str | Sequence[str]
    x_data_var: str

    ax_config: dict
    artist_config: dict
    colors: dict
    styles: dict
    line_widths: dict

    ax_loc: Optional[tuple[int, int]] = None
    legend_loc: Optional[str] = None
    color_bar_loc: Optional[ColorBarLoc] = None

    model_tags: Sequence[str] = field(init=False)

    def __contains__(self, field: str) -> bool:
        return hasattr(self, field)

    def __post_init__(self) -> None:
        if isinstance(self.model_tag, str):
            self.model_tags = [self.model_tag]
        else:
            self.model_tags = self.model_tag

    @property
    def has_defined_ax_loc(self) -> bool:
        return self.ax_loc is not None

    @property
    def has_defined_legend_loc(self) -> bool:
        return self.legend_loc is not None

    @property
    def has_defined_loc(self) -> bool:
        return self.ax_loc is not None

    def get_var_as_array(self, var_name: str) -> npt.NDArray:
        return np.array(getattr(self, var_name))


@dataclass(kw_only=True)
class TheoreticalConfig(PlotConfig):
    lines: list[dict]
    eq_module_name_per_model: dict[str, str]
    eq_type: str | Sequence[str]
    special_vars: dict

    eq_module: dict[str, ModuleType] = field(init=False)

    def load_module(self, eq_file_dir: Path) -> None:
        self.eq_module = {}

        for model_tag in self.model_tags:
            self.eq_module[model_tag] = load_module(
                eq_file_dir.joinpath(
                    f'{self.eq_module_name_per_model[model_tag]}.py'))


@dataclass(kw_only=True)
class EmpiricalConfig(PlotConfig):
    lines: Optional[list[dict]] = None
    run_time_unit: RunTimeUnit

    fit_fcn : None
   
#=======================================================
#
#=======================================================


class CcPlotter():

    def __init__(self,
                 plot_config_filename: Path,
                 create_fit_csv:bool,
                 create_line_csv:bool,
                 output_dir: Optional[Path] = None,
                 run_id: Optional[str] = None,
                 eq_file_dir: Optional[Path] = None) -> None:
        self.plot_config_filename = plot_config_filename
        self.create_fit_csv = create_fit_csv
        self.create_line_csv = create_line_csv
        self.output_dir = output_dir
        self.run_id = run_id
        self.eq_file_dir = eq_file_dir

        self.csv_fit_lines = defaultdict(list)
        self.csv_line_lines = defaultdict(list)
       
    def create_csv_files(self) -> None:
        self.create_fit_csv_file()
        self.create_line_csv_file()
        
    def create_fit_csv_file(self) -> None:
        if not (self.create_fit_csv and len(self.csv_fit_lines) > 0):
            return

        csv_filename = self.output_dir.joinpath(f'{self.plot_config_filename.stem}_fit.csv') # type: ignore
        with open(csv_filename, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile, delimiter=',', quoting=csv.QUOTE_NONNUMERIC)
            for header_line, content_lines in self.csv_fit_lines.items():
                writer.writerow(header_line)
                writer.writerows(content_lines)
                writer.writerow([])

    def create_line_csv_file(self) -> None:
        if not (self.create_line_csv and len(self.csv_line_lines) > 0):
            return

        base_dir = self.output_dir
        if base_dir is None:
            base_dir = self.eq_file_dir
        csv_filename = base_dir.joinpath(f'{self.plot_config_filename.stem}_lines.csv') # type: ignore
        with open(csv_filename, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile, delimiter=',')
            for header_line, content_lines in self.csv_line_lines.items():
                writer.writerow(header_line)
                writer.writerows(content_lines)
                writer.writerow([])

    @staticmethod
    def visualize(plot_config_filenames: list[Path],
                  output_dir: Optional[Path],
                  run_id: Optional[str],
                  eq_file_dir: Optional[Path],
                  create_fit_csv:bool,
                  create_line_csv:bool) -> None:
        for plot_config_filename in plot_config_filenames:
            plotter = CcPlotter(plot_config_filename,
                                create_fit_csv,
                                create_line_csv,
                                output_dir,
                                run_id,
                                eq_file_dir)
            plotter.load_data()
            plotter.load_plot_configs()
            plotter.plot()
            plotter.create_csv_files()        

        plt.show()

    def load_data(self) -> None:
        if self.output_dir is None or self.run_id is None:
            return

        file_data = ResultsFile(self.output_dir,
                                self.run_id).read_data_from_file()
        self.run_times, self.id_matrix, self.field_values, self.index_table, experiment_sets = file_data
        self.experiment_sets = {}
        for exp_set in experiment_sets:
            if exp_set.config_name not in self.experiment_sets:
                self.experiment_sets[exp_set.config_name] = []

            self.experiment_sets[exp_set.config_name].append(exp_set)     

    def load_plot_configs(self) -> None:
        module = load_module(self.plot_config_filename)
        self.mpl_config, divider_config, plot_configs_dict = module.get_config()

        self.divider_config = DividerConfig(**divider_config)

        self.plot_configs = []
        for plot_config in plot_configs_dict:
            if plot_config[DATA_TYPE] == EMPIRICAL_DATA:
                config_class = EmpiricalConfig
            elif plot_config[DATA_TYPE] == THEORETICAL_DATA:
                config_class = TheoreticalConfig
            else:
                raise ValueError(
                    f'The data type "{plot_config[DATA_TYPE]}" is not known!')



            if plot_config[PLOT_TYPE] == SINGLE_LINE:
                plot_config[PLOT_FCN] = single_line_plot
            elif plot_config[PLOT_TYPE] == MULTI_LINE:
                plot_config[PLOT_FCN] = multi_line_plot
            elif plot_config[PLOT_TYPE] == THREE_D:
                plot_config[PLOT_FCN] = three_d_plot
            else:
                raise ValueError(
                    f'The plot type "{plot_config[PLOT_TYPE]}" is not known!')

            config_inst = config_class(**plot_config)
            if isinstance(config_inst, TheoreticalConfig):
                if self.eq_file_dir is None:
                    raise ValueError(
                        'The equation file directory has not been provided!')
                config_inst.load_module(self.eq_file_dir)

            self.plot_configs.append(config_inst)

    def plot(self) -> None:
        fig = create_fig(self.mpl_config[FIGURE])
        axes, aux_outputs = create_axes(self.plot_configs, self.divider_config, fig)
        cb_axes = aux_outputs[2]
        legend_axes = aux_outputs[-1]
        x_data = []
        y_data = []
        z_data = []
        line_info = []
        for plot_config in self.plot_configs:
            x_data_loc, y_data_loc, z_data_loc, line_info_loc = self.get_plot_data(plot_config)
            x_data.append(x_data_loc)
            y_data.append(y_data_loc)
            z_data.append(z_data_loc)
            line_info.append(line_info_loc)

        if SINGLE_COLOR_BAR in self.mpl_config and self.mpl_config[
                SINGLE_COLOR_BAR]:
            self.set_v_data(z_data)
        if REL_DIFF_FIRST_TEXT in self.mpl_config and self.mpl_config[
                REL_DIFF_FIRST_TEXT]:
            line_info = self.get_rel_diff_text(z_data, line_info)

        for ax_nr, plot_config in enumerate(self.plot_configs):
            ax = axes[ax_nr]
            legend_ax = None
            if ax in legend_axes:
                legend_ax = legend_axes[ax]
            plot_config.plot_fcn(ax,
                                 x_data[ax_nr],
                                 y_data[ax_nr],
                                 z_data[ax_nr],
                                 plot_config,
                                 self.mpl_config,
                                 cb_axes[ax_nr],
                                 legend_ax,
                                 line_info[ax_nr])

            set_ax_config(ax, plot_config.ax_config)

    def get_rel_diff_text(self,
                          z_data: list[npt.NDArray[np.float64]],
                          line_info: list[dict[str, Any]]) -> list[dict[str, Any]]:
        base_number = z_data[0][0,0]
        for ii, plot_config in enumerate(self.plot_configs):
            if plot_config.plot_type != THREE_D:
                continue

            rel_diff = z_data[ii]/base_number
            texts = []
            for x in range(z_data[ii].shape[0]):
                texts.append([])
                for y in range(z_data[ii].shape[1]):
                    texts[-1].append(f'{rel_diff[x,y]:.1f}')

            if not isinstance(line_info[ii], dict):
                line_info[ii] = {TEXT: texts}
            else:
                line_info[ii][TEXT] = texts

        return line_info

    def set_v_data(self, z_data: list[npt.NDArray]) -> None:
        z_max = -np.inf
        z_min = np.inf 
        for z_data_loc in z_data:
            if z_data_loc is None:
                continue
            z_max = max(z_max, np.max((np.max(z_data_loc))))
            z_min = min(z_min, np.min((np.min(z_data_loc))))

        for plot_config in self.plot_configs:
            if plot_config.plot_type == THREE_D:
                plot_config.artist_config[P_COLOR_MESH]['vmax'] = z_max
                plot_config.artist_config[P_COLOR_MESH]['vmin'] = z_min

    def get_plot_data(
        self, plot_config: PlotConfig
    ) -> tuple[npt.NDArray | list[npt.NDArray],
               npt.NDArray | list[npt.NDArray],
               Optional[npt.NDArray],
               list]:
        if isinstance(plot_config, EmpiricalConfig):
            return self._get_emprical_plot_data(plot_config)
        elif isinstance(plot_config, TheoreticalConfig):
            return self._get_theoretical_plot_data(plot_config)
        else:
            raise NotImplementedError

    def _get_emprical_plot_data(
        self, plot_config: EmpiricalConfig
    ) -> tuple[npt.NDArray | list[npt.NDArray],
               npt.NDArray | list[npt.NDArray],
               Optional[npt.NDArray],
               list]:
               
        lines_info = []
        z_data = None

        if plot_config.plot_type == SINGLE_LINE:
            raise NotImplementedError
        elif plot_config.plot_type == MULTI_LINE:
            x_data, y_data, lines_info = self._get_lines_run_times(plot_config)
        elif plot_config.plot_type == THREE_D:
            raise NotImplementedError
        else:
            raise ValueError(
                f'The plot type "{plot_config.plot_type}" is not known!')

        return x_data, y_data, z_data, lines_info

    def _get_lines_run_times(self, plot_config: EmpiricalConfig) -> tuple[list[npt.NDArray], list[npt.NDArray], list]:
        y_data = []
        x_data = []
        lines_info = []
        line_info_base = {LEGEND_LABEL: None, LINE_CONFIG: {}}
        if plot_config.lines is None:
            raise TypeError

        for line in plot_config.lines:
            exp_set = self.get_experiment_set(line[CONFIG_NAME], line[KWARGS])
            line_ind = self.get_ind(exp_set, self.id_matrix)
            run_times_line = get_value_per_sec(self.run_times[line_ind]*plot_config.run_time_unit.conversion_factor, plot_config, exp_set.model_tag)
            line_x_data = self.id_matrix[line_ind, EXP_RESULT_IND[line[X_DATA_VAR]]]
            if line[X_DATA_VAR] == INFRA_ARG:
                line_x_data = np.array([x[0]*x[1] for x in exp_set.infrastructure_arg])
            
            if X_CONV_FCN in line and line[X_CONV_FCN] == DENSITY:
                area = exp_set.infrastructure_arg[0]*exp_set.infrastructure_arg[1]
                line_x_data = line_x_data/area            

            sort_ind = np.argsort(line_x_data)
            y_data.append(run_times_line[sort_ind])
            x_data.append(line_x_data[sort_ind])
                
            lines_info.append(deepcopy(line_info_base))
            lines_info[-1][LEGEND_LABEL] = line[LEGEND_LABEL]
            lines_info[-1][LINE_CONFIG] = line[LINE_CONFIG]

            if self.create_fit_csv and FIT_MODULE in line:
                header_line, content_line = fit_data(x_data[-1], y_data[-1], exp_set, line[KWARGS], line[FIT_MODULE])
                self.csv_fit_lines[header_line].append(content_line)
            
            if self.create_line_csv:
                header_line = tuple([line[X_DATA_VAR]] + [str(val)for val in x_data[-1]])
                kwargs_label = " - ".join([f'{key}={value}' for key, value in line[KWARGS].items()])
                line_label = f'{line[CONFIG_NAME]} - {kwargs_label}'
                content_line = [line_label] +  [f"{val:.3E}" for val in y_data[-1]]
                self.csv_line_lines[header_line].append(content_line)

        return x_data, y_data, lines_info

    def get_experiment_set(self, config_name:str, kwargs:dict) -> ExperimentSet:
        for exp_set in self.experiment_sets[config_name]:
            if exp_set.does_fit(kwargs):
                return exp_set

        raise Exception(f'{config_name} - {kwargs}')            

    def get_ind(self, exp_set: ExperimentSet, id_matrix: npt.NDArray) -> npt.NDArray:
        combs = exp_set.to_ind_dict()
        ind = np.zeros(id_matrix.shape[0], dtype=bool)
        for ii, base_var_value in enumerate(exp_set.base_var_values):
            if exp_set.base_var in self.index_table:
                if exp_set.base_var == INFRA_ARG:
                    base_var_value = tuple(base_var_value)
                base_var_value = self.index_table[exp_set.base_var][base_var_value]
            # values are sorted
            if len(exp_set.comb_vars) > 0:
                local_ind = np.ones(id_matrix.shape[0], dtype=bool)
                for comb_var in exp_set.comb_vars:
                    value = exp_set.get(comb_var)[ii]
                    local_ind = local_ind & lookup_action(id_matrix, comb_var, value) & lookup_action(id_matrix, exp_set.base_var, base_var_value)
            else:
                local_ind = lookup_action(id_matrix, exp_set.base_var, base_var_value)

            ind = ind | local_ind

        for key, values in combs.items():
            if not isinstance(values, list):
                values = [values]
            local_ind = np.zeros(id_matrix.shape[0], dtype=bool)
            for value in values:
                if key in self.index_table:
                    value = self.index_table[key][value]
                if isinstance(value, DistrType):
                    value = value.value
                
                local_ind = local_ind | lookup_action(id_matrix, key, value)
            ind = ind & local_ind

        return ind

    def _get_theoretical_plot_data(
        self, plot_config: TheoreticalConfig
    ) -> tuple[npt.NDArray | list[npt.NDArray],
               npt.NDArray | list[npt.NDArray],
               Optional[npt.NDArray],
               list]:

        lines_info = []
        z_data = None

        if plot_config.plot_type == SINGLE_LINE:
            if not isinstance(plot_config.eq_type, str):
                raise TypeError(
                    'When plotting a single line only a single equation type can be provided'
                )
            x_data, y_data = self._get_operations_single(plot_config, plot_config.eq_type)
        elif plot_config.plot_type == MULTI_LINE:
            if plot_config.lines is not None:
                x_data, y_data, lines_info = self._get_lines_operations(plot_config)
            else:
                if isinstance(plot_config.eq_type, str):
                    eq_types = [plot_config.eq_type]
                else:
                    eq_types = plot_config.eq_type
                x_data, y_data, lines_info = self._get_operations_multi(plot_config, eq_types)
           
        elif plot_config.plot_type == THREE_D:
            raise NotImplementedError

        return x_data, y_data, z_data, lines_info  # type: ignore

    def _get_operations_single(self,
                               plot_config: TheoreticalConfig,
                               eq_type: str) -> tuple[npt.NDArray, npt.NDArray]:
        x_data, args_list, _ = _get_ops_args(plot_config)  # type: ignore
        args = tuple(args_list)
        if not isinstance(plot_config.model_tag, str):
            raise TypeError

        y_data = _get_operations(plot_config,
                                 eq_type,
                                 plot_config.model_tag,
                                 args)
        return x_data, y_data

    def _get_lines_operations(self, plot_config: TheoreticalConfig) -> tuple[list[npt.NDArray], list[npt.NDArray], list]:
        y_data = []
        x_data = []
        lines_info = []
        line_info_base = {LEGEND_LABEL: None, LINE_CONFIG: {}}
        if plot_config.lines is None:
            raise TypeError


        for line in plot_config.lines:
            model_tag = line[MODEL_TAG]
            eq_type = line[EQ_TYPE]
            args = line[ARGS]

            y_data_loc = _get_operations(plot_config,
                                             model_tag,
                                             eq_type,
                                             args)  # type: ignore

            x_data.append(line[X_DATA_VAR])
            y_data.append(y_data_loc)

            lines_info.append(deepcopy(line_info_base))
            lines_info[-1][LEGEND_LABEL] = line[LEGEND_LABEL]
            lines_info[-1][LINE_CONFIG] = line[LINE_CONFIG]
            if self.create_line_csv:
                header_line = tuple([X_DATA_VAR] + [str(val) for val in x_data[-1]])
                content_line = [line[LEGEND_LABEL]] + [f'{val:.3E}' for val in y_data[-1]]

                self.csv_line_lines[header_line].append(content_line)


        return x_data, y_data, lines_info

    def _get_operations_multi(
        self, plot_config: TheoreticalConfig, eq_types: Sequence[str]
    ) -> tuple[list[npt.NDArray], list[npt.NDArray], list]:
        y_data = []
        x_data = []
        lines_info = []
        line_info_base = {LEGEND_LABEL: None, LINE_CONFIG: {}}

        for eq_type in eq_types:
            x_data_loc, args_list, x_ind, dim, y_values_info = _get_ops_args(plot_config)

            if dim > 1:
                for model_tag in plot_config.model_tags:
                    y_values = tuple([
                        y_value_info['values'] for y_value_info in y_values_info
                    ])
                    values = np.array(list(itertools.product(*y_values)))
                    for ii, y_value_info in enumerate(y_values_info):
                        x, y = np.meshgrid(x_data_loc, values[:, ii])
                        args_list[y_value_info['ind']] = y

                    base_label = f'{model_tag} - {eq_type}'
                    for row_ind in range(values.shape[0]):
                        label = copy(base_label)
                        value_set = values[row_ind, :]
                        line_data = {EQ_TYPE: eq_type}
                        for col_ind in range(values.shape[1]):
                            y_value_info = y_values_info[col_ind]
                            line_data[
                                y_value_info['var_name']] = value_set[col_ind]
                            label = f'{label} - {y_value_info["var_name"]} = {value_set[col_ind]}'
                        line_info = deepcopy(line_info_base)
                        line_info[LEGEND_LABEL] = label
                        line_info[LINE_CONFIG] = _get_line_config(
                            plot_config, line_data)
                        lines_info.append(line_info)


                    args_list[x_ind] = np.reshape(
                        np.tile(x_data_loc, values.shape[0]),
                        (values.shape[0], x_data_loc.size))
                    args = tuple(args_list)
                    y_data_loc = _get_operations(plot_config,
                                                 model_tag,
                                                 eq_type,
                                                 args)
                    for row_nr in range(y_data_loc.shape[0]):
                        y_data_row = y_data_loc[row_nr, :]
                        y_data.append(y_data_row)
                        x_data.append(x_data_loc)
            else:
                args = tuple(args_list)
                y_data_loc = _get_operations(plot_config,
                                             plot_config.model_tag,
                                             eq_type,
                                             args)  # type: ignore
                y_data.append(y_data_loc)
                x_data.append(x_data)
                line_info = {
                    LEGEND_LABEL: f'{eq_type}',
                }
                lines_info.append(line_info)

        return x_data, y_data, lines_info


def lookup_action(id_matrix:npt.NDArray, key:str, value:int|float) -> npt.NDArray:
    if np.isnan(value):
        return np.isnan(id_matrix[:, EXP_RESULT_IND[key]])
    else:
        return (id_matrix[:, EXP_RESULT_IND[key]] == value)

def _get_operations(plot_config: TheoreticalConfig,
                    model_tag: str,
                    eq_type: str,
                    args: tuple) -> npt.NDArray:
    if eq_type == FLOPS:
        y_data = plot_config.eq_module[model_tag].get_flop_operations(*args)
    elif eq_type == MEMOPS:
        y_data = plot_config.eq_module[model_tag].get_memop_operations(*args)
    elif eq_type == OPS:
        y_data = plot_config.eq_module[model_tag].get_operations(*args)
    else:
        raise ValueError(
            f'The operations type "{plot_config.eq_type}" is not known!')

    return get_value_per_sec(y_data,plot_config,model_tag)


def _get_ops_args(
        plot_config: TheoreticalConfig
) -> tuple[npt.NDArray, list, int, int, list]:
    args_list = []
    y_values_info = []
    dim = 1
    for model_tag in plot_config.model_tags:
        for var_name in plot_config.eq_module[model_tag].input_order:
            var_name = plot_config.eq_module[model_tag].name_2_variables[
                var_name]
            if var_name in plot_config:
                args_list.append(np.array(getattr(plot_config, var_name)))
                if var_name == plot_config.x_data_var:
                    x_data = args_list[-1]
                    x_ind = len(args_list) - 1
            elif var_name in plot_config.special_vars:
                args_list.append(np.array(plot_config.special_vars[var_name]))
            else:
                raise KeyError(f'The variable "{var_name}" cannot be found!')
            if var_name != plot_config.x_data_var and args_list[-1].size > 1:
                y_values_info.append({
                    'model_tag': model_tag,
                    'var_name': var_name,
                    'ind': len(args_list) - 1,
                    'values': args_list[-1]
                })
                dim += 1

    return x_data, args_list, x_ind, dim, y_values_info  # type: ignore


#=======================================================
# OTHER FUNCTIONS
#=======================================================


def get_value_per_sec(values: npt.NDArray,
                      plot_config: PlotConfig,
                      model_tag: str) -> npt.NDArray:
    factor = 1/plot_config.time_steps_per_sec_per_model[model_tag]
    return values*factor


def _get_line_config(plot_config: PlotConfig, line_data: dict) -> dict:
    line_config = {}
    for key, val in line_data.items():
        if key in plot_config.colors and val in plot_config.colors[key]:
            line_config[COLOR] = plot_config.colors[key][val]
        if key in plot_config.styles and val in plot_config.styles[key]:
            line_config[STYLE] = plot_config.styles[key][val]
        if key in plot_config.line_widths and val in plot_config.line_widths[
                key]:
            line_config[LINE_WIDTH] = plot_config.line_widths[key][val]

    return line_config


def fit_data(x_data:npt.NDArray, y_data:npt.NDArray, exp_set:ExperimentSet, line_kwargs:dict, fit_module_name:str) -> tuple[tuple, list[Any]]:
    fit_module = get_line_fit_module(fit_module_name)
    header_line = ['line_info', 'r_squared'] + fit_module.fields(exp_set.base_var)
    try:
        fit_fcn, arg_names = fit_module.fit_data_scipy_args(x_data, y_data, exp_set.base_var)
    except NotImplementedError:
        return tuple(header_line), [f'Not implemented for base var {exp_set.base_var}']

    if len(arg_names) > 0:        
        x_data_fit = np.expand_dims(x_data, 1)
        x_data_fit = np.append(x_data_fit, -1*np.ones_like(x_data_fit), 1)
        for ii, arg_name in enumerate(arg_names):
            if arg_name in line_kwargs:
                value = line_kwargs[arg_name]
            else:
                value = exp_set.get(arg_name)
            if arg_name == INFRA_ARG:
                value = value[0]*value[1]
            x_data_fit[ii, 1] = value
    else:
        x_data_fit = x_data

    content_line = [exp_set.get_line_string(line_kwargs)] + get_line_fit(x_data_fit, y_data, fit_fcn)
    return tuple(header_line), content_line

def get_line_fit_np(x_data: npt.NDArray, y_data: npt.NDArray) -> str:
    from numpy.polynomial import polynomial as P
    c, _ = P.polyfit(x_data, y_data, 2, full=True)
    return f'{c[-1]}n^2 + {c[-2]}n + {c[-3]}'


def get_line_fit(x_data: npt.NDArray,
                 y_data: npt.NDArray,
                 fit_fcn) -> list:

    try:
        param_values, _, = curve_fit(fit_fcn, x_data, y_data)
        r_sqrd = compute_r_squared(x_data, y_data, param_values, fit_fcn)
        
        return [r_sqrd] + list(param_values)
    except TypeError:
        return ['Too few data points for fit']


def compute_r_squared(x_data: npt.NDArray,
                      y_data: npt.NDArray,
                      param_values: npt.NDArray,
                      fcn: Callable) -> float:

    residuals = y_data - fcn(x_data, *param_values)
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((y_data - np.mean(y_data))**2)
    return 1 - (ss_res/ss_tot)


def get_line_fit_module(module_name: str) -> Any:
    if module_name is None:
        return None

    return sys.modules[f'complexity.plot_config.{module_name}']


#=======================================================
# PLOT FUNCTIONS
#=======================================================


def multi_line_plot(ax: Axes,
                    x_data: npt.NDArray,
                    y_data: list[npt.NDArray],
                    _,
                    plot_config: PlotConfig,
                    mpl_config: dict,
                    cb_ax: Opt_Axes,
                    legend_ax: Opt_Axes,
                    lines_info: list) -> None:
    line_handles = []
    line_labels = []
    for ii in range(len(x_data)):
        line_info = lines_info[ii]
        line_config = deepcopy(plot_config.artist_config[LINE])
        line_config.update(line_info[LINE_CONFIG])
        line_handle = ax.plot(x_data[ii], y_data[ii], **line_config)
        line_handles += line_handle
        line_labels.append(f'{line_info[LEGEND_LABEL]}')

    if legend_ax is not None:
        legend_ax.legend(line_handles, line_labels, **mpl_config[LEGEND])
        legend_ax.set_axis_off()
    else:
        ax.legend(line_handles, line_labels, **mpl_config[LEGEND])

def single_line_plot(ax: Axes,
                     x_data: npt.NDArray,
                     y_data: npt.NDArray,
                     _,
                     plot_config: PlotConfig,
                     mpl_config: dict,
                     cb_ax: Opt_Axes,
                     legend_ax: Opt_Axes,
                     legend_info: list) -> None:
    line_handle = ax.plot(x_data, y_data, **plot_config.artist_config[LINE])

    if isinstance(plot_config, EmpiricalConfig):
        if plot_config.fit_fcn is not None:
            base_var_values = plot_config.base_vars
            fit_str = get_line_fit(x_data, y_data, plot_config, base_var_values)
            ax.legend(line_handle, [fit_str], **mpl_config[LEGEND])
        set_x_ticks(plot_config, x_data)


def three_d_plot(ax: Axes,
                 x_data: npt.NDArray,
                 y_data: npt.NDArray,
                 z_data: npt.NDArray,
                 plot_config: PlotConfig,
                 mpl_config: dict,
                 cb_ax: Opt_Axes,
                 legend_ax: Opt_Axes,
                 legend_info: dict) -> None:

    im = ax.pcolormesh(x_data,
                       y_data,
                       z_data,
                       **plot_config.artist_config[P_COLOR_MESH])

    if plot_config.color_bar_loc is not None:
        cb_config = deepcopy(plot_config.artist_config[COLOR_BAR])
        cb_config['orientation'] = plot_config.color_bar_loc.orientation
        plt.colorbar(im, cax=cb_ax, **plot_config.artist_config[COLOR_BAR])

    if TEXT in legend_info:
        for x_ind in range(z_data.shape[0]):
            for y_ind in range(z_data.shape[1]):
                ax.text(x_data[x_ind, y_ind],
                        y_data[x_ind, y_ind],
                        legend_info[TEXT][x_ind][y_ind],
                        ha='center',
                        va='center',
                        **plot_config.artist_config[TEXT])

    set_x_ticks(plot_config, x_data[0, :])
    set_y_ticks(plot_config, y_data[:, 0])


def print_line_fit(x_data, y_data, plot_config, var_values):
    if plot_config.fit_fcn is None:
        return

    line_fit_str = get_line_fit(x_data, y_data, plot_config, var_values)
    print(' - '.join([f'{key} = {val}' for key, val in var_values.items()]))
    print(line_fit_str)
    print(get_line_fit_np(x_data, y_data))


def set_x_ticks(plot_config: PlotConfig, x_values: npt.NDArray) -> None:
    if X_TICKS in plot_config.ax_config:
        return

    plot_config.ax_config[X_TICKS] = list(x_values)


def set_y_ticks(plot_config: PlotConfig, y_values: npt.NDArray) -> None:
    if Y_TICKS in plot_config.ax_config:
        return

    plot_config.ax_config[Y_TICKS] = list(y_values)


#=======================================================
# FIGURE AND AX CREATION FUNCTIONS
#=======================================================


def create_fig(fig_config: dict) -> Figure:
    fig_title = None
    if TITLE in fig_config:
        fig_title = fig_config[TITLE]
        fig_config.pop(TITLE)
    fig = plt.figure(**fig_config)
    if fig_title is not None:
        fig.canvas.manager.set_window_title(fig_title)

    return fig


def create_axes(
    plot_configs: list[PlotConfig], divider_config: DividerConfig, fig: Figure
) -> tuple[list[Axes],
           tuple[Opt_Axes, Opt_Axes_List, list[Opt_Axes], list, list, dict[Opt_Axes,Opt_Axes]]]:
    hor_axes_locs, ver_axes_locs = get_ax_locs(plot_configs, divider_config)

    if divider_config.has_left_label:
        hor_division: list[Size._Base] = [
            Size.Fixed(divider_config.left_label),
            Size.Fixed(divider_config.left)
        ]
    else:
        hor_division: list[Size._Base] = [Size.Fixed(divider_config.left)]

    ver_division: list[Size._Base] = [Size.Fixed(divider_config.bottom)]

    col_with_cb = defaultdict(list)
    row_with_cb = defaultdict(list)
    col_with_legend = defaultdict(list)
    
    legend_ax_per_ax = {}

    for ax_nr, plot_config in enumerate(plot_configs):
        if plot_config.has_defined_legend_loc:
            hor_axes_nr, _ = get_ax_loc(plot_config, ax_nr, hor_axes_locs, ver_axes_locs)
            col_with_legend[hor_axes_nr].append(plot_config.legend_loc)
        
        if plot_config.color_bar_loc is None:
            continue

        if plot_config.color_bar_loc.is_hor:
            hor_axes_nr, _ = get_ax_loc(plot_config, ax_nr, hor_axes_locs, ver_axes_locs)
            col_with_cb[hor_axes_nr].append(plot_config.color_bar_loc)
        elif plot_config.color_bar_loc.is_ver:
            _, ver_axes_nr = get_ax_loc(plot_config, ax_nr, hor_axes_locs, ver_axes_locs)
            row_with_cb[ver_axes_nr].append(plot_config.color_bar_loc)

    hor_ind = len(hor_division)
    column_count = max(hor_axes_locs) + 1
    col_nr_2_glob_nr = {}
    col_nr_2_cb = {ColorBarLoc.LEFT: {}, ColorBarLoc.RIGHT: {}}
    for col_nr in range(column_count):
        if col_nr in col_with_cb and ColorBarLoc.LEFT in col_with_cb[col_nr]:
            hor_division += [
                Size.Fixed(divider_config.color_label_width),
                Size.Fixed(divider_config.color_bar_width),
                Size.Fixed(divider_config.color_bar_gap)
            ]
            hor_ind += 3
            col_nr_2_cb[ColorBarLoc.LEFT][col_nr] = hor_ind - 2

        col_nr_2_glob_nr[col_nr] = hor_ind

        hor_division.append(Size.Scaled(1))
        hor_ind += 1
        if col_nr in col_with_cb and ColorBarLoc.RIGHT in col_with_cb[col_nr]:
            hor_division += [
                Size.Fixed(divider_config.color_bar_gap),
                Size.Fixed(divider_config.color_bar_width),
                Size.Fixed(divider_config.color_label_width)
            ]
            hor_ind += 3
            col_nr_2_cb[ColorBarLoc.RIGHT][col_nr] = hor_ind - 2
        if col_nr in col_with_legend and RIGHT in col_with_legend[col_nr]: 
            hor_division += [
                Size.Fixed(divider_config.legend_gap),
                Size.Fixed(divider_config.legend_width),
            ]
            hor_ind += 2

        hor_division.append(Size.Fixed(divider_config.gap_hor))
        hor_ind += 1

    hor_division[-1] = Size.Fixed(divider_config.right)

    ver_ind = len(ver_division)
    row_count = max(ver_axes_locs) + 1
    row_nr_2_glob_nr = {}
    row_nr_2_cb = {ColorBarLoc.BOTTOM: {}, ColorBarLoc.TOP: {}}
    for row_nr in range(row_count):
        if row_nr in row_with_cb and ColorBarLoc.BOTTOM in row_with_cb[row_nr]:
            ver_division += [
                Size.Fixed(divider_config.color_label_width),
                Size.Fixed(divider_config.color_bar_width),
                Size.Fixed(divider_config.color_bar_gap)
            ]
            ver_ind += 3
            row_nr_2_cb[ColorBarLoc.BOTTOM][row_nr] = ver_ind - 2

        row_nr_2_glob_nr[row_nr] = ver_ind

        ver_division.append(Size.Scaled(1))
        ver_ind += 1
        if row_nr in row_with_cb and ColorBarLoc.TOP in row_with_cb[row_nr]:
            ver_division += [
                Size.Fixed(divider_config.color_bar_gap),
                Size.Fixed(divider_config.color_bar_width),
                Size.Fixed(divider_config.color_label_width)
            ]
            ver_ind += 3
            row_nr_2_cb[ColorBarLoc.TOP][row_nr] = ver_ind - 2
        ver_division.append(Size.Fixed(divider_config.gap_ver))
        ver_ind += 1

    ver_division[-1] = Size.Fixed(divider_config.top)

    if divider_config.has_v_legend:
        ver_division.append(Size.Fixed(divider_config.v_legend))

    divider = Divider(fig, (0, 0, 1, 1),
                      hor_division,
                      ver_division,
                      aspect=False)

    ax_y_coords = set()
    axes: list[Axes] = []
    color_bar_axes: list[Opt_Axes] = []
    for ax_nr, plot_config in enumerate(plot_configs):
        hor_axes_nr, ver_axes_nr = get_ax_loc(plot_config, ax_nr, hor_axes_locs, ver_axes_locs)

        nx = col_nr_2_glob_nr[hor_axes_nr]
        ny = row_nr_2_glob_nr[ver_axes_nr]

        ax = fig.add_axes((0, 0, 1, 1), f'main_ax_{ax_nr}')
        ax.set_axes_locator(divider.new_locator(nx=nx, ny=ny))
        axes.append(ax)
        ax_y_coords.add(ny)

        nx_leg = nx + 2
        if plot_config.color_bar_loc is not None:
            ax = fig.add_axes((0, 0, 1, 1), f'cb_ax_{ax_nr}')
            if plot_config.color_bar_loc.is_hor:
                ax.set_axes_locator(
                    divider.new_locator(
                        nx=col_nr_2_cb[plot_config.color_bar_loc][hor_axes_nr],
                        ny=ny))
                nx_leg = col_nr_2_cb[plot_config.color_bar_loc][hor_axes_nr] + 3 
            elif plot_config.color_bar_loc.is_ver:
                ax.set_axes_locator(
                    divider.new_locator(
                        nx=nx,
                        ny=row_nr_2_cb[plot_config.color_bar_loc][ver_axes_nr]))
            color_bar_axes.append(ax)
        else:
            color_bar_axes.append(None)

        if plot_config.has_defined_legend_loc and plot_config.legend_loc == RIGHT:
            ax_legend_ax = fig.add_axes((0, 0, 1, 1), f'legend_ax_{ax_nr}')
            ax_legend_ax.set_axes_locator(divider.new_locator(nx=nx_leg, ny=ny))
            legend_ax_per_ax[ax] = ax_legend_ax

    left_label_axes: Opt_Axes_List = None

    if divider_config.has_left_label:
        left_label_axes = []
        for y_coord in ax_y_coords:
            left_label_ax = fig.add_axes((0, 0, 1, 1), 'left_label_ax')
            left_label_ax.set_axes_locator(divider.new_locator(nx=0,
                                                               ny=y_coord))
            left_label_axes.append(left_label_ax)
            left_label_ax.set_axis_off()

    if divider_config.share_y:
        if len(divider_config.share_y_list) > 0:
            for share_set in divider_config.share_y_list:
                base_ax = axes[share_set[0]]
                for ax_ind in share_set[1:]:
                    axes[ax_ind].sharey(base_ax)
        else:
            for ii, ax in enumerate(axes):
                if ii == 0:
                    base_ax = ax
                    continue
                ax.sharey(base_ax)

    legend_ax = None
    if divider_config.has_v_legend:
        legend_ax = fig.add_axes((0, 0, 1, 1), 'legend_ax')
        nx1 = len(hor_division)
        ny = len(ver_division) - 1
        legend_ax.set_axes_locator(divider.new_locator(nx=0, nx1=nx1,
                                                       ny=ny))  # type: ignore

    return axes, (legend_ax,
                  left_label_axes,
                  color_bar_axes,
                  hor_axes_locs,
                  ver_axes_locs,
                  legend_ax_per_ax)


def get_ax_loc(plot_config: PlotConfig,
               ax_nr: int,
               hor_axes_locs: list[int],
               ver_axes_locs: list[int]) -> tuple[int, int]:
    if plot_config.has_defined_ax_loc:
        hor_axes_nr = plot_config.ax_loc[0]  # type: ignore
        ver_axes_nr = plot_config.ax_loc[1]  # type: ignore
    else:
        hor_axes_nr = hor_axes_locs[ax_nr]
        ver_axes_nr = ver_axes_locs[ax_nr]

    return hor_axes_nr, ver_axes_nr


def get_ax_locs(plot_configs: list[PlotConfig],
                divider_config: DividerConfig) -> tuple[list, list]:
    if divider_config.uses_custom_locs:
        return divider_config.hor_axes_locs, divider_config.ver_axes_locs

    hor_axes_locs = []
    ver_axes_locs = []
    for ax_nr, plot_config in enumerate(plot_configs):
        if plot_config.has_defined_loc:
            hor_axes_locs.append(plot_config.ax_loc[0])  # type: ignore
            ver_axes_locs.append(plot_config.ax_loc[1])  # type: ignore
        else:
            hor_axes_locs.append(ax_nr)
            ver_axes_locs.append(0)
        


    return hor_axes_locs, ver_axes_locs


def set_ax_config(ax: plt.Axes, ax_config: dict) -> None:
    if FONT_SIZE in ax_config:
        plt.rcParams.update({'font.size': ax_config[FONT_SIZE]})
    
    if X_TICKS in ax_config:
        ax.set_xticks(ax_config[X_TICKS])
    if X_TICK_LABELS in ax_config:
        if X_TICK_LABELS_CONFIG in ax_config:
            ax.set_xticklabels(ax_config[X_TICK_LABELS],
                               **ax_config[X_TICK_LABELS_CONFIG])
        else:
            ax.set_xticklabels(ax_config[X_TICK_LABELS])
    elif X_TICK_LABELS_CONFIG in ax_config:
        ax.tick_params(axis='x', **ax_config[X_TICK_LABELS_CONFIG])

    if Y_TICKS in ax_config:
        ax.set_yticks(ax_config[Y_TICKS])
    if Y_TICK_LABELS in ax_config:
        if len(ax_config[Y_TICK_LABELS]) == 0:
            ax.set_yticklabels(ax_config[Y_TICK_LABELS])
        elif Y_TICK_LABELS_CONFIG in ax_config:
            ax.set_yticklabels(ax_config[Y_TICK_LABELS],
                               **ax_config[Y_TICK_LABELS_CONFIG])
        if 'labelsize' in ax_config[Y_TICK_LABELS_CONFIG]:
            text = ax.yaxis.get_offset_text()
            text.set_fontsize(ax_config[Y_TICK_LABELS_CONFIG]['labelsize'])
        else:
            ax.set_yticklabels(ax_config[Y_TICK_LABELS])
    elif Y_TICK_LABELS_CONFIG in ax_config:
        ax.tick_params(axis='y', **ax_config[Y_TICK_LABELS_CONFIG])
        if 'labelsize' in ax_config[Y_TICK_LABELS_CONFIG]:
            text = ax.yaxis.get_offset_text()
            text.set_fontsize(ax_config[Y_TICK_LABELS_CONFIG]['labelsize'])

    if Y_TICKS_VISIBLE in ax_config:
        ax.yaxis.set_tick_params(labelleft=ax_config[Y_TICKS_VISIBLE])

    if TICK_LABEL_FORMAT in ax_config:
        ax.ticklabel_format(**ax_config[TICK_LABEL_FORMAT])

    if X_LIM in ax_config:
        ax.set_xlim(ax_config[X_LIM])
    if Y_LIM in ax_config:
        ax.set_ylim(ax_config[Y_LIM])

    if Y_LABEL in ax_config:
        if Y_LABEL_CONFIG in ax_config:
            ax.set_ylabel(ax_config[Y_LABEL], **ax_config[Y_LABEL_CONFIG])
        else:
            ax.set_ylabel(ax_config[Y_LABEL])

    if X_LABEL in ax_config:
        if X_LABEL_CONFIG in ax_config:
            ax.set_xlabel(ax_config[X_LABEL], **ax_config[X_LABEL_CONFIG])
        else:
            ax.set_xlabel(ax_config[X_LABEL])

    if TITLE in ax_config:
        if TITLE_CONFIG in ax_config:
            ax.set_title(ax_config[TITLE], **ax_config[TITLE_CONFIG])
        else:
            ax.set_title(ax_config[TITLE])
    if GRID in ax_config:
        ax.grid(**ax_config[GRID])


#=======================================================
# FIT OUTPUT
#=======================================================
FIT_OUTPUT_EXT = '.csv'

#=======================================================
# MAIN
#=======================================================


def execute(raw_args: Optional[Sequence[Any]] = None) -> None:
    plot_config_filenames, output_dir, run_id, eq_file_dir, create_fit_csv, create_line_csv = parse_sys_input(raw_args)
    logger.info(f'create_fit_csv = {create_fit_csv}')
    logger.info(f'create_line_csv = {create_line_csv}')
    CcPlotter.visualize(plot_config_filenames, output_dir, run_id, eq_file_dir, create_fit_csv, create_line_csv)


def parse_sys_input(
    raw_args: Optional[Sequence[Any]]
) -> tuple[list[Path], Optional[Path], Optional[str], Optional[Path], bool, bool]:
    parser = argparse.ArgumentParser(
        description=
        'Compute the computational complexity using a specific configuration')
    sub_parsers = parser.add_subparsers(title='data types',
                                        dest='action',
                                        description='Valid data types',
                                        required=True)
    empirical_parser = sub_parsers.add_parser(
        EMPIRICAL_DATA, help='For plotting empirical data files', aliases=['e'])
    theoretical_parser = sub_parsers.add_parser(
        THEORETICAL_DATA,
        help='For plotting theoretical equations',
        aliases=['t'])
    both_parser = sub_parsers.add_parser('both',
                                         help='For plotting both',
                                         aliases=['b'])
    for sub_parser in [empirical_parser, theoretical_parser, both_parser]:
        if sub_parser != empirical_parser:
            sub_parser.add_argument(
                'eq_file_dir',
                type=existing_dir,
                help=
                "The path to the of the output_dir containing the model equation files."
            )
   
        sub_parser.add_argument('--o',
                    dest='create_line_csv',
                    action='store_true',
                    help='Create a csv-file in the output directory with the data per line',
                    required=False)


        if sub_parser == theoretical_parser:
            continue
        sub_parser.add_argument(
            'output_dir',
            type=existing_dir,
            help="The path to the output_dir containing the output.")
        sub_parser.add_argument(
            'run_id',
            type=str,
            help="The run_id of the results you want to visualize.")
        sub_parser.add_argument('--f',
                        dest='create_fit_csv',
                        action='store_true',
                        help='Create a csv-file in the output directory with the fitted functions per line',
                        required=False)

    parser.add_argument('plot_config_file',
                        type=valid_input_file,
                        help='The path to the plot configuration file.')

    args, rem_args = parser.parse_known_args(raw_args)
    
    output_dir = None
    run_id = None
    eq_file_dir = None
    create_fit_csv = False
    create_line_csv = args.create_line_csv
    
    if args.action not in (EMPIRICAL_DATA, 'e'):
        eq_file_dir = args.eq_file_dir        

    if args.action not in (THEORETICAL_DATA, 't'):
        output_dir = args.output_dir
        run_id = args.run_id
        create_fit_csv = args.create_fit_csv
    
        
    plot_config_files = [args.plot_config_file]

    rem_parser = argparse.ArgumentParser('Remaining args')

    
    rem_parser.add_argument('plot_config_files',
                        nargs='*',
                        type=valid_input_file,
                        help='The path to the plot configuration file.')

    args = rem_parser.parse_args(rem_args)
    if args.plot_config_files is not None:
        plot_config_files += args.plot_config_files

    if args.plot_config_files is not None:
        plot_config_files += args.plot_config_files

    return plot_config_files, output_dir, run_id, eq_file_dir, create_fit_csv, create_line_csv


def existing_dir(string: str) -> Path:
    output_dir = Path(string).resolve()
    if not output_dir.is_dir():
        raise Exception(f'The output directory "{output_dir}" does not exist. Input string is "{string}"')
    return output_dir


def valid_input_file(string: str) -> Path:
    plot_config_filename = Path(string).resolve()
    if not plot_config_filename.is_file():
        raise FileNotFoundError(
            f'No such file: "{string}" ({plot_config_filename.resolve()})')
    return plot_config_filename


def valid_output_file(string: str):
    output_filename = Path(string).resolve()
    if not output_filename.parent.is_dir():
        raise Exception(f'The directory "{output_filename.parent}" does not exist!')
    if output_filename.suffix != FIT_OUTPUT_EXT:
        raise Exception(f'The extension "{output_filename.suffix}" is not valid! It should be "{FIT_OUTPUT_EXT}"')

    return output_filename

if __name__ == '__main__':
    execute()