import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict

class BehaviorMetricsAnalyzer:
    def __init__(self, model: nn.Module, class_names: List[str]):
        self.model = model
        self.class_names = class_names
        self.activations = {}
        self.gradients = {}
        
        # Register hooks to capture activations and gradients
        self._register_hooks()
    
    def _register_hooks(self):
        def activation_hook(module, input, output):
            self.activations[module] = output.detach()
        
        def gradient_hook(module, grad_input, grad_output):
            self.gradients[module] = grad_output[0].detach()
       
        for module in self.model.modules():
            if isinstance(module, nn.Conv2d):
                module.register_forward_hook(activation_hook)
                module.register_backward_hook(gradient_hook)
    
    def calculate_sensitivity_score(self, input_tensor: torch.Tensor, target_class: int) -> Dict[nn.Module, torch.Tensor]:
        self.model.zero_grad()
        output = self.model(input_tensor)
        
        # Create one-hot encoded target
        target = torch.zeros_like(output)
        target[:, target_class] = 1
        
        # Compute loss and backpropagate
        loss = F.cross_entropy(output, target)
        loss.backward()
        
        # Calculate sensitivity scores
        sensitivity_scores = {}
        for module, gradient in self.gradients.items():
            # Average gradient magnitude across spatial dimensions
            sensitivity_scores[module] = torch.mean(torch.abs(gradient), dim=(0,2,3))
        
        return sensitivity_scores
    
    def calculate_selectivity_score(self, input_tensors: Dict[int, torch.Tensor]) -> Dict[int, Dict[nn.Module, float]]:
        selectivity_scores = {}
        
        for class_idx, input_tensor in input_tensors.items():
            # Perform forward pass
            self.model.zero_grad()
            output = self.model(input_tensor)
            
            # Compute class-specific selectivity scores
            class_scores = {}
            for module, activation in self.activations.items():
                # Mean activation for the specific class
                class_activation = torch.mean(activation, dim=(0,2,3))
                
                # Average activation across all classes
                avg_activation = torch.mean(torch.stack([
                    torch.mean(self.model(other_input), dim=(0,2,3))
                    for other_input in input_tensors.values()
                ]))
                
                # Selectivity score calculation
                selectivity = torch.abs(class_activation - avg_activation) / \
                              torch.max(class_activation, avg_activation)
                
                class_scores[module] = selectivity.mean().item()
            
            selectivity_scores[class_idx] = class_scores
        
        return selectivity_scores
    
    def visualize_sensitivity(self, sensitivity_scores: Dict[nn.Module, torch.Tensor]):
        import matplotlib.pyplot as plt
        
        plt.figure(figsize=(12, 6))
        layer_names = []
        avg_sensitivities = []
        
        for module, scores in sensitivity_scores.items():
            layer_names.append(str(module))
            avg_sensitivities.append(torch.mean(scores).item())
        
        plt.bar(layer_names, avg_sensitivities)
        plt.title('Layer Sensitivity Scores')
        plt.xlabel('Layers')
        plt.ylabel('Average Sensitivity')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()


def main():
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
    

    class_names = ['pedestrian', 'cyclist', 'motorcyclist']
    analyzer = BehaviorMetricsAnalyzer(model, class_names)
    input_tensors = {
        0: torch.randn(1, 3, 224, 224),  # pedestrian
        1: torch.randn(1, 3, 224, 224),  # cyclist
        2: torch.randn(1, 3, 224, 224)   # motorcyclist
    }
    
    sensitivity_scores = analyzer.calculate_sensitivity_score(
        input_tensors[0], target_class=0
    )
    

    analyzer.visualize_sensitivity(sensitivity_scores)
    
    selectivity_scores = analyzer.calculate_selectivity_score(input_tensors)
    print("Selectivity Scores:", selectivity_scores)

if __name__ == '__main__':
    main()