import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import os
import json
from typing import Tuple, List, Optional, Dict, Any
from PIL import Image
import random

class CustomImageDataset(Dataset):
    def __init__(self, data_dir: str, transform=None, target_transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # Load image paths and labels
        self.samples = []
        self.class_to_idx = {}
        
        # Scan directories for classes
        class_names = sorted([d for d in os.listdir(data_dir) 
                            if os.path.isdir(os.path.join(data_dir, d))])
        
        for idx, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = idx
            class_dir = os.path.join(data_dir, class_name)
            
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, idx))
        
        self.classes = class_names
        self.num_classes = len(class_names)
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img_path, target = self.samples[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            target = self.target_transform(target)
        
        return image, target

class DatasetSampler:
    def __init__(self, dataset: Dataset, num_devices: int):
        self.dataset = dataset
        self.num_devices = num_devices
        self.device_indices = self._create_device_splits()
    
    def _create_device_splits(self) -> List[List[int]]:
        # Create balanced splits for each device
        indices = list(range(len(self.dataset)))
        random.shuffle(indices)
        
        split_size = len(indices) // self.num_devices
        device_indices = []
        
        for i in range(self.num_devices):
            start_idx = i * split_size
            end_idx = start_idx + split_size if i < self.num_devices - 1 else len(indices)
            device_indices.append(indices[start_idx:end_idx])
        
        return device_indices
    
    def get_device_dataset(self, device_id: int) -> Subset:
        if device_id >= self.num_devices:
            raise ValueError(f"Device ID {device_id} exceeds number of devices {self.num_devices}")
        
        return Subset(self.dataset, self.device_indices[device_id])
    
    def get_device_loader(self, device_id: int, batch_size: int, shuffle: bool = True) -> DataLoader:
        device_dataset = self.get_device_dataset(device_id)
        return DataLoader(device_dataset, batch_size=batch_size, shuffle=shuffle, 
                         num_workers=2, pin_memory=True)

class DataAugmentation:
    def __init__(self, img_size: int = 224, augment_strength: str = 'medium'):
        self.img_size = img_size
        self.augment_strength = augment_strength
    
    def get_train_transforms(self) -> transforms.Compose:
        if self.augment_strength == 'light':
            transform_list = [
                transforms.Resize((self.img_size, self.img_size)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        elif self.augment_strength == 'medium':
            transform_list = [
                transforms.Resize((self.img_size + 32, self.img_size + 32)),
                transforms.RandomCrop((self.img_size, self.img_size)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        elif self.augment_strength == 'strong':
            transform_list = [
                transforms.Resize((self.img_size + 64, self.img_size + 64)),
                transforms.RandomCrop((self.img_size, self.img_size)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.2),
                transforms.RandomRotation(30),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
                transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.8, 1.2)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        else:
            transform_list = [
                transforms.Resize((self.img_size, self.img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        
        return transforms.Compose(transform_list)
    
    def get_val_transforms(self) -> transforms.Compose:
        return transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class DatasetFactory:
    def __init__(self, config):
        self.config = config
        self.data_augmentation = DataAugmentation(
            img_size=config.model.img_size,
            augment_strength='medium'
        )
    
    def create_dataset(self, dataset_name: str, data_root: str, 
                      split: str = 'train') -> Tuple[Dataset, int]:
        if dataset_name.lower() == 'imagenet':
            return self._create_imagenet_dataset(data_root, split)
        elif dataset_name.lower() == 'cifar10':
            return self._create_cifar10_dataset(data_root, split)
        elif dataset_name.lower() == 'cifar100':
            return self._create_cifar100_dataset(data_root, split)
        elif dataset_name.lower() == 'custom':
            return self._create_custom_dataset(data_root, split)
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    def _create_imagenet_dataset(self, data_root: str, split: str) -> Tuple[Dataset, int]:
        if split == 'train':
            transform = self.data_augmentation.get_train_transforms()
            dataset = datasets.ImageNet(root=data_root, split='train', transform=transform)
        else:
            transform = self.data_augmentation.get_val_transforms()
            dataset = datasets.ImageNet(root=data_root, split='val', transform=transform)
        
        return dataset, 1000  # ImageNet has 1000 classes
    
    def _create_cifar10_dataset(self, data_root: str, split: str) -> Tuple[Dataset, int]:
        if split == 'train':
            transform = self.data_augmentation.get_train_transforms()
            dataset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
        else:
            transform = self.data_augmentation.get_val_transforms()
            dataset = datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
        
        return dataset, 10  # CIFAR-10 has 10 classes
    
    def _create_cifar100_dataset(self, data_root: str, split: str) -> Tuple[Dataset, int]:
        if split == 'train':
            transform = self.data_augmentation.get_train_transforms()
            dataset = datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform)
        else:
            transform = self.data_augmentation.get_val_transforms()
            dataset = datasets.CIFAR100(root=data_root, train=False, download=True, transform=transform)
        
        return dataset, 100  # CIFAR-100 has 100 classes
    
    def _create_custom_dataset(self, data_root: str, split: str) -> Tuple[Dataset, int]:
        split_dir = os.path.join(data_root, split)
        
        if split == 'train':
            transform = self.data_augmentation.get_train_transforms()
        else:
            transform = self.data_augmentation.get_val_transforms()
        
        dataset = CustomImageDataset(split_dir, transform=transform)
        return dataset, dataset.num_classes
    
    def create_distributed_loaders(self, dataset_name: str, data_root: str, 
                                  num_devices: int, batch_size: int) -> Dict[str, Any]:
        # Create training dataset
        train_dataset, num_classes = self.create_dataset(dataset_name, data_root, 'train')
        val_dataset, _ = self.create_dataset(dataset_name, data_root, 'val')
        
        # Create samplers for distributed training
        train_sampler = DatasetSampler(train_dataset, num_devices)
        val_sampler = DatasetSampler(val_dataset, num_devices)
        
        # Create device-specific loaders
        device_loaders = {}
        for device_id in range(num_devices):
            device_loaders[device_id] = {
                'train': train_sampler.get_device_loader(device_id, batch_size, shuffle=True),
                'val': val_sampler.get_device_loader(device_id, batch_size, shuffle=False)
            }
        
        # Create global loaders (for centralized evaluation)
        global_train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, 
            num_workers=4, pin_memory=True
        )
        global_val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False, 
            num_workers=4, pin_memory=True
        )
        
        return {
            'device_loaders': device_loaders,
            'global_train_loader': global_train_loader,
            'global_val_loader': global_val_loader,
            'num_classes': num_classes,
            'train_size': len(train_dataset),
            'val_size': len(val_dataset)
        }

class PartitionedDataLoader:
    def __init__(self, dataset: Dataset, device_assignments: Dict[int, List[int]], 
                 batch_size: int):
        self.dataset = dataset
        self.device_assignments = device_assignments
        self.batch_size = batch_size
        
        # Create device-specific indices
        self.device_indices = {}
        for device_id, layer_ids in device_assignments.items():
            # For simplicity, distribute data evenly among devices
            # In practice, this could be more sophisticated
            total_samples = len(dataset)
            samples_per_device = total_samples // len(device_assignments)
            start_idx = device_id * samples_per_device
            end_idx = start_idx + samples_per_device if device_id < len(device_assignments) - 1 else total_samples
            self.device_indices[device_id] = list(range(start_idx, end_idx))
    
    def get_device_loader(self, device_id: int) -> DataLoader:
        if device_id not in self.device_indices:
            raise ValueError(f"Device {device_id} not found in assignments")
        
        device_subset = Subset(self.dataset, self.device_indices[device_id])
        return DataLoader(
            device_subset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=2, 
            pin_memory=True
        )

class DatasetAnalyzer:
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
    
    def analyze_dataset(self) -> Dict[str, Any]:
        print("Analyzing dataset...")
        
        # Basic statistics
        dataset_size = len(self.dataset)
        
        # Sample analysis
        sample_batch_size = min(100, dataset_size)
        sample_indices = random.sample(range(dataset_size), sample_batch_size)
        
        # Collect samples for analysis
        images = []
        labels = []
        
        for idx in sample_indices:
            image, label = self.dataset[idx]
            images.append(image)
            labels.append(label)
        
        images = torch.stack(images)
        labels = torch.tensor(labels)
        
        # Calculate statistics
        stats = {
            'dataset_size': dataset_size,
            'num_classes': len(torch.unique(labels)),
            'image_shape': list(images[0].shape),
            'mean_pixel_values': images.mean(dim=[0, 2, 3]).tolist(),
            'std_pixel_values': images.std(dim=[0, 2, 3]).tolist(),
            'class_distribution': self._calculate_class_distribution(labels),
            'data_type': str(images.dtype),
            'memory_usage_mb': self._estimate_memory_usage(images)
        }
        
        return stats
    
    def _calculate_class_distribution(self, labels: torch.Tensor) -> Dict[int, int]:
        unique_labels, counts = torch.unique(labels, return_counts=True)
        distribution = {}
        for label, count in zip(unique_labels.tolist(), counts.tolist()):
            distribution[label] = count
        return distribution
    
    def _estimate_memory_usage(self, images: torch.Tensor) -> float:
        # Estimate memory usage for full dataset
        sample_memory = images.element_size() * images.numel()
        samples_analyzed = images.size(0)
        estimated_total_memory = sample_memory * (len(self.dataset) / samples_analyzed)
        return estimated_total_memory / (1024 * 1024)  # Convert to MB
    
    def plot_class_distribution(self, save_path: Optional[str] = None):
        import matplotlib.pyplot as plt
        
        # Analyze full dataset for class distribution
        all_labels = []
        for i in range(len(self.dataset)):
            _, label = self.dataset[i]
            all_labels.append(label)
        
        unique_labels, counts = np.unique(all_labels, return_counts=True)
        
        plt.figure(figsize=(12, 6))
        plt.bar(unique_labels, counts)
        plt.xlabel('Class Label')
        plt.ylabel('Number of Samples')
        plt.title('Class Distribution in Dataset')
        plt.xticks(unique_labels)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        else:
            plt.show()
        
        plt.close()

def create_balanced_subset(dataset: Dataset, samples_per_class: int) -> Subset:
    """Create a balanced subset with specified samples per class."""
    class_indices = {}
    
    # Group indices by class
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        if label not in class_indices:
            class_indices[label] = []
        class_indices[label].append(idx)
    
    # Sample from each class
    selected_indices = []
    for class_label, indices in class_indices.items():
        if len(indices) >= samples_per_class:
            selected_indices.extend(random.sample(indices, samples_per_class))
        else:
            selected_indices.extend(indices)
    
    return Subset(dataset, selected_indices)

def split_dataset(dataset: Dataset, split_ratios: List[float], 
                 random_seed: int = 42) -> List[Subset]:
    """Split dataset according to given ratios."""
    assert abs(sum(split_ratios) - 1.0) < 1e-6, "Split ratios must sum to 1.0"
    
    # Set random seed for reproducibility
    random.seed(random_seed)
    
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    random.shuffle(indices)
    
    # Calculate split sizes
    splits = []
    start_idx = 0
    
    for i, ratio in enumerate(split_ratios):
        if i == len(split_ratios) - 1:  # Last split gets remaining samples
            end_idx = dataset_size
        else:
            end_idx = start_idx + int(ratio * dataset_size)
        
        split_indices = indices[start_idx:end_idx]
        splits.append(Subset(dataset, split_indices))
        start_idx = end_idx
    
    return splits

def save_dataset_info(dataset_info: Dict[str, Any], filepath: str):
    """Save dataset information to JSON file."""
    with open(filepath, 'w') as f:
        json.dump(dataset_info, f, indent=2, default=str)

def load_dataset_info(filepath: str) -> Dict[str, Any]:
    """Load dataset information from JSON file."""
    with open(filepath, 'r') as f:
        return json.load(f)