import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import torchvision.transforms as T
import os
import json
from typing import List, Tuple, Dict, Optional, Union
import cv2

class CameraLidarFusionDataset(Dataset):
    def __init__(self, 
                 data_root: str,
                 camera_paths: List[str],
                 lidar_paths: List[str], 
                 labels: List[int],
                 camera_transform: Optional[callable] = None,
                 lidar_transform: Optional[callable] = None,
                 fusion_mode: str = 'early',
                 lidar_format: str = 'bev',  # 'bev', 'range', 'pointcloud'
                 target_size: Tuple[int, int] = (224, 224),
                 augment: bool = True):
        
        assert len(camera_paths) == len(lidar_paths) == len(labels), \
            "Camera paths, LiDAR paths, and labels must have same length"
        
        self.data_root = data_root
        self.camera_paths = camera_paths
        self.lidar_paths = lidar_paths
        self.labels = labels
        self.fusion_mode = fusion_mode
        self.lidar_format = lidar_format
        self.target_size = target_size
        self.augment = augment
        
        # Default camera transforms
        if camera_transform is None:
            self.camera_transform = self._get_default_camera_transform()
        else:
            self.camera_transform = camera_transform
            
        # Default LiDAR transforms
        if lidar_transform is None:
            self.lidar_transform = self._get_default_lidar_transform()
        else:
            self.lidar_transform = lidar_transform
    
    def _get_default_camera_transform(self):
        transforms = [
            T.Resize(self.target_size),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
        
        if self.augment:
            transforms.insert(-2, T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1))
            transforms.insert(-2, T.RandomHorizontalFlip(0.5))
            transforms.insert(-2, T.RandomRotation(5))
        
        return T.Compose(transforms)
    
    def _get_default_lidar_transform(self):
        def lidar_normalize(tensor):
            # Normalize LiDAR data to [0, 1] range
            if tensor.max() > tensor.min():
                tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
            return tensor
        
        transforms = []
        if self.augment:
            transforms.extend([
                lambda x: self._augment_lidar(x),
            ])
        
        transforms.append(lidar_normalize)
        
        return T.Compose(transforms)
    
    def _augment_lidar(self, lidar_data):
        if np.random.random() < 0.3:
            noise = np.random.normal(0, 0.01, lidar_data.shape)
            lidar_data = lidar_data + noise
        
        # Random dropout for point clouds
        if self.lidar_format == 'pointcloud' and np.random.random() < 0.2:
            mask = np.random.random(lidar_data.shape[0]) > 0.1
            lidar_data = lidar_data[mask]
        
        return lidar_data
    
    def _load_camera_image(self, path: str) -> torch.Tensor:
        full_path = os.path.join(self.data_root, path)
        
        try:
            image = Image.open(full_path).convert('RGB')
            return self.camera_transform(image)
        except Exception as e:
            print(f"Error loading camera image {full_path}: {e}")
            # Return dummy image
            dummy_image = Image.new('RGB', self.target_size, color=(128, 128, 128))
            return self.camera_transform(dummy_image)
    
    def _load_lidar_data(self, path: str) -> torch.Tensor:
        full_path = os.path.join(self.data_root, path)
        
        try:
            if path.endswith('.npy'):
                data = np.load(full_path)
            elif path.endswith('.bin'):
                # KITTI format: [x, y, z, intensity] as float32
                data = np.fromfile(full_path, dtype=np.float32).reshape(-1, 4)
            elif path.endswith('.pcd'):
                # Point Cloud Library format
                data = self._load_pcd(full_path)
            else:
                raise ValueError(f"Unsupported LiDAR format: {path}")
            
            # Convert based on LiDAR format
            if self.lidar_format == 'bev':
                lidar_tensor = self._convert_to_bev(data)
            elif self.lidar_format == 'range':
                lidar_tensor = self._convert_to_range_image(data)
            elif self.lidar_format == 'pointcloud':
                lidar_tensor = torch.tensor(data, dtype=torch.float32)
            else:
                raise ValueError(f"Unknown LiDAR format: {self.lidar_format}")
            
            return self.lidar_transform(lidar_tensor)
            
        except Exception as e:
            print(f"Error loading LiDAR data {full_path}: {e}")
            # Return dummy LiDAR data
            if self.lidar_format in ['bev', 'range']:
                return torch.zeros(1, *self.target_size)
            else:
                return torch.zeros(1000, 4)  # Dummy point cloud
    
    def _convert_to_bev(self, points: np.ndarray) -> torch.Tensor:
        if points.shape[1] < 3:
            raise ValueError("Points must have at least x, y, z coordinates")
        
        # Parameters for BEV conversion
        x_range = (-50, 50)  # meters
        y_range = (-50, 50)  # meters
        z_range = (-3, 5)    # meters
        resolution = 0.1     # meters per pixel
        
        # Filter points within range
        mask = ((points[:, 0] >= x_range[0]) & (points[:, 0] <= x_range[1]) &
                (points[:, 1] >= y_range[0]) & (points[:, 1] <= y_range[1]) &
                (points[:, 2] >= z_range[0]) & (points[:, 2] <= z_range[1]))
        points_filtered = points[mask]
        
        if len(points_filtered) == 0:
            return torch.zeros(1, *self.target_size)
        
        # Convert to pixel coordinates
        x_pixels = ((points_filtered[:, 0] - x_range[0]) / resolution).astype(int)
        y_pixels = ((points_filtered[:, 1] - y_range[0]) / resolution).astype(int)
        
        # Create BEV image
        bev_size = int((x_range[1] - x_range[0]) / resolution)
        bev_image = np.zeros((bev_size, bev_size))
        
        # Fill BEV with height information or intensity
        if points_filtered.shape[1] >= 4:
            # Use intensity if available
            np.add.at(bev_image, (x_pixels, y_pixels), points_filtered[:, 3])
        else:
            # Use height information
            np.add.at(bev_image, (x_pixels, y_pixels), points_filtered[:, 2])
        
        # Resize to target size
        bev_resized = cv2.resize(bev_image, self.target_size, interpolation=cv2.INTER_LINEAR)
        
        return torch.tensor(bev_resized, dtype=torch.float32).unsqueeze(0)
    
    def _convert_to_range_image(self, points: np.ndarray) -> torch.Tensor:
        """Convert point cloud to range image representation"""
        if points.shape[1] < 3:
            raise ValueError("Points must have at least x, y, z coordinates")
        
        # Spherical projection parameters
        h_fov = (-180, 180)  # degrees
        v_fov = (-25, 15)    # degrees
        h_res = 0.35         # degrees per pixel
        v_res = 0.4          # degrees per pixel
        
        # Calculate spherical coordinates
        x, y, z = points[:, 0], points[:, 1], points[:, 2]
        r = np.sqrt(x**2 + y**2 + z**2)
        
        # Horizontal angle (azimuth)
        h_angle = np.arctan2(y, x) * 180 / np.pi
        
        # Vertical angle (elevation)
        v_angle = np.arcsin(z / r) * 180 / np.pi
        
        # Convert to pixel coordinates
        h_pixels = ((h_angle - h_fov[0]) / h_res).astype(int)
        v_pixels = ((v_angle - v_fov[0]) / v_res).astype(int)
        
        # Create range image
        h_size = int((h_fov[1] - h_fov[0]) / h_res)
        v_size = int((v_fov[1] - v_fov[0]) / v_res)
        
        range_image = np.zeros((v_size, h_size))
        
        # Filter valid pixels
        valid_mask = ((h_pixels >= 0) & (h_pixels < h_size) &
                     (v_pixels >= 0) & (v_pixels < v_size))
        
        if np.any(valid_mask):
            range_image[v_pixels[valid_mask], h_pixels[valid_mask]] = r[valid_mask]
        
        # Resize to target size
        range_resized = cv2.resize(range_image, self.target_size, interpolation=cv2.INTER_LINEAR)
        
        return torch.tensor(range_resized, dtype=torch.float32).unsqueeze(0)
    
    def _load_pcd(self, path: str) -> np.ndarray:
        """Load PCD file (basic implementation)"""
        # This is a simplified PCD loader - for production use open3d or similar
        with open(path, 'r') as f:
            lines = f.readlines()
        
        # Find data start
        data_start = 0
        for i, line in enumerate(lines):
            if line.startswith('DATA'):
                data_start = i + 1
                break
        
        # Parse points
        points = []
        for line in lines[data_start:]:
            if line.strip():
                coords = [float(x) for x in line.strip().split()]
                points.append(coords)
        
        return np.array(points)
    
    def __len__(self) -> int:
        return len(self.labels)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        camera_data = self._load_camera_image(self.camera_paths[idx])
        lidar_data = self._load_lidar_data(self.lidar_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return camera_data, lidar_data, label


def create_camera_lidar_dataloaders(
    data_config: Dict,
    batch_size: int = 32,
    num_workers: int = 4,
    train_split: float = 0.8,
    val_split: float = 0.1,
    test_split: float = 0.1,
    seed: int = 42
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    camera_paths = data_config['camera_paths']
    lidar_paths = data_config['lidar_paths']
    labels = data_config['labels']
    
    # Create train/val/test splits
    torch.manual_seed(seed)
    total_samples = len(labels)
    indices = torch.randperm(total_samples)
    
    train_size = int(train_split * total_samples)
    val_size = int(val_split * total_samples)
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]
    
    # Create datasets
    common_config = {
        'data_root': data_config['data_root'],
        'fusion_mode': data_config.get('fusion_mode', 'early'),
        'lidar_format': data_config.get('lidar_format', 'bev'),
        'target_size': data_config.get('target_size', (224, 224))
    }
    
    train_dataset = CameraLidarFusionDataset(
        camera_paths=[camera_paths[i] for i in train_indices],
        lidar_paths=[lidar_paths[i] for i in train_indices],
        labels=[labels[i] for i in train_indices],
        augment=True,
        **common_config
    )
    
    val_dataset = CameraLidarFusionDataset(
        camera_paths=[camera_paths[i] for i in val_indices],
        lidar_paths=[lidar_paths[i] for i in val_indices],
        labels=[labels[i] for i in val_indices],
        augment=False,
        **common_config
    )
    
    test_dataset = CameraLidarFusionDataset(
        camera_paths=[camera_paths[i] for i in test_indices],
        lidar_paths=[lidar_paths[i] for i in test_indices],
        labels=[labels[i] for i in test_indices],
        augment=False,
        **common_config
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader


# Example usage and data configuration generator
def generate_sample_data_config(data_root: str, num_samples: int = 1000) -> Dict:
    camera_paths = [f"camera/image_{i:06d}.jpg" for i in range(num_samples)]
    lidar_paths = [f"lidar/scan_{i:06d}.npy" for i in range(num_samples)]
    labels = [i % 10 for i in range(num_samples)]  # 10 classes
    
    return {
        'data_root': data_root,
        'camera_paths': camera_paths,
        'lidar_paths': lidar_paths,
        'labels': labels,
        'fusion_mode': 'early',
        'lidar_format': 'bev',
        'target_size': (224, 224),
        'num_classes': 10
    }