# This script builds urban fragments based on the urban frontiers per perimeter of study.
import geopandas as gpd
import pandas as pd
from shapely.geometry import box
import os

CELL_SIZE = 100
BUFFER_SIZE = 10
# Minimum number of cells in an urban fragment
MIN_CELLS = 5

dir_perimeter = '../data/processed/perimeter'
dir_pop = '../data/processed/demographics'
dir_frontiers = '../data/processed/frontiers'
dir_fragments = '../data/processed/urban_fragments'

def create_squares_from_centers(points):
    """
    Creates square polygons from the center points.
    This function takes a GeoDataFrame of points and generates square polygons 
    with a specified dimension (CELL_SIZE) centered at each point.
    Parameters:
    points (GeoDataFrame): A GeoDataFrame containing point geometries.
    Returns:
    GeoSeries: A GeoSeries containing square polygons centered at each point.
    """
    # A function that creates square polygon with dimension CELL_SIZE from the geometry column of points.
    return points.apply(lambda p: box(p.x - CELL_SIZE/2, p.y - CELL_SIZE/2, p.x + CELL_SIZE/2, p.y + CELL_SIZE/2))


def create_fragments(cells, frontiers):
    """
    Create urban fragments by subtracting frontiers from cells and processing the resulting geometries.
    Parameters:
    cells (GeoDataFrame): A GeoDataFrame containing the geometries of the cells.
    frontiers (GeoDataFrame): A GeoDataFrame containing the geometries of the frontiers.
    Returns:
    GeoDataFrame: A GeoDataFrame containing the geometries of the created fragments with a unique fragment_id.
    """
    # Create fragments by substracting frontiers from cells.
    fragments = gpd.overlay(cells, frontiers, how='difference')
    # Buffering fragments to fix issues like cells only touching with the corner.
    fragments['geometry'] = fragments.buffer(BUFFER_SIZE/2)
    # Creating the fragments.
    fragments = fragments.dissolve().explode()
    fragments = fragments.reset_index(drop=True)
    fragments['fragment_id'] = fragments.index
    fragments = fragments[['fragment_id','geometry']].copy()
    return fragments

def reallocation(cells_fragment):
    """
    Reallocates cells to urban fragments based on fragment size.
    This function processes a DataFrame of cells and their associated fragment IDs,
    reallocating cells to fragments that meet a minimum size requirement. Cells that
    do not belong to any of the new fragments are reassigned to the nearest larger fragment.
    Parameters:
    cells_fragment (pd.DataFrame): DataFrame containing cells with 'fragment_id' and 'geometry' columns.
    Returns:
    pd.DataFrame: Updated DataFrame with cells reallocated to new fragments.
    Raises:
    Exception: If there is a misalignment between the number of cells in the reallocated
               DataFrame and the original DataFrame.
    """
    # Looking at the size distribution of the fragments.
    fragments_counts = cells_fragment['fragment_id'].value_counts()
    # Filtering out fragments with less than MIN_CELLS cells.
    fragments_counts = fragments_counts.loc[fragments_counts>=MIN_CELLS].copy()
    
    # Redefining as the cells belonging to the fragments with more than MIN_CELLS cells.
    new_fragments = cells_fragment.loc[cells_fragment['fragment_id'].isin(fragments_counts.index)].copy()
    new_fragments = new_fragments[['fragment_id','geometry']].dissolve(by='fragment_id').reset_index()
    
    # Extracting cells that do not belong to any of the new fragments.
    cells_reallocation = cells_fragment.loc[~cells_fragment['fragment_id'].isin(fragments_counts.index),['grid_ref','geometry']].copy()
    # Extracting cells that belong to the new fragments.
    cells_new_fragment = cells_fragment.loc[cells_fragment['fragment_id'].isin(fragments_counts.index)].copy()
    cells_reallocation = cells_reallocation.sjoin_nearest(new_fragments, how='inner')
    
    # Sometime cells are assigned to several fragments, keeping only the assignment to the bigger fragment.
    cells_reallocation = cells_reallocation.join(fragments_counts, on='fragment_id')
    cells_reallocation = cells_reallocation.sort_values(by='fragment_id', ascending=False)
    cells_reallocation = cells_reallocation.drop_duplicates(subset='grid_ref')
    
    # Checking if we are not missing/duplicating cells.
    if not len(cells_reallocation) + len(cells_new_fragment) == len(cells_fragment):
        raise Exception('Misalignment between cells_reallocation and cells_new_fragment')
    cells_fragment = pd.concat([cells_new_fragment, cells_reallocation], ignore_index=True)
    
    return cells_fragment

