import os
import pandas as pd
from tqdm import tqdm
import geopandas as gpd

dir_synthetic = '../data/processed/synthetic_partition'
dir_pop = '../data/processed/demographics'
dir_frag = '../data/processed/urban_fragments'
dir_reg = '../data/processed/regions'
dir_control = '../data/processed/control_partition'
dir_purity = '../data/processed/purity'
DEBUG = False
CAT = 'seg'
CONTROL = False

def measure_purity_scores(cells, cols):
    """
    Calculate purity scores for given cells and columns for a set of partitions.
    Parameters:
    cells (pd.DataFrame): DataFrame containing the cell data with columns 'seg' and 'pop'.
    cols (list): List of partition columns to calculate purity scores.
    Returns:
    pd.DataFrame: DataFrame with partition_id as index and purity scores as values.
    """
    purity = cells[['seg','pop'] + cols].copy()
    purity = purity.melt(id_vars=['seg','pop'],var_name='partition_id',value_name='frag_id')
    purity = purity.groupby(by=['seg','partition_id','frag_id']).sum().reset_index()
    purity = purity.sort_values(by='pop',ascending=False).groupby(by=['frag_id','partition_id']).first().reset_index()
    purity = purity[['partition_id','pop']].groupby('partition_id').sum()/cells['pop'].sum()
    return purity.rename(columns={'pop':'purity'})

def build_control(control_path, synth):
    """
    Builds a control DataFrame by merging a synthetic DataFrame with a control partition DataFrame.
    Parameters:
    control_path (str): The file path to the control partition CSV file.
    synth (pd.DataFrame): The synthetic DataFrame to be merged with the control partition.
    Returns:
    pd.DataFrame: The merged DataFrame with updated partition columns.
    Notes:
    - The function assumes that both the synthetic DataFrame and the control partition DataFrame have a common column 'grid_ref'.
    - The 'cols_partitions' variable should be defined with the appropriate column names that need to be updated.
    - The function creates new partition columns by appending 'control_id' to the existing partition columns and then reassigns new IDs.
    """
    synth = synth.copy()
    # Load control partition.
    control = pd.read_csv(control_path)
    synth = synth.merge(control, on='grid_ref', how='inner')
    cols_partitions = 'to_determine'
    cols_partitions_new = [f'{col}_new' for col in cols_partitions]
    # Improve part below.
    synth.loc[:,cols_partitions_new] = synth.loc[:,cols_partitions].astype(str) + synth['control_id'].astype(str)
    for col in cols_partitions:
        new_ids = synth[cols_partitions_new].drop_duplicates(subset=cols_partitions_new).reset_index(drop=True)
        new_ids[f'{col}_new'] = new_ids.index.copy()
        synth = synth.merge(new_ids, on=[col,'control_id'])
        synth = synth.drop(columns=col)
        synth = synth.rename(columns={f'{col}_new':col})
        synth = synth.drop(columns=f'{col}_new')
    return synth

# This script performs overlap analysis.
def main():
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    
    if os.path.exists(os.path.join(dir_purity,'purity.csv')):
        df_purity = pd.read_csv(os.path.join(dir_purity,'purity.csv'))
    else:
        df_purity = pd.DataFrame()
        
    for country in list_countries:
        try:
            list_fuas = os.listdir(os.path.join(dir_reg, country))
            # Removing the fuas that already are in df_purity.
        except:
            print(f'No data found for {country}')
            continue
        # Extract gpkg filenames
        list_fuas = [fua.split('_reg.')[0] for fua in list_fuas if fua.endswith('.parquet')]
        if not df_purity.empty:
            list_fuas = [fua for fua in list_fuas if fua not in df_purity['fua'].values]

        for fua in tqdm(list_fuas):
            if DEBUG:
                fua = 'DE027_0'
            try:
                # Load the population data.
                pop = pd.read_parquet(os.path.join(dir_pop, country, f'{fua}.parquet'))
                # Load the demographic region data.
                reg = pd.read_parquet(os.path.join(dir_reg, country,f'{fua}_reg.parquet'))
                # Load the fragment data.
                frag = pd.read_csv(os.path.join(dir_frag, country,f'{fua}_fragments.csv'))
                # Load the synthetic partition data.
                synth = pd.read_parquet(os.path.join(dir_synthetic, country,f'{fua}_synthetic.parquet'))
            except:
                print(f'No data found for {country}/{fua}')
                continue
            
            # Merge the dataframes.
            df = pop.merge(reg, on='grid_ref')
            df = df.merge(frag, on='grid_ref')
            synth = synth.reset_index()
            # Filtering the synthetic partition to match the population data.
            synth = synth.loc[synth['grid_ref'].isin(pop['grid_ref'])].copy()
            if CONTROL:
                synth = build_control(os.path.join(dir_control, country,f'{fua}_control.csv'), synth)

            # Reindex the synthetic partition as the population data.
            synth = synth.set_index('grid_ref').reindex(df['grid_ref']).reset_index()
            
            df = df[['grid_ref','seg','pop','fragment_id']].merge(synth,on='grid_ref').copy()
            synth = synth.drop(columns='grid_ref')
            df = df.drop(columns='grid_ref')
            cols_synthetic = synth.columns.tolist()
            cols = ['fragment_id'] + cols_synthetic
            purity_scores = measure_purity_scores(df, cols)
            
            purity_score_frag = purity_scores.loc['fragment_id','purity']
            purity_scores_synth = purity_scores.drop(index='fragment_id')
            # Check the quantile of the purity score for the fragmentation compared to all synthetic partitions.
            quantile = (purity_scores_synth < purity_score_frag).mean()

            if DEBUG:
                print(f'purity for {fua}: {purity_score_frag:.4f}')
                print(f'quantile for {fua}: {quantile:.4f}')
                purity_scores = pd.Series(purity_scores,name='purity').to_csv(f'{fua}_purity.csv',index=False)
                break
            
            temp = pd.DataFrame({'fua':fua, 'purity_score':purity_score_frag,'quantile':quantile.values},index=[0])
            if df_purity.empty:
                df_purity = temp.copy()
            else:
                df_purity = pd.concat([df_purity,temp],ignore_index=True)
            
        if DEBUG:
            break
        
    # fua_names = pd.DataFrame()
    # for country in ['DEU','FRA','GBR','ITA','ESP','NLD','PRT','IRL']:
    #     fua_names = pd.read_csv('../data/processed/purity/fuanames.csv')
    
    #df_purity = df_purity.merge(fua_names,on='fua',how='outer')
    df_purity.to_csv(os.path.join(dir_purity,'purity.csv'),index=False)
    return

if __name__ == "__main__":
    main()