from __future__ import annotations
from configparser import SectionProxy
from pathlib import Path
import sys
import importlib.util
from types import ModuleType
from typing import Any, Callable, Type, TYPE_CHECKING

import numpy as np

import matplotlib.cm as mpl_cm

from complexity.constants import *
from complexity.models import *
from complexity.infrastructures import *

if TYPE_CHECKING:
    from complexity.base_classes import PedModel

def read_basic_config_section(section: SectionProxy,
                              variable_config: dict[str,str],
                              optional: bool = False) -> dict[str, Any]:
    section_content: dict[str, Any] = {}
    for var_name, var_type in variable_config.items():
        if optional and var_name not in section:
            continue
        section_content[var_name] = cast_variable(section[var_name], var_type)

    return section_content


def cast_variable(value_str: str, var_type: str, *args: tuple[Any]) -> Any:    
    try:
        if '#' in value_str:
            value_str = value_str[:value_str.find('#')]
        value_str = value_str.strip()
        
        if value_str == str(None):
            if var_type in (COLLECTION_TYPE, INT_COLLECTION_TYPE, FLOAT_COLLECTION_TYPE, DISTR_COLLECTION_TYPE, ARG_COLLECTION_TYPE):
                return [None]
            return None

        if var_type == STRING_TYPE:
            return value_str
        elif var_type == FLOAT_TYPE:
            return float(value_str)
        elif var_type == INT_TYPE:
            return int(value_str)
        elif var_type == BOOL_TYPE:
            if value_str.lower() == 'true':
                return True
            else:
                return False
        elif var_type == DIR_TYPE:
            directory = Path(value_str)
            if not directory.is_dir():
                directory.mkdir()
            return directory
        elif var_type == FLOAT_COLLECTION_TYPE:
            return parse_float_collection(value_str)
        elif var_type == INT_COLLECTION_TYPE:
            return parse_int_collection(value_str)
        elif var_type == COLLECTION_TYPE:
            return parse_collection(value_str, *args)
        elif var_type == DISTR_TYPE:
            if value_str == str(DistrType.HOMOGENEOUS):
                return DistrType.HOMOGENEOUS
            elif value_str == str(DistrType.CONCENTRATED):
                return DistrType.CONCENTRATED
            else:
                return cast_variable(value_str, FLOAT_TYPE, *args)
        elif var_type == DISTR_COLLECTION_TYPE:
            return parse_distr_collection(value_str)
        elif var_type == ARG_COLLECTION_TYPE:
            return parse_args_collection(value_str)
        elif var_type == ARG_TYPE:
            return parse_arg(value_str)
    except:
        raise ValueError(f'Cannot cast value "{value_str}" to {var_type}!')

    raise TypeError(f'Unknown variable type "{var_type}"')

def to_string(value: Any, var_type: str, no_brackets:bool = False) -> str:
    if var_type in (FLOAT_COLLECTION_TYPE, INT_COLLECTION_TYPE, COLLECTION_TYPE, ARG_COLLECTION_TYPE, DISTR_COLLECTION_TYPE):
        if not isinstance(value, (list, np.ndarray)):
            value = [value]      
        if no_brackets:
            return f'{", ".join([str(el) for el in value])}'
        else:  
            return f'({", ".join([str(el) for el in value])})'
    else:
        return str(value)

def clean_comma_separated_str(value_str: str, as_tuples:bool=False) -> list[str]:
    if value_str[0] == '[':
        value_str = value_str[1:]
    if value_str[-1] == ']':
        value_str = value_str[:-1]

    content_str = "".join(value_str.split())
    if as_tuples:
        if value_str.count('(') == 1:
            return [value_str[1:-1]]

        return [f'{entry})' if entry[-1] != ')' else entry  for entry in content_str.split("),")] 

    if content_str[0] == '(':
        content_str = content_str[1:]
    if value_str[-1] == ')':
        content_str = content_str[:-1]

    return content_str.split(",")

def parse_int_collection(value_str: str, as_np:bool=False, return_sorted:bool=True) -> IntSeq:
    return parse_num_collection(value_str, int, as_np, return_sorted) # type: ignore

