import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Union
from mpl_toolkits.mplot3d import Axes3D

class Model3DAnalyzer:
    def __init__(self, model: nn.Module, class_names: List[str], model_type: str = 'centerpoint'):
        self.model = model
        self.class_names = class_names
        self.model_type = model_type.lower()
        self.point_gradients = {}
        self.feature_activations = {}
        self.detection_outputs = {}
        
        self._register_hooks()
    
    def _register_hooks(self):
        def activation_hook(name):
            def hook(module, input, output):
                self.feature_activations[name] = output.detach() if torch.is_tensor(output) else output
            return hook
        
        def gradient_hook(name):
            def hook(module, grad_input, grad_output):
                if grad_output[0] is not None:
                    self.point_gradients[name] = grad_output[0].detach()
            return hook
        
        # Register hooks for different 3D model architectures
        for name, module in self.model.named_modules():
            if any(layer_type in name.lower() for layer_type in ['conv', 'linear', 'pointnet', 'backbone']):
                module.register_forward_hook(activation_hook(name))
                module.register_backward_hook(gradient_hook(name))
    
    def calculate_3d_sensitivity_score(self, point_cloud: torch.Tensor, 
                                     target_class: int, 
                                     fusion_data: Optional[Dict] = None) -> Dict[str, torch.Tensor]:
        point_cloud.requires_grad_(True)
        self.model.zero_grad()
        
        # Forward pass with point cloud data
        if fusion_data:
            output = self.model(point_cloud, **fusion_data)
        else:
            output = self.model(point_cloud)
        
        # Handle different output formats
        if isinstance(output, dict):
            detection_scores = output.get('cls_scores', output.get('scores', output))
        elif isinstance(output, (list, tuple)):
            detection_scores = output[0]
        else:
            detection_scores = output
        
        # Create target for backpropagation
        if detection_scores.dim() > 2:
            detection_scores = detection_scores.view(-1, detection_scores.size(-1))
        
        target = torch.zeros_like(detection_scores)
        if target_class < target.size(-1):
            target[:, target_class] = 1
        
        # Compute loss and backpropagate
        loss = F.binary_cross_entropy_with_logits(detection_scores, target)
        loss.backward()
        
        # Calculate 3D sensitivity scores
        sensitivity_scores = {}
        
        # Point-level sensitivity
        if point_cloud.grad is not None:
            point_sensitivity = torch.abs(point_cloud.grad)
            sensitivity_scores['point_features'] = point_sensitivity.mean(dim=-1)  # Average across feature dimensions
        
        # Layer-wise sensitivity from gradients
        for layer_name, gradient in self.point_gradients.items():
            if gradient is not None:
                layer_sensitivity = torch.abs(gradient)
                if layer_sensitivity.dim() > 2:
                    # Average across spatial/channel dimensions appropriately
                    dims_to_average = tuple(range(2, layer_sensitivity.dim()))
                    if dims_to_average:
                        layer_sensitivity = layer_sensitivity.mean(dim=dims_to_average)
                sensitivity_scores[layer_name] = layer_sensitivity
        
        return sensitivity_scores
    
    def calculate_3d_selectivity_score(self, point_clouds: Dict[int, torch.Tensor], 
                                     fusion_data_dict: Optional[Dict[int, Dict]] = None) -> Dict[int, Dict[str, float]]:
        selectivity_scores = {class_idx: {} for class_idx in point_clouds.keys()}
        
        # Store activations for each class
        class_activations = {class_idx: {} for class_idx in point_clouds.keys()}
        
        self.model.eval()
        with torch.no_grad():
            for class_idx, point_cloud in point_clouds.items():
                self.feature_activations.clear()
                
                fusion_data = fusion_data_dict.get(class_idx) if fusion_data_dict else None
                
                if fusion_data:
                    output = self.model(point_cloud, **fusion_data)
                else:
                    output = self.model(point_cloud)
                
                # Store activations for this class
                for layer_name, activation in self.feature_activations.items():
                    if torch.is_tensor(activation):
                        class_activations[class_idx][layer_name] = activation.mean().item()
        
        # Calculate selectivity scores
        for target_class in point_clouds.keys():
            for layer_name in class_activations[target_class].keys():
                target_activation = class_activations[target_class][layer_name]
                
                # Calculate average activation across all other classes
                other_activations = [
                    class_activations[cls][layer_name] 
                    for cls in point_clouds.keys() if cls != target_class
                ]
                avg_other_activation = np.mean(other_activations) if other_activations else 0
                
                # Calculate selectivity score using 3D formula
                numerator = abs(target_activation - avg_other_activation)
                denominator = max(abs(target_activation), abs(avg_other_activation), 1e-8)
                selectivity_score = numerator / denominator
                
                selectivity_scores[target_class][layer_name] = selectivity_score
        
        return selectivity_scores
    
    def visualize_point_cloud_sensitivity(self, point_cloud: torch.Tensor, 
                                        sensitivity_scores: torch.Tensor,
                                        max_points: int = 5000,
                                        save_path: Optional[str] = None):
        # Subsample points for visualization if too many
        num_points = point_cloud.size(1) if point_cloud.dim() > 2 else point_cloud.size(0)
        
        if num_points > max_points:
            indices = torch.randperm(num_points)[:max_points]
            points = point_cloud[0, indices] if point_cloud.dim() > 2 else point_cloud[indices]
            sensitivity = sensitivity_scores[0, indices] if sensitivity_scores.dim() > 1 else sensitivity_scores[indices]
        else:
            points = point_cloud[0] if point_cloud.dim() > 2 else point_cloud
            sensitivity = sensitivity_scores[0] if sensitivity_scores.dim() > 1 else sensitivity_scores
        
        points_np = points.detach().cpu().numpy()
        sensitivity_np = sensitivity.detach().cpu().numpy()
        
        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        # Create color map based on sensitivity
        scatter = ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], 
                           c=sensitivity_np, cmap='viridis', s=1)
        
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('Point Cloud Sensitivity Visualization')
        
        plt.colorbar(scatter, label='Sensitivity Score', shrink=0.8)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def analyze_detection_confidence_bias(self, dataloader, num_batches: int = 50) -> Dict[str, List[float]]:
        confidence_scores = {class_name: [] for class_name in self.class_names}
        
        self.model.eval()
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader):
                if batch_idx >= num_batches:
                    break
                
                # Handle different data formats
                if isinstance(batch_data, (list, tuple)):
                    point_cloud, labels = batch_data[:2]
                    fusion_data = batch_data[2] if len(batch_data) > 2 else None
                else:
                    point_cloud = batch_data['points']
                    labels = batch_data['labels']
                    fusion_data = {k: v for k, v in batch_data.items() if k not in ['points', 'labels']}
                
                # Forward pass
                if fusion_data:
                    output = self.model(point_cloud, **fusion_data)
                else:
                    output = self.model(point_cloud)
                
                # Extract confidence scores
                if isinstance(output, dict):
                    scores = output.get('cls_scores', output.get('scores', output))
                elif isinstance(output, (list, tuple)):
                    scores = output[0]
                else:
                    scores = output
                
                # Apply softmax to get probabilities
                if scores.dim() > 1:
                    probs = F.softmax(scores, dim=-1)
                    
                    # Store confidence scores by class
                    for i, label in enumerate(labels):
                        if label < len(self.class_names):
                            confidence = probs[i, label].item() if i < probs.size(0) else 0
                            confidence_scores[self.class_names[label]].append(confidence)
        
        return confidence_scores
    
    def calculate_spatial_bias_patterns(self, point_cloud: torch.Tensor, 
                                      sensitivity_scores: torch.Tensor) -> Dict[str, float]:
        points = point_cloud[0] if point_cloud.dim() > 2 else point_cloud
        sensitivity = sensitivity_scores[0] if sensitivity_scores.dim() > 1 else sensitivity_scores
        
        points_np = points.detach().cpu().numpy()
        sensitivity_np = sensitivity.detach().cpu().numpy()
        
        bias_patterns = {}
        
        # Analyze bias in different spatial regions
        # X-axis bias (left vs right)
        x_coords = points_np[:, 0]
        x_median = np.median(x_coords)
        left_indices = x_coords < x_median
        right_indices = x_coords >= x_median
        
        left_sensitivity = np.mean(sensitivity_np[left_indices]) if np.any(left_indices) else 0
        right_sensitivity = np.mean(sensitivity_np[right_indices]) if np.any(right_indices) else 0
        bias_patterns['x_axis_bias'] = abs(left_sensitivity - right_sensitivity)
        
        # Y-axis bias (front vs back)
        y_coords = points_np[:, 1]
        y_median = np.median(y_coords)
        front_indices = y_coords < y_median
        back_indices = y_coords >= y_median
        
        front_sensitivity = np.mean(sensitivity_np[front_indices]) if np.any(front_indices) else 0
        back_sensitivity = np.mean(sensitivity_np[back_indices]) if np.any(back_indices) else 0
        bias_patterns['y_axis_bias'] = abs(front_sensitivity - back_sensitivity)
        
        # Z-axis bias (bottom vs top)
        z_coords = points_np[:, 2]
        z_median = np.median(z_coords)
        bottom_indices = z_coords < z_median
        top_indices = z_coords >= z_median
        
        bottom_sensitivity = np.mean(sensitivity_np[bottom_indices]) if np.any(bottom_indices) else 0
        top_sensitivity = np.mean(sensitivity_np[top_indices]) if np.any(top_indices) else 0
        bias_patterns['z_axis_bias'] = abs(bottom_sensitivity - top_sensitivity)
        
        # Distance bias (near vs far from origin)
        distances = np.linalg.norm(points_np, axis=1)
        distance_median = np.median(distances)
        near_indices = distances < distance_median
        far_indices = distances >= distance_median
        
        near_sensitivity = np.mean(sensitivity_np[near_indices]) if np.any(near_indices) else 0
        far_sensitivity = np.mean(sensitivity_np[far_indices]) if np.any(far_indices) else 0
        bias_patterns['distance_bias'] = abs(near_sensitivity - far_sensitivity)
        
        return bias_patterns
    
    def generate_3d_grad_cam(self, point_cloud: torch.Tensor, target_class: int, 
                           layer_name: Optional[str] = None) -> torch.Tensor:
        point_cloud.requires_grad_(True)
        self.model.zero_grad()
        
        output = self.model(point_cloud)
        
        # Get target layer activation
        if layer_name and layer_name in self.feature_activations:
            target_activation = self.feature_activations[layer_name]
        else:
            # Use the last convolutional layer activation
            conv_layers = [name for name in self.feature_activations.keys() if 'conv' in name.lower()]
            if conv_layers:
                target_activation = self.feature_activations[conv_layers[-1]]
            else:
                return torch.zeros_like(point_cloud)
        
        # Compute gradients
        if isinstance(output, dict):
            scores = output.get('cls_scores', output.get('scores', output))
        else:
            scores = output
        
        target_score = scores[0, target_class] if scores.dim() > 1 else scores[target_class]
        target_score.backward()
        
        # Get gradients for the target layer
        if layer_name and layer_name in self.point_gradients:
            gradients = self.point_gradients[layer_name]
        else:
            gradients = point_cloud.grad
        
        # Compute GradCAM-like visualization for 3D
        if gradients is not None and target_activation is not None:
            # Global average pooling of gradients
            if gradients.dim() > 2:
                weights = gradients.mean(dim=tuple(range(2, gradients.dim())))
            else:
                weights = gradients.mean(dim=0)
            
            # Weighted combination of activation maps
            if target_activation.dim() > 2:
                grad_cam = (weights.unsqueeze(-1) * target_activation).sum(dim=1)
            else:
                grad_cam = (weights * target_activation).sum(dim=0)
            
            # Apply ReLU
            grad_cam = F.relu(grad_cam)
            
            return grad_cam
        
        return torch.zeros_like(point_cloud)