import torch
import argparse
from torch.utils.data import DataLoader
from timm import create_model

from datasets.kitti_dataset import KITTIDataset
from datasets.nuscenes_dataset import NuScenesDataset
from datasets.waymo_dataset import WaymoDataset

from occam_training import train_vipt, train_viet

def get_dataset(name, path):
    if name == 'kitti':
        return KITTIDataset(path)
    elif name == 'nuscenes':
        return NuScenesDataset(path)
    elif name == 'waymo':
        return WaymoDataset(path)
    else:
        raise ValueError("Unsupported dataset.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='kitti', choices=['kitti', 'nuscenes', 'waymo'])
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--mode', type=str, default='viet', choices=['viet', 'vipt'])
    args = parser.parse_args()

    dataset = get_dataset(args.dataset, args.data_path)
    images, labels = [], []

    for x, y in DataLoader(dataset, batch_size=1, shuffle=False):
        images.append(x[0])
        labels.append(y[0])
        if len(images) >= 1000:  # limit for testing
            break

    images = torch.stack(images)
    labels = torch.stack(labels)

    model = create_model('vit_small_patch16_224', pretrained=False, num_classes=10).to('cuda')

    if args.mode == 'viet':
        train_viet(model, images, labels)
    else:
        train_vipt(model, images, labels)

if __name__ == '__main__':
    main()
