import torch
import torch.nn as nn
import numpy as np
import argparse
import yaml
import os
import json
from datetime import datetime
from typing import Dict, List, Tuple
import logging

from dataset_handler import NuScenesDataset, create_balanced_split
from training_script import BiasAwareTrainer, load_pretrained_model
from evaluation_metrics import ModelEvaluator
from activation_visualizer import BehaviorMetricsAnalyzer
from vit_analyzer import VisionTransformerAnalyzer
from model_3d_analyzer import Model3DAnalyzer

def setup_logging(log_dir: str):
    os.makedirs(log_dir, exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir, 'experiment.log')),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def load_config(config_path: str) -> Dict:
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def prepare_datasets(config: Dict) -> Tuple[NuScenesDataset, NuScenesDataset]:
    class_names = config['data']['class_names']
    data_path = config['data']['data_path']
    annotations_path = config['data']['annotations_path']
    
    # Create dataset with augmentation for underrepresented classes
    dataset = NuScenesDataset(
        data_path=data_path,
        annotations_path=annotations_path,
        class_names=class_names,
        augment_underrepresented=config['data'].get('augment_underrepresented', True)
    )
    
    # Create balanced train/validation split
    train_annotations, val_annotations = create_balanced_split(
        dataset.annotations, 
        train_ratio=config['data'].get('train_ratio', 0.8)
    )
    
    # Create separate datasets for train and validation
    train_dataset = NuScenesDataset(
        data_path=data_path,
        annotations_path=None,
        class_names=class_names,
        augment_underrepresented=True
    )
    train_dataset.annotations = train_annotations
    train_dataset.class_distribution = train_dataset._analyze_class_distribution()
    train_dataset.class_weights = train_dataset._calculate_class_weights()
    
    val_dataset = NuScenesDataset(
        data_path=data_path,
        annotations_path=None,
        class_names=class_names,
        augment_underrepresented=False
    )
    val_dataset.annotations = val_annotations
    val_dataset.class_distribution = val_dataset._analyze_class_distribution()
    val_dataset.class_weights = val_dataset._calculate_class_weights()
    
    return train_dataset, val_dataset

def run_baseline_experiment(config: Dict, train_dataset: NuScenesDataset, 
                          val_dataset: NuScenesDataset, logger: logging.Logger) -> Dict:
    logger.info("Starting baseline experiment...")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_type = config['model']['type']
    num_classes = len(config['data']['class_names'])
    
    # Load model
    model = load_pretrained_model(model_type, num_classes, device)
    
    # Create baseline config (no bias mitigation)
    baseline_config = config['training'].copy()
    baseline_config['use_weighted_sampling'] = False
    baseline_config['use_cost_sensitive_loss'] = False
    
    # Create trainer
    trainer = BiasAwareTrainer(
        model=model,
        model_type=model_type,
        class_names=config['data']['class_names'],
        device=device,
        config=baseline_config
    )
    
    # Train baseline model
    baseline_history = trainer.train(train_dataset, val_dataset)
    
    # Evaluate baseline model
    evaluator = ModelEvaluator(config['data']['class_names'])
    
    # Get final predictions for evaluation
    model.eval()
    all_predictions = []
    all_targets = []
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=baseline_config.get('batch_size', 32),
        shuffle=False
    )
    
    with torch.no_grad():
        for images, labels, _ in val_loader:
            images = images.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(labels.numpy())
    
    baseline_metrics = evaluator.calculate_classification_metrics(all_targets, all_predictions)
    baseline_bias_metrics = evaluator.calculate_bias_metrics(all_targets, all_predictions)
    
    baseline_results = {
        'metrics': baseline_metrics,
        'bias_metrics': baseline_bias_metrics,
        'training_history': baseline_history
    }
    
    logger.info(f"Baseline Results:")
    logger.info(f"Accuracy: {baseline_metrics['accuracy']:.4f}")
    logger.info(f"Macro F1: {baseline_metrics['macro_f1']:.4f}")
    logger.info(f"Demographic Parity Diff: {baseline_bias_metrics['demographic_parity_diff']:.4f}")
    
    return baseline_results

