import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T

try:
    from nuscenes.nuscenes import NuScenes
    import tensorflow as tf
except ImportError:
    NuScenes = None
    tf = None


class ObjectData(Dataset):
    def __init__(self, dataset_name, root_dir, transform=None, version="v1.0-mini"):
        self.dataset_name = dataset_name.lower()
        self.root_dir = root_dir
        self.transform = transform or T.Compose([
            T.Resize((224, 224)),
            T.ToTensor()
        ])
        self.version = version
        self.samples = []

        if self.dataset_name == 'kitti':
            self._load_kitti()
        elif self.dataset_name == 'nuscenes':
            assert NuScenes is not None, "nuscenes-devkit not installed"
            self._load_nuscenes()
        elif self.dataset_name == 'waymo':
            assert tf is not None, "TensorFlow not installed"
            self._load_waymo()
        else:
            raise ValueError(f"Unsupported dataset: {self.dataset_name}")

    def _load_kitti(self):
        self.image_dir = os.path.join(self.root_dir, "image_2")
        self.label_dir = os.path.join(self.root_dir, "label_2")
        self.samples = sorted(os.listdir(self.image_dir))

    def _load_nuscenes(self):
        self.nusc = NuScenes(version=self.version, dataroot=self.root_dir, verbose=False)
        self.samples = self.nusc.sample

    def _load_waymo(self):
        self.tfrecord_files = [os.path.join(self.root_dir, f)
                               for f in os.listdir(self.root_dir) if f.endswith('.tfrecord')]

    def __len__(self):
        if self.dataset_name == 'waymo':
            return len(self.tfrecord_files) * 100  # Estimate
        return len(self.samples)

    def __getitem__(self, idx):
        if self.dataset_name == 'kitti':
            return self._get_kitti(idx)
        elif self.dataset_name == 'nuscenes':
            return self._get_nuscenes(idx)
        elif self.dataset_name == 'waymo':
            return self._get_waymo(idx)

    def _get_kitti(self, idx):
        img_path = os.path.join(self.image_dir, self.samples[idx])
        image = Image.open(img_path).convert("RGB")
        label = 0  # Placeholder
        image = self.transform(image)
        return image, torch.tensor(label)

    def _get_nuscenes(self, idx):
        sample = self.samples[idx]
        cam_data = self.nusc.get('sample_data', sample['data']['CAM_FRONT'])
        img_path = os.path.join(self.nusc.dataroot, cam_data['filename'])
        image = Image.open(img_path).convert("RGB")
        label = 0  # Placeholder
        image = self.transform(image)
        return image, torch.tensor(label)

    def _get_waymo(self, idx):
        file_index = idx // 100
        example_index = idx % 100
        dataset = tf.data.TFRecordDataset(self.tfrecord_files[file_index])
        for i, data in enumerate(dataset):
            if i == example_index:
                example = tf.train.Example()
                example.ParseFromString(data.numpy())
                img_data = example.features.feature['image/encoded'].bytes_list.value[0]
                image = Image.open(tf.io.decode_jpeg(img_data).numpy()).convert("RGB")
                label = 0  # Placeholder
                image = self.transform(image)
                return image, torch.tensor(label)
        raise IndexError("Waymo index out of range")
