# This script groups spatial units into regions based on exposure, covariance matrix, and adjacency matrix.
import pandas as pd
import numpy as np
from scipy import sparse
#from scipy.stats import norm as normal_dist
import os
from sklearn.cluster import AgglomerativeClustering,ward_tree
import re
from scipy.sparse.csgraph import connected_components
import pyarrow as pa
from tqdm import tqdm

FLOAT_PRECISION = 'float32'

dir_adj_mat = '../data/processed/adjacency_matrix'
dir_demog = '../data/processed/demographics'
dir_pval = '../data/processed/pval'
dir_cov_mat = '../data/processed/covariance_matrix'
dir_perimeter = '../data/processed/perimeter'
dir_regions = '../data/processed/regions'

def cluster_analysis(pval, adj_mat, TSS_bench):
    """
    Perform hierarchical clustering analysis on the given data.
    Parameters:
    pval (pd.Series or np.ndarray): The data to be clustered.
    adj_mat (np.ndarray): The adjacency matrix representing the connectivity between data points.
    TSS_bench (float): The threshold for the total sum of squares (TSS) to determine the number of clusters.
    Returns:
    np.ndarray: An array of cluster labels for each data point.
    """
    tree = ward_tree(pval.copy(),n_clusters = None,connectivity = adj_mat,return_distance = True)

    distances = tree[4]
    df_clus = pd.DataFrame({'distances':distances,'increment_TSS':distances**2/2})
    df_clus['TSS'] = df_clus['increment_TSS'].cumsum()
    n_clus = len(df_clus.loc[df_clus['TSS']>TSS_bench])

    clustering = AgglomerativeClustering(n_clusters=n_clus+1,connectivity=adj_mat,linkage = 'ward').fit(pval.values.reshape(-1,1))

    return clustering.labels_


def av_std_expos_reg(zones, cov_mat, col_reg):
    """
    Calculate the average exposure and standard deviation of exposure for each region.
    Parameters:
    zones (pd.DataFrame): DataFrame containing zone data with columns 'pop' (population) and 'NOTEU' (exposure value).
    cov_mat (scipy.sparse.csr_matrix): Sparse covariance matrix.
    col_reg (str): Column name in `zones` representing the region identifier.
    Returns:
    pd.DataFrame: DataFrame with the original `zones` data and additional columns for average exposure ('av') and standard deviation of exposure ('std') for each region.
    """
    zones = zones.copy()
    
    # Calculate weights
    tot_pop_region = zones[['pop',col_reg]].groupby(by=col_reg).sum().reset_index().rename(columns={'pop':'denom'})
    zones['wgt'] = zones['pop'].copy()
    zones = zones.merge(tot_pop_region, on=col_reg)
    zones['wgt'] = zones['wgt']/zones['denom']
    zones['wgt'] = zones['wgt'].mask(zones['denom']==0,1)
    
    # Calculate region average exposure
    zones['wgted_expos'] = zones['wgt']*zones['NOTEU']
    reg_av_std = zones[['wgted_expos','wgt',col_reg]].groupby(by=col_reg).sum().reset_index()
    reg_av_std['wgted_expos'] = reg_av_std['wgted_expos']/reg_av_std['wgt']
    reg_av_std = reg_av_std.drop(columns='wgt')
    reg_av_std = reg_av_std.rename(columns={'wgted_expos':'av'})
    reg_av_std['std'] = np.zeros(len(reg_av_std), dtype=FLOAT_PRECISION)
    
    # Create sparse diagonal matrix of weights
    weight_diag = sparse.diags(zones['wgt'].values.astype(FLOAT_PRECISION), format='csr')
    
    # Multiply covariance matrix with weights while keeping sparse format
    weighted_cov = weight_diag @ cov_mat @ weight_diag
    
    # Calculate standard deviation for each region
    for i in zones[col_reg].unique():
        region_indices = zones.loc[zones[col_reg] == i].index
        
        # Extract submatrix for region using sparse indexing
        region_cov = weighted_cov[region_indices, :][:, region_indices]
        
        # Sum all elements in the sparse submatrix
        var_mean_ind = region_cov.sum()
        
        reg_av_std.loc[reg_av_std[col_reg] == i, 'std'] = np.float32(np.sqrt(var_mean_ind))
    
    zones = zones.merge(reg_av_std, on=col_reg, how='left')
    zones = zones.drop(columns=['wgt','denom','wgted_expos'])
    
    return zones

