import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
from collections import Counter
import copy

class FocalLoss(nn.Module):
    def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, 
                 reduction: str = 'mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            if self.alpha.type() != inputs.data.type():
                self.alpha = self.alpha.type_as(inputs.data)
            at = self.alpha.gather(0, targets.data.view(-1))
            focal_loss = at * focal_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights: torch.Tensor):
        super(WeightedCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights
    
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        return F.cross_entropy(inputs, targets, weight=self.class_weights)

class CostSensitiveLearning:
    def __init__(self, class_distribution: Dict[str, int], class_names: List[str],
                 loss_type: str = 'weighted_ce', focal_gamma: float = 2.0):
        self.class_distribution = class_distribution
        self.class_names = class_names
        self.loss_type = loss_type
        self.focal_gamma = focal_gamma
        
        self.class_weights = self._calculate_class_weights()
        self.loss_function = self._create_loss_function()
    
    def _calculate_class_weights(self) -> torch.Tensor:
        total_samples = sum(self.class_distribution.values())
        weights = []
        
        for class_name in self.class_names:
            class_count = self.class_distribution.get(class_name, 1)
            class_percentage = (class_count / total_samples) * 100
            weight = 1.0 / class_percentage
            weights.append(weight)
        
        weights = torch.tensor(weights, dtype=torch.float32)
        # Normalize weights
        weights = weights / weights.sum() * len(weights)
        
        return weights
    
    def _create_loss_function(self) -> nn.Module:
        if self.loss_type == 'weighted_ce':
            return WeightedCrossEntropyLoss(self.class_weights)
        elif self.loss_type == 'focal':
            return FocalLoss(alpha=self.class_weights, gamma=self.focal_gamma)
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")
    
    def get_loss_function(self) -> nn.Module:
        return self.loss_function
    
    def update_weights_based_on_performance(self, class_performance: Dict[str, float], 
                                          adaptation_rate: float = 0.1):
        # Dynamically adjust weights based on validation performance
        for i, class_name in enumerate(self.class_names):
            performance = class_performance.get(class_name, 1.0)
            # Lower performance means higher weight needed
            adjustment = (1.0 - performance) * adaptation_rate
            self.class_weights[i] = self.class_weights[i] * (1 + adjustment)
        
        # Renormalize weights
        self.class_weights = self.class_weights / self.class_weights.sum() * len(self.class_weights)
        
        # Update loss function with new weights
        self.loss_function = self._create_loss_function()

class BiasAwareResampling:
    def __init__(self, sensitivity_threshold: float = 0.7, selectivity_threshold: float = 0.3):
        self.sensitivity_threshold = sensitivity_threshold
        self.selectivity_threshold = selectivity_threshold
    
    def identify_biased_samples(self, sensitivity_scores: Dict[str, torch.Tensor],
                              selectivity_scores: Dict[str, float],
                              sample_indices: List[int]) -> List[int]:
        biased_samples = []
        
        for idx in sample_indices:
            is_biased = False
            
            # Check if any layer shows high sensitivity (potential overreliance)
            for layer_name, scores in sensitivity_scores.items():
                if torch.is_tensor(scores) and scores.dim() > 0:
                    if idx < len(scores) and scores[idx] > self.sensitivity_threshold:
                        is_biased = True
                        break
            
            # Check if selectivity is too low (poor discrimination)
            avg_selectivity = np.mean(list(selectivity_scores.values()))
            if avg_selectivity < self.selectivity_threshold:
                is_biased = True
            
            if is_biased:
                biased_samples.append(idx)
        
        return biased_samples
    
    def apply_targeted_oversampling(self, dataset, target_classes: List[int], 
                                  oversample_factor: float = 2.0) -> List[int]:
        class_indices = {i: [] for i in range(len(target_classes))}
        
        # Group samples by class
        for idx in range(len(dataset)):
            _, label, _ = dataset[idx]
            if label in target_classes:
                class_indices[label].append(idx)
        
        # Oversample underrepresented classes
        augmented_indices = []
        for class_idx, indices in class_indices.items():
            augmented_indices.extend(indices)
            if class_idx in target_classes:
                # Add additional samples for target classes
                num_additional = int(len(indices) * (oversample_factor - 1))
                additional_indices = np.random.choice(indices, num_additional, replace=True)
                augmented_indices.extend(additional_indices.tolist())
        
        return augmented_indices

class AttentionGuidedAugmentation:
    def __init__(self, attention_threshold: float = 0.8):
        self.attention_threshold = attention_threshold
    
    def identify_attention_gaps(self, attention_maps: Dict[str, torch.Tensor],
                              class_label: int) -> Dict[str, List[Tuple[int, int]]]:
        attention_gaps = {}
        
        for layer_name, attention_map in attention_maps.items():
            gaps = []
            
            if attention_map.dim() >= 2:
                # Convert attention map to 2D if needed
                if attention_map.dim() > 2:
                    attention_2d = attention_map.mean(dim=0)
                else:
                    attention_2d = attention_map
                
                # Find regions with low attention
                low_attention_mask = attention_2d < self.attention_threshold
                low_attention_indices = torch.nonzero(low_attention_mask, as_tuple=False)
                
                for indices in low_attention_indices:
                    gaps.append((indices[0].item(), indices[1].item()))
            
            attention_gaps[layer_name] = gaps
        
        return attention_gaps
    
    def generate_attention_aware_augmentations(self, image: torch.Tensor,
                                             attention_gaps: Dict[str, List[Tuple[int, int]]],
                                             augmentation_strength: float = 0.5) -> List[torch.Tensor]:
        augmented_images = []
        
        # Apply different augmentations based on attention gaps
        for layer_name, gaps in attention_gaps.items():
            if not gaps:
                continue
            
            augmented_image = image.clone()
            
            # Apply local augmentations to regions with attention gaps
            for gap_row, gap_col in gaps:
                # Create a local patch around the gap
                patch_size = 16
                start_row = max(0, gap_row - patch_size // 2)
                end_row = min(image.size(-2), gap_row + patch_size // 2)
                start_col = max(0, gap_col - patch_size // 2)
                end_col = min(image.size(-1), gap_col + patch_size // 2)
                
                # Apply augmentation to the patch
                patch = augmented_image[..., start_row:end_row, start_col:end_col]
                
                # Random augmentation: brightness, contrast, or noise
                aug_type = np.random.choice(['brightness', 'contrast', 'noise'])
                
                if aug_type == 'brightness':
                    brightness_factor = 1.0 + (np.random.random() - 0.5) * augmentation_strength
                    patch = torch.clamp(patch * brightness_factor, 0, 1)
                elif aug_type == 'contrast':
                    contrast_factor = 1.0 + (np.random.random() - 0.5) * augmentation_strength
                    mean_val = patch.mean()
                    patch = torch.clamp((patch - mean_val) * contrast_factor + mean_val, 0, 1)
                elif aug_type == 'noise':
                    noise = torch.randn_like(patch) * augmentation_strength * 0.1
                    patch = torch.clamp(patch + noise, 0, 1)
                
                augmented_image[..., start_row:end_row, start_col:end_col] = patch
            
            augmented_images.append(augmented_image)
        
        return augmented_images

class CrossMetricOptimization:
    def __init__(self, sensitivity_weight: float = 0.3, selectivity_weight: float = 0.3,
                 performance_weight: float = 0.4):
        self.sensitivity_weight = sensitivity_weight
        self.selectivity_weight = selectivity_weight
        self.performance_weight = performance_weight
    
    def calculate_composite_score(self, sensitivity_scores: Dict[str, float],
                                selectivity_scores: Dict[str, float],
                                performance_metrics: Dict[str, float]) -> float:
        # Normalize scores to [0, 1] range
        norm_sensitivity = np.mean(list(sensitivity_scores.values()))
        norm_selectivity = np.mean(list(selectivity_scores.values()))
        norm_performance = np.mean(list(performance_metrics.values()))
        
        # Composite score (lower sensitivity bias + higher selectivity + higher performance = better)
        composite_score = (
            (1.0 - norm_sensitivity) * self.sensitivity_weight +
            norm_selectivity * self.selectivity_weight +
            norm_performance * self.performance_weight
        )
        
        return composite_score
    
    def optimize_class_weights(self, class_weights: torch.Tensor,
                             current_scores: Dict[str, Dict[str, float]],
                             learning_rate: float = 0.01) -> torch.Tensor:
        # Gradient-free optimization based on composite scores
        optimized_weights = class_weights.clone()
        
        for class_idx, class_data in current_scores.items():
            if isinstance(class_idx, str):
                continue
            
            sensitivity = class_data.get('sensitivity', {})
            selectivity = class_data.get('selectivity', {})
            performance = class_data.get('performance', {})
            
            composite_score = self.calculate_composite_score(sensitivity, selectivity, performance)
            
            # Adjust weight based on composite score
            # Lower score means need higher weight
            weight_adjustment = (1.0 - composite_score) * learning_rate
            optimized_weights[class_idx] = optimized_weights[class_idx] * (1 + weight_adjustment)
        
        # Renormalize weights
        optimized_weights = optimized_weights / optimized_weights.sum() * len(optimized_weights)
        
        return optimized_weights

class AdaptiveBiasMitigation:
    def __init__(self, class_names: List[str], initial_class_distribution: Dict[str, int]):
        self.class_names = class_names
        self.cost_sensitive = CostSensitiveLearning(initial_class_distribution, class_names)
        self.resampling = BiasAwareResampling()
        self.attention_augmentation = AttentionGuidedAugmentation()
        self.cross_metric_optimizer = CrossMetricOptimization()
        
        self.history = {
            'sensitivity_scores': [],
            'selectivity_scores': [],
            'performance_metrics': [],
            'class_weights': []
        }
    
    def update_strategy(self, sensitivity_scores: Dict[str, Dict[str, float]],
                       selectivity_scores: Dict[str, Dict[str, float]],
                       performance_metrics: Dict[str, float],
                       validation_data=None):
        # Store history
        self.history['sensitivity_scores'].append(sensitivity_scores)
        self.history['selectivity_scores'].append(selectivity_scores)
        self.history['performance_metrics'].append(performance_metrics)
        self.history['class_weights'].append(self.cost_sensitive.class_weights.clone())
        
        # Update class weights based on cross-metric optimization
        combined_scores = {}
        for class_idx in range(len(self.class_names)):
            combined_scores[class_idx] = {
                'sensitivity': sensitivity_scores.get(class_idx, {}),
                'selectivity': selectivity_scores.get(class_idx, {}),
                'performance': {self.class_names[class_idx]: performance_metrics.get(self.class_names[class_idx], 0.0)}
            }
        
        optimized_weights = self.cross_metric_optimizer.optimize_class_weights(
            self.cost_sensitive.class_weights, combined_scores
        )
        
        self.cost_sensitive.class_weights = optimized_weights
        self.cost_sensitive.loss_function = self.cost_sensitive._create_loss_function()
    
    def get_current_strategy(self) -> Dict:
        return {
            'loss_function': self.cost_sensitive.get_loss_function(),
            'class_weights': self.cost_sensitive.class_weights,
            'resampling_config': {
                'sensitivity_threshold': self.resampling.sensitivity_threshold,
                'selectivity_threshold': self.resampling.selectivity_threshold
            },
            'augmentation_config': {
                'attention_threshold': self.attention_augmentation.attention_threshold
            }
        }
    
    def apply_mitigation_to_batch(self, batch_data: Tuple, model_outputs: Dict) -> Tuple:
        # Apply real-time bias mitigation during training
        images, labels, metadata = batch_data
        
        # Identify samples that need augmentation based on model outputs
        if 'attention_maps' in model_outputs:
            attention_maps = model_outputs['attention_maps']
            
            augmented_batch = []
            for i, (image, label) in enumerate(zip(images, labels)):
                if i < len(attention_maps):
                    gaps = self.attention_augmentation.identify_attention_gaps(
                        {f'layer_{j}': att_map[i] for j, att_map in enumerate(attention_maps)},
                        label.item()
                    )
                    
                    if any(gaps.values()):  # If attention gaps found
                        aug_images = self.attention_augmentation.generate_attention_aware_augmentations(
                            image, gaps
                        )
                        # Add one augmented version
                        if aug_images:
                            augmented_batch.append((aug_images[0], label, metadata[i]))
                
                augmented_batch.append((image, label, metadata[i]))
            
            # Reconstruct batch
            if augmented_batch:
                aug_images, aug_labels, aug_metadata = zip(*augmented_batch)
                aug_images = torch.stack(aug_images)
                aug_labels = torch.stack(aug_labels)
                return aug_images, aug_labels, list(aug_metadata)
        
        return batch_data