from __future__ import annotations
import copy
import itertools
import logging
from math import ceil, floor
from multiprocessing.sharedctypes import Value
import random
from threading import local

from complexity.base_classes import Infrastructure

import numpy as np
import numpy.typing as npt

from complexity.constants import *

DEST_LENGTH = 0.5  #m
ROUTE_CELL_SIZE = 0.1  #m
PREF_DIR = (1, 0)
DENS_CELL_SIZE = 0.4  #m
DENS_CELL_AREA = DENS_CELL_SIZE*DENS_CELL_SIZE  #m2

logger = logging.getLogger(__name__)

class UnidirectionalCorridor(Infrastructure):

    def __init__(self, length: float, width: float, obstacle_count:int, alpha_cp:float=1, beta_cp:float=1) -> None:
        super().__init__(length*width, obstacle_count)
        self.length = length
        self.width = width
        self.alpha_cp = alpha_cp
        self.beta_cp = beta_cp

        self.walkable_space = [(0.0, 0.0), (length, 0.0), (length, width),
                               (0.0, width)]
        self.destination = [(length, 0), (length + DEST_LENGTH, 0.0),
                            (length + DEST_LENGTH, width), (length, width)]
        
        self.create_obstacles()

    def create_obstacles(self) -> None:
        right_obs_count = floor(self.obstacle_count*self.alpha_cp) # lambda < 0  if vec = [1 0]
        left_obs_count = floor(self.obstacle_count*self.beta_cp) # lambda > 1  if vec = [1 0]
        mid_obs_count = self.obstacle_count - right_obs_count - left_obs_count

        self.obstacles = [((0,0),(self.length, 0), 1/(self.length**2))]*mid_obs_count
        self.obstacles += [((-self.length-1,0),(self.length, 0), 1/(self.length**2))]*right_obs_count
        self.obstacles += [((self.length+1,0),(self.length, 0), 1/(self.length**2))]*left_obs_count

    def get_route_floor_field(self,
                              cell_size: float) -> npt.NDArray[np.float64]:
        x_count = ceil(self.length/cell_size)
        y_count = ceil(self.width/cell_size)

        pref_direction = np.ones((y_count, x_count, 2), dtype=float)
        pref_direction[:, :, 0] = PREF_DIR[0]
        pref_direction[:, :, 1] = PREF_DIR[1]

        return pref_direction

    def get_ped_pos_continuous(
            self,
            ped_count: int,
            jam_density: float,
            local_density: float|DistrType) -> tuple[list[CONT_POS], float]:
        cell_pos_per_ped, density = self.get_ped_pos_cell(
            ped_count, jam_density, local_density
        )
        # Convert to xy coords
        return [
            (
                round(col*DENS_CELL_SIZE + DENS_CELL_SIZE/2, 12),
                round(row*DENS_CELL_SIZE + DENS_CELL_SIZE/2, 12)
            ) for row,
            col in cell_pos_per_ped
        ], density

    def get_ped_pos_cell(
            self,
            ped_count: int,
            jam_density: float,
            local_density: float|DistrType) -> tuple[list[CELL_POS], float]:
        # Assume a grid-like packing
        self._dens_grid = DensityGrid(self.length, self.width, DENS_CELL_SIZE)
        if self._dens_grid.cell_count < ped_count:
            raise ValueError(f'The number of cells ({self._dens_grid.cell_count}) is smaller than the number of pedestrians ({ped_count})!')
        mean_density = ped_count/self.area
        if local_density is DistrType.HOMOGENEOUS:
            return self._get_position_per_ped(ped_count, self.area), mean_density
        elif local_density is DistrType.CONCENTRATED:
            local_density = jam_density
        else:
            if local_density < mean_density:
                logger.debug(f'The local density is smaller than the mean density! {self} - pedCount = {ped_count} - localDensity = {local_density}')
            local_density = max(mean_density, local_density)
            
        upper_dens_area = mean_density/local_density*self.area
        self._dens_grid = self._dens_grid.get_sub_area(upper_dens_area)

        return self._get_position_per_ped(ped_count, upper_dens_area), local_density

    def _get_position_per_ped(self, ped_count: int,
                              area: float) -> list[CELL_POS]:
        if ped_count == 0:
            return []

        pos_per_ped = []
        cell_count_per_ped = self._dens_grid._get_cells_per_ped(ped_count, area)
        for cell_count in cell_count_per_ped:
            choice_set = []
            for _ in range(cell_count):
                choice_set.append(self._dens_grid.move_position())
            pos_per_ped.append(random.choice(choice_set))

        return pos_per_ped

    def get_density_per_cell(self, cell_size:float, ped_count: int,
            jam_density: float,
            local_density: float) -> npt.NDArray[np.float64]:
        density_grid, x_count, y_count = self.get_grid_representation(cell_size)
        mean_density = ped_count/self.area
        if local_density == DistrType.HOMOGENEOUS:
            density_grid[:] = mean_density
            return density_grid
        elif local_density is DistrType.CONCENTRATED:
            local_density = jam_density
        else:
            if local_density < mean_density:
                logger.debug(f'The local density is smaller than the mean density! {self} - pedCount = {ped_count} - localDensity = {local_density}')
            local_density = max(mean_density, local_density)
            
        upper_dens_area = mean_density/local_density*self.area
         
        dens_grid = DensityGrid(self.width, self.length, cell_size)
        cell_area = cell_size*cell_size
        upper_dens_cell_count = int(floor(upper_dens_area/cell_area))
        if upper_dens_cell_count > dens_grid.column_count*dens_grid.row_count:
            local_density = ped_count/(dens_grid.column_count*dens_grid.row_count*cell_area)
            upper_dens_cell_count = dens_grid.column_count*dens_grid.row_count
            if local_density > jam_density:
                raise Exception(f'Density to large ({local_density} - {jam_density})')

        for _ in range(upper_dens_cell_count):
            density_grid[dens_grid.move_position()] = local_density
        remaining_dens = (ped_count - upper_dens_cell_count*cell_area*local_density)/cell_area
        
        if upper_dens_cell_count == dens_grid.cell_count:
            if abs(remaining_dens) > 1e-12:
                raise Exception
        else:        
            density_grid[dens_grid.move_position()] = remaining_dens
        for _ in range(x_count*y_count - upper_dens_cell_count - 1):
            density_grid[dens_grid.move_position()] = 0
                
        return density_grid

    def get_grid_representation(self,
                                cell_size: float,
                                z_size: int = 1,
                                dtype: npt.DTypeLike = np.float64,
                                default_val: object = 1) -> tuple[npt.NDArray, int, int]:
        x_count = ceil(self.length/cell_size)
        y_count = ceil(self.width/cell_size)

        if z_size == 1:
            grid = np.ones((x_count, y_count), dtype=dtype)
        else:
            grid = np.ones((x_count, y_count, z_size), dtype=dtype)
        grid[:] = default_val

        return grid, x_count,y_count

    def get_grid_representation_as_dict(
            self, cell_size: float) -> dict[tuple[int, int], list]:
        x_count = ceil(self.length/cell_size)
        y_count = ceil(self.width/cell_size)
        return {
            coord: []
            for coord in itertools.product(range(x_count), range(y_count))
        }

    def get_network_representation(self, grid_size: float) -> None:
        pass

    @staticmethod
    def get_instance(length:float, width: float, alpha_cp:float, beta_cp:float, obstacle_count:int) -> UnidirectionalCorridor:
        return UnidirectionalCorridor(length, width, obstacle_count, alpha_cp, beta_cp)

    def __str__(self) -> str:
        return f'Unidirectional corridor ({self.length} X {self.width})m. {self.obstacle_count} obstacles'

    @property
    def tag(self) -> str:
        return f'unidirectional_corridor'

    @property
    def min_obstacle_count(self) -> int:
        return 2

