import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torchvision
from torchvision.models.vision_transformer import VisionTransformer
import numpy as np
from torch.quantization import quantize_dynamic

class VariationalVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000):
        super().__init__()
        
        self.model = VisionTransformer(
            image_size=img_size,
            patch_size=patch_size,
            num_classes=num_classes,
            num_layers=12,
            num_heads=12,
            hidden_dim=768,
            mlp_dim=3072
        )
        
        self.weight_means = {}
        self.weight_log_vars = {}
        

        for name, param in self.model.named_parameters():
            if 'weight' in name:
                self.weight_means[name] = nn.Parameter(param.data.clone())
                self.weight_log_vars[name] = nn.Parameter(torch.ones_like(param) * -10)  
                
               
                self.register_parameter(f"{name}_mean", self.weight_means[name])
                self.register_parameter(f"{name}_log_var", self.weight_log_vars[name])
    
    def forward(self, x):
        if self.training:
            self._sample_weights()
        
        return self.model(x)
    
    def _sample_weights(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and name in self.weight_means:
                mean = self.weight_means[name]
                log_var = self.weight_log_vars[name]
                
                # Sample using reparameterization trick: w = μ + σ * ε, where ε ~ N(0, 1)
                epsilon = torch.randn_like(mean)
                std = torch.exp(0.5 * log_var)
                sampled_weight = mean + std * epsilon
                
                # Replace original weight with sampled weight
                param.data.copy_(sampled_weight)
    
    def kl_divergence(self):
        kl_sum = 0.0
        
        for name in self.weight_means:
            mean = self.weight_means[name]
            log_var = self.weight_log_vars[name]
            
            # KL divergence between N(μ, σ²) and N(0, 1)
            kl_div = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
            kl_sum += kl_div
        
        return kl_sum

def occams_razor_loss(model, outputs, targets, beta=0.01):
    nll = F.cross_entropy(outputs, targets)
    
    # KL divergence
    kl = model.kl_divergence()
    
    # Total loss
    total_loss = nll + beta * kl
    
    return total_loss, nll, kl

def train_vipt(model, train_dataset, val_dataset=None, learning_rate=0.001, num_epochs=100, 
               batch_size=32, beta=0.01, weight_decay=1e-5, device='cuda'):
    
    
    model = model.to(device)
    
    #  optimizer 
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Learning rate 
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    #  data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    if val_dataset:
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Training 
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0
        total_nll = 0.0
        total_kl = 0.0
        

        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss, nll, kl = occams_razor_loss(model, outputs, targets, beta=beta)
            loss.backward()
            
            # Update parameters
            optimizer.step()
            
            # Track losses
            total_loss += loss.item()
            total_nll += nll.item()
            total_kl += kl.item()
            
            # Print 
            if (batch_idx + 1) % 20 == 0 or (batch_idx + 1) == len(train_loader):
                print(f'Epoch: {epoch}/{num_epochs} | Batch: {batch_idx+1}/{len(train_loader)} | '
                      f'Loss: {loss.item():.4f} | NLL: {nll.item():.4f} | KL: {kl.item():.4f}')
        
        # average loss
        avg_loss = total_loss / len(train_loader)
        avg_nll = total_nll / len(train_loader)
        avg_kl = total_kl / len(train_loader)
        
        print(f'Epoch: {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | '
              f'Avg NLL: {avg_nll:.4f} | Avg KL: {avg_kl:.4f}')
        
        # Validation
        if val_dataset:
            val_loss = validate(model, val_loader, beta, device)
            scheduler.step(val_loss)
    
    # Post-training quantization
    print("Performing post-training quantization to 8 bits...")
    quantized_model = quantize_model(model)
    
    return quantized_model

def validate(model, val_loader, beta, device):
    """Validate the model on the validation set"""
    model.eval()
    total_loss = 0.0
    correct = 0
    
    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(data)
            
            # Compute loss
            loss, _, _ = occams_razor_loss(model, outputs, targets, beta=beta)
            
            # Track loss
            total_loss += loss.item()
            
            # Track accuracy
            pred = outputs.argmax(dim=1)
            correct += pred.eq(targets).sum().item()
    
    # Compute average loss and accuracy
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)
    
    print(f'Validation Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%')
    
    return avg_loss

def quantize_model(model):
    """Quantize the model to 8 bits using dynamic quantization"""
    # Switch model to eval mode
    model.eval()
    
    # Create a copy of the model for quantization
    quantized_model = torch.quantization.quantize_dynamic(
        model,  # model to quantize
        {nn.Linear},  # specify which layers to quantize
        dtype=torch.qint8  # specify target dtype for quantized weights
    )
    
    return quantized_model

def save_model(model, path="vipt_model.pth"):
    """Save the trained model"""
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Hyperparameters
    img_size = 224
    patch_size = 16
    num_classes = 1000
    learning_rate = 0.001
    num_epochs = 100
    batch_size = 32
    beta = 0.01  # KL divergence weight
    
    model = VariationalVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        in_channels=3,
        num_classes=num_classes
    )
    
    try:
        # Example using a subset of ImageNet
        train_dataset = torchvision.datasets.ImageFolder(
            root='path_to_train_data',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize((img_size, img_size)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        )
        
        val_dataset = torchvision.datasets.ImageFolder(
            root='path_to_val_data',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize((img_size, img_size)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        )
    except:
        print("error loading dataset for device test.")
        train_dataset = torch.utils.data.TensorDataset(
            torch.randn(100, 3, img_size, img_size),
            torch.randint(0, num_classes, (100,))
        )
        val_dataset = torch.utils.data.TensorDataset(
            torch.randn(20, 3, img_size, img_size),
            torch.randint(0, num_classes, (20,))
        )
   
    trained_model = train_vipt(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        learning_rate=learning_rate,
        num_epochs=num_epochs,
        batch_size=batch_size,
        beta=beta,
        device=device
    )
    
    # Save the trained and quantized model
    save_model(trained_model, "vipt_quantized_model.pth")
    
    print("Training and quantization complete!")

if __name__ == "__main__":
    main()