from __future__ import annotations
import argparse
from collections import defaultdict
from configparser import ConfigParser
from copy import deepcopy
import csv
from datetime import datetime
import itertools
import json
import logging
from math import floor
from multiprocessing import Pool, Value, cpu_count
import random
from dataclasses import astuple, dataclass, field
from pathlib import Path
import time
from typing import Any, Optional, TextIO

import numpy as np
import numpy.typing as npt

from complexity.constants import *
from complexity.supporting_functions import *
from complexity.base_classes import *

logger = logging.getLogger(__name__)

#=======================================================
#=======================================================

CC_EXPORT_TAGS = (INIT_STEP_COUNT, ALLOWABLE_ERR, STEP_SIZE, MAX_STEP_COUNT)

CC_EXPORT_TAGS_TO_TYPE = {
    INIT_STEP_COUNT: INT_TYPE,
    ALLOWABLE_ERR: FLOAT_TYPE,
    STEP_SIZE: INT_TYPE,
    MAX_STEP_COUNT: INT_TYPE
}


@dataclass
class ConvergenceConfig():
    init_step_count: int = 30
    allowable_err: float = 0.005
    step_size: int = 10
    max_step_count: int = 3000
    current_step: int = 0
    converged_experiments: set[str] = field(default_factory=set)

    def get_copy(self) -> ConvergenceConfig:
        return ConvergenceConfig(*astuple(self))

    def increment_step(self) -> None:
        self.current_step += self.step_size

    def to_tuple(self) -> tuple[int, float, int, int]:
        return (self.init_step_count,
                self.allowable_err,
                self.step_size,
                self.max_step_count)

    @property
    def header_str(self) -> str:
        return f'{COMMENT} {", ".join(CC_EXPORT_TAGS)}'

    def to_config_str(self) -> str:
        return f'{", ".join([str(getattr(self, key)) for key in CC_EXPORT_TAGS])}'

    @staticmethod
    def from_config_str(config_str: str) -> ConvergenceConfig:
        args = tuple(
            cast_variable(
                config_str,
                COLLECTION_TYPE,
                [CC_EXPORT_TAGS_TO_TYPE[tag] for tag in CC_EXPORT_TAGS
                ]))  # type: ignore
        return ConvergenceConfig(*args)


EXP_EXPORT_TAGS = (MODEL_TAG,
                   INFRA_TAG,
                   PED_COUNT,
                   INFRA_ARG,
                   OBSTACLE_COUNT,
                   LOCAL_DENSITY,
                   PED_CLASS_COUNT,
                   CLASS_DISTR,
                   MODEL_INSTANCE_IND)

EXP_RESULT_TAGS = EXP_EXPORT_TAGS[:-1]
EXP_SET_EXPORT_TAGS = tuple(list(EXP_RESULT_TAGS) + [BASE_VAR, COMB_VARS, CONFIG_NAME])

EXP_RESULT_IND = {key: ii for ii, key in enumerate(EXP_EXPORT_TAGS)}

EXP_FIELDS_TO_TYPE = {
    MODEL_TAG: STRING_TYPE,
    INFRA_TAG: STRING_TYPE,
    PED_COUNT: INT_TYPE,
    INFRA_ARG: FLOAT_COLLECTION_TYPE,
    OBSTACLE_COUNT: INT_TYPE,
    LOCAL_DENSITY: DISTR_TYPE,
    PED_CLASS_COUNT: INT_TYPE,
    CLASS_DISTR: DISTR_TYPE,
    MODEL_INSTANCE_IND: INT_TYPE
}


@dataclass(frozen=True)
class ExperimentBase():
    model_tag: str
    infra_tag: str
    ped_count: int
    infrastructure_arg: tuple
    obstacle_count: int
    local_density: float | DistrType
    ped_class_count: int
    class_distribution: DistrType

    def to_tuple(self) -> tuple:
        return tuple(self.to_list())

    def to_list(self) -> list:
        return [getattr(self, key) for key in self.export_tags]

    def get_tuple_without_fields(self, fields: str | list[str]) -> tuple:
        if isinstance(fields, str):
            fields = [fields]

        return tuple([
            getattr(self, key) for key in self.export_tags if key not in fields
        ])

    @property
    def header_str(self) -> str:
        return f'{COMMENT} {", ".join(self.export_tags)}'

    def __str__(self) -> str:
        return self.to_tag()

    def to_tag(self) -> str:
        return f'{"_".join([str(getattr(self, key)) for key in self.export_tags])}'

    def to_config_str(self) -> str:
        return f'{", ".join([to_string(getattr(self, key), EXP_FIELDS_TO_TYPE[key]) for key in self.export_tags])}'

    def is_part_of_set(self, base_vars: dict[str, Any]) -> bool:
        for field_nm, value in base_vars.items():
            if not self.has_field_value(field_nm, value):
                return False

        return True

    def get(self, field: str) -> Any:
        return getattr(self, field)

    def has_field_value(self, field: str, value: Any) -> bool:
        field_value = getattr(self, field)
        if isinstance(field_value, float):
            return abs(field_value - value) < 1e-12  # type: ignore
        elif field == INFRA_TAG:
            return field_value.startswith(value)  # type: ignore
        return field_value == value  # type: ignore

    @property
    def export_tags(self) -> tuple:
        raise NotImplementedError


@dataclass(frozen=True)
class Experiment(ExperimentBase):
    model_instance_ind: int

    @property
    def export_tags(self) -> tuple:
        return EXP_EXPORT_TAGS

    @staticmethod
    def from_config_str(config_str: str) -> Experiment:
        args_list = cast_variable(
            config_str,
            COLLECTION_TYPE,
            [EXP_FIELDS_TO_TYPE[tag] for tag in EXP_EXPORT_TAGS
            ])  # type: ignore
        args = tuple(args_list)
        return Experiment(*args)


