# This script measure the theoretical variance of the demographic variables.
import pandas as pd
import numpy as np
import os

path_demographics = '../data/raw/demographics/'
path_output = '../data/processed/demographics/'

def load_hh_data():
    """
    Load and process household data for selected European countries.
    This function performs the following steps:
    1. Loads data on inhabitants from a TSV file.
    2. Filters the data to include only total population counts for the year 2011.
    3. Filters the data to include only selected countries.
    4. Loads household size data from a TSV file.
    5. Filters the household data to include only total households and specific household sizes.
    6. Cleans and processes the household size data.
    7. Aggregates households with 6 or more persons into one category.
    8. Computes the number of single-person households.
    9. Computes the size distribution of households per NUTS region.
    Returns:
        pd.DataFrame: A DataFrame containing the processed household size distribution data 
                      for the selected countries and regions.
    """
    # Load data on inhabitants.
    inh = pd.read_csv(os.path.join(path_demographics,'inhabitants_count.tsv'), sep=r'[\t,]')

    # Discarding data on age groups and sex.
    inh = inh.loc[(inh['age'] == 'TOTAL') & (inh['sex'] == 'T')].copy()
    inh = inh[['geo\\TIME_PERIOD', '2011']].copy()
    inh = inh.rename(columns={'geo\\TIME_PERIOD': 'geo'})
    inh['2011'] = inh['2011'].astype(int)

    # Filtering for countries of interest.
    countries = ['DE','FR','IT','ES','UK','IE','PT','NL']
    inh = inh.loc[(inh['geo'].str.contains('|'.join(countries))) & (inh['2011'] > 0)].copy()

    # Loading household size data.
    hh = pd.read_csv(os.path.join(path_demographics,'household_size.tsv'), sep=r'[\t,]')
    hh = hh[['hhcomp','n_person','geo\\TIME_PERIOD', '2011']].copy()
    hh = hh.rename(columns={'geo\\TIME_PERIOD': 'geo'})

    # Keeping only total households no interest in household type.
    hh = hh.loc[hh['hhcomp'] == 'TOTAL'].copy()

    # Keeping number of persons in household between 1 and greater than 10 (GE11)
    hh = hh.loc[hh['n_person'].isin(['1','2','3','4','5','6','7','8','9','10','GE11'])].copy()

    # Removing " u" from 2011 column
    hh['2011'] = hh['2011'].str.replace(' u','')
    hh['2011'] = hh['2011'].astype(int)

    # Putting houshold size as columns.
    hh = hh.pivot(index='geo',columns='n_person',values='2011')
    hh = hh.reset_index()
    hh.columns.name = None

    # Adding population.
    hh = hh.merge(inh.rename(columns={'2011':'pop'}), on='geo')

    hh = hh.set_index('geo')
    # Aggregating households with 6 or more persons into one category.
    hh['6'] = hh['6'] + hh['7'] + hh['8'] + hh['9'] + hh['10'] + hh['GE11']
    # Computing number of single-person households.
    hh['1'] = hh['pop'] - (hh['2'] * 2 + hh['3'] * 3 + hh['4'] * 4 + hh['5'] * 5 + hh['6'] *6)
    # Computing size distribution of households per NUTS region.
    hh = hh.div(hh['pop'], axis=0)
    hh = hh.drop(columns=['pop','7','8','9','10','GE11'])
    return hh

