import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
from transformers import ViTModel, ViTConfig
import cv2

class VisionTransformerAnalyzer:
    def __init__(self, model: nn.Module, class_names: List[str], image_size: int = 224):
        self.model = model
        self.class_names = class_names
        self.image_size = image_size
        self.attention_maps = {}
        self.relevance_scores = {}
        self.patch_size = 16  # Standard ViT patch size
        self.num_patches = (image_size // self.patch_size) ** 2
        
        self._register_attention_hooks()
    
    def _register_attention_hooks(self):
        def attention_hook(module, input, output):
            # Extract attention weights from transformer layer
            if hasattr(output, 'attentions'):
                self.attention_maps[module] = output.attentions.detach()
            elif isinstance(output, tuple) and len(output) > 1:
                self.attention_maps[module] = output[1].detach()
    
        for name, module in self.model.named_modules():
            if 'attention' in name.lower() or 'transformer' in name.lower():
                module.register_forward_hook(attention_hook)
    
    def extract_attention_maps(self, input_tensor: torch.Tensor) -> Dict[str, torch.Tensor]:
        self.attention_maps.clear()
        
        with torch.no_grad():
            output = self.model(input_tensor)
        
        processed_maps = {}
        for layer_name, attention_weights in self.attention_maps.items():
            # Average attention across heads
            if attention_weights.dim() == 4:  # [batch, heads, tokens, tokens]
                avg_attention = attention_weights.mean(dim=1)  # Average over heads
                processed_maps[layer_name] = avg_attention
            else:
                processed_maps[layer_name] = attention_weights
        
        return processed_maps
    
    def calculate_attention_bias_score(self, attention_maps: Dict[str, torch.Tensor]) -> Dict[str, float]:
        bias_scores = {}
        
        for layer_name, attention_map in attention_maps.items():
            # Calculate attention distribution entropy
            # Lower entropy indicates more focused (potentially biased) attention
            batch_size = attention_map.size(0)
            
            layer_bias_scores = []
            for batch_idx in range(batch_size):
                att_map = attention_map[batch_idx]
                
                # Exclude CLS token (first token) and focus on image patches
                if att_map.size(0) > self.num_patches:
                    patch_attention = att_map[1:, 1:]  # Exclude CLS token
                else:
                    patch_attention = att_map
                
                # Calculate entropy of attention distribution
                attention_probs = F.softmax(patch_attention.sum(dim=0), dim=0)
                entropy = -(attention_probs * torch.log(attention_probs + 1e-8)).sum()
                
                # Normalize entropy (higher score = more bias)
                max_entropy = torch.log(torch.tensor(patch_attention.size(-1), dtype=torch.float))
                bias_score = 1.0 - (entropy / max_entropy)
                layer_bias_scores.append(bias_score.item())
            
            bias_scores[layer_name] = np.mean(layer_bias_scores)
        
        return bias_scores
    
    def layer_wise_relevance_propagation(self, input_tensor: torch.Tensor, 
                                       target_class: int) -> torch.Tensor:
        input_tensor.requires_grad_(True)
        output = self.model(input_tensor)
        
        # Create one-hot target
        target = torch.zeros_like(output)
        target[:, target_class] = 1
        
        # Compute gradients
        output.backward(target, retain_graph=True)
        
        # Get relevance scores (simplified LRP)
        relevance = input_tensor.grad * input_tensor
        
        return relevance.detach()
    
    def analyze_class_attention_patterns(self, dataloader, num_samples: int = 100) -> Dict[int, Dict[str, float]]:
        class_attention_patterns = {i: {} for i in range(len(self.class_names))}
        samples_per_class = {i: 0 for i in range(len(self.class_names))}
        
        self.model.eval()
        with torch.no_grad():
            for batch_idx, (images, labels, _) in enumerate(dataloader):
                if sum(samples_per_class.values()) >= num_samples:
                    break
                
                attention_maps = self.extract_attention_maps(images)
                bias_scores = self.calculate_attention_bias_score(attention_maps)
                
                for i, label in enumerate(labels):
                    class_idx = label.item()
                    if samples_per_class[class_idx] < num_samples // len(self.class_names):
                        for layer_name, score in bias_scores.items():
                            if layer_name not in class_attention_patterns[class_idx]:
                                class_attention_patterns[class_idx][layer_name] = []
                            class_attention_patterns[class_idx][layer_name].append(score)
                        samples_per_class[class_idx] += 1
        
        # Average scores per class
        for class_idx in class_attention_patterns:
            for layer_name in class_attention_patterns[class_idx]:
                scores = class_attention_patterns[class_idx][layer_name]
                class_attention_patterns[class_idx][layer_name] = np.mean(scores) if scores else 0.0
        
        return class_attention_patterns
    
    def visualize_attention_bias(self, class_attention_patterns: Dict[int, Dict[str, float]], 
                               save_path: Optional[str] = None):
        # Prepare data for heatmap
        layers = list(next(iter(class_attention_patterns.values())).keys())
        classes = [self.class_names[i] for i in sorted(class_attention_patterns.keys())]
        
        data = []
        for class_idx in sorted(class_attention_patterns.keys()):
            row = []
            for layer in layers:
                score = class_attention_patterns[class_idx].get(layer, 0.0)
                row.append(score)
            data.append(row)
        
        plt.figure(figsize=(12, 8))
        sns.heatmap(data, 
                   xticklabels=[f"Layer_{i}" for i in range(len(layers))],
                   yticklabels=classes,
                   annot=True, 
                   fmt='.3f',
                   cmap='RdYlBu_r',
                   cbar_kws={'label': 'Attention Bias Score'})
        
        plt.title('Attention Bias Patterns Across Classes and Layers')
        plt.xlabel('Transformer Layers')
        plt.ylabel('Classes')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def generate_attention_heatmap(self, image: torch.Tensor, layer_idx: int = -1) -> np.ndarray:
        attention_maps = self.extract_attention_maps(image.unsqueeze(0))
        
        if not attention_maps:
            return np.zeros((self.image_size, self.image_size))
        
        # Get attention from specified layer (default: last layer)
        layer_names = list(attention_maps.keys())
        if layer_idx >= len(layer_names):
            layer_idx = -1
        
        attention_map = attention_maps[layer_names[layer_idx]][0]  # First batch item
        
        # Average attention across all tokens for each patch
        if attention_map.size(0) > self.num_patches:
            patch_attention = attention_map[1:, 1:].mean(dim=0)  # Exclude CLS token
        else:
            patch_attention = attention_map.mean(dim=0)
        
        # Reshape to spatial dimensions
        patches_per_side = int(np.sqrt(len(patch_attention)))
        attention_2d = patch_attention[:patches_per_side**2].reshape(patches_per_side, patches_per_side)
        
        # Resize to original image size
        attention_resized = cv2.resize(
            attention_2d.cpu().numpy(), 
            (self.image_size, self.image_size), 
            interpolation=cv2.INTER_CUBIC
        )
        
        return attention_resized
    
    def compute_cross_class_attention_similarity(self, dataloader, num_samples: int = 50) -> np.ndarray:
        class_attention_vectors = {i: [] for i in range(len(self.class_names))}
        
        self.model.eval()
        with torch.no_grad():
            for images, labels, _ in dataloader:
                attention_maps = self.extract_attention_maps(images)
                
                for i, label in enumerate(labels):
                    class_idx = label.item()
                    if len(class_attention_vectors[class_idx]) < num_samples:
                        # Flatten attention maps into a single vector
                        attention_vector = []
                        for layer_attention in attention_maps.values():
                            attention_vector.extend(layer_attention[i].flatten().cpu().numpy())
                        class_attention_vectors[class_idx].append(attention_vector)
        
        # Calculate similarity matrix
        num_classes = len(self.class_names)
        similarity_matrix = np.zeros((num_classes, num_classes))
        
        for i in range(num_classes):
            for j in range(num_classes):
                if class_attention_vectors[i] and class_attention_vectors[j]:
                    # Calculate cosine similarity between mean attention vectors
                    mean_vec_i = np.mean(class_attention_vectors[i], axis=0)
                    mean_vec_j = np.mean(class_attention_vectors[j], axis=0)
                    
                    similarity = np.dot(mean_vec_i, mean_vec_j) / (
                        np.linalg.norm(mean_vec_i) * np.linalg.norm(mean_vec_j) + 1e-8
                    )
                    similarity_matrix[i, j] = similarity
        
        return similarity_matrix