class RectangularCorridor(UnidirectionalCorridor):
    
    @staticmethod
    def get_instance(length:float, width: float, obstacle_count:int) -> UnidirectionalCorridor:
        if not RectangularCorridor.is_dens_cell_length_multiple(length):
            raise ValueError(f'The length must be a a multiple of {DENS_CELL_SIZE}')
        if not RectangularCorridor.is_dens_cell_length_multiple(width):
            raise ValueError(f'The width must be a a multiple of {DENS_CELL_SIZE}')

        return RectangularCorridor(length, width, obstacle_count)

    def __str__(self) -> str:
        return f'Rectangular corridor ({self.length} X {self.width})m. {self.obstacle_count} obstacles'

    @property
    def tag(self) -> str:
        return f'rectangular_corridor'

    @staticmethod
    def is_dens_cell_length_multiple(magnitude:float) -> bool:
        rem = magnitude % DENS_CELL_SIZE
        if rem != 0:
            if rem < 1e-12:
                return True
            if abs(rem - DENS_CELL_SIZE) < 1e-12:
                return True
            return False
        return True

class DensityGrid():

    def __init__(self, length: float, width: float, cell_size: float) -> None:
        self.length = length
        self.width = width
        self.cell_size = cell_size

        self.column_count = ceil(self.length/cell_size)
        self.row_count = ceil(self.width/cell_size)
        self.cell_count = self.column_count*self.row_count
        self.start_cell = (random.choice([0, self.row_count - 1]),
                           random.choice([0, self.column_count - 1]))

        self._col_direction = 1 if self.start_cell[1] == 0 else -1
        self._row_direction = 1 if self.start_cell[0] == 0 else -1
        self._current_cell = copy.copy(self.start_cell)

    def get_upper_dens_cell_count(self,
                                  mean_dens: float,
                                  dens_diff: float,
                                  lower_dens: float) -> int:
        return int(
            round(((mean_dens-lower_dens)*
                   (self.cell_count - 1) + mean_dens)/dens_diff))

    def move_position(self) -> tuple[int, int]:
        cur_cell = self._current_cell
        self._previous_cell = cur_cell
        new_row_ind = cur_cell[0] + self._row_direction
        new_col_ind = self._current_cell[1]
        if new_row_ind < 0:
            new_row_ind = 0
            self._row_direction = 1
            new_col_ind += self._col_direction
        elif new_row_ind == self.row_count:
            new_row_ind = self.row_count - 1
            self._row_direction = -1
            new_col_ind += self._col_direction

        self._current_cell = (new_row_ind, new_col_ind)

        return cur_cell


    def _get_cells_per_ped(self, ped_count: int, sub_area: float, full_area:bool = False) -> list[int]:
        if full_area:
            cell_count = self.cell_count
        else:      
            cell_count = max(round(sub_area/DENS_CELL_AREA), ped_count)
        
        cells_per_ped = cell_count/ped_count
        cells_per_ped_upper = ceil(cells_per_ped)
        cells_per_ped_lower = max(floor(cells_per_ped),1)

        if cells_per_ped_lower == cells_per_ped_upper:
            peds_with_upper = ped_count
            peds_with_lower = 0
        else:
            peds_with_lower = (cell_count - cells_per_ped_upper*ped_count)//-1
            peds_with_upper = ped_count - peds_with_lower
        cell_count_per_ped = [cells_per_ped_lower]*peds_with_lower + [
            cells_per_ped_upper
        ]*peds_with_upper
        random.shuffle(cell_count_per_ped)

        return cell_count_per_ped

    def get_sub_area(self, area:float) -> DensityGrid:
        side_count = int(np.ceil(np.sqrt(area)/self.cell_size))
        if side_count <= self.row_count:
            return DensityGrid(side_count*self.cell_size, side_count*self.cell_size, self.cell_size) 
            
        width = self.row_count*self.cell_size
        length = np.ceil(area/width/self.cell_size)*self.cell_size
        if length > self.length:
            raise Exception
            
        return DensityGrid(length, width, self.cell_size)