from __future__ import annotations
from dataclasses import dataclass, field
from math import ceil, exp, floor, log, pi, sqrt
from complexity.base_classes import PedModel, Infrastructure

from typing import TYPE_CHECKING, Optional

import numpy as np
import numpy.typing as npt

from complexity.constants import *

if TYPE_CHECKING:
    from complexity.computational_complexity import Experiment

ROUTE_CELL_SIZE = 0.1
PED_RANGE = 2  #m
PED_RANGE_2 = PED_RANGE**2
SPATIAL_GRID_SIZE = PED_RANGE  #m
OFF_SET_TUPLE = (-1, 0, 1)



@dataclass(frozen=True)
class WeidemannFD():
    jam_density: float = 6.4
    capacity: float = 1.22495
    free_flow_speed: float = 1.34
    critical_density: float = 1.75
    exp_factor: float = field(init=False)
    capacity_speed: float = field(init=False)
    jam_density_inv: float = field(init=False)

    def __post_init__(self) -> None:
        exp_factor = log(1 - self.capacity/
                         (self.free_flow_speed*self.critical_density))/(
                             1/self.critical_density - 1/self.jam_density)
        object.__setattr__(self, 'exp_factor', exp_factor)
        object.__setattr__(self, 'jam_density_inv', 1/self.jam_density)
        capacity_speed = self._calc_speed(self.critical_density)
        object.__setattr__(self, 'capacity_speed', capacity_speed)        

    def get_speed(self, density: float) -> float:
        if density == 0:
            return self.free_flow_speed

        return self._calc_speed(density)


    def get_speed_np(self, density: npt.NDArray) -> npt.NDArray:
        speeds = np.ones_like(density)*self.free_flow_speed

        speeds[density > 0] = self.free_flow_speed*(1 - np.exp(self.exp_factor*
                                             (1/density[density > 0] - self.jam_density_inv)))

        return speeds


    def _calc_speed(self, density: float) -> float:
        return self.free_flow_speed*(1 - exp(self.exp_factor*
                                             (1/density - self.jam_density_inv)))

    def get_flow(self, density: float) -> float:
        return self.get_speed(density)*density

    def get_local_demand_flow(self, density: float, sum_density:float) -> float:
        if density < self.critical_density:
            return density*self._calc_speed(sum_density)
        
        return density/sum_density*self.critical_density*self.free_flow_speed

    def get_local_capacity_flow(self, density: float, sum_density:float) -> float:
        if sum_density == 0:
            return self.free_flow_speed*self.critical_density

        if density < self.critical_density:
            return density/sum_density*self.critical_density*self.free_flow_speed
        
        return density*self._calc_speed(sum_density)


WEIDEMANN_FD = WeidemannFD()


@dataclass
class Pedestrian():

    pos_x: float
    pos_y: float
    vel_x: float
    vel_y: float

    acc_x: float = 0.0
    acc_y: float = 0.0
    grid_loc: tuple[int, int] = (-1, -1)
    array_loc: Optional[int] = None

    radius: float = 0.2  #m
    relaxation_time: float = 0.5

    pref_speed: float = WEIDEMANN_FD.free_flow_speed
    ped_force_mag: float = 1.0
    ped_force_sig: float = 1.0
    obs_force_mag: float = 1.0
    obs_force_sig: float = 1.0

    def __post_init__(self) -> None:
        self.relaxation_time_inv = 1/self.relaxation_time


