import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from typing import Dict, List, Tuple, Optional, Union
import json
import os
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import Rectangle
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

class BiasVisualizationUtils:
    def __init__(self, class_names: List[str], output_dir: str = './visualizations'):
        self.class_names = class_names
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Set style
        plt.style.use('default')
        sns.set_palette("husl")
    
    def plot_training_curves(self, training_history: Dict, save_name: str = 'training_curves.png'):
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Training and validation loss
        epochs = range(len(training_history['train_loss']))
        axes[0, 0].plot(epochs, training_history['train_loss'], label='Training Loss', linewidth=2)
        axes[0, 0].plot(epochs, training_history['val_loss'], label='Validation Loss', linewidth=2)
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Accuracy progression
        if 'performance_metrics' in training_history:
            accuracies = [metrics.get('accuracy', 0) for metrics in training_history['performance_metrics']]
            axes[0, 1].plot(epochs, accuracies, label='Accuracy', color='green', linewidth=2)
            axes[0, 1].set_title('Validation Accuracy')
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Accuracy')
            axes[0, 1].grid(True, alpha=0.3)
            axes[0, 1].set_ylim(0, 1)
        
        # Bias score progression
        if 'bias_scores' in training_history and training_history['bias_scores']:
            bias_epochs = np.linspace(0, len(epochs)-1, len(training_history['bias_scores']))
            axes[1, 0].plot(bias_epochs, training_history['bias_scores'], 
                           label='Bias Score', color='red', linewidth=2)
            axes[1, 0].set_title('Bias Score Over Time')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Bias Score')
            axes[1, 0].grid(True, alpha=0.3)
        
        # Class-specific F1 scores
        if 'performance_metrics' in training_history:
            for i, class_name in enumerate(self.class_names):
                f1_scores = [metrics.get(f'{class_name}_f1', 0) for metrics in training_history['performance_metrics']]
                axes[1, 1].plot(epochs, f1_scores, label=f'{class_name} F1', linewidth=2)
            
            axes[1, 1].set_title('Class-specific F1 Scores')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('F1 Score')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
            axes[1, 1].set_ylim(0, 1)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_sensitivity_heatmap(self, sensitivity_scores: Dict[int, Dict[str, float]], 
                               save_name: str = 'sensitivity_heatmap.png'):
        # Prepare data for heatmap
        if not sensitivity_scores:
            return
        
        layers = list(next(iter(sensitivity_scores.values())).keys())
        classes = [self.class_names[i] for i in sorted(sensitivity_scores.keys())]
        
        data = []
        for class_idx in sorted(sensitivity_scores.keys()):
            row = []
            for layer in layers:
                score = sensitivity_scores[class_idx].get(layer, 0.0)
                row.append(score)
            data.append(row)
        
        plt.figure(figsize=(14, 8))
        sns.heatmap(data, 
                   xticklabels=[f"Layer_{i}" for i in range(len(layers))],
                   yticklabels=classes,
                   annot=True, 
                   fmt='.3f',
                   cmap='YlOrRd',
                   cbar_kws={'label': 'Sensitivity Score'})
        
        plt.title('Layer-wise Sensitivity Scores by Class')
        plt.xlabel('Model Layers')
        plt.ylabel('Classes')
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_selectivity_comparison(self, selectivity_scores: Dict[int, Dict[str, float]],
                                  save_name: str = 'selectivity_comparison.png'):
        if not selectivity_scores:
            return
        
        # Extract data for plotting
        class_indices = sorted(selectivity_scores.keys())
        layers = list(next(iter(selectivity_scores.values())).keys())
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        x = np.arange(len(layers))
        width = 0.25
        
        for i, class_idx in enumerate(class_indices):
            if class_idx < len(self.class_names):
                scores = [selectivity_scores[class_idx].get(layer, 0) for layer in layers]
                ax.bar(x + i * width, scores, width, 
                      label=self.class_names[class_idx], alpha=0.8)
        
        ax.set_xlabel('Model Layers')
        ax.set_ylabel('Selectivity Score')
        ax.set_title('Layer-wise Selectivity Scores by Class')
        ax.set_xticks(x + width)
        ax.set_xticklabels([f'Layer_{i}' for i in range(len(layers))])
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_attention_patterns(self, attention_patterns: Dict[int, Dict[str, float]],
                              save_name: str = 'attention_patterns.png'):
        if not attention_patterns:
            return
        
        fig, axes = plt.subplots(1, len(self.class_names), figsize=(5*len(self.class_names), 4))
        if len(self.class_names) == 1:
            axes = [axes]
        
        for i, class_idx in enumerate(sorted(attention_patterns.keys())):
            if class_idx < len(self.class_names):
                class_name = self.class_names[class_idx]
                patterns = attention_patterns[class_idx]
                
                layers = list(patterns.keys())
                scores = list(patterns.values())
                
                axes[i].bar(range(len(layers)), scores, alpha=0.7)
                axes[i].set_title(f'{class_name}\nAttention Patterns')
                axes[i].set_xlabel('Transformer Layers')
                axes[i].set_ylabel('Attention Bias Score')
                axes[i].set_xticks(range(len(layers)))
                axes[i].set_xticklabels([f'L{j}' for j in range(len(layers))])
                axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_3d_sensitivity_distribution(self, spatial_bias_patterns: Dict[str, float],
                                       save_name: str = '3d_bias_patterns.png'):
        if not spatial_bias_patterns:
            return
        
        bias_types = list(spatial_bias_patterns.keys())
        bias_values = list(spatial_bias_patterns.values())
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(bias_types, bias_values, alpha=0.7, color=['red', 'blue', 'green', 'orange'])
        
        plt.title('3D Spatial Bias Patterns')
        plt.xlabel('Bias Type')
        plt.ylabel('Bias Magnitude')
        plt.xticks(rotation=45)
        plt.grid(True, alpha=0.3)
        
        # Add value labels on bars
        for bar, value in zip(bars, bias_values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def create_interactive_bias_dashboard(self, experiment_results: Dict, 
                                        save_name: str = 'bias_dashboard.html'):
        # Create subplots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Performance Comparison', 'Class Distribution', 
                          'Bias Metrics', 'Training Progress'),
            specs=[[{"type": "bar"}, {"type": "pie"}],
                   [{"type": "bar"}, {"type": "scatter"}]]
        )
        
        baseline_metrics = experiment_results.get('baseline_results', {}).get('metrics', {})
        improved_metrics = experiment_results.get('mitigation_results', {}).get('metrics', {})
        
        # Performance comparison
        if baseline_metrics and improved_metrics:
            metrics_to_compare = ['accuracy', 'macro_f1', 'macro_precision', 'macro_recall']
            baseline_vals = [baseline_metrics.get(m, 0) for m in metrics_to_compare]
            improved_vals = [improved_metrics.get(m, 0) for m in metrics_to_compare]
            
            fig.add_trace(
                go.Bar(name='Baseline', x=metrics_to_compare, y=baseline_vals),
                row=1, col=1
            )
            fig.add_trace(
                go.Bar(name='After Mitigation', x=metrics_to_compare, y=improved_vals),
                row=1, col=1
            )
        
        # Class distribution (if available)
        baseline_results = experiment_results.get('baseline_results', {})
        if 'training_history' in baseline_results:
            # Use last performance metrics for class distribution
            last_metrics = baseline_results['training_history']['performance_metrics'][-1]
            class_f1s = [last_metrics.get(f'{class_name}_f1', 0) for class_name in self.class_names]
            
            fig.add_trace(
                go.Pie(labels=self.class_names, values=class_f1s, name="Class F1 Scores"),
                row=1, col=2
            )
        
        # Bias metrics comparison
        baseline_bias = experiment_results.get('baseline_results', {}).get('bias_metrics', {})
        improved_bias = experiment_results.get('mitigation_results', {}).get('bias_metrics', {})
        
        if baseline_bias and improved_bias:
            bias_metrics = ['demographic_parity_diff', 'equalized_odds_diff']
            baseline_bias_vals = [baseline_bias.get(m, 0) for m in bias_metrics]
            improved_bias_vals = [improved_bias.get(m, 0) for m in bias_metrics]
            
            fig.add_trace(
                go.Bar(name='Baseline Bias', x=bias_metrics, y=baseline_bias_vals),
                row=2, col=1
            )
            fig.add_trace(
                go.Bar(name='After Mitigation', x=bias_metrics, y=improved_bias_vals),
                row=2, col=1
            )
        
        # Training progress
        mitigation_history = experiment_results.get('mitigation_results', {}).get('training_history', {})
        if 'train_loss' in mitigation_history:
            epochs = list(range(len(mitigation_history['train_loss'])))
            fig.add_trace(
                go.Scatter(x=epochs, y=mitigation_history['train_loss'], 
                          mode='lines', name='Training Loss'),
                row=2, col=2
            )
            fig.add_trace(
                go.Scatter(x=epochs, y=mitigation_history['val_loss'], 
                          mode='lines', name='Validation Loss'),
                row=2, col=2
            )
        
        fig.update_layout(height=800, showlegend=True, 
                         title_text="Bias Analysis Dashboard")
        
        # Save interactive plot
        fig.write_html(os.path.join(self.output_dir, save_name))
        print(f"Interactive dashboard saved to {os.path.join(self.output_dir, save_name)}")
    
    def visualize_image_with_attention(self, image: np.ndarray, attention_map: np.ndarray,
                                     save_name: str = 'attention_overlay.png'):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(image)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Attention map
        im1 = axes[1].imshow(attention_map, cmap='hot', alpha=0.8)
        axes[1].set_title('Attention Map')
        axes[1].axis('off')
        plt.colorbar(im1, ax=axes[1])
        
        # Overlay
        axes[2].imshow(image)
        axes[2].imshow(attention_map, cmap='hot', alpha=0.5)
        axes[2].set_title('Attention Overlay')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_class_performance_radar(self, metrics: Dict[str, float], 
                                   save_name: str = 'class_performance_radar.png'):
        # Extract per-class metrics
        performance_metrics = ['precision', 'recall', 'f1']
        
        fig, axes = plt.subplots(1, len(self.class_names), 
                               figsize=(5*len(self.class_names), 5),
                               subplot_kw=dict(projection='polar'))
        
        if len(self.class_names) == 1:
            axes = [axes]
        
        for i, class_name in enumerate(self.class_names):
            values = []
            labels = []
            
            for metric in performance_metrics:
                metric_key = f'{class_name}_{metric}'
                values.append(metrics.get(metric_key, 0))
                labels.append(metric.capitalize())
            
            # Close the polygon
            values += values[:1]
            
            angles = np.linspace(0, 2 * np.pi, len(performance_metrics), endpoint=False).tolist()
            angles += angles[:1]
            
            axes[i].plot(angles, values, 'o-', linewidth=2)
            axes[i].fill(angles, values, alpha=0.25)
            axes[i].set_xticks(angles[:-1])
            axes[i].set_xticklabels(labels)
            axes[i].set_ylim(0, 1)
            axes[i].set_title(f'{class_name} Performance', pad=20)
            axes[i].grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()
    
    def generate_bias_report_summary(self, experiment_results: Dict, 
                                   save_name: str = 'bias_report_summary.png'):
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        baseline_metrics = experiment_results.get('baseline_results', {}).get('metrics', {})
        improved_metrics = experiment_results.get('mitigation_results', {}).get('metrics', {})
        improvements = experiment_results.get('improvement_summary', {}).get('performance_improvements', {})
        
        # Overall performance improvement
        if improvements:
            metrics = ['accuracy', 'macro_f1', 'macro_precision', 'macro_recall']
            improvements_vals = [improvements.get(f'{m}_improvement', 0) for m in metrics]
            
            bars = ax1.bar(metrics, improvements_vals, alpha=0.7)
            ax1.set_title('Overall Performance Improvements (%)')
            ax1.set_ylabel('Improvement (%)')
            ax1.grid(True, alpha=0.3)
            
            # Color bars based on improvement (green for positive, red for negative)
            for bar, val in zip(bars, improvements_vals):
                bar.set_color('green' if val > 0 else 'red')
                ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                        f'{val:.1f}%', ha='center', va='bottom')
        
        # Class-specific F1 improvements
        if baseline_metrics and improved_metrics:
            baseline_f1s = [baseline_metrics.get(f'{cls}_f1', 0) for cls in self.class_names]
            improved_f1s = [improved_metrics.get(f'{cls}_f1', 0) for cls in self.class_names]
            
            x = np.arange(len(self.class_names))
            width = 0.35
            
            ax2.bar(x - width/2, baseline_f1s, width, label='Baseline', alpha=0.7)
            ax2.bar(x + width/2, improved_f1s, width, label='After Mitigation', alpha=0.7)
            
            ax2.set_title('Class-specific F1 Score Comparison')
            ax2.set_ylabel('F1 Score')
            ax2.set_xticks(x)
            ax2.set_xticklabels(self.class_names)
            ax2.legend()
            ax2.grid(True, alpha=0.3)
        
        # Bias metrics comparison
        baseline_bias = experiment_results.get('baseline_results', {}).get('bias_metrics', {})
        improved_bias = experiment_results.get('mitigation_results', {}).get('bias_metrics', {})
        
        if baseline_bias and improved_bias:
            bias_metrics = ['demographic_parity_diff', 'equalized_odds_diff']
            baseline_vals = [baseline_bias.get(m, 0) for m in bias_metrics]
            improved_vals = [improved_bias.get(m, 0) for m in bias_metrics]
            
            x = np.arange(len(bias_metrics))
            width = 0.35
            
            ax3.bar(x - width/2, baseline_vals, width, label='Baseline', alpha=0.7, color='red')
            ax3.bar(x + width/2, improved_vals, width, label='After Mitigation', alpha=0.7, color='blue')
            
            ax3.set_title('Bias Metrics Comparison (Lower is Better)')
            ax3.set_ylabel('Bias Score')
            ax3.set_xticks(x)
            ax3.set_xticklabels(['Demographic Parity', 'Equalized Odds'])
            ax3.legend()
            ax3.grid(True, alpha=0.3)
        
        # Training loss comparison
        baseline_history = experiment_results.get('baseline_results', {}).get('training_history', {})
        mitigation_history = experiment_results.get('mitigation_results', {}).get('training_history', {})
        
        if baseline_history and mitigation_history:
            baseline_epochs = range(len(baseline_history['train_loss']))
            mitigation_epochs = range(len(mitigation_history['train_loss']))
            
            ax4.plot(baseline_epochs, baseline_history['train_loss'], 
                    label='Baseline Training', alpha=0.7)
            ax4.plot(mitigation_epochs, mitigation_history['train_loss'], 
                    label='With Bias Mitigation', alpha=0.7)
            
            ax4.set_title('Training Loss Comparison')
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('Loss')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, save_name), dpi=300, bbox_inches='tight')
        plt.show()