def assign_fragments(cells_center, fragments):
    """
    Assign cells to urban fragments based on their center coordinates.
    Parameters:
    cells_center (GeoDataFrame): A GeoDataFrame containing the center coordinates of the cells.
    fragments (GeoDataFrame): A GeoDataFrame containing the urban fragments.
    Returns:
    GeoDataFrame: A GeoDataFrame with cells assigned to fragments, containing 'grid_ref' and 'fragment_id' columns.
    Raises:
    Exception: If the total number of cells assigned to fragments does not match the number of input cells.
    """
    # Assigning cells to fragments by looking in which fragment the cell center is contained.
    cells_within = cells_center.sjoin(fragments, how='inner', predicate='within')
    # Finding cells that do not belong to any fragment.
    cells_isolated = cells_center.loc[~cells_center['grid_ref'].isin(cells_within['grid_ref'])].copy()
    # Assigning isolated cells to the nearest fragment.
    cells_isolated = cells_isolated.sjoin_nearest(fragments, how='inner')
    # Checking if we are not missing/duplicating cells.
    if not len(cells_within) + len(cells_isolated) == len(cells_center):
        raise Exception('Misaligned cells_within and cells_isolated')
    
    # Creating a new dataframe with the cells assigned to fragments.
    cells_within = cells_within.drop(columns=['index_right'])
    cells_isolated = cells_isolated.drop(columns=['index_right'])
    cells_fragment = pd.concat([cells_within, cells_isolated], ignore_index=True)
    
    # Reallocation of cells belonging to fragments that are too small.
    cells_fragment = reallocation(cells_fragment)
    return cells_fragment[['grid_ref','fragment_id']].copy()
    

def main():
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    for country in list_countries:

        list_fuas = os.listdir(os.path.join(dir_perimeter, country))
        # Extract gpkg filenames
        list_fuas = [fua.split('.')[0] for fua in list_fuas if fua.endswith('.gpkg')]

        if not os.path.exists(os.path.join(dir_fragments, country)):
            os.makedirs(os.path.join(dir_fragments, country))

        for fua in list_fuas:
            if os.path.exists(os.path.join(dir_fragments, country,f'{fua}_fragments.csv')):
                continue
            cells_center = gpd.read_file(os.path.join(dir_perimeter,country,f'{fua}.gpkg'))
            cells_center['grid_ref'] = cells_center.index.copy()
            
            # Creating cells from the center.
            cells = cells_center.copy()
            cells['geometry'] = create_squares_from_centers(cells_center['geometry'])

            # Filtering out unpopulated cells.
            pop = pd.read_parquet(os.path.join(dir_pop, country, f'{fua}.parquet'))
            cells = cells.merge(pop, on='grid_ref')
            cells = cells.loc[cells['pop']>0].copy()
            cells_center = cells_center.loc[cells_center['grid_ref'].isin(cells['grid_ref'])].copy()
            
            # Extracting urban frontiers.
            try:
                frontiers_line = gpd.read_file(os.path.join(dir_frontiers, country,f'{fua}_line.gpkg'))
                if 'highway' in frontiers_line.columns:
                    frontiers_line = frontiers_line.loc[frontiers_line['highway']!='primary'].copy()
            except:
                print(f'No frontiers found for {country}/{fua}')
                continue
            try:
                frontiers_poly = gpd.read_file(os.path.join(dir_frontiers, country,f'{fua}_poly.gpkg'))
            except:
                print(f'No polygons found for {country}/{fua}')
                frontiers_poly = gpd.GeoDataFrame()
                if 'highway' in frontiers_poly.columns:
                    frontiers_poly = frontiers_poly.loc[frontiers_poly['highway']!='primary'].copy()
            frontiers_line['geometry'] = frontiers_line.buffer(BUFFER_SIZE)
            frontiers = pd.concat([frontiers_line, frontiers_poly], ignore_index=True)
            
            # Create urban fragments
            fragments = create_fragments(cells, frontiers)
            
            cells_fragment = assign_fragments(cells_center, fragments)
            if not os.path.exists(os.path.join(dir_fragments, country)):
                os.makedirs(os.path.join(dir_fragments, country))
            cells_fragment.to_csv(os.path.join(dir_fragments, country,f'{fua}_fragments.csv'), index=False)
    return

if __name__ == "__main__":
    main()