import pandas as pd
pd.set_option('mode.chained_assignment','raise')
import numpy as np
import os
import re
from scipy.stats import norm as normal_dist
from tqdm import tqdm
from build_covariance_matrix import theoretical_variance_resmix
import pyarrow as pa

FLOAT_PRECISION = 'float32'
test = False
dir_demog = '../data/processed/demographics'
dir_proximity = '../data/processed/proximity'
dir_pval = '../data/processed/pval'
dir_perimeter = '../data/processed/perimeter'

def compute_exposure(demog,proximity,city_mean):
    """
    Compute the exposure of a demographic group to a certain proximity measure.
    Parameters:
    demog (pd.DataFrame): DataFrame containing demographic data with columns 'grid_ref', 'pop', 'NOTEU', and 'var_coef'.
    proximity (pd.DataFrame): DataFrame containing proximity data with columns 'from_grid_ref', 'to_grid_ref', and 'prox'.
    city_mean (float): The mean value of the city for the demographic measure.
    Returns:
    pd.DataFrame: DataFrame containing the computed exposure with columns 'grid_ref', 'NOTEU', 'var_coef', 'z', and 'p_val'.
    """
    wgt = demog[['grid_ref','pop']].copy()

    wgt = pd.merge(wgt, proximity, left_on = 'grid_ref', right_on = 'from_grid_ref')
    
    # Numerator of the weight.
    wgt['wgt'] = wgt['prox'] * wgt['pop']
    wgt = wgt.drop(columns = ['grid_ref','pop','prox'])
    # Denominator of the weight.
    denom_wgt = wgt[['to_grid_ref','wgt']].groupby(by = 'to_grid_ref').sum().reset_index()
    denom_wgt = denom_wgt.rename(columns = {'wgt':'denom'})
    denom_wgt = denom_wgt.replace(np.inf,np.nan)
    # Computing the weight.
    wgt = wgt.merge(denom_wgt, on = 'to_grid_ref', how = 'left')
    wgt['wgt'] = wgt['wgt'].mask(wgt['denom'] > 0,wgt['wgt']/wgt['denom'])
    wgt['wgt'] = wgt['wgt'].mask(wgt['denom'] == 0,0)
        
    expos = demog.merge(wgt, left_on = 'grid_ref', right_on = 'from_grid_ref').drop(columns = 'grid_ref')
    expos['NOTEU'] = expos['NOTEU'] * expos['wgt']
    expos['var_coef'] = expos['var_coef'] * expos['wgt']**2
    expos = expos[['to_grid_ref','NOTEU','var_coef']].groupby('to_grid_ref').sum()
    expos = expos.reset_index()
    expos = expos.rename(columns = {'to_grid_ref':'grid_ref'})            
    expos['z'] = (expos['NOTEU'] - city_mean).div(np.sqrt(expos['var_coef'].mask(expos['var_coef'] == 0, 0.01)))
    expos['p_val'] = normal_dist.cdf(expos['z'], loc = 0,scale = 1)

    return expos


def main():
    city_statistics = pd.read_csv(os.path.join(dir_demog,'city_statistics.csv'))
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    
    for country in list_countries:

        var_coef = pd.read_parquet(os.path.join(dir_demog,country,f'var_coef_{country}.parquet'))

        list_fuas = os.listdir(os.path.join(dir_perimeter, country))

        # Extract substrings
        pattern = r"indices_(.+?)\.csv"
        list_fuas = [re.search(pattern, s).group(1) for s in list_fuas if re.search(pattern, s)]

        for fua in tqdm(list_fuas):
            dir_output = os.path.join(dir_pval,country)
            if not os.path.exists(dir_output):
                os.makedirs(dir_output)
            if os.path.exists(os.path.join(dir_output,f'pval_{fua}.parquet')):
                continue
            grid = pd.read_csv(os.path.join(dir_perimeter,country,f'indices_{fua}.csv'),header=0,names=['grid_ref','GRID_ID'])
            grid_ref_dtype = 'uint16' if grid['grid_ref'].max() < 2**16 else 'uint32'
            grid['grid_ref'] = grid['grid_ref'].astype(grid_ref_dtype)
            demog = pd.read_parquet(os.path.join(dir_demog,country,f'{fua}.parquet'))
            var_coef_fua = theoretical_variance_resmix(grid,var_coef,city_statistics,fua)
            var_coef_fua['grid_ref'] = var_coef_fua['grid_ref'].astype(demog['grid_ref'].dtype)
            var_coef_fua['var_coef'] = var_coef_fua['var_coef'].astype(FLOAT_PRECISION)
            proximity = pd.read_parquet(os.path.join(dir_proximity,country,f'{fua}.parquet'))

            demog = demog.merge(var_coef_fua, on = 'grid_ref')
            city_mean = city_statistics.loc[city_statistics['city'] == fua,'mu'].values[0]
            exposure = compute_exposure(demog,proximity,city_mean)
            exposure['p_val'] = exposure['p_val'].astype(FLOAT_PRECISION)
            grid_ref_dtype = pa.uint16() if exposure['grid_ref'].max() < 2**16 else pa.uint32()
            schema = pa.schema([
                ('grid_ref', grid_ref_dtype),
                ('NOTEU', pa.float32()),
                ('var_coef', pa.float32()),
                ('z', pa.float32()),
                ('p_val', pa.float32())
            ])
            exposure.to_parquet(os.path.join(dir_output,f'pval_{fua}.parquet'),index=False,
                                engine="pyarrow", compression="snappy",schema=schema)


if __name__ == '__main__':
    main()