import torch
import torch.nn as nn
from timm.models.vision_transformer import vit_small_patch16_224

# Precision levels and their simulated energy costs
PRECISION_COST = {
    "32bit": 1.0,
    "16bit": 0.6,
    "8bit": 0.3,
    "4bit": 0.1
}

# Global energy budget
ENERGY_BUDGET = 100.0
energy_used = 0.0

# Precision-aware value approximation
def approximate_tensor(tensor: torch.Tensor, bits: int) -> torch.Tensor:
    levels = 2 ** bits
    min_val, max_val = tensor.min(), tensor.max()
    scale = (max_val - min_val) / (levels - 1) if levels > 1 else 1.0
    quantized = torch.round((tensor - min_val) / scale) * scale + min_val
    return quantized

# Select appropriate bit-width based on component, position, and energy constraint
def apply_precision(tensor: torch.Tensor, preferred_bits: int) -> torch.Tensor:
    global energy_used
    bit_key = f"{preferred_bits}bit"
    energy_cost = PRECISION_COST.get(bit_key, 1.0)
    if energy_used + energy_cost > ENERGY_BUDGET:
        # fallback to higher precision if budget exceeded
        preferred_bits = 32
        energy_cost = PRECISION_COST["32bit"]
    energy_used += energy_cost
    return approximate_tensor(tensor, preferred_bits)

# Identify block position
def get_block_position(block_index, total_blocks):
    if block_index < total_blocks // 3:
        return "input"
    elif block_index < 2 * total_blocks // 3:
        return "middle"
    else:
        return "output"

# Modify weights based on role and budget
def process_attention_block(attn_module, block_pos):
    with torch.no_grad():
        qkv_weight = attn_module.qkv.weight
        qkv_split = torch.chunk(qkv_weight, 3, dim=0)

        # Assign bit-widths
        q = apply_precision(qkv_split[0], 8)
        k = apply_precision(qkv_split[1], 8)
        v = apply_precision(qkv_split[2], 16)

        attn_module.qkv.weight = nn.Parameter(torch.cat([q, k, v], dim=0))

def process_mlp_block(mlp_module):
    with torch.no_grad():
        fc1_weight = mlp_module.fc1.weight
        fc2_weight = mlp_module.fc2.weight

        fc1_new = apply_precision(fc1_weight, 16)
        fc2_new = apply_precision(fc2_weight, 8)

        mlp_module.fc1.weight = nn.Parameter(fc1_new)
        mlp_module.fc2.weight = nn.Parameter(fc2_new)

# Structured sparsity (threshold-based pruning)
def apply_sparsity(module, threshold=1e-3):
    with torch.no_grad():
        for name, param in module.named_parameters():
            if 'attn' in name and param.dim() > 1:
                param[torch.abs(param) < threshold] = 0.0


def prune_attention_heads(model, threshold=0.1):
    for block in model.blocks:
        attn = block.attn
        with torch.no_grad():
            # Get query, key, value
            qkv = attn.qkv(attn.norm1(torch.randn(1, attn.qkv.in_features)))
            q, k, v = torch.chunk(qkv, 3, dim=-1)
            B, N, C = q.shape
            head_dim = C // attn.num_heads
            q = q.view(B, N, attn.num_heads, head_dim).transpose(1, 2)
            k = k.view(B, N, attn.num_heads, head_dim).transpose(1, 2)

            # Attention score computation
            attn_scores = (q @ k.transpose(-2, -1)) / (head_dim ** 0.5)
            attn_probs = F.softmax(attn_scores, dim=-1)

            # Compute entropy per head
            entropy = -torch.sum(attn_probs * torch.log(attn_probs + 1e-8), dim=-1).mean(dim=-1)  # shape: (B, num_heads)
            avg_entropy = entropy.mean(dim=0)  # shape: (num_heads,)

            # Mask out heads with low entropy
            inactive_heads = (avg_entropy < threshold).nonzero(as_tuple=True)[0]
            if len(inactive_heads) > 0:
                print(f"Pruning {len(inactive_heads)} attention heads in block.")
                for h in inactive_heads:
                    # Zero out corresponding weights (QKV are combined)
                    head_size = head_dim
                    start = h * head_size
                    end = start + head_size
                    attn.qkv.weight.data[start:end, :] = 0.0

def apply_layer_scale_calibration(model, scale_clip=1.0):
    """
    Applies LayerScale calibration by renormalizing output scaling parameters
    to prevent instability after reduced-precision operations.
    """
    for block in model.blocks:
        with torch.no_grad():
            if hasattr(block, 'gamma_1'):
                block.gamma_1.data.clamp_(-scale_clip, scale_clip)
            if hasattr(block, 'gamma_2'):
                block.gamma_2.data.clamp_(-scale_clip, scale_clip)

            # For standard ViTs without LayerScale, we simulate it
            # by scaling residual outputs to unit norm
            x = torch.randn(1, 197, block.norm1.normalized_shape[0])
            residual = block.attn(block.norm1(x))
            norm = residual.norm(dim=-1, keepdim=True) + 1e-6
            residual /= norm

# Apply energy-aware precision tuning to the model
def precision_optimize_vit(model):
    total_blocks = len(model.blocks)
    for idx, block in enumerate(model.blocks):
        block_pos = get_block_position(idx, total_blocks)
        process_attention_block(block.attn, block_pos)
        process_mlp_block(block.mlp)
    apply_sparsity(model)
    prune_attention_heads(model)
    apply_layer_scale_calibration(model)
    return model

# Main
if __name__ == "__main__":
    model = vit_small_patch16_224(pretrained=False)
    model.eval()

    optimized_model = precision_optimize_vit(model)

    print("Total energy used:", round(energy_used, 2))

    # Test with dummy input
    dummy_input = torch.randn(1, 3, 224, 224)
    with torch.no_grad():
        output = optimized_model(dummy_input)
        print("Output shape:", output.shape)