@dataclass(frozen=True)
class ExperimentSet(ExperimentBase):
    base_var: str
    comb_vars: list[str]
    config_name: str

    def __post_init__(self) -> None:
        if isinstance(getattr(self, self.base_var), list):
            return

        if self.comb_vars is None:
            comb_vars = []
        elif isinstance(self.comb_vars, list) and len(self.comb_vars) > 0:
            if self.comb_vars[0] is None:
                comb_vars = []
            else:
                comb_vars = [
                    INPUT_TO_EXP_TAGS[comb_var] for comb_var in self.comb_vars
                ]
        elif isinstance(self.comb_vars, str):
            comb_vars = [self.comb_vars]
        else:
            raise TypeError

        object.__setattr__(self, 'comb_vars', comb_vars)
        object.__setattr__(self, self.base_var, [getattr(self, self.base_var)])
        for comb_var in self.comb_vars:
            object.__setattr__(self, comb_var, [getattr(self, comb_var)])

    def add_values(self, experiment: Experiment) -> None:
        if experiment.get(self.base_var) in getattr(self, self.base_var):
            raise Exception
        getattr(self, self.base_var).append(experiment.get(self.base_var))
        for comb_var in self.comb_vars:
            getattr(self, comb_var).append(experiment.get(comb_var))

    def sort_base_var(self) -> None:
        if self.base_var == INFRA_ARG:
            sort_ind = np.argsort([x[0]*x[1] for x in getattr(self, self.base_var)])
        else:
            sort_ind = np.argsort(getattr(self, self.base_var))

        object.__setattr__(self,
                           self.base_var,
                           list(np.array(self.base_var_values)[sort_ind]))
        for comb_var in self.comb_vars:
             object.__setattr__(self,
                           comb_var,
                           list(np.array(self.get(comb_var))[sort_ind]))

    def get_key_tuple(self) -> tuple:
        fields = [self.base_var, MODEL_INSTANCE_IND, BASE_VAR, COMB_VARS, CONFIG_NAME] + self.comb_vars

        return super().get_tuple_without_fields(fields)

    def does_fit(self, value_dict:dict[str, Any]) -> bool:
        for field, value in value_dict.items():
            if self.get(field) != value:
                return False

        return True

    def to_ind_dict(self) -> dict[str, Any]:
        return {field:self.get(field) for field in EXP_RESULT_TAGS if field not in self.comb_vars and field != self.base_var}

    @property
    def base_var_values(self) -> Any:
        return self.get(self.base_var)

    @property
    def export_tags(self) -> tuple:
        return EXP_SET_EXPORT_TAGS

    @staticmethod
    def get_from_experiment(experiment: Experiment,
                            base_var: str,
                            comb_vars: Optional[list[str]],
                            config_name:str) -> ExperimentSet:
        args_tuple = tuple(
            list(experiment.get_tuple_without_fields(MODEL_INSTANCE_IND))
            + [base_var, comb_vars, config_name])
        return ExperimentSet(*args_tuple)


    def get_line_string(self, line_kwargs:dict) -> str:
        parts = []
        for field in EXP_RESULT_TAGS:
            if field in line_kwargs:
                parts.append(str(line_kwargs[field]))
                continue
            value = self.get(field)
            if isinstance(value, (np.ndarray, tuple, list)):
                parts.append(f'{value[0]}-{value[-1]}')
            else:
                parts.append(str(value))

        return '_'.join(parts)

#=======================================================
#=======================================================

EXPERIMENT_SETUP_VARS = {
    MODEL_TAGS: COLLECTION_TYPE,
    INFRA_TAGS: COLLECTION_TYPE,
    PED_COUNTS: INT_COLLECTION_TYPE,
    OBSTACLE_COUNTS: INT_COLLECTION_TYPE,
    LOCAL_DENSITIES: DISTR_COLLECTION_TYPE,
    CLASS_DISTRS: DISTR_COLLECTION_TYPE,
    PED_CLASS_COUNTS: INT_COLLECTION_TYPE,
    INFRA_ARGS: ARG_COLLECTION_TYPE,
    BASE_VAR: STRING_TYPE,
}

CONVERGENCE_SETUP_VARS = {
    'init_step_count': INT_TYPE,
    'allowable_err': FLOAT_TYPE,
    'step_size': INT_TYPE,
    'max_step_count': INT_TYPE
}


def get_inputs(
    input_file: Path, base_name: str
) -> tuple[list[PedModel],
           list[Experiment],
           list[ExperimentSet],
           ConvergenceConfig,
           str]:
    experiment_config = parse_input_file(input_file)
    convergence_config = ConvergenceConfig(
        **experiment_config['convergence_setup'])
    date_str = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')
    run_id = f'{base_name}_{date_str}'

    experiments, model_instances, experiment_sets  = ExperimentCreator.create_experiments(experiment_config)

    return model_instances, experiments, experiment_sets, convergence_config, run_id


def parse_input_file(input_file: Path) -> dict[str, Any]:
    config_parser = ConfigParser()
    with open(input_file, 'r') as f:
        config_parser.read_file(f)
    #config_parser.read(input_file)

    config: dict[str, Any] = {}
    config['models'] = get_classes(config_parser['models'], 'models')
    config['infrastructures'] = get_classes(config_parser['infrastructures'],
                                            'infrastructures')

    if not config_parser.has_section('convergence_setup'):
        config_parser.add_section('convergence_setup')
    config['convergence_setup'] = read_basic_config_section(
        config_parser['convergence_setup'],
        CONVERGENCE_SETUP_VARS,
        optional=True)

    parse_experiments(config_parser, config)

    return config


