from statsmodels.othermod.betareg import BetaModel
import statsmodels.api as sm
import numpy as np
import pandas as pd


# Model diagnostic plots and tests for beta regression

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import statsmodels.api as sm



def beta_regression_diagnostics(beta_results, X):
    """
    Function to create diagnostic plots and tests for a beta regression model.
    :param beta_results: Fitted Beta regression model
    :param X: Design matrix used in the model
    :return: None
    """
    # 1. Extract fitted values and residuals
    fittedvalues = beta_results.fittedvalues
    residuals = beta_results.resid_pearson  # Pearson residuals

    # 2. Basic residual plots
    fig = plt.figure(figsize=(12, 8))

    # a) Residuals vs Fitted plot
    plt.subplot(221)
    plt.scatter(fittedvalues, residuals, alpha=0.6)
    plt.axhline(y=0, color='r', linestyle='-')
    plt.xlabel('Fitted values')
    plt.ylabel('Pearson residuals')
    plt.title('Residuals vs Fitted')

    # Add a lowess smoother to help identify patterns
    lowess = sm.nonparametric.lowess(residuals, fittedvalues, frac=0.3)
    plt.plot(lowess[:, 0], lowess[:, 1], 'r-', lw=1)

    # b) QQ plot for residuals
    plt.subplot(222)
    stats.probplot(residuals, dist="norm", plot=plt)
    plt.title('Normal Q-Q Plot')

    # c) Scale-Location plot (sqrt of absolute residuals vs fitted)
    plt.subplot(223)
    sqrt_abs_resid = np.sqrt(np.abs(residuals))
    plt.scatter(fittedvalues, sqrt_abs_resid, alpha=0.6)
    lowess = sm.nonparametric.lowess(sqrt_abs_resid, fittedvalues, frac=0.3)
    plt.plot(lowess[:, 0], lowess[:, 1], 'r-', lw=1)
    plt.xlabel('Fitted values')
    plt.ylabel('√|Pearson residuals|')
    plt.title('Scale-Location Plot')

    # d) Leverage plot
    plt.subplot(224)
    influence = beta_results.get_influence()
    leverage = influence.hat_matrix_diag
    plt.scatter(leverage, residuals, alpha=0.6)
    plt.xlabel('Leverage')
    plt.ylabel('Pearson residuals')
    plt.title('Residuals vs Leverage')

    plt.tight_layout()
    plt.show()

    # 3. Additional diagnostic statistics
    # Cook's distance to identify influential observations
    cooksd = influence.cooks_distance[0]
    plt.figure(figsize=(10, 6))
    plt.stem(np.arange(len(cooksd)), cooksd)
    plt.xlabel('Observation')
    plt.ylabel("Cook's distance")
    plt.title("Cook's Distance for Influential Observations")
    plt.axhline(y=4/len(X), color='r', linestyle='--')  # Threshold line
    plt.show()

    # 4. Check for patterns in residuals by predictors
    # For continuous predictors
    continuous_predictors = ['share_migrants', 'entropy', 'log_pop', 'center_of_mass_dist_norm']
    plt.figure(figsize=(15, 10))

    for i, var in enumerate(continuous_predictors):
        plt.subplot(2, 2, i+1)
        plt.scatter(X[var], residuals, alpha=0.6)
        plt.axhline(y=0, color='r', linestyle='-')
        lowess = sm.nonparametric.lowess(residuals, X[var], frac=0.3)
        plt.plot(lowess[:, 0], lowess[:, 1], 'r-', lw=1)
        plt.xlabel(var)
        plt.ylabel('Pearson residuals')
        plt.title(f'Residuals vs {var}')

    plt.tight_layout()
    plt.show()

    # 5. Test for heteroskedasticity
    # Looking at squared residuals against fitted values
    resid_squared = residuals**2
    plt.figure(figsize=(8, 6))
    plt.scatter(fittedvalues, resid_squared, alpha=0.6)
    plt.xlabel('Fitted values')
    plt.ylabel('Squared Pearson residuals')
    plt.title('Testing for heteroskedasticity')
    lowess = sm.nonparametric.lowess(resid_squared, fittedvalues, frac=0.3)
    plt.plot(lowess[:, 0], lowess[:, 1], 'r-', lw=1)
    plt.show()

    # 6. Additional statistical tests - if needed
    # Information criteria
    print(f"AIC: {beta_results.aic}")
    print(f"BIC: {beta_results.bic}")
    return fig