def dissolve_regions(seg, adj_matrix):
    """
    Dissolves regions by filtering the adjacency matrix to include only edges between nodes with the same state,
    and then finds the connected components in the filtered adjacency matrix.
    Parameters:
    seg (pd.Series): A pandas Series where each entry represents the state of a node.
    adj_matrix (scipy.sparse.csr_matrix): A sparse adjacency matrix representing the graph.
    Returns:
    np.ndarray: An array where each entry represents the component label of the corresponding node.
    """
    # Filter adjacency matrix to include only edges between nodes with the same state
    seg_array = seg.to_numpy()
    
    rows, cols = adj_matrix.nonzero()
    
    same_state_mask = sparse.csr_matrix((seg_array[rows] == seg_array[cols], (rows, cols)),
                                        shape=adj_matrix.shape)
    
    filtered_adj_matrix = adj_matrix.multiply(same_state_mask)  # Apply mask

    # Find connected components
    n_components, labels = connected_components(csgraph=filtered_adj_matrix, directed=False)

    return labels


def label_segregation(zones, city_mu, cov_mat, adj_mat):
    """
    Labels zones based on segregation and aggregates them into regions.
    Parameters:
    zones (pd.DataFrame): DataFrame containing zone data.
    city_mu (float): Mean value of the city for comparison.
    cov_mat (np.ndarray): Covariance matrix for the zones.
    adj_mat (np.ndarray): Adjacency matrix representing spatial relationships between zones.
    Returns:
    pd.DataFrame: DataFrame with updated segregation labels and new region identifiers.
    """
    
    zones = av_std_expos_reg(zones,cov_mat,'region')
    zones['z_val'] = (zones['av'] - city_mu)/zones['std'] 
    zones['seg'] = 0
    zones.loc[:,'seg'] = zones['seg'].mask(zones['z_val'] > 2, 1)
    zones.loc[:,'seg'] = zones['seg'].mask(zones['z_val'] < -2, -1)
    zones['seg'] = zones['seg'].astype('int8')

    # Spatial delineation of the regions (we aggregate spatially regions that are labelled the same way).
    zones['new_region'] = dissolve_regions(zones['seg'], adj_mat)
    
    zones = zones.drop(columns= ['av','std'])
    zones = av_std_expos_reg(zones, cov_mat, 'new_region')

    zones['z_val'] = (zones['av'] - city_mu)/zones['std'] 
    zones['seg'] = 0
    zones.loc[:,'seg'] = zones['seg'].mask(zones['z_val'] > 2, 1)
    zones.loc[:,'seg'] = zones['seg'].mask(zones['z_val'] < -2, -1)

    zones = zones.rename(columns = {'new_regions':'region_id'})
    zones = zones.drop(columns = ['region'])
    return zones

def main():
    city_statistics = pd.read_csv(os.path.join(dir_demog,'city_statistics.csv'))

    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    for country in list_countries:

        list_fuas = os.listdir(os.path.join(dir_perimeter, country))
        # Extract substrings
        pattern = r"indices_(.+?)\.csv"
        list_fuas = [re.search(pattern, s).group(1) for s in list_fuas if re.search(pattern, s)]

        for fua in tqdm(list_fuas):
            city_mu = city_statistics.loc[city_statistics['city'] == fua,'mu'].values[0]
            dir_destination = os.path.join(dir_regions,country)
            if os.path.exists(os.path.join(dir_destination,f'{fua}_reg.parquet')):
                continue
            
            try:
                adj_mat = sparse.load_npz(os.path.join(dir_adj_mat, country, f'{fua}_adj.npz'))
                pval = pd.read_parquet(os.path.join(dir_pval, country, f'pval_{fua}.parquet'))
                pop = pd.read_parquet(os.path.join(dir_demog, country, f'{fua}.parquet'), columns=['grid_ref','pop'])
                pval = pval.merge(pop, on = 'grid_ref')
            
                cov_mat = sparse.load_npz(os.path.join(dir_cov_mat, country,f'{fua}_cov_mat.npz'))
            except:
                print(f'No data found for {country}/{fua}')
                continue

            # Total Sum of Square in the benchmark scenario.
            TSS_bench = len(pval) * 0.08333/10
            pval['region'] = cluster_analysis(pval['p_val'],adj_mat,TSS_bench)
            pval = label_segregation(pval,city_mu,cov_mat,adj_mat)
            pval = pval[['grid_ref','new_region','seg']].rename(columns={'new_region':'region'}).copy()
            if not os.path.exists(dir_destination):
                os.makedirs(dir_destination)
            grid_ref_dtype = pa.uint16() if pval['grid_ref'].max() < 2**16 else pa.uint32()
            schema = pa.schema([
                ('grid_ref', grid_ref_dtype),
                ('region', pa.uint16()),
                ('seg', pa.int8())
            ])
            pval.to_parquet(os.path.join(dir_destination,f'{fua}_reg.parquet'),index=False,
                            engine="pyarrow", compression="snappy",schema=schema)

if __name__ == '__main__':
    main()