# This script builds the spatial contiguity matrix.
import geopandas as gpd
import os
from scipy import sparse
from tqdm import tqdm

dir_perimeter = '../data/processed/perimeter/'
dir_adjacency = '../data/processed/adjacency_matrix/'

CELL_SIZE = 100

def build_adj_mat(country, fua):
    """
    Builds an adjacency matrix for a given Functional Urban Area (FUA) within a specified country.
    Parameters:
    country (str): The name of the country.
    fua (str): The name of the Functional Urban Area (FUA).
    Returns:
    bool: True if the adjacency matrix is successfully built and saved, False if there are too many contiguities.
    Notes:
    - The function reads a GeoPackage file containing zone geometries for the specified FUA.
    - It creates a buffered version of these zones and performs a spatial join to determine contiguities.
    - The resulting adjacency matrix is saved as a sparse matrix in .npz format.
    - If the number of contiguities exceeds 20,000, the function prints a warning and returns False.
    """
    path_adjacency = os.path.join(dir_adjacency, country, fua + '_adj.npz')
    if os.path.exists(path_adjacency):
        return True
    
    zones = gpd.read_file(os.path.join(dir_perimeter, country, fua + '.gpkg'))
    zones['grid_ref'] = zones.index.copy()
    zones_buffered = zones.copy()
    zones_buffered.geometry = zones_buffered.buffer(CELL_SIZE + 5)
    adj_mat = zones.sjoin(zones_buffered.rename(columns={'grid_ref':'grid_ref_2'}), how='inner', predicate='intersects')
    adj_mat = adj_mat[['grid_ref','grid_ref_2']].copy()
    row = adj_mat['grid_ref'].values
    col = adj_mat['grid_ref_2'].values
    data = [True] * len(row)
    adj_mat = sparse.csr_matrix((data, (row, col)), shape=(len(zones), len(zones)), dtype=bool)

    sparse.save_npz(path_adjacency, adj_mat)
    return True

def main():
    list_countries = ['DEU','FRA','ESP','ITA','NLD','IRL','GBR','PRT']
    for country in list_countries:
        print('Processing country:', country)
        if not os.path.exists(os.path.join(dir_adjacency, country)):
            os.makedirs(os.path.join(dir_adjacency, country))
        # list all fuas in the perimeter folder
        list_fuas = os.listdir(os.path.join(dir_perimeter, country))
        list_fuas = [fua.split('.')[0] for fua in list_fuas if fua.endswith('.gpkg')]
        list_done = os.listdir(os.path.join(dir_adjacency, country))
        list_done = [fua.split('_a')[0] for fua in list_done if fua.endswith('.npz')]
        list_fuas = [fua for fua in list_fuas if fua not in list_done]
        for fua in tqdm(list_fuas):
            build_adj_mat(country, fua)
    
if __name__ == '__main__':
    main()