from __future__ import annotations
from enum import Enum, auto
from typing import Optional, Sequence, TypeAlias

import numpy as np
import numpy.typing as npt

Vector: TypeAlias = Sequence[float]
Line: TypeAlias = tuple[tuple[float,float], tuple[float,float], float]
Obstacle: TypeAlias = tuple[tuple[float,float], tuple[float,float]]

CELL_POS: TypeAlias = tuple[int, int]
CONT_POS: TypeAlias = tuple[float, float]
FloatSeq: TypeAlias = npt.NDArray[np.double]|list[float]
IntSeq: TypeAlias = npt.NDArray[np.int_]|list[int]
NumSeq: TypeAlias = FloatSeq|IntSeq



STRING_TYPE:str = 'str'
FLOAT_TYPE:str = 'float'
INT_TYPE:str = 'int'
BOOL_TYPE:str = 'bool'
DIR_TYPE:str = 'dir'
COLLECTION_TYPE:str = 'collection'
FLOAT_COLLECTION_TYPE:str = 'float_collection'
INT_COLLECTION_TYPE:str = 'int_collection'
DISTR_COLLECTION_TYPE:str = 'distr_collection'
ARG_COLLECTION_TYPE:str = 'arg_collection'
DISTR_TYPE: str = 'distr'
ARG_TYPE: str = 'arg'

EXPERIMENT_COLLECTION:str = 'exp_collection'
NON_COMB:str = 'non_comb'
BASE_VAR:str = 'base_var'
COMB_VARS:str = 'comb_vars'
CONFIG_NAME:str = 'config_name'
X_CONV_FCN: str = 'x_conv_fcn'
DENSITY: str = 'density'
HOMOGENEOUS:str = 'homogeneous'
CONCENTRATED:str = 'concentrated'

COMMENT:str = '#'
INPUT_FILE_EXT:str = '.input'
OUTPUT_FILE_EXT_PARTS:tuple[str,str] = ('.output','.npz')
EXP_SETS_FILE_EXT:str = '.sets.npz'

INIT_STEP_COUNT:str = 'init_step_count'
ALLOWABLE_ERR:str = 'allowable_err'
STEP_SIZE:str = 'step_size'
MAX_STEP_COUNT:str = 'max_step_count'

MODEL_TAG:str = 'model_tag'
INFRA_TAG:str = 'infra_tag'
PED_COUNT:str = 'ped_count'
INFRA_ARG:str = 'infrastructure_arg'
OBSTACLE_COUNT:str = 'obstacle_count'
LOCAL_DENSITY:str = 'local_density'
CLASS_DISTR:str = 'class_distribution'
PED_CLASS_COUNT:str = 'ped_class_count'
MODEL_INSTANCE_IND:str = 'model_instance_ind'

MODEL_TAGS:str = 'model_tags'
INFRA_TAGS:str = 'infra_tags'
PED_COUNTS:str = 'ped_counts'
OBSTACLE_COUNTS:str = 'obstacle_counts'
LOCAL_DENSITIES:str = 'local_densities'
CLASS_DISTRS:str = 'class_distributions'
PED_CLASS_COUNTS:str = 'ped_class_counts'
INFRA_ARGS:str = 'infrastructure_args'

INPUT_TO_EXP_TAGS = {
    MODEL_TAGS: MODEL_TAG,
    INFRA_TAGS: INFRA_TAG,
    PED_COUNTS: PED_COUNT,
    OBSTACLE_COUNTS: OBSTACLE_COUNT,
    LOCAL_DENSITIES: LOCAL_DENSITY,
    CLASS_DISTRS: CLASS_DISTR,
    PED_CLASS_COUNTS: PED_CLASS_COUNT,
    INFRA_ARGS: INFRA_ARG,
}


MODEL_INFRA_COMBS:str = 'model_infra_combs'