def parse_experiments(config_parser: ConfigParser, config: dict[str,
                                                                Any]) -> None:
    config['experiments'] = []
    if 'experiment_setup' in config_parser:
        experiment_setup = read_basic_config_section(
            config_parser['experiment_setup'], EXPERIMENT_SETUP_VARS)

    if 'experiments' not in config_parser:
        exp_config = {}
        for key, val_type in EXPERIMENT_SETUP_VARS.items():
            exp_config[key] = experiment_setup[key]  # type: ignore

        exp_config[MODEL_TAGS] = list(config['models'].keys())
        exp_config[INFRA_TAGS] = list(config['infrastructures'].keys())
        return

    exp_groups = clean_comma_separated_str(
        config_parser['experiments']['experiments'])
    for exp_group in exp_groups:
        group_config = config_parser[exp_group]
        exp_config: dict[str, Any] = {CONFIG_NAME: exp_group}
        for key, val_type in EXPERIMENT_SETUP_VARS.items():
            if key in group_config:
                exp_config[key] = cast_variable(group_config[key], val_type)
            elif key == MODEL_TAG:
                exp_config[MODEL_TAGS] = list(config['models'].keys())
            elif key == INFRA_TAG:
                exp_config[INFRA_TAGS] = list(config['infrastructures'].keys())
            else:
                exp_config[key] = experiment_setup[key]  # type: ignore

            if NON_COMB in group_config:
                exp_config[NON_COMB] = cast_variable(group_config[NON_COMB],
                                                     COLLECTION_TYPE,
                                                     (STRING_TYPE,))
        exp_config[BASE_VAR] = INPUT_TO_EXP_TAGS[exp_config[BASE_VAR]]

        config['experiments'].append(exp_config)


ITER_ARGS = (INFRA_TAGS,
             MODEL_TAGS,
             PED_COUNTS,
             OBSTACLE_COUNTS,
             LOCAL_DENSITIES,
             CLASS_DISTRS,
             PED_CLASS_COUNTS,
             INFRA_ARGS)

ITER_ARGS_2_EXP_FIELD = {
    OBSTACLE_COUNTS: OBSTACLE_COUNT,
    LOCAL_DENSITIES: LOCAL_DENSITY,
    CLASS_DISTRS: CLASS_DISTR,
    PED_CLASS_COUNTS: PED_CLASS_COUNT,
    INFRA_ARGS: INFRA_ARG
}

COMB = 'comb'


class ExperimentCreator():

    def __init__(self, experiment_config: dict[str, Any]) -> None:
        self.exp_configs = experiment_config['experiments']
        self.models = experiment_config['models']
        self.infrastructures = experiment_config['infrastructures']

        self.experiments = []
        self.experiments_sets = []
        self.model_instances = []
        self.model_inst_to_ind = {}
        self.infra_instances: dict[tuple, Infrastructure] = {}

    def _create_experiments(
            self) -> tuple[list[Experiment], list[PedModel], list]:
        for exp_config in self.exp_configs:
            self.exp_sets_local: dict[tuple, ExperimentSet] = {}
            iter_args = self.get_iter_args(exp_config)
            combs = itertools.product(*iter_args)
            for comb in combs:
                self.create_experiment(comb, exp_config)
            for exp_set in self.exp_sets_local.values():
                exp_set.sort_base_var()
                self.experiments_sets.append(exp_set)

        return list(set(self.experiments)), self.model_instances, self.experiments_sets

    def create_experiment(self, comb, exp_config):
        exp_kwargs = {}
        exp_kwargs[PED_COUNT] = comb[ITER_ARGS.index(PED_COUNTS)]
        for ii, config_key in enumerate(ITER_ARGS):
            if config_key in (INFRA_TAGS, MODEL_TAGS, PED_COUNTS):
                continue
            value = comb[ii]
            if value == COMB:
                ind = exp_config[PED_COUNTS].index(exp_kwargs[PED_COUNT])
                value = exp_config[config_key][ind]
            if value is None:
                value = np.nan

            exp_kwargs[ITER_ARGS_2_EXP_FIELD[config_key]] = value

        infra_tag = comb[ITER_ARGS.index(INFRA_TAGS)]
        infra_class = self.infrastructures[infra_tag]
        infra_id = (infra_class,
                    exp_kwargs[INFRA_ARG],
                    exp_kwargs[OBSTACLE_COUNT])
        if infra_id in self.infra_instances:
            infra_instance = self.infra_instances[infra_id]
        else:
            infra_instance = infra_class.get_instance(
                *exp_kwargs[INFRA_ARG], exp_kwargs[OBSTACLE_COUNT])
            self.infra_instances[infra_id] = infra_instance

        model_tag = comb[ITER_ARGS.index(MODEL_TAGS)]
        model_class = self.models[model_tag]
        model_inst_id = (model_class, infra_instance)
        if model_inst_id not in self.model_inst_to_ind:
            self.model_instances.append(model_class(infra_instance))
            self.model_inst_to_ind[model_inst_id] = len(
                self.model_instances) - 1
            model_instance_ind = len(self.model_instances) - 1
        else:
            model_instance_ind = self.model_inst_to_ind[model_inst_id]

        exp_kwargs[INFRA_TAG] = infra_instance.tag
        exp_kwargs[MODEL_TAG] = self.model_instances[model_instance_ind].tag
        exp_kwargs[MODEL_INSTANCE_IND] = model_instance_ind

        experiment = Experiment(**exp_kwargs)
        self.experiments.append(experiment)
        self.add_exp_to_set(experiment,
                            exp_config[BASE_VAR],
                            exp_config[NON_COMB],
                            exp_config[CONFIG_NAME])

    def add_exp_to_set(self,
                       experiment: Experiment,
                       base_var: str,
                       non_comb: Optional[list[str]],
                       config_name:str) -> None:
        exp_set = ExperimentSet.get_from_experiment(experiment,
                                                    base_var,
                                                    non_comb,
                                                    config_name)
        key = exp_set.get_key_tuple()
        if key not in self.exp_sets_local:
            self.exp_sets_local[key] = exp_set
        else:
            self.exp_sets_local[key].add_values(experiment)

    def get_iter_args(self, exp_config) -> tuple:
        iter_args = [exp_config[key] for key in ITER_ARGS]
        if NON_COMB in exp_config:
            for key in exp_config[NON_COMB]:
                if key is not None:
                    iter_args[ITER_ARGS.index(key)] = [COMB]

        return tuple(iter_args)

    @staticmethod
    def create_experiments(
        experiment_config: dict[str, Any]
    ) -> tuple[list[Experiment], list[PedModel], list[ExperimentSet]]:
        return ExperimentCreator(experiment_config)._create_experiments()


