import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import multipletests


plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.family'] = 'Liberation Sans'
plt.rcParams['font.size'] = 7
plt.rcParams['axes.linewidth'] = 0.5
plt.rcParams['grid.linewidth'] = 0.5
plt.rcParams['svg.fonttype'] = 'none'



def test_individual_quantile_bins(quantiles, n_bins=10, alpha=0.05):
    """
    Perform statistical tests on each bin of the quantile distribution to 
    determine which specific bins deviate from uniformity
    
    Parameters:
    -----------
    quantiles : array-like
        List or array of quantile values (between 0 and 1)
    n_bins : int
        Number of bins to divide the [0,1] interval into (default=10 for deciles,
        use 20 for 0.05 quantiles, etc.)
    alpha : float
        Significance level for tests (default=0.05)
        
    Returns:
    --------
    DataFrame with test results for each bin
    """
    n = len(quantiles)
    expected_per_bin = n / n_bins
    
    # Create bins and count observations
    bin_edges = np.linspace(0, 1, n_bins + 1)
    observed_counts = np.histogram(quantiles, bins=bin_edges)[0]
    
    # Create results dataframe
    results = pd.DataFrame({
        'Bin': [bin_edges[i] for i in range(n_bins)],
        'Bin_name': [f'{bin_edges[i]:.2g}-{bin_edges[i+1]:.2g}' for i in range(n_bins)],
        'Observed': observed_counts,
        'Expected': expected_per_bin
    })
    
    # Method 1: Individual Binomial Tests
    results['Binom_p_value'] = [
        stats.binomtest(count, n=n, p=1/n_bins, alternative='two-sided').pvalue
        for count in observed_counts
    ]
    
    # Method 2: Chi-Square Components (Standardized Residuals)
    results['Chi2_contribution'] = (observed_counts - expected_per_bin)**2 / expected_per_bin
    results['Std_residual'] = (observed_counts - expected_per_bin) / np.sqrt(expected_per_bin)
    
    # Apply multiple testing correction (Benjamini-Yekutieli method for dependent tests)
    results['BY_adjusted_p'] = multipletests(results['Binom_p_value'], method='fdr_by')[1]
    
    # Add significance indicators
    results['Is_significant'] = results['BY_adjusted_p'] < alpha
    results['Direction'] = np.where(
        results['Is_significant'],
        np.where(results['Observed'] > results['Expected'], 'Higher', 'Lower'),
        'Not significant'
    )
    
    return results


def visualize_quantile_results_merged(quantiles, results, bin_type='quantile', alpha=0.05):
    """
    Create single plot visualization showing counts with significance thresholds
    
    Parameters:
    -----------
    quantiles : array-like
        List or array of quantile values (between 0 and 1)
    results : DataFrame
        Output from test_individual_quantile_bins function
    bin_type : str
        Type of bins to display in titles ('quantile', 'decile', 'percentile', etc.)
    alpha : float
        Original significance level before correction (default=0.05)
    """
    bin_width = results['Bin'][1] - results['Bin'][0]
    width_mm = 80
    height_mm = 45  # Reduced height since we only have one plot
    width_in = width_mm / 25.4
    height_in = height_mm / 25.4
    
    fig, ax = plt.subplots(1, 1, figsize=(width_in, height_in), dpi=300)
    
    # Calculate significance thresholds in count units
    # For standardized residual = ±1.96, we solve: (observed - expected) / sqrt(expected) = ±1.96
    # So: observed = expected ± 1.96 * sqrt(expected)
    expected_count = results['Expected'].mean()
    unadjusted_threshold = stats.norm.ppf(1 - alpha/2)  # 1.96 for alpha=0.05
    upper_threshold = expected_count + unadjusted_threshold * np.sqrt(expected_count)
    lower_threshold = expected_count - unadjusted_threshold * np.sqrt(expected_count)
    
    # Define color scheme for 5 categories
    colors = []
    
    for i, bin_start in enumerate(results['Bin']):
        bin_end = bin_start + bin_width
        
        # Determine color category based on bin range
        if bin_end <= 0.06:  # Below 0.05
            color = '#ffffcc'  # Dark red
        elif bin_start >= 0.04 and bin_end <= 0.11:  # 0.05-0.1
            color = '#a1dab4'  # Orange-red
        elif bin_start >= 0.09 and bin_end <= 0.91:  # 0.1-0.9 (middle)
            color = '#41b6c4'  # Light yellow
        elif bin_start >= 0.89 and bin_end <= 0.96:  # 0.9-0.95
            color = '#2c7fb8'  # Light blue
        else:  # Above 0.95
            color = '#253494'  # Dark blue
        
        colors.append(color)
    
    # Create main bars with black edges
    bars = ax.bar(results['Bin'] + bin_width/2, results['Observed'], 
                  color=colors, width=bin_width, edgecolor='black', linewidth=0.5)
    
    # Create overlay bars for significant bins with transparent fill and red dashed stroke
    # We'll track if we've added the first significant bar for legend purposes
    # significant_added = False
    # for i, is_significant in enumerate(results['Is_significant']):
    #     if is_significant:
    #         if not significant_added:
    #             # Add label only for the first significant bar
    #             bar = ax.bar(results['Bin'].iloc[i] + bin_width/2, results['Observed'].iloc[i], 
    #                    color='none', width=bin_width, edgecolor='black', linewidth=1.5, 
    #                    linestyle='--', zorder=10, label='Significant (BY-adjusted p < 0.05)')
    #             significant_added = True
    #         else:
    #             # No label for subsequent significant bars
    #             bar = ax.bar(results['Bin'].iloc[i] + bin_width/2, results['Observed'].iloc[i], 
    #                    color='none', width=bin_width, edgecolor='black', linewidth=1.5, 
    #                    linestyle='--', zorder=10)
                
    #         for patch in bar:
    #             patch.set_clip_on(False)
    
    # Plot expected counts as dashed line
    ax.axhline(y=expected_count, color='black', linestyle='--', alpha=0.7, linewidth=0.5,
               label=f'Expected bin height $\\mu = {expected_count:.0f}$',zorder=5)
    
    # Add unadjusted significance thresholds
    ax.axhline(y=upper_threshold, color='#6C5623', linestyle=':', alpha=0.7, linewidth=0.5,
               label=r'Unadjusted 95% confidence bounds ($\mu \pm 2\sigma$)',zorder=5)
    ax.axhline(y=lower_threshold, color='#6C5623', linestyle=':', alpha=0.7, linewidth=0.5,zorder=5)
    
    # Use automatic legend instead of manually creating legend elements
    ax.set_xlim(0, 1)
    # Set ticks at every 0.05 interval
    tick_positions = np.arange(0, 1.05, 0.1)
    ax.set_xticks(tick_positions)
    tick_labels = [f'{x:.15g}' for x in tick_positions]  # uses general format
    ax.set_xticklabels(tick_labels)
    ax.set_ylabel('Count')
    ax.set_xlabel('Quantile')
    ax.legend(loc='upper right', bbox_to_anchor=(1, 1))
    
    plt.tight_layout()
    return fig


def main():
    df = pd.read_csv('../data/processed/city_indicators/city_indicators.csv')
    df = df.loc[df['valid'] == True]
    results_05q = test_individual_quantile_bins(df['quantile'], n_bins=20)
    fig_05q = visualize_quantile_results_merged(df['quantile'], results_05q, bin_type='0.05-quantile')
    #fig_05q.savefig('../report/figures/quantile_dist_merged.png')
    fig_05q.savefig('../report/figures/quantile_dist_merged.svg', format='svg')
    
if __name__ == '__main__':
    main()