def run_bias_mitigation_experiment(config: Dict, train_dataset: NuScenesDataset,
                                 val_dataset: NuScenesDataset, logger: logging.Logger) -> Dict:
    logger.info("Starting bias mitigation experiment...")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_type = config['model']['type']
    num_classes = len(config['data']['class_names'])
    
    # Load model
    model = load_pretrained_model(model_type, num_classes, device)
    
    # Create bias mitigation config
    mitigation_config = config['training'].copy()
    mitigation_config['use_weighted_sampling'] = True
    mitigation_config['use_cost_sensitive_loss'] = True
    mitigation_config['bias_analysis_interval'] = config['bias_mitigation'].get('analysis_interval', 10)
    
    # Create trainer with bias mitigation
    trainer = BiasAwareTrainer(
        model=model,
        model_type=model_type,
        class_names=config['data']['class_names'],
        device=device,
        config=mitigation_config
    )
    
    # Train model with bias mitigation
    mitigation_history = trainer.train(train_dataset, val_dataset)
    
    # Evaluate improved model
    evaluator = ModelEvaluator(config['data']['class_names'])
    
    # Get final predictions for evaluation
    model.eval()
    all_predictions = []
    all_targets = []
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=mitigation_config.get('batch_size', 32),
        shuffle=False
    )
    
    with torch.no_grad():
        for images, labels, _ in val_loader:
            images = images.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(labels.numpy())
    
    improved_metrics = evaluator.calculate_classification_metrics(all_targets, all_predictions)
    improved_bias_metrics = evaluator.calculate_bias_metrics(all_targets, all_predictions)
    
    # Analyze final model bias
    if 'vit' in model_type.lower():
        analyzer = VisionTransformerAnalyzer(model, config['data']['class_names'])
        attention_patterns = analyzer.analyze_class_attention_patterns(val_loader)
        final_bias_analysis = {'attention_patterns': attention_patterns}
    elif any(name in model_type.lower() for name in ['centerpoint', 'fs3d']):
        analyzer = Model3DAnalyzer(model, config['data']['class_names'], model_type)
        confidence_scores = analyzer.analyze_detection_confidence_bias(val_loader)
        final_bias_analysis = {'confidence_bias': confidence_scores}
    else:
        analyzer = BehaviorMetricsAnalyzer(model, config['data']['class_names'])
        # Analyze a sample batch
        sample_images, sample_labels, _ = next(iter(val_loader))
        sample_images = sample_images.to(device)
        sensitivity_scores = analyzer.calculate_sensitivity_score(sample_images[:1], sample_labels[0].item())
        final_bias_analysis = {'sensitivity_patterns': {str(k): v.mean().item() for k, v in sensitivity_scores.items()}}
    
    mitigation_results = {
        'metrics': improved_metrics,
        'bias_metrics': improved_bias_metrics,
        'training_history': mitigation_history,
        'final_bias_analysis': final_bias_analysis
    }
    
    logger.info(f"Bias Mitigation Results:")
    logger.info(f"Accuracy: {improved_metrics['accuracy']:.4f}")
    logger.info(f"Macro F1: {improved_metrics['macro_f1']:.4f}")
    logger.info(f"Demographic Parity Diff: {improved_bias_metrics['demographic_parity_diff']:.4f}")
    
    return mitigation_results

def analyze_improvement(baseline_results: Dict, mitigation_results: Dict, 
                       config: Dict, logger: logging.Logger) -> Dict:
    logger.info("Analyzing improvements...")
    
    evaluator = ModelEvaluator(config['data']['class_names'])
    
    # Calculate improvements
    improvements = evaluator.compare_model_performance(
        baseline_results['metrics'],
        mitigation_results['metrics']
    )
    
    # Analyze bias reduction
    baseline_bias = baseline_results['bias_metrics']
    improved_bias = mitigation_results['bias_metrics']
    
    bias_improvements = {}
    for metric in baseline_bias:
        if metric in improved_bias:
            baseline_val = baseline_bias[metric]
            improved_val = improved_bias[metric]
            
            # For bias metrics, reduction is improvement
            if 'diff' in metric or 'bias' in metric:
                if baseline_val != 0:
                    bias_reduction = ((baseline_val - improved_val) / baseline_val) * 100
                else:
                    bias_reduction = -improved_val * 100
                bias_improvements[f'{metric}_reduction'] = bias_reduction
    
    # Create summary report
    summary = {
        'performance_improvements': improvements,
        'bias_reductions': bias_improvements,
        'key_findings': {
            'accuracy_improvement': improvements.get('accuracy_improvement', 0),
            'f1_improvement': improvements.get('macro_f1_improvement', 0),
            'bias_reduction': bias_improvements.get('demographic_parity_diff_reduction', 0),
            'training_stability': len(mitigation_results['training_history']['train_loss'])
        }
    }
    
    # Log key improvements
    logger.info("Key Improvements:")
    logger.info(f"Accuracy: +{summary['key_findings']['accuracy_improvement']:.2f}%")
    logger.info(f"Macro F1: +{summary['key_findings']['f1_improvement']:.2f}%")
    logger.info(f"Bias Reduction: {summary['key_findings']['bias_reduction']:.2f}%")
    
    return summary

