import torch
import torch.nn as nn
import numpy as np
import argparse
import yaml
import json
import os
from typing import Dict, List, Tuple
import logging
from tqdm import tqdm

from dataset_handler import NuScenesDataset, BiasAwareDataLoader
from evaluation_metrics import ModelEvaluator
from activation_visualizer import BehaviorMetricsAnalyzer
from vit_analyzer import VisionTransformerAnalyzer
from model_3d_analyzer import Model3DAnalyzer
from training_script import load_pretrained_model
from visualization_utils import BiasVisualizationUtils

class ModelEvaluationPipeline:
    def __init__(self, model_path: str, config_path: str, device: torch.device):
        self.device = device
        self.config = self._load_config(config_path)
        self.class_names = self.config['data']['class_names']
        
        # Load model
        self.model = self._load_model(model_path)
        self.model.eval()
        
        # Initialize evaluator
        self.evaluator = ModelEvaluator(self.class_names)
        
        # Initialize analyzer based on model type
        model_type = self.config['model']['type'].lower()
        if 'vit' in model_type:
            self.analyzer = VisionTransformerAnalyzer(self.model, self.class_names)
        elif any(name in model_type for name in ['centerpoint', 'fs3d']):
            self.analyzer = Model3DAnalyzer(self.model, self.class_names, model_type)
        else:
            self.analyzer = BehaviorMetricsAnalyzer(self.model, self.class_names)
        
        # Setup logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def _load_config(self, config_path: str) -> Dict:
        with open(config_path, 'r') as f:
            return yaml.safe_load(f)
    
    def _load_model(self, model_path: str) -> nn.Module:
        checkpoint = torch.load(model_path, map_location=self.device)
        
        model_type = self.config['model']['type']
        num_classes = len(self.class_names)
        
        model = load_pretrained_model(model_type, num_classes, self.device)
        
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        return model
    
    def evaluate_on_dataset(self, test_dataset: NuScenesDataset) -> Dict:
        self.logger.info(f"Evaluating model on {len(test_dataset)} samples")
        
        # Create data loader
        test_loader = BiasAwareDataLoader(
            test_dataset,
            batch_size=self.config['training'].get('batch_size', 32),
            use_weighted_sampling=False
        ).get_dataloader()
        
        # Collect predictions and targets
        all_predictions = []
        all_targets = []
        all_outputs = []
        
        with torch.no_grad():
            for images, labels, metadata in tqdm(test_loader, desc="Evaluating"):
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.model(images)
                predictions = torch.argmax(outputs, dim=1)
                
                all_predictions.extend(predictions.cpu().numpy())
                all_targets.extend(labels.cpu().numpy())
                all_outputs.append(outputs.cpu())
        
        # Calculate performance metrics
        performance_metrics = self.evaluator.calculate_classification_metrics(
            all_targets, all_predictions
        )
        
        # Calculate bias metrics
        bias_metrics = self.evaluator.calculate_bias_metrics(
            all_targets, all_predictions
        )
        
        # Generate confusion matrix plot
        self.evaluator.plot_confusion_matrix(
            all_targets, all_predictions,
            save_path='confusion_matrix_test.png'
        )
        
        # Generate class performance plot
        self.evaluator.plot_class_performance(
            performance_metrics,
            save_path='class_performance_test.png'
        )
        
        return {
            'performance_metrics': performance_metrics,
            'bias_metrics': bias_metrics,
            'predictions': all_predictions,
            'targets': all_targets
        }
    
    def analyze_model_bias(self, test_dataset: NuScenesDataset, num_samples: int = 200) -> Dict:
        self.logger.info("Analyzing model bias patterns")
        
        test_loader = BiasAwareDataLoader(
            test_dataset,
            batch_size=16,  # Smaller batch size for analysis
            use_weighted_sampling=False
        ).get_dataloader()
        
        model_type = self.config['model']['type'].lower()
        
        if 'vit' in model_type:
            return self._analyze_vit_bias(test_loader, num_samples)
        elif any(name in model_type for name in ['centerpoint', 'fs3d']):
            return self._analyze_3d_bias(test_loader, num_samples)
        else:
            return self._analyze_cnn_bias(test_loader, num_samples)
    
    def _analyze_vit_bias(self, dataloader, num_samples: int) -> Dict:
        # Analyze attention patterns
        attention_patterns = self.analyzer.analyze_class_attention_patterns(
            dataloader, num_samples
        )
        
        # Generate attention similarity matrix
        similarity_matrix = self.analyzer.compute_cross_class_attention_similarity(dataloader)
        
        # Visualize attention bias
        self.analyzer.visualize_attention_bias(
            attention_patterns,
            save_path='attention_bias_test.png'
        )
        
        return {
            'attention_patterns': attention_patterns,
            'attention_similarity': similarity_matrix.tolist(),
            'bias_type': 'attention_bias'
        }
    
    def _analyze_3d_bias(self, dataloader, num_samples: int) -> Dict:
        # Analyze detection confidence bias
        confidence_scores = self.analyzer.analyze_detection_confidence_bias(
            dataloader, num_batches=min(50, num_samples // 16)
        )
        
        # Analyze spatial bias patterns for a sample
        sample_bias = {}
        sample_count = 0
        
        for batch_data in dataloader:
            if sample_count >= 5:  # Analyze a few samples
                break
            
            if isinstance(batch_data, (list, tuple)):
                point_clouds, labels = batch_data[:2]
            else:
                point_clouds = batch_data['points']
                labels = batch_data['labels']
            
            for i, label in enumerate(labels):
                if sample_count >= 5:
                    break
                
                # Calculate sensitivity scores
                sensitivity_scores = self.analyzer.calculate_3d_sensitivity_score(
                    point_clouds[i:i+1], label.item()
                )
                
                if 'point_features' in sensitivity_scores:
                    spatial_patterns = self.analyzer.calculate_spatial_bias_patterns(
                        point_clouds[i:i+1], sensitivity_scores['point_features']
                    )
                    sample_bias[f'sample_{sample_count}'] = spatial_patterns
                
                sample_count += 1
        
        return {
            'confidence_bias': confidence_scores,
            'spatial_bias_patterns': sample_bias,
            'bias_type': '3d_spatial_bias'
        }
    
    def _analyze_cnn_bias(self, dataloader, num_samples: int) -> Dict:
        # Collect samples for each class
        class_samples = {i: [] for i in range(len(self.class_names))}
        sample_count = 0
        
        for images, labels, _ in dataloader:
            if sample_count >= num_samples:
                break
            
            images = images.to(self.device)
            for i, label in enumerate(labels):
                class_idx = label.item()
                if len(class_samples[class_idx]) < num_samples // len(self.class_names):
                    class_samples[class_idx].append(images[i:i+1])
                    sample_count += 1
        
        # Calculate sensitivity and selectivity scores
        sensitivity_scores = {}
        selectivity_scores = {}
        
        for class_idx, samples in class_samples.items():
            if not samples:
                continue
            
            # Calculate sensitivity for a few samples
            class_sensitivity = {}
            for sample in samples[:3]:  # Limit to avoid memory issues
                sens_scores = self.analyzer.calculate_sensitivity_score(sample, class_idx)
                for module, score in sens_scores.items():
                    module_name = str(module)
                    if module_name not in class_sensitivity:
                        class_sensitivity[module_name] = []
                    class_sensitivity[module_name].append(score.mean().item())
            
            # Average sensitivity scores
            sensitivity_scores[class_idx] = {}
            for module_name, scores in class_sensitivity.items():
                sensitivity_scores[class_idx][module_name] = np.mean(scores)
        
        # Calculate selectivity scores
        for class_idx, samples in class_samples.items():
            if len(samples) >= 2:
                input_tensors = {class_idx: torch.cat(samples[:2])}
                selectivity_result = self.analyzer.calculate_selectivity_score(input_tensors)
                selectivity_scores.update(selectivity_result)
        
        # Visualize sensitivity patterns
        if sensitivity_scores:
            self.analyzer.visualize_sensitivity(
                {next(iter(sensitivity_scores.keys())): 
                 torch.tensor(list(next(iter(sensitivity_scores.values())).values()))}
            )
        
        return {
            'sensitivity_scores': sensitivity_scores,
            'selectivity_scores': selectivity_scores,
            'bias_type': 'cnn_sensitivity_bias'
        }
    
    def generate_comprehensive_report(self, test_results: Dict, bias_analysis: Dict, 
                                    output_dir: str) -> Dict:
        os.makedirs(output_dir, exist_ok=True)
        
        # Compile comprehensive report
        report = {
            'model_info': {
                'type': self.config['model']['type'],
                'num_classes': len(self.class_names),
                'class_names': self.class_names
            },
            'performance_evaluation': test_results,
            'bias_analysis': bias_analysis,
            'summary': self._generate_summary(test_results, bias_analysis)
        }
        
        # Save detailed report
        with open(os.path.join(output_dir, 'evaluation_report.json'), 'w') as f:
            json.dump(report, f, indent=2, default=str)
        
        # Generate visualizations
        visualizer = BiasVisualizationUtils(self.class_names, output_dir)
        
        # Create summary visualizations
        self._create_evaluation_visualizations(report, visualizer, output_dir)
        
        self.logger.info(f"Comprehensive evaluation report saved to {output_dir}")
        
        return report
    
    def _generate_summary(self, test_results: Dict, bias_analysis: Dict) -> Dict:
        metrics = test_results['performance_metrics']
        bias_metrics = test_results['bias_metrics']
        
        # Calculate overall performance score
        overall_score = (
            metrics['accuracy'] * 0.3 +
            metrics['macro_f1'] * 0.3 +
            metrics['macro_precision'] * 0.2 +
            metrics['macro_recall'] * 0.2
        )
        
        # Calculate bias score (lower is better)
        bias_score = (
            bias_metrics.get('demographic_parity_diff', 0) * 0.5 +
            bias_metrics.get('equalized_odds_diff', 0) * 0.5
        )
        
        # Identify best and worst performing classes
        class_f1s = {cls: metrics.get(f'{cls}_f1', 0) for cls in self.class_names}
        best_class = max(class_f1s, key=class_f1s.get)
        worst_class = min(class_f1s, key=class_f1s.get)
        
        return {
            'overall_performance_score': overall_score,
            'bias_score': bias_score,
            'best_performing_class': best_class,
            'worst_performing_class': worst_class,
            'class_performance_gap': class_f1s[best_class] - class_f1s[worst_class],
            'bias_analysis_type': bias_analysis.get('bias_type', 'unknown'),
            'recommendations': self._generate_recommendations(metrics, bias_metrics, bias_analysis)
        }
    
    def _generate_recommendations(self, metrics: Dict, bias_metrics: Dict, 
                                bias_analysis: Dict) -> List[str]:
        recommendations = []
        
        # Performance-based recommendations
        if metrics['accuracy'] < 0.8:
            recommendations.append("Consider increasing model capacity or training time")
        
        if metrics['macro_f1'] < 0.7:
            recommendations.append("Focus on improving recall for underperforming classes")
        
        # Bias-based recommendations
        if bias_metrics.get('demographic_parity_diff', 0) > 0.1:
            recommendations.append("Apply stronger class balancing techniques")
        
        if bias_metrics.get('equalized_odds_diff', 0) > 0.1:
            recommendations.append("Consider fairness-aware training objectives")
        
        # Model-specific recommendations
        bias_type = bias_analysis.get('bias_type', '')
        if 'attention' in bias_type:
            recommendations.append("Use attention-guided data augmentation")
        elif 'spatial' in bias_type:
            recommendations.append("Apply spatial data augmentation techniques")
        elif 'sensitivity' in bias_type:
            recommendations.append("Implement cost-sensitive learning")
        
        return recommendations
    
    def _create_evaluation_visualizations(self, report: Dict, visualizer: BiasVisualizationUtils,
                                        output_dir: str):
        metrics = report['performance_evaluation']['performance_metrics']
        bias_analysis = report['bias_analysis']
        
        # Class performance radar chart
        visualizer.plot_class_performance_radar(metrics, 'class_performance_radar.png')
        
        # Bias-specific visualizations
        bias_type = bias_analysis.get('bias_type', '')
        
        if bias_type == 'attention_bias':
            attention_patterns = bias_analysis.get('attention_patterns', {})
            if attention_patterns:
                visualizer.plot_attention_patterns(attention_patterns, 'attention_patterns.png')
        
        elif bias_type == 'cnn_sensitivity_bias':
            sensitivity_scores = bias_analysis.get('sensitivity_scores', {})
            selectivity_scores = bias_analysis.get('selectivity_scores', {})
            
            if sensitivity_scores:
                visualizer.plot_sensitivity_heatmap(sensitivity_scores, 'sensitivity_heatmap.png')
            if selectivity_scores:
                visualizer.plot_selectivity_comparison(selectivity_scores, 'selectivity_comparison.png')
        
        elif bias_type == '3d_spatial_bias':
            spatial_patterns = bias_analysis.get('spatial_bias_patterns', {})
            if spatial_patterns:
                # Average spatial patterns across samples
                avg_patterns = {}
                for sample_data in spatial_patterns.values():
                    for pattern_type, value in sample_data.items():
                        if pattern_type not in avg_patterns:
                            avg_patterns[pattern_type] = []
                        avg_patterns[pattern_type].append(value)
                
                avg_patterns = {k: np.mean(v) for k, v in avg_patterns.items()}
                visualizer.plot_3d_sensitivity_distribution(avg_patterns, '3d_bias_patterns.png')

def main():
    parser = argparse.ArgumentParser(description='Evaluate trained model and analyze bias')
    parser.add_argument('--model_path', type=str, required=True,
                       help='Path to trained model checkpoint')
    parser.add_argument('--config_path', type=str, required=True,
                       help='Path to experiment configuration file')
    parser.add_argument('--test_data_path', type=str, required=True,
                       help='Path to test dataset annotations')
    parser.add_argument('--data_root', type=str, required=True,
                       help='Root directory of the dataset')
    parser.add_argument('--output_dir', type=str, default='./evaluation_results',
                       help='Output directory for evaluation results')
    parser.add_argument('--num_bias_samples', type=int, default=200,
                       help='Number of samples for bias analysis')
    
    args = parser.parse_args()
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize evaluation pipeline
    evaluator = ModelEvaluationPipeline(args.model_path, args.config_path, device)
    
    # Load test dataset
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    test_dataset = NuScenesDataset(
        data_path=args.data_root,
        annotations_path=args.test_data_path,
        class_names=config['data']['class_names'],
        augment_underrepresented=False
    )
    
    print(f"Loaded test dataset with {len(test_dataset)} samples")
    
    # Run evaluation
    print("Running performance evaluation...")
    test_results = evaluator.evaluate_on_dataset(test_dataset)
    
    print("Running bias analysis...")
    bias_analysis = evaluator.analyze_model_bias(test_dataset, args.num_bias_samples)
    
    # Generate comprehensive report
    print("Generating comprehensive report...")
    report = evaluator.generate_comprehensive_report(
        test_results, bias_analysis, args.output_dir
    )
    
    # Print summary
    summary = report['summary']
    print("\n" + "="*50)
    print("EVALUATION SUMMARY")
    print("="*50)
    print(f"Overall Performance Score: {summary['overall_performance_score']:.3f}")
    print(f"Bias Score: {summary['bias_score']:.3f}")
    print(f"Best Performing Class: {summary['best_performing_class']}")
    print(f"Worst Performing Class: {summary['worst_performing_class']}")
    print(f"Class Performance Gap: {summary['class_performance_gap']:.3f}")
    print(f"Bias Analysis Type: {summary['bias_analysis_type']}")
    
    print("\nRecommendations:")
    for i, rec in enumerate(summary['recommendations'], 1):
        print(f"{i}. {rec}")
    
    print(f"\nDetailed results saved to: {args.output_dir}")

if __name__ == '__main__':
    main()