class SocialForceModel(PedModel):

    JAM_DENSITY = 1/0.16  #ped/m2
    DELTA_T = 0.05

    def __init__(self,
                 infrastructure: Infrastructure,
                 ped_parameters: Optional[dict] = None) -> None:
        if ped_parameters is None:
            ped_parameters = {}
        self.ped_parameters: dict = ped_parameters
        super().__init__(infrastructure)

    def run(self) -> None:
        for ped in self.pedestrians:
            pref_dir_x, pref_dir_y = self.get_pref_dir(ped.pos_x, ped.pos_y)
            acc_x = ped.relaxation_time_inv*(ped.pref_speed*pref_dir_x
                                             - ped.vel_x)
            acc_y = ped.relaxation_time_inv*(ped.pref_speed*pref_dir_y
                                             - ped.vel_y)

            for other_ped in self.pedestrians:
                if ped == other_ped:
                    continue

                x_diff = ped.pos_x - other_ped.pos_x
                y_diff = ped.pos_y - other_ped.pos_y
                dist = sqrt(x_diff*x_diff + y_diff*y_diff)
                force_mag = ped.ped_force_mag*exp(
                    (ped.radius + other_ped.radius - dist)/ped.ped_force_sig)
                acc_x += x_diff/dist*force_mag
                acc_y += y_diff/dist*force_mag

            for line in self.obstacles:
                closest_x, closest_y = closest_point_on_line(line, ped.pos_x, ped.pos_y)
                x_diff = ped.pos_x - closest_x
                y_diff = ped.pos_y - closest_y
                dist = sqrt(x_diff*x_diff + y_diff*y_diff)
                force_mag = ped.obs_force_mag*exp(
                    (ped.radius - dist)/ped.obs_force_sig)
                acc_x += x_diff/dist*force_mag
                acc_y += y_diff/dist*force_mag

            ped.vel_x += acc_x*self.DELTA_T # type: ignore
            ped.vel_y += acc_y*self.DELTA_T # type: ignore

            ped.pos_x += ped.vel_x*self.DELTA_T
            ped.pos_y += ped.vel_y*self.DELTA_T

    def set_state(self, experiment: Experiment) -> None:
        self._create_pedestrians(experiment)

    def initialize(self) -> None:
        self.obstacles: list[Line] = []
        for obstacle in self.infra.obstacles:
            x = obstacle[0][0]
            y = obstacle[0][1]
            delta_x = obstacle[1][0] - x
            delta_y = obstacle[1][1] - y
            length_inv = 1/(delta_x*delta_x + delta_y*delta_y)

            self.obstacles.append(((x, y), (delta_x, delta_y), length_inv))

        self.pref_direction = self.infra.get_route_floor_field(ROUTE_CELL_SIZE)

    def _create_pedestrians(self, experiment: Experiment) -> None:
        # Create pedestrians based on density distribution
        pos_per_ped, density = self.infra.get_ped_pos_continuous(experiment.ped_count, self.JAM_DENSITY, experiment.local_density)
        pref_speed = WEIDEMANN_FD.get_speed(density)
        self.pedestrians: list[Pedestrian] = []
        for ped_pos in pos_per_ped:
            row_ind, col_ind = get_cell_ind_from_pos(*ped_pos)
            vel = tuple(pref_speed*self.pref_direction[row_ind, col_ind, :])
            self.pedestrians.append(
                Pedestrian(*ped_pos, *vel, **self.ped_parameters))

    def get_pref_dir(self, pos_x: float, pos_y: float) -> tuple[float, float]:
        row_ind, col_ind = get_cell_ind_from_pos(pos_x, pos_y)
        return self.pref_direction[row_ind, col_ind,0], self.pref_direction[row_ind, col_ind,1]

    @property
    def tag(self) -> str:
        return "simple_social_force"

    @property
    def label(self) -> str:
        return "Simple Social Force"


def get_cell_ind_from_pos(pos_x: float, pos_y: float) -> tuple[int, int]:
    return (int(floor(pos_y/ROUTE_CELL_SIZE)),
            int(floor(pos_x/ROUTE_CELL_SIZE)))


def closest_point_on_line(line: Line, pos_x: float, pos_y: float) -> Vector:
    rel_line_pos = ((pos_x - line[0][0])*line[1][0] +
                    (pos_y - line[0][1])*line[1][1])*line[2]
    # In the corridor without obstacles case rel_line_pos is always [0,1] for the two borders
    if rel_line_pos < 0:
        rel_line_pos = 0
    elif rel_line_pos > 1:
        rel_line_pos = 1

    x = line[0][0] + rel_line_pos*line[1][0]
    y = line[0][1] + rel_line_pos*line[1][1]
    return x, y


class SocialForceModelNoPedInteraction(SocialForceModel):

    def run(self) -> None:
        for ped in self.pedestrians:
            pref_dir_x, pref_dir_y = self.get_pref_dir(ped.pos_x, ped.pos_y)
            acc_x = ped.relaxation_time_inv*(ped.pref_speed*pref_dir_x
                                             - ped.vel_x)
            acc_y = ped.relaxation_time_inv*(ped.pref_speed*pref_dir_y
                                             - ped.vel_y)

            for line in self.obstacles:
                closest_x, closest_y = closest_point_on_line(line, ped.pos_x, ped.pos_y)
                x_diff = ped.pos_x - closest_x
                y_diff = ped.pos_y - closest_y
                dist = sqrt(x_diff*x_diff + y_diff*y_diff)
                force_mag = ped.obs_force_mag*exp(
                    (ped.radius - dist)/ped.obs_force_sig)
                acc_x += x_diff/dist*force_mag
                acc_y += y_diff/dist*force_mag

            ped.vel_x += acc_x*self.DELTA_T # type: ignore
            ped.vel_y += acc_y*self.DELTA_T # type: ignore

            ped.pos_x += ped.vel_x*self.DELTA_T
            ped.pos_y += ped.vel_y*self.DELTA_T


