import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Normal, Laplace, kl_divergence
import copy

class EnhancedVariationalVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=4, num_classes=1000, 
                 vi_distribution='gaussian', laplace_b=1.0, mixed_precision_config=None):
        super().__init__()
        
        from torchvision.models.vision_transformer import VisionTransformer
        self.model = VisionTransformer(
            image_size=img_size,
            patch_size=patch_size,
            num_classes=num_classes,
            num_layers=12,
            num_heads=12,
            hidden_dim=768,
            mlp_dim=3072
        )
        
        # Modify patch embedding for 4 channels (RGB + LiDAR)
        original_conv = self.model.conv_proj
        self.model.conv_proj = nn.Conv2d(
            in_channels, original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride
        )
        
        # Variational parameters
        self.vi_distribution = vi_distribution
        self.laplace_b = laplace_b  # Scale parameter for Laplace distribution
        self.weight_means = {}
        self.weight_scales = {}  # For both std (Gaussian) and scale (Laplace)
        self.mixed_precision_config = mixed_precision_config or {}
        
        # Initialize variational parameters
        self._initialize_variational_parameters()
        
        # Track current precision for each layer
        self.current_precisions = {}
        
    def _initialize_variational_parameters(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                self.weight_means[name] = nn.Parameter(param.data.clone())
                
                if self.vi_distribution == 'gaussian':
                    self.weight_scales[name] = nn.Parameter(torch.ones_like(param) * -3.0)
                elif self.vi_distribution == 'laplace':
                    self.weight_scales[name] = nn.Parameter(torch.ones_like(param) * np.log(self.laplace_b))
                else:
                    raise ValueError(f"Unknown distribution: {self.vi_distribution}")
                
                # Register parameters
                self.register_parameter(f"{name}_mean", self.weight_means[name])
                self.register_parameter(f"{name}_scale", self.weight_scales[name])

    def switch_distribution(self, new_distribution, laplace_b=None):
        if new_distribution not in ['gaussian', 'laplace']:
            raise ValueError("Distribution must be 'gaussian' or 'laplace'")
            
        self.vi_distribution = new_distribution
        if laplace_b is not None:
            self.laplace_b = laplace_b
    
    def _sample_weights_gaussian(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and name in self.weight_means:
                mean = self.weight_means[name]
                log_var = self.weight_scales[name]
                std = torch.exp(0.5 * log_var)
                
                # Reparameterization trick: w = μ + σ * ε, where ε ~ N(0, 1)
                epsilon = torch.randn_like(mean)
                sampled_weight = mean + std * epsilon
                param.data.copy_(sampled_weight)
    
    def _sample_weights_laplace(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and name in self.weight_means:
                mean = self.weight_means[name]
                log_scale = self.weight_scales[name]
                scale = torch.exp(log_scale)
                
                #   Laplace using inverse CDF: L^(-1)(u) = μ - b*sgn(u-0.5)*ln(1-2|u-0.5|)
                u = torch.rand_like(mean)
                u_centered = u - 0.5
                sampled_weight = mean - scale * torch.sign(u_centered) * torch.log1p(-2 * torch.abs(u_centered))
                param.data.copy_(sampled_weight)
    
    def _sample_weights(self):
        if self.vi_distribution == 'gaussian':
            self._sample_weights_gaussian()
        elif self.vi_distribution == 'laplace':
            self._sample_weights_laplace()
    
    def apply_quantization(self, layer_name, precision):
        if layer_name in self.model.state_dict():
            weight = self.model.state_dict()[layer_name]
            
            # Quantization levels
            if precision == 8:
                levels = 256
            elif precision == 16:
                levels = 65536
            elif precision == 4:
                levels = 16
            else:
                return weight  # No quantization for 32-bit
            
            # Symmetric quantization
            w_max = torch.max(torch.abs(weight))
            scale = 2 * w_max / (levels - 1)
            
            # Quantize
            weight_q = torch.round(weight / scale) * scale
            weight_q = weight + (weight_q - weight).detach()
            
            return weight_q
        return None
    
    def forward(self, camera_data, lidar_data):
        if self.training:
            self._sample_weights()
        
        # Early fusion: concatenate camera and lidar along channel dimension
        # camera_data: [B, 3, H, W], lidar_data: [B, 1, H, W]
        if lidar_data.dim() == 3:
            lidar_data = lidar_data.unsqueeze(1)  # Add channel dimension
        
        if lidar_data.shape[-2:] != camera_data.shape[-2:]:
            lidar_data = F.interpolate(lidar_data, size=camera_data.shape[-2:], mode='bilinear', align_corners=False)
        

        fused_input = torch.cat([camera_data, lidar_data], dim=1)  # [B, 4, H, W]
        
        return self.model(fused_input)
    
    def kl_divergence_gaussian(self):
        kl_sum = 0.0
        for name in self.weight_means:
            mean = self.weight_means[name]
            log_var = self.weight_scales[name]
            kl_div = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
            kl_sum += kl_div
        
        return kl_sum
    
    def kl_divergence_laplace(self):
        kl_sum = 0.0
        for name in self.weight_means:
            mean = self.weight_means[name]
            log_scale = self.weight_scales[name]
            scale = torch.exp(log_scale)
            
            # KL(q(w|θ) || p(w)) where q ~ Laplace(μ, b) and p ~ Laplace(0, 1)
            # KL = log(b_p/b_q) + (|μ_q - μ_p| + b_q)/b_p - 1
            # Here: μ_p = 0, b_p = 1, μ_q = mean, b_q = scale
            kl_div = torch.sum(-log_scale + torch.abs(mean) + scale - 1)
            kl_sum += kl_div
        
        return kl_sum
    
    def kl_divergence(self):
        if self.vi_distribution == 'gaussian':
            return self.kl_divergence_gaussian()
        elif self.vi_distribution == 'laplace':
            return self.kl_divergence_laplace()
    
    def update_bit_assignments(self, camera_data, lidar_data, target_sparsity=0.7):
        gradients = {}
        activations = {}
        
        # Forward pass to get activations
        with torch.no_grad():
            _ = self.forward(camera_data, lidar_data)
        
        # Compute gradient magnitudes
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                gradients[name] = torch.norm(param.grad).item()
        
        sorted_layers = sorted(gradients.items(), key=lambda x: x[1], reverse=True)
        total_layers = len(sorted_layers)
        high_precision_count = int(total_layers * 0.3)  # Top 30% get 16-bit
        medium_precision_count = int(total_layers * 0.5)  # Next 50% get 8-bit
        
        for i, (layer_name, grad_mag) in enumerate(sorted_layers):
            if i < high_precision_count:
                self.current_precisions[layer_name] = 16
            elif i < high_precision_count + medium_precision_count:
                self.current_precisions[layer_name] = 8
            else:
                self.current_precisions[layer_name] = 4

def mixed_distribution_loss(model, outputs, targets, lambda_var=0.01, lambda_sparse=0.001):

    nll = F.cross_entropy(outputs, targets)
    kl_loss = model.kl_divergence()
    sparsity_loss = 0.0
    for name in model.weight_means:
        sparsity_loss += torch.sum(torch.abs(model.weight_means[name]))
    
    # Total loss
    total_loss = nll + lambda_var * kl_loss + lambda_sparse * sparsity_loss
    
    return total_loss, nll, kl_loss, sparsity_loss

def adaptive_training_step(model, camera_data, lidar_data, targets, optimizer, lambda_var=0.01, epoch=0, switch_every=10):
    if epoch % switch_every == 0:
        current_dist = 'gaussian' if epoch % (2 * switch_every) == 0 else 'laplace'
        model.switch_distribution(current_dist)
    

    outputs = model(camera_data, lidar_data)
    total_loss, nll, kl_loss, sparsity_loss = mixed_distribution_loss(
        model, outputs, targets, lambda_var
    )
    
    optimizer.zero_grad()
    total_loss.backward()
    
    # Update bit assignments based on gradients
    model.update_bit_assignments(camera_data, lidar_data)
    
    # Optimizer 
    optimizer.step()
    
    return {
        'total_loss': total_loss.item(),
        'nll': nll.item(),
        'kl_loss': kl_loss.item(),
        'sparsity_loss': sparsity_loss.item(),
        'distribution': model.vi_distribution
    }