from torch.utils.data import Dataset
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T

class CameraLidarFusionDataset(Dataset):
    def __init__(self, cam_img_paths, lidar_paths, labels, img_transform=None, lidar_transform=None):
        self.cam_img_paths = cam_img_paths
        self.lidar_paths = lidar_paths
        self.labels = labels
        self.img_transform = img_transform or T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
        ])
        self.lidar_transform = lidar_transform  # Optional additional lidar transforms

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img = Image.open(self.cam_img_paths[idx]).convert("RGB")
        img = self.img_transform(img)
        lidar = np.load(self.lidar_paths[idx])  # expects [H, W] or [N,3]
        lidar_tensor = torch.tensor(lidar, dtype=torch.float32)
        if lidar_tensor.ndim == 2:
            lidar_tensor = lidar_tensor.unsqueeze(0)  # [1, H, W]
        if self.lidar_transform:
            lidar_tensor = self.lidar_transform(lidar_tensor)
        label = torch.tensor(self.labels[idx]).long()
        return img, lidar_tensor, label
