from typing import TYPE_CHECKING, Callable
import numpy.typing as npt

from complexity.constants import *
if TYPE_CHECKING:
    from computational_complexity import ExperimentSet

def fit_data_scipy_args(x_data: npt.NDArray, y_data:npt.NDArray, base_var:str) -> tuple[Callable, list]:
    if base_var == PED_COUNT:
        return ped_count_fit, []
    elif base_var == OBSTACLE_COUNT:
        return obs_count_fit, []
    else:
        raise NotImplementedError

def fields(base_var:str) -> list:
    if base_var == PED_COUNT:
        return ['a*n', 'b'] 
    elif base_var == OBSTACLE_COUNT:
        return ['a*m', 'b'] 
    else:
        raise NotImplementedError
    

def ped_count_fit(x:npt.NDArray, a:float, b:float) -> npt.NDArray:
    return a*x + b # type: ignore

def obs_count_fit(x:npt.NDArray, a:float, b:float) -> npt.NDArray:
    return a*x + b