import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
import numpy as np
import random
import time


# Configuration and Setup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 16
epochs = 10
lr = 1e-4
energy_budget = 1.5  # in joules (simulated)
lambda_reg = 1e-3
gamma_energy = 0.2


# Simulated Energy Tracker - use in combination with device tegrastats

def estimate_batch_energy(model, input_tensor):
    ops = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return ops * 1e-9  


# Fisher Information

def fisher_information(model):
    total = 0.0
    for p in model.parameters():
        if p.requires_grad:
            total += torch.sum(p ** 2)
    return total


# Custom Loss Functions
def loss_occam(pred, target, model, lambda_):
    ce_loss = F.cross_entropy(pred, target)
    fisher_term = 0.5 * lambda_ * fisher_information(model)
    return ce_loss + fisher_term

def loss_occam_energy(pred, target, model, lambda_, gamma, ee_stat):
    l_occam = loss_occam(pred, target, model, lambda_)
    return l_occam + gamma * ee_stat


# Quantization (Post-training)

def quantize_model_weights(model, bits=8):
    levels = 2 ** bits
    for p in model.parameters():
        if p.requires_grad:
            min_val = p.min()
            max_val = p.max()
            scale = (max_val - min_val) / (levels - 1)
            p.data = torch.round((p - min_val) / scale) * scale + min_val


def load_data(n_samples=1000, num_classes=10):
    images = torch.randn(n_samples, 3, 224, 224)
    labels = torch.randint(0, num_classes, (n_samples,))
    return images, labels


# Training Method: VIET (Energy-Aware)
def train_viet(model, images, labels):
    global lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        indices = torch.randperm(len(images))
        avg_loss = 0.0
        avg_energy = 0.0

        for i in range(0, len(images), batch_size):
            batch_idx = indices[i:i + batch_size]
            x = images[batch_idx].to(device)
            y = labels[batch_idx].to(device)

            # Sample weights (here simulated by dropout-like randomness)
            for p in model.parameters():
                if p.requires_grad:
                    p.data += 0.01 * torch.randn_like(p)

            optimizer.zero_grad()
            out = model(x)
            energy = estimate_batch_energy(model, x)
            loss = loss_occam_energy(out, y, model, lambda_reg, gamma_energy, energy)
            loss.backward()
            optimizer.step()

            avg_loss += loss.item()
            avg_energy += energy

        avg_loss /= (len(images) / batch_size)
        avg_energy /= (len(images) / batch_size)
        print(f"[VIET] Epoch {epoch+1} | Loss: {avg_loss:.4f} | Energy: {avg_energy:.4f}")

        # Adaptive learning rate
        if avg_energy > energy_budget:
            lr *= 0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr


# Training Approach: VIPT (Quantized Post-training)

def train_vipt(model, images, labels):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        indices = torch.randperm(len(images))
        avg_loss = 0.0

        for i in range(0, len(images), batch_size):
            batch_idx = indices[i:i + batch_size]
            x = images[batch_idx].to(device)
            y = labels[batch_idx].to(device)

            # Weight sampling simulation
            for p in model.parameters():
                if p.requires_grad:
                    p.data += 0.01 * torch.randn_like(p)

            optimizer.zero_grad()
            out = model(x)
            loss = loss_occam(out, y, model, lambda_reg)
            loss.backward()
            optimizer.step()

            avg_loss += loss.item()

        avg_loss /= (len(images) / batch_size)
        print(f"[VIPT] Epoch {epoch+1} | Loss: {avg_loss:.4f}")

    print("[VIPT] Quantizing model to 8-bit")
    quantize_model_weights(model, bits=8)


# Main Runner

def main():
    images, labels = generate_synthetic_data()
    print("***Training VIET (Energy-Aware)***")
    model_viet = create_model('vit_small_patch16_224', pretrained=False, num_classes=10).to(device)
    train_viet(model_viet, images, labels)

    print("\nTraining VIPT (Quantization after training)...")
    model_vipt = create_model('vit_small_patch16_224', pretrained=False, num_classes=10).to(device)
    train_vipt(model_vipt, images, labels)

    print("Done.")

if __name__ == "__main__":
    main()
