import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import time
import os
from typing import Dict, List, Tuple, Optional
import logging
from tqdm import tqdm

from dataset_handler import NuScenesDataset, BiasAwareDataLoader
from activation_visualizer import BehaviorMetricsAnalyzer
from vit_analyzer import VisionTransformerAnalyzer
from model_3d_analyzer import Model3DAnalyzer
from bias_mitigation import AdaptiveBiasMitigation
from evaluation_metrics import ModelEvaluator

class BiasAwareTrainer:
    def __init__(self, model: nn.Module, model_type: str, class_names: List[str],
                 device: torch.device, config: Dict):
        self.model = model.to(device)
        self.model_type = model_type.lower()
        self.class_names = class_names
        self.device = device
        self.config = config
        
        # Initialize analyzers based on model type
        if 'vit' in self.model_type:
            self.analyzer = VisionTransformerAnalyzer(model, class_names)
        elif any(model_name in self.model_type for model_name in ['centerpoint', 'fs3d']):
            self.analyzer = Model3DAnalyzer(model, class_names, self.model_type)
        else:
            self.analyzer = BehaviorMetricsAnalyzer(model, class_names)
        
        self.evaluator = ModelEvaluator(class_names)
        self.bias_mitigator = None
        
        # Training components
        self.optimizer = None
        self.scheduler = None
        self.scaler = torch.cuda.amp.GradScaler() if config.get('use_amp', False) else None
        
        # Logging
        self.setup_logging()
        
        # Metrics tracking
        self.training_history = {
            'train_loss': [],
            'val_loss': [],
            'sensitivity_scores': [],
            'selectivity_scores': [],
            'performance_metrics': [],
            'bias_scores': []
        }
    
    def setup_logging(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('training.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def setup_training_components(self, train_dataset: NuScenesDataset):
        # Initialize bias mitigation strategy
        self.bias_mitigator = AdaptiveBiasMitigation(
            self.class_names, 
            train_dataset.class_distribution
        )
        
        # Setup optimizer
        optimizer_config = self.config.get('optimizer', {})
        if optimizer_config.get('type', 'adam').lower() == 'adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=optimizer_config.get('lr', 1e-3),
                weight_decay=optimizer_config.get('weight_decay', 1e-4)
            )
        else:
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=optimizer_config.get('lr', 1e-3),
                momentum=optimizer_config.get('momentum', 0.9),
                weight_decay=optimizer_config.get('weight_decay', 1e-4)
            )
        
        # Setup scheduler
        scheduler_config = self.config.get('scheduler', {})
        if scheduler_config.get('type') == 'cosine':
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=self.config.get('epochs', 100),
                eta_min=scheduler_config.get('min_lr', 1e-6)
            )
        elif scheduler_config.get('type') == 'step':
            self.scheduler = optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=scheduler_config.get('step_size', 30),
                gamma=scheduler_config.get('gamma', 0.1)
            )
    
    def analyze_model_bias(self, dataloader: DataLoader, num_samples: int = 100) -> Tuple[Dict, Dict]:
        self.model.eval()
        
        if 'vit' in self.model_type:
            # ViT analysis
            attention_patterns = self.analyzer.analyze_class_attention_patterns(dataloader, num_samples)
            
            sensitivity_scores = {}
            selectivity_scores = {}
            
            for class_idx, patterns in attention_patterns.items():
                sensitivity_scores[class_idx] = patterns
                selectivity_scores[class_idx] = patterns
                
        elif any(model_name in self.model_type for model_name in ['centerpoint', 'fs3d']):
            # 3D model analysis
            sensitivity_scores = {}
            selectivity_scores = {}
            
            sample_count = 0
            for batch_data in dataloader:
                if sample_count >= num_samples:
                    break
                
                if isinstance(batch_data, (list, tuple)):
                    point_clouds, labels = batch_data[:2]
                else:
                    point_clouds = batch_data['points']
                    labels = batch_data['labels']
                
                for i, label in enumerate(labels):
                    if sample_count >= num_samples:
                        break
                    
                    class_idx = label.item()
                    if class_idx not in sensitivity_scores:
                        sensitivity_scores[class_idx] = {}
                        selectivity_scores[class_idx] = {}
                    
                    # Calculate 3D sensitivity scores
                    sens_scores = self.analyzer.calculate_3d_sensitivity_score(
                        point_clouds[i:i+1], class_idx
                    )
                    
                    for layer_name, score in sens_scores.items():
                        if layer_name not in sensitivity_scores[class_idx]:
                            sensitivity_scores[class_idx][layer_name] = []
                        sensitivity_scores[class_idx][layer_name].append(score.mean().item())
                    
                    sample_count += 1
            
            # Average scores
            for class_idx in sensitivity_scores:
                for layer_name in sensitivity_scores[class_idx]:
                    scores = sensitivity_scores[class_idx][layer_name]
                    sensitivity_scores[class_idx][layer_name] = np.mean(scores) if scores else 0.0
                    selectivity_scores[class_idx][layer_name] = np.mean(scores) if scores else 0.0
                    
        else:
            # CNN analysis
            sensitivity_scores = {}
            selectivity_scores = {}
            
            # Collect samples for each class
            class_samples = {i: [] for i in range(len(self.class_names))}
            
            sample_count = 0
            for images, labels, _ in dataloader:
                if sample_count >= num_samples:
                    break
                
                images = images.to(self.device)
                for i, label in enumerate(labels):
                    class_idx = label.item()
                    if len(class_samples[class_idx]) < num_samples // len(self.class_names):
                        class_samples[class_idx].append(images[i:i+1])
                        sample_count += 1
            
            # Calculate sensitivity scores for each class
            for class_idx, samples in class_samples.items():
                if not samples:
                    continue
                
                class_sensitivity = {}
                for sample in samples[:5]:  # Limit to avoid memory issues
                    sens_scores = self.analyzer.calculate_sensitivity_score(sample, class_idx)
                    for module, score in sens_scores.items():
                        module_name = str(module)
                        if module_name not in class_sensitivity:
                            class_sensitivity[module_name] = []
                        class_sensitivity[module_name].append(score.mean().item())
                
                # Average sensitivity scores
                sensitivity_scores[class_idx] = {}
                for module_name, scores in class_sensitivity.items():
                    sensitivity_scores[class_idx][module_name] = np.mean(scores)
                
                # Calculate selectivity scores
                input_tensors = {class_idx: torch.cat(samples[:3]) if len(samples) >= 3 else samples[0]}
                selectivity_result = self.analyzer.calculate_selectivity_score(input_tensors)
                selectivity_scores.update(selectivity_result)
        
        return sensitivity_scores, selectivity_scores
    
    def train_epoch(self, train_loader: DataLoader, epoch: int) -> float:
        self.model.train()
        total_loss = 0.0
        num_batches = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')
        
        for batch_idx, batch_data in enumerate(progress_bar):
            # Unpack batch data
            if len(batch_data) == 3:
                inputs, labels, metadata = batch_data
            else:
                inputs, labels = batch_data[:2]
                metadata = None
            
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            
            # Apply bias mitigation if available
            if self.bias_mitigator and hasattr(self.model, 'get_intermediate_outputs'):
                with torch.no_grad():
                    model_outputs = self.model.get_intermediate_outputs(inputs)
                batch_data = self.bias_mitigator.apply_mitigation_to_batch(
                    (inputs, labels, metadata), model_outputs
                )
                inputs, labels, _ = batch_data
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass with mixed precision if enabled
            if self.scaler:
                with torch.cuda.amp.autocast():
                    outputs = self.model(inputs)
                    loss = self.bias_mitigator.get_current_strategy()['loss_function'](outputs, labels)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(inputs)
                loss = self.bias_mitigator.get_current_strategy()['loss_function'](outputs, labels)
                loss.backward()
                self.optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})
            
            # Log batch metrics periodically
            if batch_idx % self.config.get('log_interval', 100) == 0:
                self.logger.info(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / num_batches
        return avg_loss
    
    def validate_epoch(self, val_loader: DataLoader, epoch: int) -> Tuple[float, Dict]:
        self.model.eval()
        total_loss = 0.0
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels, _ in val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.model(inputs)
                loss = self.bias_mitigator.get_current_strategy()['loss_function'](outputs, labels)
                
                total_loss += loss.item()
                
                # Collect predictions for metrics calculation
                predictions = torch.argmax(outputs, dim=1)
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        avg_loss = total_loss / len(val_loader)
        
        # Calculate performance metrics
        performance_metrics = self.evaluator.calculate_classification_metrics(
            all_labels, all_predictions, self.class_names
        )
        
        return avg_loss, performance_metrics
    
    def train(self, train_dataset: NuScenesDataset, val_dataset: NuScenesDataset) -> Dict:
        self.setup_training_components(train_dataset)
        
        # Create data loaders
        train_loader = BiasAwareDataLoader(
            train_dataset, 
            self.config.get('batch_size', 32),
            use_weighted_sampling=self.config.get('use_weighted_sampling', True)
        ).get_dataloader()
        
        val_loader = BiasAwareDataLoader(
            val_dataset,
            self.config.get('batch_size', 32),
            use_weighted_sampling=False
        ).get_dataloader()
        
        best_val_loss = float('inf')
        best_model_state = None
        patience_counter = 0
        
        self.logger.info("Starting training...")
        
        for epoch in range(self.config.get('epochs', 100)):
            start_time = time.time()
            
            # Training phase
            train_loss = self.train_epoch(train_loader, epoch)
            
            # Validation phase
            val_loss, performance_metrics = self.validate_epoch(val_loader, epoch)
            
            # Bias analysis (periodic)
            if epoch % self.config.get('bias_analysis_interval', 10) == 0:
                sensitivity_scores, selectivity_scores = self.analyze_model_bias(val_loader)
                
                # Update bias mitigation strategy
                class_performance = {
                    self.class_names[i]: performance_metrics.get(f'{self.class_names[i]}_f1', 0.0)
                    for i in range(len(self.class_names))
                }
                
                self.bias_mitigator.update_strategy(
                    sensitivity_scores, selectivity_scores, class_performance
                )
                
                # Store metrics
                self.training_history['sensitivity_scores'].append(sensitivity_scores)
                self.training_history['selectivity_scores'].append(selectivity_scores)
                
                # Calculate overall bias score
                avg_sensitivity = np.mean([
                    np.mean(list(class_scores.values())) 
                    for class_scores in sensitivity_scores.values()
                ])
                avg_selectivity = np.mean([
                    np.mean(list(class_scores.values())) 
                    for class_scores in selectivity_scores.values()
                ])
                bias_score = avg_sensitivity - avg_selectivity  # Higher = more biased
                self.training_history['bias_scores'].append(bias_score)
            
            # Learning rate scheduling
            if self.scheduler:
                self.scheduler.step()
            
            # Store training history
            self.training_history['train_loss'].append(train_loss)
            self.training_history['val_loss'].append(val_loss)
            self.training_history['performance_metrics'].append(performance_metrics)
            
            # Model checkpointing
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = self.model.state_dict().copy()
                patience_counter = 0
                
                # Save best model
                if self.config.get('save_best_model', True):
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': best_model_state,
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_loss': best_val_loss,
                        'config': self.config
                    }, f'best_model_{self.model_type}.pth')
            else:
                patience_counter += 1
            
            epoch_time = time.time() - start_time
            
            # Logging
            self.logger.info(
                f'Epoch {epoch}: Train Loss: {train_loss:.4f}, '
                f'Val Loss: {val_loss:.4f}, '
                f'Val Accuracy: {performance_metrics.get("accuracy", 0):.4f}, '
                f'Time: {epoch_time:.2f}s'
            )
            
            # Early stopping
            if patience_counter >= self.config.get('patience', 20):
                self.logger.info(f'Early stopping triggered after {epoch} epochs')
                break
        
        # Load best model
        if best_model_state:
            self.model.load_state_dict(best_model_state)
        
        self.logger.info("Training completed!")
        
        return self.training_history
    
    def save_training_results(self, save_dir: str):
        os.makedirs(save_dir, exist_ok=True)
        
        # Save training history
        torch.save(self.training_history, os.path.join(save_dir, 'training_history.pth'))
        
        # Save final model
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'config': self.config,
            'class_names': self.class_names,
            'model_type': self.model_type
        }, os.path.join(save_dir, 'final_model.pth'))
        
        # Save bias mitigation strategy
        if self.bias_mitigator:
            torch.save(
                self.bias_mitigator.get_current_strategy(),
                os.path.join(save_dir, 'bias_mitigation_strategy.pth')
            )
        
        self.logger.info(f"Training results saved to {save_dir}")

def load_pretrained_model(model_type: str, num_classes: int, device: torch.device) -> nn.Module:
    if model_type.lower() == 'resnet18':
        import torchvision.models as models
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_type.lower() == 'squeezenet':
        import torchvision.models as models
        model = models.squeezenet1_1(pretrained=True)
        model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
    elif 'vit' in model_type.lower():
        from transformers import ViTForImageClassification, ViTConfig
        config = ViTConfig(
            image_size=224,
            num_labels=num_classes,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12
        )
        model = ViTForImageClassification(config)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    return model.to(device)