def load_and_visualize_results(results_path: str, class_names: List[str], output_dir: str):
    # Load experiment results
    with open(results_path, 'r') as f:
        results = json.load(f)
    
    # Initialize visualizer
    visualizer = BiasVisualizationUtils(class_names, output_dir)
    
    # Generate all visualizations
    print("Generating bias analysis visualizations...")
    
    # Training curves
    if 'mitigation_results' in results and 'training_history' in results['mitigation_results']:
        visualizer.plot_training_curves(results['mitigation_results']['training_history'])
    
    # Sensitivity and selectivity analysis
    final_bias_analysis = results.get('mitigation_results', {}).get('final_bias_analysis', {})
    
    if 'sensitivity_patterns' in final_bias_analysis:
        # Convert string keys back to integers for sensitivity scores
        sensitivity_scores = {}
        for k, v in final_bias_analysis['sensitivity_patterns'].items():
            try:
                class_idx = int(k) if k.isdigit() else 0
                sensitivity_scores[class_idx] = {'pattern': v}
            except:
                pass
        
        if sensitivity_scores:
            visualizer.plot_sensitivity_heatmap(sensitivity_scores)
    
    # Interactive dashboard
    visualizer.create_interactive_bias_dashboard(results)
    
    # Summary report
    visualizer.generate_bias_report_summary(results)
    
    # Radar chart for class performance
    if 'mitigation_results' in results and 'metrics' in results['mitigation_results']:
        visualizer.plot_class_performance_radar(results['mitigation_results']['metrics'])
    
    print(f"All visualizations saved to {output_dir}")

if __name__ == '__main__':
    import argparse
    
    parser = argparse.ArgumentParser(description='Generate bias analysis visualizations')
    parser.add_argument('--results_path', type=str, required=True,
                       help='Path to experiment results JSON file')
    parser.add_argument('--output_dir', type=str, default='./visualizations',
                       help='Output directory for visualizations')
    parser.add_argument('--class_names', type=str, nargs='+', 
                       default=["pedestrian", "cyclist", "motorcyclist"],
                       help='Class names for the dataset')
    
    args = parser.parse_args()
    
    load_and_visualize_results(args.results_path, args.class_names, args.output_dir)