def create_input_file(experiment_defs: dict[str, dict[str, Any]],
                      models: dict[str, str],
                      infrastructures: dict[str, str],
                      input_file: Path) -> None:
    with open(input_file, 'w') as file:
        file.write('[models]\n')
        for name, module_name in models.items():
            file.write(f'{name} = {module_name}\n')
        file.write('\n[infrastructures]\n')
        for name, module_name in infrastructures.items():
            file.write(f'{name} = {module_name}\n')

        file.write('\n[convergence_setup]\n')

        file.write('\n[experiments]\n')
        file.write(f'experiments = {", ".join(list(experiment_defs.keys()))}\n')

        for exp_key, exp_def in experiment_defs.items():
            file.write(f'\n[{exp_key}]\n')
            for field_key, field_value in exp_def.items():
                if field_key == NON_COMB:
                    if isinstance(field_value, (tuple, list)):
                        file.write(
                            f'{field_key} = {to_string(field_value, COLLECTION_TYPE)}\n'
                        )
                    else:
                        file.write(
                            f'{field_key} = {to_string(field_value, STRING_TYPE)}\n'
                        )
                else:
                    file.write(
                        f'{field_key} = {to_string(field_value, EXPERIMENT_SETUP_VARS[field_key], no_brackets=True)}\n'
                    )


#=======================================================
#=======================================================

INPUT_FILE_VERSION = 2
OUTPUT_FILE_VERSION = 1
VERSION_HEADER_BASE = 'File version = '
SEEDS_HEADER = f'{COMMENT} Seeds'
CONVERGENCE_CONFIG_HEADER = f'{COMMENT} Convergence config'
EXPERIMENTS_HEADER = f'{COMMENT} Experiments'
SEEDS_MODE = 'seeds'
CONVERGENCE_CONFIG_MODE = 'convergence_config'
EXPERIMENTS_MODE = 'experiments'


class File():

    def __init__(self, directory: Path, run_id: str) -> None:
        self.directory = directory
        self.run_id = run_id

    def write_data_to_file(self, *args: Any, **kwargs: Any) -> None:
        raise NotImplementedError

    def read_data_from_file(self) -> Any:
        raise NotImplementedError

    def write_version_header(self, file: TextIO) -> None:
        file.write(f'{VERSION_HEADER_BASE}{self.file_version}\n')

    @property
    def file_version(self) -> int:
        raise NotImplementedError


class InputFile(File):

    def __init__(self, directory: Path, run_id: str) -> None:
        super().__init__(directory, run_id)
        self.filename = self.directory.joinpath(
            f'{self.run_id}{INPUT_FILE_EXT}')

    def write_data_to_file(  # type:ignore
            self,
            experiments: list[Experiment],
            convergence_config: ConvergenceConfig,
            seeds: list[int]) -> None:
        experiments.sort(key=lambda x: (x.model_tag,
                                        x.infra_tag,
                                        x.ped_count,
                                        x.infrastructure_arg,
                                        x.obstacle_count,
                                        x.local_density,
                                        x.ped_class_count,
                                        x.class_distribution))
        with open(self.filename, 'w') as file:
            self.write_version_header(file)

            file.write(f'{SEEDS_HEADER}\n')
            file.write(f'{seeds}'[1:-1])
            file.write('\n\n')

            file.write('{CONVERGENCE_CONFIG_HEADER}\n')
            file.write(f'{convergence_config.header_str}\n')
            file.write(f'{convergence_config.to_config_str()}\n\n')

            file.write(f'{EXPERIMENTS_HEADER}\n')
            file.write(f'{experiments[0].header_str}\n')
            for experiment in experiments:
                file.write(f'{experiment.to_config_str()}\n')
            file.write('\n')

    def read_data_from_file(
            self) -> tuple[list[Experiment], ConvergenceConfig, list[int]]:
        with open(self.filename, 'r') as file:
            lines = file.readlines()

        version = -1
        seeds = None
        convergence_config = None
        experiments: list[Experiment] = []

        mode = None
        for line in lines:
            line = line.strip()
            if len(line) == 0:
                continue
            if line[0] == '#':
                if line == SEEDS_HEADER:
                    mode = SEEDS_MODE
                elif line == CONVERGENCE_CONFIG_HEADER:
                    mode = CONVERGENCE_CONFIG_MODE
                elif line == EXPERIMENTS_HEADER:
                    mode = EXPERIMENTS_MODE
                continue
            if line.startswith(VERSION_HEADER_BASE):
                version = int(line.replace(VERSION_HEADER_BASE, ''))
            elif mode == SEEDS_MODE:
                seeds = cast_variable(line, INT_COLLECTION_TYPE)
            elif mode == CONVERGENCE_CONFIG_MODE:
                convergence_config = ConvergenceConfig.from_config_str(line)
            elif mode == EXPERIMENTS_MODE:
                experiments.append(Experiment.from_config_str(line))
            else:
                raise ValueError(f'The line "{line}" has an unknown format!')

        return experiments, convergence_config, seeds  # type:ignore

    @property
    def file_version(self) -> int:
        return INPUT_FILE_VERSION


