import torch
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

class ModelEvaluator:
    def __init__(self, class_names: List[str]):
        self.class_names = class_names
        self.num_classes = len(class_names)
    
    def calculate_iou(self, predicted_boxes: torch.Tensor, target_boxes: torch.Tensor,
                     threshold: float = 0.5) -> torch.Tensor:
        # Calculate Intersection over Union for bounding boxes
        # predicted_boxes: [N, 4] (x1, y1, x2, y2)
        # target_boxes: [M, 4] (x1, y1, x2, y2)
        
        if predicted_boxes.numel() == 0 or target_boxes.numel() == 0:
            return torch.zeros(predicted_boxes.size(0), target_boxes.size(0))
        
        # Expand dimensions for broadcasting
        pred_expanded = predicted_boxes.unsqueeze(1)  # [N, 1, 4]
        target_expanded = target_boxes.unsqueeze(0)   # [1, M, 4]
        
        # Calculate intersection coordinates
        x1 = torch.max(pred_expanded[:, :, 0], target_expanded[:, :, 0])
        y1 = torch.max(pred_expanded[:, :, 1], target_expanded[:, :, 1])
        x2 = torch.min(pred_expanded[:, :, 2], target_expanded[:, :, 2])
        y2 = torch.min(pred_expanded[:, :, 3], target_expanded[:, :, 3])
        
        # Calculate intersection area
        intersection = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
        
        # Calculate areas
        pred_area = (predicted_boxes[:, 2] - predicted_boxes[:, 0]) * \
                   (predicted_boxes[:, 3] - predicted_boxes[:, 1])
        target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * \
                     (target_boxes[:, 3] - target_boxes[:, 1])
        
        # Calculate union
        pred_area_expanded = pred_area.unsqueeze(1)
        target_area_expanded = target_area.unsqueeze(0)
        union = pred_area_expanded + target_area_expanded - intersection
        
        # Calculate IoU
        iou = intersection / (union + 1e-8)
        
        return iou
    
    def calculate_map(self, predictions: List[Dict], targets: List[Dict],
                     iou_thresholds: List[float] = None) -> Dict[str, float]:
        if iou_thresholds is None:
            iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
        
        # Initialize results
        ap_results = {class_name: [] for class_name in self.class_names}
        
        for iou_threshold in iou_thresholds:
            class_aps = self._calculate_ap_at_iou(predictions, targets, iou_threshold)
            for class_name, ap in class_aps.items():
                ap_results[class_name].append(ap)
        
        # Calculate mean AP across IoU thresholds
        map_scores = {}
        for class_name in self.class_names:
            map_scores[f'{class_name}_mAP'] = np.mean(ap_results[class_name])
        
        map_scores['mAP'] = np.mean(list(map_scores.values()))
        
        return map_scores
    
    def _calculate_ap_at_iou(self, predictions: List[Dict], targets: List[Dict],
                           iou_threshold: float) -> Dict[str, float]:
        ap_scores = {}
        
        for class_idx, class_name in enumerate(self.class_names):
            # Collect all predictions and targets for this class
            class_predictions = []
            class_targets = []
            
            for pred, target in zip(predictions, targets):
                # Filter predictions for current class
                if 'boxes' in pred and 'labels' in pred and 'scores' in pred:
                    class_mask = pred['labels'] == class_idx
                    if torch.any(class_mask):
                        class_pred_boxes = pred['boxes'][class_mask]
                        class_pred_scores = pred['scores'][class_mask]
                        class_predictions.extend(zip(class_pred_boxes, class_pred_scores))
                
                # Filter targets for current class
                if 'boxes' in target and 'labels' in target:
                    target_class_mask = target['labels'] == class_idx
                    if torch.any(target_class_mask):
                        class_target_boxes = target['boxes'][target_class_mask]
                        class_targets.extend(class_target_boxes)
            
            if not class_predictions or not class_targets:
                ap_scores[class_name] = 0.0
                continue
            
            # Sort predictions by confidence score
            class_predictions.sort(key=lambda x: x[1], reverse=True)
            
            # Calculate precision and recall
            tp = np.zeros(len(class_predictions))
            fp = np.zeros(len(class_predictions))
            
            target_boxes_tensor = torch.stack(class_targets)
            
            for i, (pred_box, score) in enumerate(class_predictions):
                # Calculate IoU with all target boxes
                ious = self.calculate_iou(pred_box.unsqueeze(0), target_boxes_tensor)
                max_iou, max_idx = torch.max(ious, dim=1)
                
                if max_iou >= iou_threshold:
                    tp[i] = 1
                else:
                    fp[i] = 1
            
            # Calculate cumulative precision and recall
            tp_cumsum = np.cumsum(tp)
            fp_cumsum = np.cumsum(fp)
            
            precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-8)
            recall = tp_cumsum / len(class_targets)
            
            # Calculate AP using 11-point interpolation
            ap = self._calculate_ap_11_point(precision, recall)
            ap_scores[class_name] = ap
        
        return ap_scores
    
    def _calculate_ap_11_point(self, precision: np.ndarray, recall: np.ndarray) -> float:
        # 11-point interpolation method
        ap = 0.0
        for t in np.arange(0, 1.1, 0.1):
            if np.sum(recall >= t) == 0:
                p = 0
            else:
                p = np.max(precision[recall >= t])
            ap += p / 11.0
        return ap
    
    def calculate_nds(self, predictions: List[Dict], targets: List[Dict]) -> float:
        # Calculate nuScenes Detection Score
        map_scores = self.calculate_map(predictions, targets)
        mAP = map_scores['mAP']
        
        # Calculate translation and scale errors (simplified)
        translation_errors = []
        scale_errors = []
        orientation_errors = []
        velocity_errors = []
        attribute_errors = []
        
        for pred, target in zip(predictions, targets):
            if 'boxes' in pred and 'boxes' in target:
                # Simplified error calculations
                pred_centers = (pred['boxes'][:, :2] + pred['boxes'][:, 2:]) / 2
                target_centers = (target['boxes'][:, :2] + target['boxes'][:, 2:]) / 2
                
                if len(pred_centers) > 0 and len(target_centers) > 0:
                    # Find closest matches
                    distances = torch.cdist(pred_centers, target_centers)
                    min_distances, _ = torch.min(distances, dim=1)
                    translation_errors.extend(min_distances.tolist())
        
        # Calculate mean errors (placeholder values for demonstration)
        mATE = np.mean(translation_errors) if translation_errors else 0.5
        mASE = 0.3  # Placeholder
        mAOE = 0.4  # Placeholder
        mAVE = 0.6  # Placeholder
        mAAE = 0.2  # Placeholder
        
        # NDS formula from nuScenes
        nds = (1/10) * (5 * mAP + (1 - min(1, mATE)) + (1 - min(1, mASE)) + 
                        (1 - min(1, mAOE)) + (1 - min(1, mAVE)) + (1 - min(1, mAAE)))
        
        return nds
    
    def calculate_classification_metrics(self, y_true: List[int], y_pred: List[int],
                                       class_names: List[str] = None) -> Dict[str, float]:
        if class_names is None:
            class_names = self.class_names
        
        # Overall metrics
        accuracy = accuracy_score(y_true, y_pred)
        precision, recall, f1, support = precision_recall_fscore_support(
            y_true, y_pred, average=None, labels=range(len(class_names))
        )
        
        # Class-specific metrics
        metrics = {'accuracy': accuracy}
        
        for i, class_name in enumerate(class_names):
            metrics[f'{class_name}_precision'] = precision[i] if i < len(precision) else 0.0
            metrics[f'{class_name}_recall'] = recall[i] if i < len(recall) else 0.0
            metrics[f'{class_name}_f1'] = f1[i] if i < len(f1) else 0.0
            metrics[f'{class_name}_support'] = support[i] if i < len(support) else 0
        
        # Macro averages
        metrics['macro_precision'] = np.mean(precision)
        metrics['macro_recall'] = np.mean(recall)
        metrics['macro_f1'] = np.mean(f1)
        
        # Weighted averages
        weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted'
        )
        metrics['weighted_precision'] = weighted_precision
        metrics['weighted_recall'] = weighted_recall
        metrics['weighted_f1'] = weighted_f1
        
        return metrics
    
    def calculate_bias_metrics(self, y_true: List[int], y_pred: List[int]) -> Dict[str, float]:
        # Calculate bias-related metrics
        cm = confusion_matrix(y_true, y_pred)
        
        bias_metrics = {}
        
        # Calculate per-class bias metrics
        for i, class_name in enumerate(self.class_names):
            if i < cm.shape[0]:
                # True positive rate (sensitivity)
                tpr = cm[i, i] / (cm[i, :].sum() + 1e-8)
                bias_metrics[f'{class_name}_tpr'] = tpr
                
                # False positive rate
                fpr = cm[:, i].sum() - cm[i, i]
                fpr = fpr / (cm.sum() - cm[i, :].sum() + 1e-8)
                bias_metrics[f'{class_name}_fpr'] = fpr
                
                # Balanced accuracy for this class
                tnr = (cm.sum() - cm[i, :].sum() - cm[:, i].sum() + cm[i, i]) / \
                      (cm.sum() - cm[i, :].sum() + 1e-8)
                balanced_acc = (tpr + tnr) / 2
                bias_metrics[f'{class_name}_balanced_accuracy'] = balanced_acc
        
        # Overall bias metrics
        # Demographic parity difference (simplified)
        class_pred_rates = []
        for i in range(len(self.class_names)):
            pred_rate = np.sum(np.array(y_pred) == i) / len(y_pred)
            class_pred_rates.append(pred_rate)
        
        bias_metrics['demographic_parity_diff'] = max(class_pred_rates) - min(class_pred_rates)
        
        # Equalized odds difference
        tprs = [bias_metrics.get(f'{class_name}_tpr', 0) for class_name in self.class_names]
        bias_metrics['equalized_odds_diff'] = max(tprs) - min(tprs)
        
        return bias_metrics
    
    def generate_performance_report(self, predictions: List[Dict], targets: List[Dict],
                                  y_true: List[int] = None, y_pred: List[int] = None,
                                  save_path: Optional[str] = None) -> Dict[str, float]:
        
        report = {}
        
        # Detection metrics (if applicable)
        if predictions and targets and all('boxes' in p for p in predictions):
            map_scores = self.calculate_map(predictions, targets)
            nds_score = self.calculate_nds(predictions, targets)
            
            report.update(map_scores)
            report['NDS'] = nds_score
        
        # Classification metrics (if applicable)
        if y_true is not None and y_pred is not None:
            classification_metrics = self.calculate_classification_metrics(y_true, y_pred)
            bias_metrics = self.calculate_bias_metrics(y_true, y_pred)
            
            report.update(classification_metrics)
            report.update(bias_metrics)
        
        # Save report if requested
        if save_path:
            self._save_report(report, save_path)
        
        return report
    
    def _save_report(self, report: Dict[str, float], save_path: str):
        import json
        
        # Convert any tensor values to float
        serializable_report = {}
        for key, value in report.items():
            if torch.is_tensor(value):
                serializable_report[key] = value.item()
            elif isinstance(value, np.ndarray):
                serializable_report[key] = value.tolist()
            else:
                serializable_report[key] = value
        
        with open(save_path, 'w') as f:
            json.dump(serializable_report, f, indent=2)
    
    def plot_confusion_matrix(self, y_true: List[int], y_pred: List[int],
                            save_path: Optional[str] = None):
        cm = confusion_matrix(y_true, y_pred)
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=self.class_names,
                   yticklabels=self.class_names)
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_class_performance(self, metrics: Dict[str, float],
                             save_path: Optional[str] = None):
        # Extract per-class metrics
        class_metrics = {
            'Precision': [],
            'Recall': [],
            'F1-Score': []
        }
        
        for class_name in self.class_names:
            class_metrics['Precision'].append(metrics.get(f'{class_name}_precision', 0))
            class_metrics['Recall'].append(metrics.get(f'{class_name}_recall', 0))
            class_metrics['F1-Score'].append(metrics.get(f'{class_name}_f1', 0))
        
        # Create grouped bar plot
        x = np.arange(len(self.class_names))
        width = 0.25
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        bars1 = ax.bar(x - width, class_metrics['Precision'], width, label='Precision')
        bars2 = ax.bar(x, class_metrics['Recall'], width, label='Recall')
        bars3 = ax.bar(x + width, class_metrics['F1-Score'], width, label='F1-Score')
        
        ax.set_xlabel('Classes')
        ax.set_ylabel('Score')
        ax.set_title('Per-Class Performance Metrics')
        ax.set_xticks(x)
        ax.set_xticklabels(self.class_names)
        ax.legend()
        ax.set_ylim(0, 1)
        
        # Add value labels on bars
        for bars in [bars1, bars2, bars3]:
            for bar in bars:
                height = bar.get_height()
                ax.annotate(f'{height:.3f}',
                           xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3),
                           textcoords="offset points",
                           ha='center', va='bottom',
                           fontsize=8)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def compare_model_performance(self, baseline_metrics: Dict[str, float],
                                improved_metrics: Dict[str, float],
                                save_path: Optional[str] = None) -> Dict[str, float]:
        improvements = {}
        
        for metric_name in baseline_metrics:
            if metric_name in improved_metrics:
                baseline_val = baseline_metrics[metric_name]
                improved_val = improved_metrics[metric_name]
                
                if baseline_val != 0:
                    improvement = ((improved_val - baseline_val) / baseline_val) * 100
                else:
                    improvement = improved_val * 100
                
                improvements[f'{metric_name}_improvement'] = improvement
        
        # Create comparison plot
        metrics_to_plot = ['mAP', 'NDS', 'accuracy', 'macro_f1']
        available_metrics = [m for m in metrics_to_plot if m in baseline_metrics and m in improved_metrics]
        
        if available_metrics:
            baseline_values = [baseline_metrics[m] for m in available_metrics]
            improved_values = [improved_metrics[m] for m in available_metrics]
            
            x = np.arange(len(available_metrics))
            width = 0.35
            
            fig, ax = plt.subplots(figsize=(10, 6))
            
            bars1 = ax.bar(x - width/2, baseline_values, width, label='Baseline', alpha=0.7)
            bars2 = ax.bar(x + width/2, improved_values, width, label='After Bias Mitigation', alpha=0.7)
            
            ax.set_xlabel('Metrics')
            ax.set_ylabel('Score')
            ax.set_title('Model Performance Comparison')
            ax.set_xticks(x)
            ax.set_xticklabels(available_metrics)
            ax.legend()
            
            # Add improvement percentages
            for i, metric in enumerate(available_metrics):
                improvement = improvements.get(f'{metric}_improvement', 0)
                ax.annotate(f'+{improvement:.1f}%',
                           xy=(i + width/2, improved_values[i]),
                           xytext=(0, 5),
                           textcoords="offset points",
                           ha='center', va='bottom',
                           fontweight='bold',
                           color='green' if improvement > 0 else 'red')
            
            plt.tight_layout()
            
            if save_path:
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.show()
        
        return improvements