def parse_float_collection(value_str: str, as_np:bool=False, return_sorted:bool=True) -> FloatSeq:
    return parse_num_collection(value_str, float, as_np, return_sorted) # type: ignore

def parse_num_collection(value_str: str, num_fcn: Callable[[str], int|float], as_np:bool=False, return_sorted:bool=True) -> NumSeq:
    contents = clean_comma_separated_str(value_str)
    num_range = []
    check_unique = False
    for entry in contents:
        if ":" in entry:
            check_unique = True
            parts = [
                num_fcn(number_str.strip()) for number_str in entry.split(':')
            ]
            if len(parts) == 2:
                num_range += list(np.arange(parts[0], parts[1] + 1))
            elif len(parts) == 3:
                num_range += list(np.arange(parts[0], parts[2] + parts[1], parts[1]))
            else:
                raise Exception(f'Invalid pattern "{entry}"')
        else:
            num_range.append(num_fcn(entry))

    if return_sorted:
        if check_unique:
            num_range = sorted(set(num_range))
        else:
            num_range = sorted(num_range)
    if as_np:
        return np.array(num_range)

    return num_range

def parse_distr_collection(value_str: str) -> list[float|DistrType]:
    contents = clean_comma_separated_str(value_str)
    distr_collection = []
    for entry in contents:
        distr_collection.append(cast_variable(entry, DISTR_TYPE))

    return distr_collection

def parse_args_collection(value_str: str) -> list[Any]:
    contents = clean_comma_separated_str(value_str, as_tuples=True)
    args_collection = []
    for entry in contents:
        args_collection.append(parse_arg(entry))

    return args_collection

def parse_arg(value_str:str) -> Any:
    if can_be_casted_as(value_str, INT_COLLECTION_TYPE):
        return tuple(parse_int_collection(value_str, return_sorted=False))
    if can_be_casted_as(value_str, FLOAT_COLLECTION_TYPE):
        return tuple(parse_float_collection(value_str, return_sorted=False))
    if can_be_casted_as(value_str, BOOL_TYPE) and not can_be_casted_as(value_str, INT_TYPE):
        return cast_variable(value_str, BOOL_TYPE)
    if can_be_casted_as(value_str, INT_TYPE):
        return cast_variable(value_str, INT_TYPE)
    if can_be_casted_as(value_str, FLOAT_TYPE):
        return cast_variable(value_str, FLOAT_TYPE)
    
    return value_str

def can_be_casted_as(value_str:str, var_type: str) -> bool:
    try:
        cast_variable(value_str, var_type)
    except ValueError:
        return False
    
    return True


def parse_collection(value_str: str, types: Optional[Sequence[str]|str] = None) -> list[Any]:
    contents = clean_comma_separated_str(value_str)
    if types == None:
        return contents
    
    if isinstance(types, str):
        for ii, entry in enumerate(contents):
            contents[ii] = cast_variable(entry, types)
    else:
        for ii, entry in enumerate(contents):
            contents[ii] = cast_variable(entry, types[ii])

    return contents

def get_classes(section: SectionProxy,
                package_name: str) -> dict[str, Type[PedModel]]:
    classes = {}
    for entry_name in section:
        module_name, class_name = section[entry_name].split('.')
        classes[entry_name] = getattr(
            sys.modules[f'complexity.{package_name}.{module_name}'], class_name)

    return classes


def get_colors_from_cmap(cmap_name:str, color_info:int|Sequence[float], start_buffer:float = 0, end_buffer:float = 0) -> list[tuple[float, float, float,float]]:
    cm = mpl_cm.get_cmap(cmap_name)
    if isinstance(color_info, (tuple, list)):
        color_locs = color_info
    elif isinstance(color_info, int):        
        color_locs = np.linspace(start_buffer,1 - end_buffer,color_info)
    else:
        raise NotImplementedError
    colors = cm(color_locs) # type: ignore
    if isinstance(colors, tuple): # type: ignore
        return [colors]
    
    return [tuple(colors[ii,:]) for ii in range(colors.shape[0])]

def load_module(filename: Path) -> ModuleType:
    module_name = filename.stem
    spec = importlib.util.spec_from_file_location(module_name, filename)
    if spec is None:
        raise Exception
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)  # type: ignore
    return module


