import torch
from dataclasses import dataclass
from typing import Dict, List, Optional

@dataclass
class ModelConfig:
    model_name: str = 'vit_base_patch16_224'
    num_classes: int = 1000
    img_size: int = 224
    patch_size: int = 16
    embed_dim: int = 768
    depth: int = 12
    num_heads: int = 12
    max_attention_heads: int = 12
    max_hidden_dim: int = 768

@dataclass
class DeviceConfig:
    num_devices: int = 4
    energy_range: tuple = (0.5, 1.0)  # Energy as fraction of max
    compute_range: tuple = (1.0, 8.0)  # GFLOPs
    memory_range: tuple = (2.0, 16.0)  # GB
    bandwidth_range: tuple = (10.0, 100.0)  # Mbps
    energy_threshold: float = 0.1  # Minimum energy threshold

@dataclass
class PartitionConfig:
    strategy: str = 'hybrid'  # 'meta', 'ensembled', 'hybrid'
    lambda_energy: float = 0.5  # Energy weight in redistribution
    compression_ratio: float = 0.3  # Feature compression ratio
    max_layers_per_device: int = 6
    min_layers_per_device: int = 1

@dataclass
class TrainingConfig:
    batch_size: int = 32
    learning_rate: float = 1e-4
    num_epochs: int = 100
    weight_decay: float = 0.01
    warmup_epochs: int = 10
    optimizer: str = 'adamw'
    scheduler: str = 'cosine'
    
    # Weights for multi-objective optimization
    weight_accuracy: float = 0.5
    weight_energy: float = 0.3
    weight_latency: float = 0.2

@dataclass
class InferenceConfig:
    batch_size: int = 1
    max_latency: float = 1.0  # seconds
    cache_intermediate: bool = True
    fault_tolerance: bool = True
    dynamic_adjustment: bool = True

@dataclass
class MonitoringConfig:
    log_interval: int = 10
    save_interval: int = 100
    tensorboard_dir: str = 'runs'
    checkpoint_dir: str = 'checkpoints'
    metrics_file: str = 'metrics.json'
    use_wandb: bool = False
    wandb_project: str = 'lap-dtr'

class Config:
    def __init__(self):
        self.model = ModelConfig()
        self.device = DeviceConfig()
        self.partition = PartitionConfig()
        self.training = TrainingConfig()
        self.inference = InferenceConfig()
        self.monitoring = MonitoringConfig()
        
        # Global settings
        self.seed = 42
        self.device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.mixed_precision = True
        self.distributed_backend = 'nccl'
        
        # Dataset paths
        self.data_root = './data'
        self.dataset = 'imagenet'
        
        # Requirements validation
        self.required_accuracy = 0.70  # Minimum acceptable accuracy
        self.required_energy = 10.0  # Maximum energy consumption (Wh)
        self.required_latency = 2.0  # Maximum latency (seconds)
        
    def get_partition_weights(self) -> Dict[str, float]:
        return {
            'accuracy': self.training.weight_accuracy,
            'energy': self.training.weight_energy,
            'latency': self.training.weight_latency
        }
    
    def update_from_dict(self, config_dict: Dict):
        for section, values in config_dict.items():
            if hasattr(self, section):
                section_obj = getattr(self, section)
                for key, value in values.items():
                    if hasattr(section_obj, key):
                        setattr(section_obj, key, value)
    
    def validate_config(self) -> bool:
        # Validate device configuration
        if self.device.num_devices < 1:
            raise ValueError("Number of devices must be at least 1")
        
        # Validate partition configuration
        if self.partition.strategy not in ['meta', 'ensembled', 'hybrid']:
            raise ValueError("Invalid partitioning strategy")
        
        # Validate weights sum to 1
        weights = self.get_partition_weights()
        if abs(sum(weights.values()) - 1.0) > 1e-6:
            raise ValueError("Partition weights must sum to 1.0")
        
        return True
    
    def __str__(self) -> str:
        config_str = "LAP-DTR Configuration:\n"
        config_str += f"Model: {self.model.model_name}\n"
        config_str += f"Devices: {self.device.num_devices}\n"
        config_str += f"Strategy: {self.partition.strategy}\n"
        config_str += f"Batch Size: {self.training.batch_size}\n"
        return config_str