KEYS_TO_EXP = 'keys_to_exp'
FIELDS = 'fields'
FILE_VERSION = 'version'
INDEX_TABLE = 'index_table'
ID_MATRIX = 'id_matrix'
RUN_TIMES = 'run_times'
EXP_SETS = 'exp_sets'


class ExpSetsFile(File):

    def __init__(self, directory: Path, run_id: str) -> None:
        super().__init__(directory, run_id)
        self.filename = self.directory.joinpath(
            f'{self.run_id}{EXP_SETS_FILE_EXT}')

    def write_data_to_file(self, experiment_sets: list[ExperimentSet]) -> None:

        data_to_save: dict[str, tuple] = {
            str(exp_set): exp_set.to_tuple() for exp_set in experiment_sets
        }

        np.savez(self.filename, **data_to_save)

    def read_data_from_file(self) -> list[ExperimentSet]:
        data_from_file = np.load(self.filename, allow_pickle=True)
        experiment_sets = []
        for exp_set_array in data_from_file.values():
            exp_set_tuple = tuple(exp_set_array)
            experiment_sets.append(ExperimentSet(*exp_set_tuple))

        return experiment_sets

class OutputFile(File):

    def __init__(self,
                 directory: Path,
                 run_id: str,
                 is_temp: bool = False) -> None:
        super().__init__(directory, run_id)
        self.is_temp = is_temp
        self.filename = self.directory.joinpath(
            f'{self.run_id}{OUTPUT_FILE_EXT_PARTS[0]}{".temp" if self.is_temp else ""}{OUTPUT_FILE_EXT_PARTS[1]}'
        )

    def write_data_to_file(  # type:ignore
            self,
            run_times_per_exp: dict[Experiment, list[float]],
            field_values: Optional[dict[str, Any]] = None) -> None:

        data_to_save: dict[str, Any] = {}
        keys_to_exp: dict[str, tuple] = {}

        for ii, (experiment, run_times) in enumerate(run_times_per_exp.items()):
            data_to_save[str(ii)] = run_times
            keys_to_exp[str(ii)] = experiment.to_tuple()

        data_to_save[KEYS_TO_EXP] = keys_to_exp
        data_to_save[FILE_VERSION] = self.file_version

        if field_values is not None:
            fields = []
            for key, values in field_values.items():
                data_to_save[key] = field_values_to_np_array(key, values)
                fields.append(key)

            data_to_save[FIELDS] = fields

        np.savez(self.filename, **data_to_save)

    def export_to_csv(self) -> None:
        csv_filename = self.filename.with_suffix('.csv')

        run_times_per_exp, field_values = self.read_data_from_file()

        with open(csv_filename, 'w', newline='') as file:
            csv_writer = csv.writer(file, delimiter=',')
            fields_header = []
            fields_values = []
            for key, value in field_values.items():
                fields_header.append(key)
                fields_values.append(value)
            csv_writer.writerows((fields_header, fields_values))

            exp_header = list(EXP_EXPORT_TAGS) + ['run_time']
            exp_rows = []
            for exp, run_times in run_times_per_exp.items():
                exp_rows.append(exp.to_list() + [np.mean(run_times)])

            csv_writer.writerow(exp_header)
            csv_writer.writerows(exp_rows)

    def read_data_from_file(
        self,
        as_list: bool = False
    ) -> tuple[dict[Experiment, NumSeq], dict[str, list]]:
        data_from_file = np.load(self.filename, allow_pickle=True)
        keys_to_exp = data_from_file[KEYS_TO_EXP].item()

        run_times_per_exp: dict[Experiment, NumSeq] = {}
        for ii, (key, exp_tuple) in enumerate(keys_to_exp.items()):
            experiment = Experiment(*exp_tuple)
            if as_list:
                run_times_per_exp[experiment] = list(data_from_file[key])
            else:
                run_times_per_exp[experiment] = data_from_file[key]
            sys.stdout.write(f'\r({ii+1}/{len(keys_to_exp)})')
            sys.stdout.flush()

        field_values = OutputFile.get_field_values(data_from_file)

        return run_times_per_exp, field_values

    @staticmethod
    def get_field_values(data_from_file: Any) -> dict[str, list]:
        field_values = {}
        if FIELDS in data_from_file:
            for key in list(data_from_file[FIELDS]):
                field_values[key] = np_array_to_field_values(
                    key, data_from_file[key])

        return field_values

    def convert_to_summary_file(self) -> ResultsFile:
        return ResultsFile.create_from_output(self)

    @property
    def file_version(self) -> int:
        return OUTPUT_FILE_VERSION


