from __future__ import annotations
from dataclasses import dataclass, field
import itertools
from typing import TYPE_CHECKING
from complexity.base_classes import PedModel

import numpy as np
import numpy.typing as npt

from complexity.constants import DistrType

if TYPE_CHECKING:
    from complexity.computational_complexity import Experiment

CELL_SIZE = 1  #m

@dataclass(frozen=True)
class TriangularFD():
    jam_density: float = 6.25
    free_flow_speed: float = 1.34
    critical_density: float = 1.75
    capacity: float = field(init=False)
    cong_wave_speed: float = field(init=False)

    def __post_init__(self) -> None:
        capacity = self.critical_density*self.free_flow_speed
        object.__setattr__(self, 'capacity', capacity)
        cong_wave_speed = -capacity/(self.jam_density - self.critical_density)
        object.__setattr__(self, 'cong_wave_speed', cong_wave_speed)


    def get_local_demand_flow(self, density: float, sum_density:float) -> float:
        if sum_density < self.critical_density:
            return density*self.free_flow_speed
        
        return density/sum_density*self.capacity

    def get_local_capacity_flow(self, density: float, sum_density:float) -> float:
        if sum_density == 0:
            return self.free_flow_speed*self.critical_density

        if density < self.critical_density:
            return density/sum_density*self.capacity
        
        return self.capacity - (density - self.critical_density)*self.cong_wave_speed



