import pandas as pd
import numpy as np
import os
import re
from tqdm import tqdm
import pyarrow as pa

path_demographics = '../data/raw/demographics'
path_perimeter = '../data/processed/perimeter'
path_demographics_processed = '../data/processed/demographics'

def process_demographics(full_demog, grid, country):
    """
    Processes demographic data by merging it with a grid, filling missing values, 
    pivoting the data, and calculating population proportions.
    Args:
        full_demog (pd.DataFrame): DataFrame containing full demographic data with columns 'GRID_ID', 'origin', and 'pop'.
        grid (pd.DataFrame): DataFrame containing grid data with columns 'GRID_ID' and 'grid_ref'.
        country (str): The country code to fill missing 'origin' values.
    Returns:
        tuple: A tuple containing:
            - demog (pd.DataFrame): Processed demographic DataFrame with population proportions.
            - mu (float): Proportion of non-EU population at the city level.
    """
    demog = full_demog.merge(grid,how='right',on='GRID_ID').drop(columns=['GRID_ID'])
    # Filling missing values with majority class (that will have population 0 anyway).
    demog['origin'] = demog['origin'].fillna(country)
    demog = demog.pivot(index='grid_ref',columns='origin',values='pop')
    demog.columns.name = None
    demog = demog.fillna(0)
    demog['pop'] = demog.sum(axis=1)
    # Calculating the proportion of non-EU population at the city level.
    mu = demog['NOTEU'].sum()/demog['pop'].sum()
    # Expressing NOTEU in proportion of population.
    demog['NOTEU'] = demog['NOTEU'].mask(demog['pop']>0,demog['NOTEU']/demog['pop'])
    demog['pop'] = np.floor(demog['pop'])
    demog = demog.reset_index()
    demog = demog.drop(columns=[country,'EU27']).copy()
    return demog, mu


def main():
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    if os.path.exists(os.path.join(path_demographics_processed,'city_statistics.csv')):
        city_statistics = pd.read_csv(os.path.join(path_demographics_processed,'city_statistics.csv')) 
    else:
        city_statistics = pd.DataFrame()
        
    for country in list_countries:
        demog = pd.read_csv(os.path.join(path_demographics,f'{country}_L4.csv'))
        # Filtering out population with unknown origins
        demog = demog.loc[demog['origin'] != 'OTH'].copy()

        list_fuas = os.listdir(os.path.join(path_perimeter, country))

        # Extract substrings
        pattern = r"indices_(.+?)\.csv"
        list_fuas = [re.search(pattern, s).group(1) for s in list_fuas if re.search(pattern, s)]
        list_fuas = [fua for fua in list_fuas if fua not in city_statistics['city'].values]
        
        for fua in tqdm(list_fuas):
            destination_path = os.path.join(path_demographics_processed,country)
            if not os.path.exists(destination_path):
                os.makedirs(destination_path)
            # if os.path.exists(os.path.join(destination_path,f'{fua}.parquet')):
            #     continue
            # Loading perimeter of study.
            grid = pd.read_csv(os.path.join(path_perimeter,country,f'indices_{fua}.csv'),header=0,names=['grid_ref','GRID_ID'])
            demog_fua, mu = process_demographics(demog, grid, country)
            # Saving demographic data.
            grid_ref_dtype = pa.uint16() if grid['grid_ref'].max() < 2**16 else pa.uint32()
            pop_dtype = pa.uint16() if demog_fua['pop'].max() < 2**16 else pa.uint32()
            schema = pa.schema([
                ('grid_ref', grid_ref_dtype),
                ('NOTEU', pa.float32()),
                ('pop', pop_dtype)
            ])
            demog_fua.to_parquet(os.path.join(destination_path,f'{fua}.parquet'), 
                                 engine='pyarrow', compression='snappy', schema=schema)
            # Saving city statistics
            city_statistics_temp = pd.DataFrame({'city':fua,'mu':mu,'var':mu*(1-mu)},index=[0])
            if city_statistics.empty:
                city_statistics = city_statistics_temp
            else:
                city_statistics = pd.concat([city_statistics,city_statistics_temp],ignore_index=True)
    city_statistics.to_csv(os.path.join(path_demographics_processed,'city_statistics.csv'),index=False)
    
    return

if __name__ == '__main__':
    main()