class SocialForceModelGrid(SocialForceModel):

    JAM_DENSITY = 1/0.16  #ped/m2
    DELTA_T = 0.05

    def __init__(self,
                 infrastructure: Infrastructure,
                 ped_parameters: Optional[dict] = None) -> None:
        super().__init__(infrastructure, ped_parameters)

    def run(self) -> None:
        for ped in self.pedestrians:
            pref_dir_x, pref_dir_y = self.get_pref_dir(ped.pos_x, ped.pos_y)
            acc_x = ped.relaxation_time_inv*(ped.pref_speed*pref_dir_x
                                             - ped.vel_x)
            acc_y = ped.relaxation_time_inv*(ped.pref_speed*pref_dir_y
                                             - ped.vel_y)

            x_grid, y_grid = ped.grid_loc
            for x_offset in OFF_SET_TUPLE:
                x = x_grid + x_offset
                if x >= 0 and x < self.x_count:
                    for y_offset in OFF_SET_TUPLE:
                        y = y_grid + y_offset
                        if y >= 0 and y < self.y_count:
                            ped_array = self.ped_cell_pos[x][y]
                            for i in range(self.array_size):
                                other_ped = ped_array[i]
                                if other_ped is None:
                                    continue
                                elif ped == other_ped:
                                    continue

                                x_diff = ped.pos_x - other_ped.pos_x
                                y_diff = ped.pos_y - other_ped.pos_y
                                dist_2 = x_diff*x_diff + y_diff*y_diff
                                if dist_2 > PED_RANGE_2:
                                    continue

                                dist = sqrt(dist_2)
                                force_mag = ped.ped_force_mag*exp(
                                    (ped.radius + other_ped.radius - dist)
                                    /ped.ped_force_sig)
                                acc_x += x_diff/dist*force_mag
                                acc_y += y_diff/dist*force_mag

            for line in self.obstacles:
                closest_x, closest_y = closest_point_on_line(line, ped.pos_x, ped.pos_y)
                x_diff = ped.pos_x - closest_x
                y_diff = ped.pos_y - closest_y
                dist = sqrt(x_diff*x_diff + y_diff*y_diff)
                force_mag = ped.obs_force_mag*exp(
                    (ped.radius - dist)/ped.obs_force_sig)
                acc_x += x_diff/dist*force_mag
                acc_y += y_diff/dist*force_mag

            ped.vel_x += acc_x*self.DELTA_T # type: ignore
            ped.vel_y += acc_y*self.DELTA_T # type: ignore

            ped.pos_x += ped.vel_x*self.DELTA_T
            ped.pos_y += ped.vel_y*self.DELTA_T

            new_grid_x, new_grid_y = get_spatial_grid_pos(ped.pos_x,
                                                ped.pos_y,
                                                SPATIAL_GRID_SIZE)
            if new_grid_x != x_grid and new_grid_y != y_grid: # type: ignore
                self.ped_cell_pos[ped.grid_loc][ped.array_loc] = None
                self.add_ped_to_grid(ped, new_grid_x, new_grid_y)

    def set_state(self, experiment:Experiment) -> None:
        super().set_state(experiment)
        self.min_ped_radius = min(ped.radius for ped in self.pedestrians)
        self.create_spatial_grid()

    def add_ped_to_grid(self, ped: Pedestrian, x: int, y:int) -> None:
        for ii in range(len(self.ped_cell_pos[x,y])):
            if self.ped_cell_pos[x,y][ii] is None:
                self.ped_cell_pos[x,y][ii] = ped
                ped.array_loc = ii
                ped.grid_loc = (x,y)
                break

    def remove_ped_from_grid(self, ped: Pedestrian) -> None:
        self.ped_cell_pos[ped.grid_loc][ped.array_loc] = None
        ped.grid_loc = (-1, -1)
        ped.array_loc = None

    def initialize(self) -> None:
        super().initialize()
                
    def create_spatial_grid(self) -> None:
        z_size = int(
            ceil(SPATIAL_GRID_SIZE*SPATIAL_GRID_SIZE/self.JAM_DENSITY))
        self.ped_cell_pos, _, _ = self.infra.get_grid_representation(
            SPATIAL_GRID_SIZE, z_size=z_size, dtype=object, default_val=None)
        self.x_count, self.y_count, self.array_size = self.ped_cell_pos.shape

        for ped in self.pedestrians:
            new_grid_loc = get_spatial_grid_pos(ped.pos_x,
                                                ped.pos_y,
                                                SPATIAL_GRID_SIZE)
            self.add_ped_to_grid(ped, *new_grid_loc)

    @property
    def tag(self) -> str:
        return "simple_social_force_grid"

    @property
    def label(self) -> str:
        return "Simple Social Force Grid"


def get_unit_spatial_grid_pos(x: float, y: float) -> tuple[int, int]:
    return int(floor(x)), int(floor(y))


def get_spatial_grid_pos(x: float, y: float,
                         grid_size: float) -> tuple[int, int]:
    return int(floor(x/grid_size)), int(floor(y/grid_size))