# This Repo contains code related to efficient ViT training using VI
# Mixed-Precision Variational Inference Transformer for Camera-LiDAR Fusion

A PyTorch implementation of mixed-precision transformers with variational inference for camera and LiDAR sensor fusion. This project implements the core algorithm for training transformers with both Gaussian and Laplace posterior distributions, dynamic bit assignment, and approximate arithmetic.

##  Features

- **Dual Distribution Support**: Switch between Gaussian and Laplace distributions during training for robust weight sampling
- **Mixed-Precision Training**: Dynamic bit assignment (4, 8, 16, 32-bit) based on layer sensitivity and gradient magnitudes
- **Sensor Fusion**: Unified processing of camera (RGB) and LiDAR data with multiple fusion strategies
- **Approximate Arithmetic**: Custom approximate multiplication for efficient low-precision operations
- **Variational Inference**: KL regularization with Occam's Razor principle for model compression
- **Dynamic Quantization**: Adaptive bit-width assignment during training and inference

##  Table of Contents

- [Installation](#installation)
- [Quick Start](#quick-start)
- [Algorithm Overview](#algorithm-overview)
- [Data Preparation](#data-preparation)
- [Training](#training)
- [Testing](#testing)
- [Configuration](#configuration)
- [Examples](#examples)
- [Results](#results)
- [Contributing](#contributing)
- [License](#license)

##  Installation

### Prerequisites

- Python 3.8+
- CUDA 11.0+ (for GPU support)
- 16GB+ RAM recommended
- 8GB+ GPU memory recommended

### Install Dependencies

```bash
# Clone the repository
git clone https://github.com/dewantkatare/mixed-precision-vi-transformer.git
cd mixed-precision-vi-transformer

# Install requirements
pip install -r requirements.txt

# Optional: Install additional point cloud processing tools
pip install open3d>=0.17.0
```

### Development Installation

```bash
# For development with additional tools
pip install -r requirements.txt
pip install -e .
```

##  Quick Start

### Basic Training Example

```python
from training_pipeline import main_training_pipeline
from unified_dataloader import generate_sample_data_config
from enhanced_vi_transformer import EnhancedVariationalVisionTransformer

# Quick training with sample data
trainer, history = main_training_pipeline()
```

### Custom Data Training

```python
import torch
from training_pipeline import MixedPrecisionVariationalTrainer
from unified_dataloader import create_camera_lidar_dataloaders
from enhanced_vi_transformer import EnhancedVariationalVisionTransformer

# Prepare your data configuration
data_config = {
    'data_root': '/path/to/your/data',
    'camera_paths': ['camera/img_001.jpg', 'camera/img_002.jpg', ...],
    'lidar_paths': ['lidar/scan_001.npy', 'lidar/scan_002.npy', ...],
    'labels': [0, 1, 2, ...],  # Class labels
    'fusion_mode': 'early',
    'lidar_format': 'bev',  # 'bev', 'range', or 'pointcloud'
    'target_size': (224, 224),
    'num_classes': 10
}

# Create dataloaders
train_loader, val_loader, test_loader = create_camera_lidar_dataloaders(
    data_config, batch_size=16, num_workers=4
)

# Initialize model
model = EnhancedVariationalVisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=4,  # RGB + LiDAR
    num_classes=10,
    vi_distribution='gaussian'  # or 'laplace'
)

# Train
trainer = MixedPrecisionVariationalTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device='cuda'
)

history = trainer.train(epochs=100, learning_rate=1e-4)
```

##  Algorithm Overview

Our implementation follows this core algorithm:

```
Algorithm: Mixed-Precision Variational Inference Training

Input: Dataset D, Model M, Epochs E, Learning rate η, 
       Bit-width range B, Variational parameters V, Balancing factor λ

1. Initialize M with variational parameters V
2. For epoch e = 1 to E:
   a. For each batch in D:
      - Split batch into (Camera, LiDAR) data
      - Process through M with current precision settings
      - Apply quantization using dynamic bit assignment
      - Compute training loss L_train
      - Compute variational loss L_var (Gaussian or Laplace)
      - Total loss = L_train + λ * L_var + λ_sparse * L_sparse
      - Apply straight-through estimator for backpropagation
      - Update weights and variational parameters
      - Update bit assignments based on gradients
   b. If e > validation_start: Validate on D_val
3. Return optimized model M*
```

### Key Components

1. **Variational Weight Sampling**:
   - **Gaussian**: `w = μ + σ * ε` where `ε ~ N(0,1)`
   - **Laplace**: `w = μ - b * sign(u-0.5) * log(1-2|u-0.5|)` where `u ~ U(0,1)`

2. **Dynamic Bit Assignment**:
   - Monitor gradient magnitudes
   - Assign higher precision to sensitive layers
   - Adaptive threshold-based assignment

3. **Approximate Arithmetic**:
   - Custom 4x4, 8-bit, and 16-bit approximate multipliers
   - Reduced computational complexity
   - Configurable precision per layer

##  Data Preparation

### Supported Data Formats

#### Camera Data
- **Formats**: JPG, PNG, BMP
- **Preprocessing**: Resize, normalize, augmentation
- **Resolution**: Configurable (default: 224x224)

#### LiDAR Data
- **Point Clouds**: `.bin` (KITTI), `.pcd`, `.npy`
- **Representations**:
  - **BEV (Bird's Eye View)**: Top-down 2D projection
  - **Range Images**: Spherical projection
  - **Raw Point Clouds**: 3D coordinates + intensity

### Data Directory Structure

```
data/
├── camera/
│   ├── train/
│   │   ├── class_0/
│   │   │   ├── img_001.jpg
│   │   │   └── ...
│   │   └── class_1/
│   └── val/
├── lidar/
│   ├── train/
│   │   ├── class_0/
│   │   │   ├── scan_001.npy
│   │   │   └── ...
│   │   └── class_1/
│   └── val/
└── labels.json
```

### Data Conversion Scripts

```python
# Convert KITTI point clouds to BEV
from unified_dataloader import CameraLidarFusionDataset

dataset = CameraLidarFusionDataset(
    data_root='data',
    camera_paths=camera_paths,
    lidar_paths=lidar_paths,
    labels=labels,
    lidar_format='bev',  # Converts point clouds to BEV
    target_size=(224, 224)
)
```

##  Training

### Configuration Options

```python
config = {
    # Model parameters
    'img_size': 224,
    'patch_size': 16,
    'num_classes': 10,
    'vi_distribution': 'gaussian',  # or 'laplace'
    
    # Training parameters
    'epochs': 100,
    'batch_size': 16,
    'learning_rate': 1e-4,
    'lambda_var': 0.01,        # VI regularization weight
    'lambda_sparse': 0.001,    # Sparsity regularization weight
    'switch_every': 10,        # Switch distribution every N epochs
    
    # Mixed precision
    'precision_config': {
        'patch_embed': 16,
        'blocks.0': 32,     # First layer high precision
        'blocks.1': 16,
        'blocks.2': 8,
        # ... more layers
        'head': 32          # Classification head high precision
    }
}
```

### Training Commands

```bash
# Basic training
python training_pipeline.py

# Training with custom config
python training_pipeline.py --config configs/custom_config.yaml

# Resume from checkpoint
python training_pipeline.py --resume checkpoints/checkpoint_epoch_50.pth

# Multi-GPU training
python -m torch.distributed.launch --nproc_per_node=4 training_pipeline.py
```

### Monitoring Training

- **Tensorboard**: `tensorboard --logdir logs/`
- **Weights & Biases**: Automatic logging of metrics
- **Console Output**: Real-time training progress

##  Testing

### Model Evaluation

```python
# Load trained model
checkpoint = torch.load('checkpoints/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Test on test set
test_metrics = trainer.test()
print(f"Test Accuracy: {test_metrics['test_accuracy']:.2f}%")
```

### Inference Example

```python
# Single sample inference
model.eval()
with torch.no_grad():
    camera_input = torch.randn(1, 3, 224, 224)
    lidar_input = torch.randn(1, 1, 224, 224)
    
    output = model(camera_input, lidar_input)
    predicted_class = output.argmax(dim=1)
    confidence = torch.softmax(output, dim=1).max()
```

### Quantization Analysis

```python
# Analyze bit assignment efficiency
from Vit-test import analyze_model_sensitivity

sensitivity_scores = analyze_model_sensitivity(model, test_dataset)
optimal_config = auto_mixed_precision_config(model, test_dataset)
```

## ️ Configuration

### Model Configuration

```yaml
# config.yaml
model:
  img_size: 224
  patch_size: 16
  in_channels: 4
  num_classes: 10
  vi_distribution: "gaussian"
  laplace_b: 1.0

training:
  epochs: 100
  batch_size: 16
  learning_rate: 1e-4
  lambda_var: 0.01
  lambda_sparse: 0.001
  switch_every: 10

data:
  data_root: "data"
  fusion_mode: "early"
  lidar_format: "bev"
  target_size: [224, 224]
  augment: true

precision:
  default: 8
  rules:
    "patch_embed": 16
    "head": 32
```

### Hardware-Specific Optimization

```python
# For different hardware configurations
if torch.cuda.get_device_capability()[0] >= 8:  # A100, H100
    config['precision_config']['default'] = 4  # More aggressive quantization
elif torch.cuda.get_device_capability()[0] >= 7:  # V100, T4
    config['precision_config']['default'] = 8
else:  # Older GPUs
    config['precision_config']['default'] = 16
```


### Visualization

```python
# Plot training curves
trainer.plot_training_curves()

# Visualize bit assignments
plot_precision_distribution(model.current_precisions)

# Show sensor fusion results
visualize_fusion_results(model, test_loader)
```

##  Use Cases

### Autonomous Driving
- Object detection and classification
- Scene understanding with camera + LiDAR
- Real-time inference with mixed precision

### Robotics
- SLAM with visual-LiDAR fusion
- Object manipulation planning
- Navigation in complex environments

### Surveillance
- Multi-modal scene analysis
- Intrusion detection
- Activity recognition

## Advanced Features

### Custom Approximate Multipliers

```python
from AxC-Multipliers import MixedPrecisionMultiplier

# Create custom precision rules
precision_rules = {
    "attention": 8,
    "mlp": 4,
    "embedding": 16
}

multiplier = MixedPrecisionMultiplier(
    default_precision=8,
    precision_rules=precision_rules
)
```

### Dynamic Distribution Switching

```python
# Implement custom switching strategy
def custom_distribution_schedule(epoch, total_epochs):
    if epoch < total_epochs * 0.3:
        return 'gaussian'
    elif epoch < total_epochs * 0.7:
        return 'laplace'
    else:
        return 'gaussian'  # Return to Gaussian for final convergence
```

##  Troubleshooting

### Common Issues

1. **CUDA Out of Memory**
   ```python
   # Reduce batch size or enable gradient checkpointing
   config['batch_size'] = 8
   torch.cuda.empty_cache()
   ```

2. **Slow Training**
   ```python
   # Enable mixed precision training
   from torch.cuda.amp import autocast, GradScaler
   scaler = GradScaler()
   ```

3. **NaN Loss Values**
   ```python
   # Reduce learning rate or VI regularization
   config['learning_rate'] = 1e-5
   config['lambda_var'] = 0.001
   ```

### Performance Optimization

- Use `DataLoader` with `pin_memory=True` and `num_workers > 0`
- Enable `torch.backends.cudnn.benchmark = True`
- Use gradient accumulation for large effective batch sizes
- Consider using `torch.compile()` for PyTorch 2.0+