def estimate_hh_distribution(hh, country, NUTS_1):
    """
    Estimate the household size distribution and coefficient of variation for a given country and NUTS-1 region.
    Parameters:
    hh (pd.DataFrame): DataFrame containing household size distribution ratios with columns ['1', '2', '3', '4', '5', '6'].
    country (str): The country code for which the household distribution is being estimated.
    NUTS_1 (str): The NUTS-1 region code to be used when nuts_id is missing.
    Returns:
    None: The function saves the resulting DataFrame with GRID_ID and var_coef to a parquet file.
    """
    # Renaming columns to have ratio in name.
    cols_hh = ['1','2','3','4','5','6']
    cols_ratio = [col + '_ratio' for col in cols_hh]
    hh = hh.rename(columns={col: col + '_ratio' for col in cols_hh})

    # Loading grid data.
    grid = pd.read_csv(os.path.join(path_demographics,f'grid_link_{country}.csv'),usecols=['GRID_ID','nuts_id'])
    demog = pd.read_csv(os.path.join(path_demographics,f'{country}_L4.csv'),usecols=['GRID_ID','pop','origin'])
    # Removing population with unknown origins.
    demog = demog.loc[demog['origin'] != 'OTH'].copy()
    demog = demog.drop(columns='origin')
    # Getting total population per grid cell.
    demog = demog.groupby('GRID_ID').sum().reset_index()
    # Adding nuts_id to demographic data.
    demog = demog.merge(grid, on='GRID_ID', how='left')
    # When nuts_id is missing, we use the country-level data.
    demog['nuts_id'] = demog['nuts_id'].fillna(NUTS_1)
    demog = demog.loc[demog['pop'] >= 1]
    # Adding household size distribution to demographic data, based on nuts-level data.
    demog = demog.merge(hh, left_on='nuts_id', right_index=True)
    # Rounding down population.
    demog['pop'] = np.floor(demog['pop'])
    # Initialize household size distribution.
    demog[cols_hh] = np.floor(demog[cols_ratio].multiply(demog['pop'], axis=0).values)
    # Population obtained by summing all households.
    demog['pop_left'] = demog['pop'] - (demog['1'] + 2*demog['2'] + 3*demog['3'] + 4*demog['4'] + 5*demog['5'] + 6*demog['6'] )
    # Fixing household distribution to match population.
    demog['6'] = demog['6'].mask(demog['pop_left'] >=6, demog['6'] + 1)
    demog['pop_left'] = demog['pop_left'].mask(demog['pop_left'] >=6, demog['pop_left'] -6)
    demog['5'] = demog['5'].mask(demog['pop_left'] >=5, demog['5'] + 1)
    demog['pop_left'] = demog['pop_left'].mask(demog['pop_left'] >=5, demog['pop_left'] -5)
    demog['4'] = demog['4'].mask(demog['pop_left'] >=4, demog['4'] + 1)
    demog['pop_left'] = demog['pop_left'].mask(demog['pop_left'] >=4, demog['pop_left'] -4)
    demog['3'] = demog['3'].mask(demog['pop_left'] >=3, demog['3'] + 1)
    demog['pop_left'] = demog['pop_left'].mask(demog['pop_left'] >=3, demog['pop_left'] -3)
    demog['2'] = demog['2'].mask(demog['pop_left'] >=2, demog['2'] + 1)
    demog['pop_left'] = demog['pop_left'].mask(demog['pop_left'] >=2, demog['pop_left'] -2)
    demog['1'] = demog['1'].mask(demog['pop_left'] == 1, demog['1'] + 1) 
    demog['pop_left'] = demog['pop_left'].mask(demog['pop_left'] == 1, demog['pop_left'] -1)
    demog['var_coef'] = (demog['1'] + demog['2']*2**2 + demog['3']*3**2 + demog['4']*4**2 + demog['5']*5**2 + demog['6']*6**2) / demog['pop']**2
    demog = demog[['GRID_ID','var_coef']].copy()
    if os.path.exists(os.path.join(path_output,country)) == False:
        os.makedirs(os.path.join(path_output,country))
    demog.to_parquet(os.path.join(path_output,country,f'var_coef_{country}.parquet'), engine='pyarrow', compression='snappy')
    return


def main():
    hh = load_hh_data()
    for country, NUTS_1 in [('DEU','DE'), ('FRA','FR'), ('ITA','IT'), ('ESP','ES'), ('GBR','UK'), ('IRL','IE'), ('PRT','PT'), ('NLD','NL')]:
        print(f'Estimating household size distribution for {country}')
        estimate_hh_distribution(hh, country, NUTS_1)

if __name__ == '__main__':
    main()