class ResultsFile(File):

    def __init__(self, directory: Path, run_id: str) -> None:
        super().__init__(directory, run_id)
        self.filename = self.directory.joinpath(f'{self.run_id}.res.npz')

    def read_data_from_file(
            self) -> tuple[npt.NDArray, npt.NDArray, dict, dict, list]:
        data_from_file = np.load(self.filename, allow_pickle=True)
        index_table = data_from_file[INDEX_TABLE].item()
        field_values = data_from_file[FIELDS].item()
        id_matrix = data_from_file[ID_MATRIX]
        run_times = data_from_file[RUN_TIMES]
        exp_sets_arrays = data_from_file[EXP_SETS].item()

        experiment_sets = []
        for exp_set_array in exp_sets_arrays.values():
            exp_set_tuple = tuple(exp_set_array)
            experiment_sets.append(ExperimentSet(*exp_set_tuple))

        return run_times, id_matrix, field_values, index_table, experiment_sets

    def create(self,
               run_times_per_exp: dict[Experiment, NumSeq],
               field_values: dict[str, list],
               experiment_sets:list[ExperimentSet]) -> None:

        index_table = {
            MODEL_TAG: {
                value: ii for ii, value in enumerate(field_values[MODEL_TAG])
            },
            INFRA_TAG: {
                value: ii for ii, value in enumerate(field_values[INFRA_TAG])
            },
            INFRA_ARG: {
                value: ii for ii, value in enumerate(field_values[INFRA_ARG])
            }
        }

        res_count = len(run_times_per_exp)
        id_matrix = np.zeros((res_count, len(EXP_RESULT_TAGS)))
        id_matrix[:] = np.nan
        run_time_per_exp = np.zeros(res_count)

        for ii, (experiment, run_times) in enumerate(run_times_per_exp.items()):
            for jj, tag in enumerate(EXP_RESULT_TAGS):
                field_value = experiment.get(tag)
                if tag in index_table:
                    id_matrix[ii, jj] = index_table[tag][field_value]
                else:
                    id_matrix[ii,
                              jj] = field_value_to_np_array(tag, field_value)
            run_time_per_exp[ii] = np.mean(run_times)

        exp_sets = {str(exp_set): exp_set.to_tuple() for exp_set in experiment_sets}

        data_to_save: dict[str, Any] = {
            INDEX_TABLE: index_table,
            ID_MATRIX: id_matrix,
            RUN_TIMES: run_time_per_exp,
            FIELDS: field_values,
            EXP_SETS: exp_sets
        }

        np.savez(self.filename, **data_to_save)

    @staticmethod
    def create_from_output(output_file: OutputFile) -> ResultsFile:
        run_times_per_exp, field_values = output_file.read_data_from_file()
        experiment_sets = ExpSetsFile(output_file.directory, output_file.run_id).read_data_from_file()
        res_file = ResultsFile(output_file.directory, output_file.run_id)
        res_file.create(run_times_per_exp, field_values, experiment_sets)

        return res_file

class TempStateFile(File):

    def __init__(self, directory: Path, run_id: str) -> None:
        super().__init__(directory, run_id)
        self.temp_output_file = OutputFile(self.directory,
                                           self.run_id,
                                           is_temp=True)
        self.temp_state_filename = self.directory.joinpath(
            f'{self.run_id}.state.temp')

    def write_data_to_file(  # type: ignore
            self,
            run_times_per_exp: dict[Experiment, list[float]],
            convergence_config: ConvergenceConfig,
            experiment_seed_ind: dict[Experiment, int]) -> None:
        self.temp_output_file.write_data_to_file(run_times_per_exp)
        save_dict = {
            'seed_ind': {
                exp.to_tag(): seed_ind for exp,
                seed_ind in experiment_seed_ind.items()
            },
            'current_step':
                convergence_config.current_step,
            'converged_experiments':
                list(convergence_config.converged_experiments)
        }
        with open(self.temp_state_filename, 'w') as f:
            json.dump(save_dict, f, indent=4)

    def read_data_from_file(
            self) -> tuple[list[Experiment], ConvergenceConfig, list[int]]:
        raise NotImplementedError

    def finalize(self) -> None:
        self.temp_output_file.filename.unlink()
        self.temp_state_filename.rename(
            self.temp_state_filename.with_suffix(''))

    @property
    def file_version(self) -> int:
        return OUTPUT_FILE_VERSION


#=======================================================
#=======================================================


def compute_complexity(input_file: Path,
                       base_name: str,
                       output_dir: Path,
                       run_parallel: bool = True) -> None:
    model_instances, experiments, experiment_sets, convergence_config, run_id = get_inputs(input_file, base_name)

    _compute_complexity(model_instances,
                        experiments,
                        experiment_sets,
                        convergence_config,
                        run_id,
                        output_dir,
                        run_parallel)


def _compute_complexity(
        model_instances: list[PedModel],
        experiments: list[Experiment],
        experiment_sets: list[ExperimentSet],
        convergence_config: ConvergenceConfig,
        run_id: str,
        output_dir: Path,
        run_parallel: bool = True,
        cpu_time: bool = False) -> dict[Experiment, list[float]]:

    convergence_config = convergence_config.get_copy()
    rng = np.random.default_rng()
    seeds: list[int] = list(
        rng.integers(1,
                     max(convergence_config.max_step_count,
                         np.iinfo('int32').max),
                     size=convergence_config.max_step_count))
    InputFile(output_dir, run_id).write_data_to_file(experiments,
                                                     convergence_config,
                                                     seeds)

    ExpSetsFile(output_dir, run_id).write_data_to_file(experiment_sets)

    run_times_per_exp: dict[str, list[float]] = {
        exp.to_tag(): [] for exp in experiments
    }

    if cpu_time:
        get_time_in_nano_sec = get_cpu_time_in_nano_sec
    else:
        get_time_in_nano_sec = get_run_time_in_nano_sec

    logger.info(
        f'Running {len(experiments)} experiments using {"cpu" if cpu_time else "wall"} time'
    )

    temp_file = TempStateFile(output_dir, run_id)
    experiment_seed_ind = {experiment: 0 for experiment in experiments}
    tag_to_exp = {exp.to_tag(): exp for exp in experiments}
    experiment_tags = set(tag_to_exp.keys())
    run_times_per_exp_out: dict[Experiment,
                                list[float]] = {exp: [] for exp in experiments}
    while not all_have_converged(run_times_per_exp, convergence_config):
        temp_file.write_data_to_file(
            run_times_per_exp_out,
            convergence_config,
            experiment_seed_ind,
        )
        experiments_rep = get_replications(tag_to_exp,
                                           experiment_tags,
                                           convergence_config)

        if run_parallel:
            run_experiments_parallel(experiments_rep,
                                     experiment_seed_ind,
                                     run_times_per_exp,
                                     seeds,
                                     get_time_in_nano_sec,
                                     model_instances)
        else:
            run_experiments_sequential(experiments_rep,
                                       experiment_seed_ind,
                                       run_times_per_exp,
                                       seeds,
                                       get_time_in_nano_sec,
                                       model_instances,
                                       temp_file,
                                       convergence_config,
                                       tag_to_exp)
        print('\n')
        for tag, values in run_times_per_exp.items():
            run_times_per_exp_out[tag_to_exp[tag]] += values

    field_values = get_field_values(experiments)
    output_file = OutputFile(output_dir, run_id)
    output_file.write_data_to_file(run_times_per_exp_out, field_values)
    temp_file.finalize()
    ResultsFile(output_dir, run_id).create(run_times_per_exp_out, field_values, experiment_sets) # type:ignore
    return run_times_per_exp_out