def save_experiment_results(baseline_results: Dict, mitigation_results: Dict,
                          summary: Dict, config: Dict, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    
    # Save all results
    results = {
        'experiment_config': config,
        'baseline_results': baseline_results,
        'mitigation_results': mitigation_results,
        'improvement_summary': summary,
        'timestamp': datetime.now().isoformat()
    }
    
    with open(os.path.join(output_dir, 'experiment_results.json'), 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    # Save individual components
    torch.save(baseline_results['training_history'], 
               os.path.join(output_dir, 'baseline_history.pth'))
    torch.save(mitigation_results['training_history'], 
               os.path.join(output_dir, 'mitigation_history.pth'))
    
    # Generate plots
    evaluator = ModelEvaluator(config['data']['class_names'])
    
    # Performance comparison plot
    evaluator.compare_model_performance(
        baseline_results['metrics'],
        mitigation_results['metrics'],
        save_path=os.path.join(output_dir, 'performance_comparison.png')
    )
    
    # Class-specific performance plots
    evaluator.plot_class_performance(
        baseline_results['metrics'],
        save_path=os.path.join(output_dir, 'baseline_class_performance.png')
    )
    evaluator.plot_class_performance(
        mitigation_results['metrics'],
        save_path=os.path.join(output_dir, 'improved_class_performance.png')
    )

def main():
    parser = argparse.ArgumentParser(description='Bias Analysis and Mitigation Experiment')
    parser.add_argument('--config', type=str, required=True, 
                       help='Path to experiment configuration file')
    parser.add_argument('--output_dir', type=str, default='./experiment_results',
                       help='Directory to save experiment results')
    parser.add_argument('--baseline_only', action='store_true',
                       help='Run only baseline experiment')
    parser.add_argument('--mitigation_only', action='store_true',
                       help='Run only bias mitigation experiment')
    
    args = parser.parse_args()
    
    # Load configuration
    config = load_config(args.config)
    
    # Setup logging
    logger = setup_logging(args.output_dir)
    logger.info("Starting bias analysis and mitigation experiment")
    logger.info(f"Configuration: {args.config}")
    
    # Prepare datasets
    logger.info("Preparing datasets...")
    train_dataset, val_dataset = prepare_datasets(config)
    
    logger.info(f"Training dataset: {len(train_dataset)} samples")
    logger.info(f"Validation dataset: {len(val_dataset)} samples")
    logger.info(f"Class distribution: {train_dataset.class_distribution}")
    
    # Run experiments
    baseline_results = None
    mitigation_results = None
    
    if not args.mitigation_only:
        baseline_results = run_baseline_experiment(config, train_dataset, val_dataset, logger)
    
    if not args.baseline_only:
        mitigation_results = run_bias_mitigation_experiment(config, train_dataset, val_dataset, logger)
    
    # Analyze improvements if both experiments were run
    if baseline_results and mitigation_results:
        summary = analyze_improvement(baseline_results, mitigation_results, config, logger)
        save_experiment_results(baseline_results, mitigation_results, summary, config, args.output_dir)
    elif baseline_results:
        # Save only baseline results
        with open(os.path.join(args.output_dir, 'baseline_results.json'), 'w') as f:
            json.dump(baseline_results, f, indent=2, default=str)
    elif mitigation_results:
        # Save only mitigation results
        with open(os.path.join(args.output_dir, 'mitigation_results.json'), 'w') as f:
            json.dump(mitigation_results, f, indent=2, default=str)
    
    logger.info(f"Experiment completed. Results saved to {args.output_dir}")

if __name__ == '__main__':
    main()