import numpy as np

def appx_mul4x4(A, B):
    if A == 0 or B == 0:
        return 0
    elif A == 1:
        return B
    elif A > 1 and A % 2 == 0:  # A is even
        if B < 8:
            return 0
        else:
            return 32 * (A // 2)
    else:  # A is odd and > 1
        if B < 8:
            return B
        else:
            return B + 32 * ((A - 1) // 2)

def approximate_multiplication(A, B, precision):
    if precision == 16:
        # Split into 8-bit chunks
        a0 = (A >> 8) & 0xFF  # Upper 8 bits
        a1 = A & 0xFF         # Lower 8 bits
        b0 = (B >> 8) & 0xFF  # Upper 8 bits
        b1 = B & 0xFF         # Lower 8 bits
        

        a0b0 = approximate_multiplication(a0, b0, 8)
        a1b0 = approximate_multiplication(a1, b0, 8) << 8
        a0b1 = approximate_multiplication(a0, b1, 8) << 8
        a1b1 = approximate_multiplication(a1, b1, 8) << 16
        
        # Sum the partial products
        S = a0b0 + a1b0 + a0b1 + a1b1
        return S
        
    elif precision == 8:
        # Split into 4-bit chunks
        a0 = (A >> 4) & 0xF  # Upper 4 bits
        a1 = A & 0xF         # Lower 4 bits
        b0 = (B >> 4) & 0xF  # Upper 4 bits
        b1 = B & 0xF         # Lower 4 bits
        
        # Perform 4x4 approximate multiplications
        a0b0 = appx_mul4x4(a0, b0)
        a1b0 = appx_mul4x4(a1, b0)
        a0b1 = appx_mul4x4(a0, b1)
        a1b1 = appx_mul4x4(a1, b1)
        
        # Combine results
        S = a0b0 + ((a1b0 + a0b1) << 4) + (a1b1 << 8)
        return S
        
    else:  # precision == 4
        return appx_mul4x4(A, B)


class ApproximateMultiplicationLayer(object):
    """
    Layer that performs approximate multiplication with configurable precision
    """
    def __init__(self, precision=8):
        self.precision = precision
    
    def __call__(self, A, B):
        # Handle scalar inputs
        if np.isscalar(A) and np.isscalar(B):
            return approximate_multiplication(A, B, self.precision)
        
        # Handle array inputs
        result = np.zeros_like(A * B)
        
        # Apply element-wise approximate multiplication
        if np.isscalar(A):
            for idx in np.ndindex(B.shape):
                result[idx] = approximate_multiplication(A, B[idx], self.precision)
        elif np.isscalar(B):
            for idx in np.ndindex(A.shape):
                result[idx] = approximate_multiplication(A[idx], B, self.precision)
        else:
            for idx in np.ndindex(A.shape):
                result[idx] = approximate_multiplication(A[idx], B[idx], self.precision)
        
        return result


class MixedPrecisionMultiplier(object):

    def __init__(self, default_precision=8, precision_rules=None):
        self.default_precision = default_precision
        self.precision_rules = precision_rules or {}
        self.multipliers = {}
        
        # Create multipliers for each precision
        for precision in set([default_precision] + list(precision_rules.values())):
            self.multipliers[precision] = ApproximateMultiplicationLayer(precision)
    
    def get_precision_for_layer(self, layer_name):
        return self.precision_rules.get(layer_name, self.default_precision)
    
    def multiply(self, A, B, layer_name=None):
        precision = self.get_precision_for_layer(layer_name)
        return self.multipliers[precision](A, B)


def apply_to_vit_model(model, mixed_precision_config=None):
    if mixed_precision_config is None:
        # Default configuration: higher precision for critical layers, lower for others
        mixed_precision_config = {
            # Critical layers get higher precision
            "embedding": 16,
            "position_embedding": 16,
            "cls_token": 16,
            "mlp.0": 8,  # First MLP layer in each block
            "mlp.2": 4,  # Second MLP layer in each block
            "attn.qkv": 8,  # Attention query/key/value projections
            "attn.proj": 8,  # Attention output projection
            "norm": 16,  # Layer normalization
            "head": 16,  # Classification head
        }
    
    # Create mixed precision multiplier
    multiplier = MixedPrecisionMultiplier(
        default_precision=8,
        precision_rules=mixed_precision_config
    )
    
    def hook_fn(name):
        def hook(module, inputs, output):
            # For demonstration purposes, we're just showing the concept
            # In a real implementation, you would modify the forward pass
            print(f"Layer {name} using precision {multiplier.get_precision_for_layer(name)}")
            return output
        return hook
    
    # Register hooks for each layer
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            module.register_forward_hook(hook_fn(name))
    
    print("Applied mixed precision multiplication to model")
    return model


def create_mixed_precision_config(model, importance_threshold=0.5):
    config = {}
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            # Example importance calculation (placeholder)
            # In reality, you might use gradient magnitudes, activation ranges, etc.
            if "attn" in name:
                # Attention layers generally more important
                importance = 0.8
            elif "mlp" in name and ".0" in name:
                # First MLP layers somewhat important
                importance = 0.6
            elif "mlp" in name and ".2" in name:
                # Second MLP layers less important
                importance = 0.4
            elif "head" in name or "embed" in name:
                # Input embedding and output head are important
                importance = 0.9
            else:
                importance = 0.5
            
            # Assign precision based on importance
            if importance > 0.7:
                config[name] = 16  # High precision for very important layers
            elif importance > 0.5:
                config[name] = 8   # Medium precision for moderately important layers
            else:
                config[name] = 4   # Low precision for less important layers
    
    return config