LEFT: str = 'left'
RIGHT: str = 'right'
BOTTOM: str = 'bottom'
TOP: str = 'top'
GAP_HOR: str = 'gap_hor'
GAP_VER: str = 'gap_ver'
LEFT_LABEL: str = 'left_label'
V_LEGEND: str = 'v_legend'
SHARE_Y: str = 'share_y'
SHARE_X: str = 'share_x'
SHARE_Y_LIST: str = 'share_y_list'
SHARE_X_LIST: str = 'share_x_list'
HOR_AXES_LOCS: str = 'hor_axes_locs'
VER_AXES_LOCS: str = 'ver_axes_locs'
COLOR_BAR_WIDTH:str = 'color_bar_width'
COLOR_LABEL_WIDTH:str = 'color_label_width'
COLOR_BAR_GAP:str = 'color_bar_gap'
LEGEND_WIDTH:str = 'legend_width'
LEGEND_GAP:str = 'legend_gap'

class DistrType(Enum):
    HOMOGENEOUS = -np.inf
    CONCENTRATED = np.inf

    def __str__(self) -> str:
        if self is DistrType.HOMOGENEOUS:
            return HOMOGENEOUS
        elif self is DistrType.CONCENTRATED:
            return CONCENTRATED
        else:
            raise KeyError(f'Unknown unit "{self.name}"!')

    def __lt__(self, other:float|DistrType) -> bool:
        if self is DistrType.HOMOGENEOUS:
            return True
        if self is DistrType.CONCENTRATED:
            return False
        
        raise TypeError

    def __gt__(self, other:DistrType):
        if self is DistrType.HOMOGENEOUS:
            return False
        if self is DistrType.CONCENTRATED:
            return True
        
        raise TypeError

    def __eq__(self, __o: object) -> bool:
        if isinstance(__o, DistrType):
            return super().__eq__(__o)
        else:
            try:
                if self.value == __o:
                    return True
                if str(self) == __o:
                    return True
            except:
                pass
                
        return False

    def __hash__(self) -> int:
        return super().__hash__()

    @staticmethod
    def get_instance(arg:str|float) -> DistrType:
        if isinstance(arg, float):
            if arg == np.inf:
                return DistrType.CONCENTRATED
            if arg == -np.inf:
                return DistrType.HOMOGENEOUS
            raise ValueError(f'The float argument {arg} is not recognized!')
        
        if arg == str(DistrType.CONCENTRATED):
            return DistrType.CONCENTRATED
        if arg == str(DistrType.HOMOGENEOUS):
            return DistrType.HOMOGENEOUS

        raise ValueError(f'The string argument {arg} is not recognized!')


class ColorBarLoc(Enum):
    TOP = auto()
    BOTTOM = auto()
    RIGHT = auto()
    LEFT = auto()

    @property
    def is_hor(self) -> bool:
        return self in (ColorBarLoc.LEFT, ColorBarLoc.RIGHT)

    @property
    def is_ver(self) -> bool:
        return self in (ColorBarLoc.BOTTOM, ColorBarLoc.TOP)

    @property
    def orientation(self) -> Optional[str]:
        if self.is_hor:
            return 'vertical'
        elif self.is_ver:
            return 'horizontal'

        return None


IntVar: TypeAlias = int | Sequence[int] | str
FloatVar: TypeAlias = int | Sequence[float] | str
DistrVar: TypeAlias = float | DistrType | Sequence[float | DistrType] | str

PLOT_FCN:str = 'plot_fcn'

DATA_TYPE: str = 'data_type'
PLOT_TYPE: str = 'plot_type'
X_DATA_VAR: str = 'x_data_var'
Y_DATA_VAR: str = 'y_data_var'
FIT_MODULE: str = 'fit_module'
AX_CONFIG: str = 'ax_config'
ARTIST_CONFIG: str = 'artist_config'
COLOR_BAR_LOC:str = 'color_bar_loc'
AX_LOC:str = 'ax_loc'
LEGEND_LOC:str = 'legend_loc'
TIME_STEPS_PER_SEC_PER_MODEL: str = 'time_steps_per_sec_per_model'