def compute_ferrari_cribari_neto_r2(beta_results):
    """
    Compute the Ferrari & Cribari-Neto pseudo-R² for beta regression.
    
    This is defined as the square of the sample correlation coefficient between
    the linear predictor and the link-transformed response.
    
    Parameters:
    -----------
    beta_results : BetaResults
        The fitted beta regression model results.
        
    Returns:
    --------
    float
        The Ferrari & Cribari-Neto pseudo-R² value.
    """
    # Calculate the linear predictor (Xβ) manually
    X = beta_results.model.exog
    params = beta_results.params[:-1]
    linear_predictor = np.dot(X, params)
    
    # Get the original response values
    y = beta_results.model.endog
    
    # Apply the link function to the response
    # For beta regression with logit link (default), this is log(y/(1-y))
    link_transformed_y = np.log(y / (1 - y))
    
    # Compute the correlation coefficient
    corr_coef = np.corrcoef(linear_predictor, link_transformed_y)[0, 1]
    
    # Square it to get the R²
    r2 = corr_coef**2
    
    return r2

def main():
    df = pd.read_csv('../data/processed/city_indicators/city_indicators.csv')
    df = df.loc[df['valid'] == True]
    df['log_pop'] = np.log10(df['total_pop'])
    # df = df.loc[~df['fuaname'].isin(['Flörsheim am Main', 'Mainz', 'Zwickau', 'Wiesbaden', 'Bayreuth',
    #        'Geslenkirchen', 'Braunschweig', 'Mulhouse', 'Pau', 'Ajaccio',
    #        'Montpellier', 'Genoa', 'Busto Arsizio', 'Piacenza', 'Bolzano',
    #        'Jaen', 'Talavera_de_la_Reina', 'San Fernando', 'Girona', 'Mataro',
    #        'Linea_de_la_Concepcion', 'Elda', 'Bilbao', 'Arrecife',
    #        'Royal_Tunbridge_Wells', 'Guildford', 'Eastbourne', 'Lisbon',
    #        'Limerick', 'Cork', 'Hengelo'])]

    # First, store your original numerical variables
    y = df['quantile'].copy()
    y = (y * (len(y) - 1) + 0.5) / len(y)
    include_numerical_vars = True
    numerical_vars = ['dissimilarity','center_of_mass_dist_norm']

    if include_numerical_vars:
    # Create a copy of your dataframe with numerical variables
        X = df[numerical_vars].copy()

        # Center each numerical variable by subtracting its mean
        for var in numerical_vars:
            X[var] = X[var] - X[var].mean()
            # Optionally print to confirm centering worked
    

    # Now add the fixed effects (country dummies)
    group_dummies = pd.get_dummies(df['country'], prefix='country', dtype=int)

    include_intercept = False
    if include_intercept:
        # If using intercept model, drop one category
        group_dummies = group_dummies.drop('country_UK', axis=1)
        group_dummies = sm.add_constant(group_dummies)
    
    if not include_numerical_vars:
        # If not using numerical variables, just use the dummies
        X = group_dummies.copy()
    else:
        X = pd.concat([X, group_dummies], axis=1)

    # Create and fit the Beta regression model
    beta_model = BetaModel(y, X)
    beta_results = beta_model.fit()
    
    r2 = compute_ferrari_cribari_neto_r2(beta_results)
    print(f"Ferrari & Cribari-Neto R²: {r2:.4f}")

    # Display results
    print(beta_results.summary())

    latex_table = beta_results.summary().as_latex()

    with open('../report/regression/beta_regression_table.tex', 'w') as f:
        f.write(latex_table)
        
if __name__ == '__main__':
    main()