import torch
import torch.nn as nn
from torch.quantization import quantize_dynamic, QuantStub, DeQuantStub
from torch.ao.quantization import get_default_qconfig, prepare, convert
import copy


def quantize_to_8bit(model, example_inputs=None):
    model_8bit = copy.deepcopy(model)
    model_8bit.eval()
    
    # Specify which layers to quantize
    quantized_model = torch.quantization.quantize_dynamic(
        model_8bit,
        {nn.Linear, nn.Conv2d},  # Quantize linear and conv layers
        dtype=torch.qint8  # 8-bit quantization
    )
    
    return quantized_model

def quantize_to_16bit(model):
    # Make a copy of the model for quantization
    model_16bit = copy.deepcopy(model)
    model_16bit.eval()
    
    # Convert to half precision (16-bit)
    model_16bit = model_16bit.half()
    
    return model_16bit

class MixedPrecisionViT(nn.Module):

    def __init__(self, base_model):
        super().__init__()
        self.base_model = copy.deepcopy(base_model)
        self.precision_config = {}
        self.quantized_submodules = {}
        
    def configure_precision(self, config):
        self.precision_config = config
        
        # Apply quantization to each submodule as specified
        for name, module in self.base_model.named_children():
            if name in config:
                precision = config[name]
                if precision == '8bit':
                    self.quantized_submodules[name] = quantize_to_8bit(module)
                elif precision == '16bit':
                    self.quantized_submodules[name] = quantize_to_16bit(module)
                # For '32bit', we keep the original module
        
    def forward(self, x):
        if hasattr(self.base_model, 'patch_embed'):
            module_name = 'patch_embed'
            if module_name in self.quantized_submodules:
                x = self.quantized_submodules[module_name](x)
            else:
                x = self.base_model.patch_embed(x)
            
            if hasattr(self.base_model, 'pos_embed'):
                x = x + self.base_model.pos_embed
            
            if hasattr(self.base_model, 'cls_token'):
                cls_token = self.base_model.cls_token.expand(x.shape[0], -1, -1)
                x = torch.cat((cls_token, x), dim=1)
        
        if hasattr(self.base_model, 'blocks'):
            for i, block in enumerate(self.base_model.blocks):
                module_name = f'blocks.{i}'
                if module_name in self.quantized_submodules:
                    x = self.quantized_submodules[module_name](x)
                else:
                    x = block(x)
        

        if hasattr(self.base_model, 'norm'):
            module_name = 'norm'
            if module_name in self.quantized_submodules:
                x = self.quantized_submodules[module_name](x)
            else:
                x = self.base_model.norm(x)
        
        if hasattr(self.base_model, 'head'):
            module_name = 'head'
            if module_name in self.quantized_submodules:
                x = self.quantized_submodules[module_name](x[:, 0])
            else:
                x = self.base_model.head(x[:, 0])
        
        return x

# Custom 8-bit matrix multiplication implementation
class Int8MatrixMultiply(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        self.scale = nn.Parameter(torch.ones(1), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(out_features))
        
    def forward(self, x):
        x_int8 = torch.clamp(x * 127, -128, 127).to(torch.int8)
        output_int32 = torch.matmul(x_int8, self.weight.t())
        
        # Convert back to floating point with scaling
        output_float = output_int32.float() * self.scale + self.bias
        
        return output_float


class FP16MatrixMultiply(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Initialize weights as 16-bit floats
        self.weight = nn.Parameter(torch.randn(out_features, in_features).half())
        self.bias = nn.Parameter(torch.zeros(out_features).half())
        
    def forward(self, x):
        if x.dtype != torch.float16:
            x = x.half()
            
        output = torch.matmul(x, self.weight.t()) + self.bias
        
        return output

def analyze_model_sensitivity(model, dataset, num_samples=100):

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=16, shuffle=True, num_workers=2
    )
    
    # Original accuracy
    original_accuracy = evaluate_accuracy(model, data_loader, num_samples)
    
    sensitivity_scores = {}
    
    # Test each module with reduced precision
    for name, module in model.named_children():
        # Create a copy of the model
        test_model = copy.deepcopy(model)
        
        # Quantize just this module to 8-bit
        module_to_quantize = getattr(test_model, name)
        quantized_module = quantize_to_8bit(module_to_quantize)
        setattr(test_model, name, quantized_module)
        
        # Evaluate accuracy
        quantized_accuracy = evaluate_accuracy(test_model, data_loader, num_samples)
        
        # Calculate sensitivity (drop in accuracy)
        sensitivity = original_accuracy - quantized_accuracy
        sensitivity_scores[name] = sensitivity
        
        print(f"Module {name}: Sensitivity = {sensitivity:.4f}")
    
    return sensitivity_scores

def auto_mixed_precision_config(model, dataset, num_samples=100):
    sensitivity_scores = analyze_model_sensitivity(model, dataset, num_samples)
    

    sorted_modules = sorted(sensitivity_scores.items(), key=lambda x: x[1], reverse=True)
    
    config = {}
    for name, sensitivity in sorted_modules:
        if sensitivity > 0.05:  # High sensitivity
            config[name] = '32bit'
        elif sensitivity > 0.01:  # Medium sensitivity
            config[name] = '16bit'
        else:  # Low sensitivity
            config[name] = '8bit'
    
    return config

def evaluate_accuracy(model, data_loader, num_samples):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
            
            if total >= num_samples:
                break
    
    return correct / total

def apply_mixed_precision_to_vit(vit_model, precision_config=None, dataset=None):
    mixed_precision_model = MixedPrecisionViT(vit_model)
    
    if precision_config is None and dataset is not None:
        # Auto-generate configuration
        precision_config = auto_mixed_precision_config(vit_model, dataset)
        print("Auto-generated precision configuration:")
        for module, precision in precision_config.items():
            print(f"  {module}: {precision}")
    elif precision_config is None:
        # Default configuration if no dataset is provided
        precision_config = {
            'patch_embed': '16bit',    # Less sensitive
            'blocks.0': '32bit',       # First layer more sensitive
            'blocks.1': '16bit',
            'blocks.2': '16bit',
            'blocks.3': '8bit',
            'blocks.4': '8bit',
            'blocks.5': '8bit',
            'blocks.6': '8bit',
            'blocks.7': '8bit',
            'blocks.8': '16bit',
            'blocks.9': '16bit',
            'blocks.10': '32bit',      # Last layers more sensitive
            'blocks.11': '32bit',
            'norm': '32bit',
            'head': '32bit'            # Classification head sensitive
        }
    
    # Apply configuration
    mixed_precision_model.configure_precision(precision_config)
    
    return mixed_precision_model

# Example usage
def example_mixed_precision_usage():
    # Load a pretrained ViT model
    vit_model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
    
    # Option 1: Manually specify precision configuration
    manual_config = {
        'patch_embed': '16bit',
        'blocks.0': '32bit',
        'blocks.1': '16bit',
        'blocks.2': '8bit',
        'blocks.3': '8bit',
        'blocks.4': '8bit',
        'blocks.5': '8bit',
        'blocks.6': '8bit',
        'blocks.7': '8bit',
        'blocks.8': '16bit',
        'blocks.9': '16bit',
        'blocks.10': '32bit',
        'blocks.11': '32bit',
        'norm': '32bit',
        'head': '32bit'
    }
    
    mixed_precision_vit = apply_mixed_precision_to_vit(vit_model, manual_config)
    
    # Option 2: Auto-configure based on dataset sensitivity (if dataset available)
    # mixed_precision_vit = apply_mixed_precision_to_vit(vit_model, dataset=my_dataset)
    
    return mixed_precision_vit