import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
import json
import os
from typing import Dict, List, Tuple, Optional
from collections import Counter
import albumentations as A
from albumentations.pytorch import ToTensorV2

class NuScenesDataset(Dataset):
    def __init__(self, data_path: str, annotations_path: str, 
                 class_names: List[str], transform=None, augment_underrepresented=False):
        self.data_path = data_path
        self.class_names = class_names
        self.class_to_idx = {name: idx for idx, name in enumerate(class_names)}
        self.transform = transform
        self.augment_underrepresented = augment_underrepresented
        
        self.annotations = self._load_annotations(annotations_path)
        self.class_distribution = self._analyze_class_distribution()
        self.class_weights = self._calculate_class_weights()
        
    def _load_annotations(self, annotations_path: str) -> List[Dict]:
        with open(annotations_path, 'r') as f:
            annotations = json.load(f)
        
        filtered_annotations = []
        for ann in annotations:
            if ann['category_name'] in self.class_names:
                filtered_annotations.append(ann)
        
        return filtered_annotations
    
    def _analyze_class_distribution(self) -> Dict[str, int]:
        distribution = Counter()
        for ann in self.annotations:
            distribution[ann['category_name']] += 1
        return dict(distribution)
    
    def _calculate_class_weights(self) -> Dict[str, float]:
        total_samples = len(self.annotations)
        weights = {}
        
        for class_name in self.class_names:
            class_count = self.class_distribution.get(class_name, 1)
            class_percentage = (class_count / total_samples) * 100
            weights[class_name] = 1.0 / class_percentage
        
        # Normalize weights
        total_weight = sum(weights.values())
        for class_name in weights:
            weights[class_name] /= total_weight
            
        return weights
    
    def get_sampling_weights(self) -> torch.Tensor:
        weights = []
        for ann in self.annotations:
            class_name = ann['category_name']
            weights.append(self.class_weights[class_name])
        return torch.FloatTensor(weights)
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # Load image (placeholder - adapt based on your data structure)
        image_path = os.path.join(self.data_path, ann['image_path'])
        image = self._load_image(image_path)
        
        label = self.class_to_idx[ann['category_name']]
        
        if self.transform:
            if self.augment_underrepresented and ann['category_name'] in ['vehicle.bicycle', 'vehicle.motorcycle']:
                # Apply stronger augmentation for underrepresented classes
                image = self._apply_targeted_augmentation(image)
            else:
                image = self.transform(image)
        
        return image, label, ann
    
    def _load_image(self, image_path: str):
        # Placeholder implementation
        return np.random.rand(224, 224, 3).astype(np.float32)
    
    def _apply_targeted_augmentation(self, image):
        augment = A.Compose([
            A.RandomRotate90(p=0.3),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.OneOf([
                A.MotionBlur(p=0.2),
                A.MedianBlur(blur_limit=3, p=0.1),
                A.Blur(blur_limit=3, p=0.1),
            ], p=0.2),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
            A.OneOf([
                A.OpticalDistortion(p=0.3),
                A.GridDistortion(p=0.1),
                A.PiecewiseAffine(p=0.3),
            ], p=0.2),
            A.OneOf([
                A.CLAHE(clip_limit=2),
                A.Sharpen(),
                A.Emboss(),
                A.RandomBrightnessContrast(),
            ], p=0.3),
            A.HueSaturationValue(p=0.3),
            ToTensorV2(),
        ])
        
        return augment(image=image)['image']

class BiasAwareDataLoader:
    def __init__(self, dataset: NuScenesDataset, batch_size: int, 
                 use_weighted_sampling: bool = True, num_workers: int = 4):
        self.dataset = dataset
        self.batch_size = batch_size
        self.use_weighted_sampling = use_weighted_sampling
        self.num_workers = num_workers
        
    def get_dataloader(self) -> DataLoader:
        if self.use_weighted_sampling:
            weights = self.dataset.get_sampling_weights()
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
            return DataLoader(
                self.dataset,
                batch_size=self.batch_size,
                sampler=sampler,
                num_workers=self.num_workers,
                pin_memory=True
            )
        else:
            return DataLoader(
                self.dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True
            )

def create_balanced_split(annotations: List[Dict], train_ratio: float = 0.8) -> Tuple[List[Dict], List[Dict]]:
    # Split data while maintaining class distribution
    class_data = {}
    for ann in annotations:
        class_name = ann['category_name']
        if class_name not in class_data:
            class_data[class_name] = []
        class_data[class_name].append(ann)
    
    train_data, val_data = [], []
    
    for class_name, class_annotations in class_data.items():
        np.random.shuffle(class_annotations)
        split_idx = int(len(class_annotations) * train_ratio)
        train_data.extend(class_annotations[:split_idx])
        val_data.extend(class_annotations[split_idx:])
    
    return train_data, val_data