class MacroContGodunov(PedModel):

    def set_to_zero(self) -> None:
        self.d_right = np.zeros_like(self.density_grid)
        self.d_left = np.zeros_like(self.density_grid)
        self.d_top = np.zeros_like(self.density_grid)
        self.d_bottom = np.zeros_like(self.density_grid)

        self.demand_sum = np.zeros_like(self.density_grid_sum)
        self.demand_in = np.zeros_like(self.density_grid_sum)
        self.supply = np.zeros_like(self.density_grid_sum)

        self.s_right = np.zeros_like(self.density_grid)
        self.s_left = np.zeros_like(self.density_grid)
        self.s_top = np.zeros_like(self.density_grid)
        self.s_bottom = np.zeros_like(self.density_grid)

    def run(self) -> None:
        self.set_to_zero()

        # Fill grids
        for x,y in self.cells:
            if self.density_grid_sum[x, y] > 0:
                for ii in range(self.ped_class_count):
                    if self.density_grid[x, y, ii] > 0:
                        ped_count = self.step_size*self.fd.get_local_demand_flow(
                            self.density_grid[x, y, ii],
                            self.density_grid_sum[x, y])

                        self.d_right[x, y,
                                ii] = max(self.dir_x[x, y, ii], 0)*ped_count
                        self.d_left[x, y, ii] = -1*min(self.dir_x[x, y, ii],
                                                    0)*ped_count
                        self.d_top[x, y,
                                ii] = max(self.dir_y[x, y, ii], 0)*ped_count
                        self.d_bottom[x, y, ii] = -1*min(self.dir_y[x, y, ii],
                                                    0)*ped_count

                        self.demand_sum[x, y] += self.d_right[x, y, ii] + self.d_left[
                            x, y, ii] + self.d_top[x, y, ii] + self.d_bottom[x, y, ii]
                for ii in range(self.ped_class_count):
                        ratio = min(
                            1,
                            self.cell_area*self.density_grid[x, y, ii]
                            /self.demand_sum[x, y])
                        self.d_right[x, y, ii] = ratio*self.d_right[x, y, ii]
                        self.d_left[x, y, ii] = ratio*self.d_left[x, y, ii]
                        self.d_top[x, y, ii] = ratio*self.d_top[x, y, ii]
                        self.d_bottom[x, y, ii] = ratio*self.d_bottom[x, y, ii]



        for x,y in self.cells:
            supply_est = 0.0
            for ii in range(self.ped_class_count):
                self.demand_in[x, y] += self.d_right[x - 1, y, ii] + self.d_left[
                    x + 1, y, ii] + self.d_top[x, y - 1,
                                            ii] + self.d_bottom[x, y + 1, ii]
                supply_est += (
                    abs(self.dir_y[x, y, ii]) + abs(self.dir_x[x, y, ii])
                )*self.step_size*self.fd.get_local_capacity_flow(
                    self.density_grid[x, y, ii],
                    self.density_grid_sum[x, y])
            self.supply[x, y] = min(
                self.cell_area*
                (self.fd.jam_density - self.density_grid_sum[x, y]),
                supply_est)

        for x,y in self.cells:
            if self.demand_in[x, y] > 0.0:
                factor = self.supply[x, y]/self.demand_in[x, y]
                for ii in range(self.ped_class_count):
                        self.s_right[x, y, ii] = self.d_left[x + 1, y, ii]*factor
                        self.s_left[x, y, ii] = self.d_right[x - 1, y, ii]*factor
                        self.s_top[x, y, ii] = self.d_bottom[x, y + 1, ii]*factor
                        self.s_bottom[x, y, ii] = self.d_top[x, y - 1, ii]*factor

        self.density_grid_sum[:] = 0
        for x,y in self.cells:
            for ii in range(self.ped_class_count):
                f_right = min(self.s_right[x, y, ii],
                                self.d_left[x + 1, y, ii]) - min(
                                    self.s_left[x + 1, y, ii], self.d_right[x, y, ii])
                f_left = min(self.s_left[x, y, ii], self.d_right[x - 1, y, ii]) - min(
                    self.s_right[x - 1, y, ii], self.d_left[x, y, ii])
                f_top = min(self.s_top[x, y, ii], self.d_bottom[x, y + 1, ii]) - min(
                    self.s_bottom[x, y + 1, ii], self.d_top[x, y, ii])
                f_bottom = min(self.s_bottom[x, y, ii],
                                self.d_top[x, y - 1, ii]) - min(
                                    self.s_top[x, y - 1, ii], self.d_bottom[x, y, ii])

                self.density_grid[x, y, ii] += self.cell_area_inv*(
                    f_right+f_left+f_top+f_bottom)
                self.density_grid_sum[x, y] += self.density_grid[x, y, ii]

    def set_state(self, experiment:Experiment) -> None:
        self._fill_cells(experiment)

        self.d_right = np.ones_like(self.density_grid)
        self.d_left = np.ones_like(self.density_grid)
        self.d_top = np.ones_like(self.density_grid)
        self.d_bottom = np.ones_like(self.density_grid)

        self.demand_sum = np.ones_like(self.density_grid_sum)
        self.demand_in = np.ones_like(self.density_grid_sum)
        self.supply = np.ones_like(self.density_grid_sum)

        self.s_right = np.ones_like(self.density_grid)
        self.s_left = np.ones_like(self.density_grid)
        self.s_top = np.ones_like(self.density_grid)
        self.s_bottom = np.ones_like(self.density_grid)


    def initialize(self) -> None:
        # Make all grids two rows and columns larger to deal with borders
        self.delta_t = 0.075
        _, self.x_count, self.y_count = self.infra.get_grid_representation(CELL_SIZE) # type: ignore
        self.fd = TriangularFD()
        self.cell_area = CELL_SIZE*CELL_SIZE
        self.cell_area_inv = 1/(CELL_SIZE*CELL_SIZE)
        self.step_size = self.delta_t*CELL_SIZE
        self.cells = tuple(itertools.product(range(1, self.x_count + 1),range(1, self.y_count + 1)))
        self.pref_direction = self.infra.get_route_floor_field(CELL_SIZE)      

    def _fill_cells(self, experiment:Experiment) -> None:
        self.ped_class_count = experiment.ped_class_count
        self.density_grid = np.zeros(
            (self.x_count + 2, self.y_count + 2, experiment.ped_class_count))
        density_grid = self.infra.get_density_per_cell(CELL_SIZE,
                                                       experiment.ped_count,
                                                       self.fd.jam_density,
                                                       experiment.local_density)

        self.density_grid_sum = np.zeros(self.density_grid.shape[:2])

        class_count_per_cell = self._get_class_count_per_cell(experiment, density_grid)
        for x in range(self.x_count):
            for y in range(self.y_count):
                for ii in range(class_count_per_cell[x,y]):
                    self.density_grid[x+1, y+1,
                                    ii] = density_grid[x,y]/class_count_per_cell[x,y]
        self.density_grid_sum[1:self.x_count + 1,
                              1:self.y_count + 1] = density_grid

        self.flow = np.zeros_like(self.density_grid)
        for x in range(1, self.x_count + 1):
            for y in range(1, self.y_count + 1):
                for ii in range(experiment.ped_class_count):
                    if self.density_grid_sum[x, y] > 0:
                        self.flow[x, y, ii] = self.fd.get_local_demand_flow(
                            self.density_grid[x, y, ii],
                            self.density_grid_sum[x, y])

        self.dir_x: npt.NDArray[np.float64] = np.zeros_like(self.density_grid)
        self.dir_y: npt.NDArray[np.float64] = np.zeros_like(self.density_grid)
        
        for ii in range(experiment.ped_class_count):
            self.dir_x[1:self.x_count + 1, 1:self.y_count + 1,
                       ii] = np.transpose(self.pref_direction[:, :, 0])
            self.dir_y[1:self.x_count + 1, 1:self.y_count + 1,
                       ii] = np.transpose(self.pref_direction[:, :, 1])

    def _get_class_count_per_cell(self, experiment:Experiment, density_grid:npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
        if experiment.class_distribution is DistrType.HOMOGENEOUS:
            return np.ones_like(density_grid, dtype=int)*experiment.ped_class_count
        elif experiment.class_distribution is DistrType.CONCENTRATED:
            class_count_per_cell = np.ones_like(density_grid, dtype=int)
            ped_classes_assigned = np.sum(density_grid > 0)
            if ped_classes_assigned >= experiment.ped_class_count:
                return class_count_per_cell

            remaining_count = experiment.ped_class_count - ped_classes_assigned
            max_count_to_add = max(np.floor(np.max(density_grid)*CELL_SIZE*CELL_SIZE), np.floor(experiment.ped_class_count/ped_classes_assigned))

            class_count = min(max_count_to_add, remaining_count)
            for x in range(self.x_count):
                for y in range(self.y_count):
                    if density_grid[x,y] != 0:
                        class_count_per_cell[x,y] += class_count
                        remaining_count -= class_count
                        class_count = min(max_count_to_add, remaining_count)
                    if remaining_count == 0:
                        return class_count_per_cell
            raise Exception        
        else:
            raise NotImplementedError

    @property
    def tag(self) -> str:
        return "macro_cont_godunov"

    @property
    def label(self) -> str:
        return "Macro Continuum Godunov"

class MacroContGodunovV2(MacroContGodunov):

    def set_to_zero(self) -> None:
        self.d_right[:] = 0
        self.d_left[:] = 0
        self.d_top[:] = 0
        self.d_bottom[:] = 0

        self.demand_sum[:] = 0
        self.demand_in[:] = 0
        self.supply[:] = 0

        self.s_right[:] = 0
        self.s_left[:] = 0
        self.s_top[:] = 0
        self.s_bottom[:] = 0

    @property
    def tag(self) -> str:
        return "macro_cont_godunov_v2"

    @property
    def label(self) -> str:
        return "Macro Continuum Godunov V2"