EMPIRICAL_DATA:str = 'emprirical'
THEORETICAL_DATA:str = 'theoretical'

SINGLE_LINE: str = 'single_line'
MULTI_LINE: str = 'multi_line'
THREE_D: str = '3d'
ALL: str = 'all'


FIGURE:str = 'figure'
AXES:str = 'axes'
LEGEND:str = 'legend'
LINE:str = 'line'
TEXT:str = 'text'
P_COLOR_MESH:str = 'pcolormesh'
COLOR_BAR:str = 'color_bar'
TITLE:str = 'title'
TITLE_CONFIG:str = 'title_config'
GRID:str = 'grid'
COLORS:str = 'colors'
STYLES:str = 'styles'
LINE_WIDTHS:str = 'line_widths'
SINGLE_COLOR_BAR:str = 'single_cb'
REL_DIFF_FIRST_TEXT:str = 'rel_diff_first_text'
ARGS: str = 'args'
KWARGS: str = 'kwargs'
LINES:str = 'lines'

COLOR:str = 'color'
STYLE:str = 'linestyle'
LINE_WIDTH: str = 'linewidth'
LINE_CONFIG:str = 'line_config'

FONT_SIZE:str = 'font_size'

X_TICKS:str = 'x_ticks'
X_TICK_LABELS:str = 'x_tick_labels'
X_TICK_LABELS_CONFIG:str = 'x_tick_labels_config'
X_LIM:str = 'x_lim'
X_LABEL:str = 'x_label'
X_LABEL_CONFIG:str = 'x_label_config'
                       
Y_TICKS:str = 'y_ticks'
Y_TICK_LABELS:str = 'y_tick_labels'
Y_TICK_LABELS_CONFIG:str = 'y_tick_labels_config'
Y_TICKS_VISIBLE:str = 'y_ticks_visible'
Y_LIM:str = 'y_lim'
Y_LABEL:str = 'y_label'
Y_LABEL_CONFIG:str = 'y_label_config'
TICK_LABEL_FORMAT:str = 'tick_label_format'

LEGEND_LABEL:str = 'legend_label'

RUN_TIME_UNIT:str = 'run_time_unit'
class RunTimeUnit(Enum):
    NANO_SEC = 1
    MICRO_SEC = 1/1e3
    MILI_SEC = 1/1e6

    @property
    def conversion_factor(self) -> float:
        return self.value # type: ignore

    def __str__(self) -> str:
        if self == RunTimeUnit.NANO_SEC:
            return 'ns'
        elif self == RunTimeUnit.MICRO_SEC:
            return r'$\mu$s'
        elif self == RunTimeUnit.MILI_SEC:
            return 'ms'
        else:
            raise KeyError(f'Unknown unit "{self.name}"!')

ALPHA: str = 'alpha'
BETA: str ='beta'
GAMMA: str = 'gamma'
ZETA: str = 'zeta'
XI: str = 'xi'
PSI: str = 'psi'
CHI: str = 'chi'
ALPHA_CP: str = 'a_cp'
BETA_CP: str = 'b_cp'
ALPHA_DD: str = 'a_dd'
ALPHA_CELL:str = 'a_cell'
ALPHA_DQ: str = 'a_dq'
ALPHA_SQ: str = 'a_sq'
BETA_SQ: str = 'b_sq'
CELL_COUNT: str = 'cell_count'

FLOPS:str = 'FLOPS'
MEMOPS:str = 'MEMOPS'
OPS:str = 'OPS'

EQ_MODULE_NAME_PER_MODEL:str = 'eq_module_name_per_model'
EQ_TYPE:str = 'eq_type'
SPECIAL_VARS:str = 'special_vars'
CONVERT:str = 'convert'
