from typing import TYPE_CHECKING, Callable
import numpy.typing as npt

from complexity.constants import *
if TYPE_CHECKING:
    from computational_complexity import ExperimentSet

def fit_data_scipy_args(x_data: npt.NDArray, y_data:npt.NDArray, base_var:str) -> tuple[Callable, list]:
    if base_var == INFRA_ARG:
        return infra_fit, []
    elif base_var == PED_CLASS_COUNT:
        return ped_class_count_fit, []
    else:
        raise NotImplementedError

def fields(base_var:str) -> list:
    if base_var == INFRA_ARG:
        return ['a*l', 'b'] 
    elif base_var == PED_CLASS_COUNT:
        return ['a*k', 'b'] 
    else:
        raise NotImplementedError

    return 

def infra_fit(x:npt.NDArray, a:float, b:float) -> npt.NDArray:
    return a*x + b# type: ignore

def ped_class_count_fit(x:npt.NDArray, a:float, b:float) -> npt.NDArray:
    return a*x + b # type: ignore