def run_experiments_sequential(experiments_rep: list[Experiment],
                               experiment_seed_ind: dict[Experiment, int],
                               run_times_per_exp: dict[str, list[float]],
                               seeds: list[int],
                               get_time_in_nano_sec: Callable[
                                   [Experiment, list[PedModel]], int],
                               model_instances: list[PedModel],
                               temp_file: TempStateFile,
                               convergence_config: ConvergenceConfig,
                               tag_to_exp: dict[str, Experiment]) -> None:
    exp_count = len(experiments_rep)
    sys.stdout.write(f'\r0/{exp_count}')
    sys.stdout.flush()
    t_start = time.time()
    for ii, experiment in enumerate(experiments_rep):
        seed_ind = experiment_seed_ind[experiment]
        RNG = np.random.default_rng(seeds[seed_ind])
        random.seed(seeds[seed_ind])

        run_time = get_time_in_nano_sec(experiment, model_instances)
        run_times_per_exp[experiment.to_tag()].append(run_time)
        experiment_seed_ind[experiment] += 1
        sys.stdout.write(f'\r{ii+1}/{exp_count}')
        sys.stdout.flush()
        if time.time() - t_start > 60*10:
            logger.debug('Write data to temp file')
            run_times_per_exp_out = {
                tag_to_exp[tag]: values for tag,
                values in run_times_per_exp.items()
            }
            temp_file.write_data_to_file(
                run_times_per_exp_out,
                convergence_config,
                experiment_seed_ind,
            )
            t_start = time.time()


def run_experiments_parallel(experiments_rep: list[Experiment],
                             experiment_seed_ind: dict[Experiment, int],
                             run_times_per_exp: dict[str, list[float]],
                             seeds: list[int],
                             get_time_in_nano_sec: Callable[
                                 [Experiment, list[PedModel]], int],
                             model_instances: list[PedModel]) -> None:
    cpu_count_use = cpu_count() - 1
    logger.info(f'Running on {cpu_count_use} cpus')
    set_size = floor(len(experiments_rep)/cpu_count_use)
    experiments_rep_sub_set = []
    start = 0
    for _ in range(cpu_count_use - 1):
        end = start + set_size
        experiments_rep_sub_set.append(experiments_rep[start:end])
        start = end
    experiments_rep_sub_set.append(experiments_rep[start:-1])

    exp_count = len(experiments_rep)

    args_per_subset = [(experiments_rep_sub_set[ii],
                        experiment_seed_ind,
                        seeds,
                        get_time_in_nano_sec,
                        model_instances,
                        exp_count) for ii in range(cpu_count_use)]

    with Pool(cpu_count_use, initializer=init_worker,
              initargs=(Value('i', 0),)) as p:
        run_times_per_exp_pool = p.starmap(run_experiments_subset,
                                           args_per_subset)
        p.close()
        p.join()

    for run_times_per_exp_loc in run_times_per_exp_pool:
        for key, value in run_times_per_exp_loc.items():
            run_times_per_exp[key] += value


def init_worker(counter) -> None:  # type: ignore
    global exp_counter
    exp_counter = counter


def run_experiments_subset(experiments_rep: list[Experiment],
                           experiment_seed_ind: dict[Experiment, int],
                           seeds: list[int],
                           get_time_in_nano_sec: Callable[
                               [Experiment, list[PedModel]], int],
                           model_instances: list[PedModel],
                           exp_count: int) -> dict[str, list[float]]:
    global exp_counter
    run_times_per_exp = defaultdict(list)
    for experiment in experiments_rep:
        seed_ind = experiment_seed_ind[experiment]
        RNG = np.random.default_rng(seeds[seed_ind])
        random.seed(seeds[seed_ind])
        try:
            run_time = get_time_in_nano_sec(experiment, model_instances)
        except Exception as e:
            logger.fatal(f'{experiment}')
            raise e

        run_times_per_exp[experiment.to_tag()].append(run_time)
        experiment_seed_ind[experiment] += 1
        with exp_counter.get_lock():
            exp_counter.value += 1
        sys.stdout.write(f'\r{exp_counter.value}/{exp_count}')
        sys.stdout.flush()

    return run_times_per_exp


def all_have_converged(run_times_per_exp: dict[str, list[float]],
                       convergence_config: ConvergenceConfig) -> bool:
    if convergence_config.current_step < convergence_config.init_step_count:
        convergence_config.increment_step()
        return False

    experiments_to_check = set(run_times_per_exp.keys()).difference(
        convergence_config.converged_experiments)

    all_converged = True
    for experiment_tag in experiments_to_check:
        if has_converged(run_times_per_exp[experiment_tag],
                         convergence_config,
                         experiment_tag):
            convergence_config.converged_experiments.add(experiment_tag)
            logger.debug(
                f'{experiment_tag} has converged at step {convergence_config.current_step}.'
            )
        else:
            all_converged = False

    if all_converged:
        logger.info(
            f'All experiments have converged at step {convergence_config.current_step}.'
        )

    logger.info(
        f'({len(convergence_config.converged_experiments)}/{len(run_times_per_exp)}) experiments have converged at step {convergence_config.current_step}.'
    )

    convergence_config.increment_step()

    if convergence_config.current_step > convergence_config.max_step_count:
        logger.warn('Not all experiments have converged at maximum step count!')
        for experiment in experiments_to_check:
            logger.warn(f'{experiment} has not converged!')
        return True

    return all_converged


def has_converged(run_times: list[float],
                  config: ConvergenceConfig,
                  experiment_tag: str) -> bool:
    converged = True
    remove_outliers(run_times, experiment_tag)
    last_mean = np.mean(run_times)
    for ii in range(1, config.step_size):
        mean_ii = np.mean(run_times[:-ii])
        if abs((last_mean-mean_ii)/last_mean) > config.allowable_err:
            converged = False
            break

    return converged

def remove_outliers(run_times: list[float], experiment_tag: str) -> None:
    while True:
        run_times_arr = np.array(run_times)
        max_value = np.max(run_times_arr)
        mean_val = np.mean(run_times_arr[run_times_arr < max_value])
        std_val = np.std(run_times_arr[run_times_arr < max_value])
        if max_value > mean_val + 12*std_val:
            loc = np.argwhere(run_times_arr == max_value).flatten()[0]
            run_times[loc] = mean_val  # type: ignore
            logger.warn(
                f'{experiment_tag} Outlier detected! {max_value} {mean_val} {max_value/mean_val - 1}'
            )
        else:
            break


def get_replications(tag_to_exp: dict[str, Experiment],
                     experiment_tags: set[str],
                     config: ConvergenceConfig) -> list[Experiment]:
    instantiated_models_to_run = [
        tag_to_exp[tag]
        for tag in experiment_tags.difference(config.converged_experiments)
    ]
    runs = instantiated_models_to_run*config.step_size
    random.shuffle(runs)
    return runs


def get_field_values(experiments: list[Experiment]) -> dict[str, list[Any]]:
    field_values: dict[str, set[Any]] = {key: set() for key in EXP_EXPORT_TAGS}
    for experiment in experiments:
        for key in EXP_EXPORT_TAGS:
            field_values[key].add(experiment.get(key))

    return {key: list(field_values[key]) for key in EXP_EXPORT_TAGS}


def field_values_to_np_array(field: str, values: list) -> npt.NDArray:
    if field in (LOCAL_DENSITY, CLASS_DISTR):
        for ii, value in enumerate(values):
            values[ii] = field_value_to_np_array(field, value)

    return np.array(sorted(values))


def field_value_to_np_array(field: str, value: Any) -> Any:
    if isinstance(value, DistrType):
        return value.value

    return value


def np_array_to_field_values(field: str, values: npt.NDArray) -> list:
    return [np_array_to_field_value(field, value) for value in values]


def np_array_to_field_value(field: str, value: Any) -> Any:
    if field in (LOCAL_DENSITY, CLASS_DISTR):
        try:
            return DistrType.get_instance(value)
        except ValueError:
            return value
    if field == INFRA_ARG:
        return tuple(value)

    return value


def get_run_time_in_nano_sec(experiment: Experiment,
                             model_instances: list[PedModel]) -> int:
    model = model_instances[experiment.model_instance_ind]
    model.set_state(experiment)
    start_time = time.perf_counter_ns()
    model.run()
    return time.perf_counter_ns() - start_time


def get_cpu_time_in_nano_sec(experiment: Experiment,
                             model_instances: list[PedModel]) -> int:
    model = model_instances[experiment.model_instance_ind]
    model.set_state(experiment)
    start_time = time.process_time_ns()
    model.run()
    return time.process_time_ns() - start_time


#=======================================================
#=======================================================
def execute(raw_args: Optional[Sequence[Any]] = None) -> None:
    input_file, base_name, output_dir, run_parallel = parse_sys_input(raw_args)
    compute_complexity(input_file, base_name, output_dir, run_parallel)


def parse_sys_input(
        raw_args: Optional[Sequence[Any]]) -> tuple[Path, str, Path, bool]:
    parser = argparse.ArgumentParser(
        description=
        'Compute the computational complexity using a specific configuration')
    parser.add_argument('input_file',
                        type=valid_input_file,
                        help="The path to the configuration file")
    parser.add_argument('basename',
                        type=valid_base_name,
                        help="The basename for all output files")
    parser.add_argument('output_dir',
                        type=valid_output_dir,
                        help="The path to the output file")

    parser.add_argument('-p', dest='run_parallel', action='store_true', help="Run simulations in parallel")

    args = parser.parse_args(raw_args)



    return args.input_file, args.basename, args.output_dir, args.run_parallel


def valid_input_file(string: str) -> Path:
    input_file = Path(string).resolve()
    if not input_file.is_file():
        raise FileNotFoundError(
            f'No such file: "{string}" ({input_file.resolve()})')
    return input_file


def valid_base_name(string: str) -> str:
    if ' ' in string:
        raise ValueError(f'The basename cannot contain spaces! ({string})')

    test_filename = Path(f'{string}.tmp')
    try:
        with open(test_filename, 'w'):
            pass
        test_filename.unlink()
    except OSError:
        raise ValueError(f'Invalid basename: "{string}"')

    return string


def valid_output_dir(string: str) -> Path:
    output_dir = Path(string).resolve()
    if not output_dir.is_dir():
        output_dir.mkdir()
    return output_dir


if __name__ == '__main__':
    execute(sys.argv[1:])
