import torch
import numpy as np
import psutil
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from threading import Lock

try:
    import pynvml
    NVML_AVAILABLE = True
except ImportError:
    NVML_AVAILABLE = False

@dataclass
class DeviceCapabilities:
    device_id: int
    compute_capacity: float  # GFLOPs
    memory_capacity: float  # GB
    current_energy: float  # Fraction of max energy (0-1)
    max_energy: float = 1.0
    bandwidth: float = 100.0  # Mbps
    device_type: str = 'gpu'
    
    # Dynamic states
    current_workload: float = 0.0
    temperature: float = 0.0
    utilization: float = 0.0
    is_available: bool = True
    last_update: float = field(default_factory=time.time)

class EdgeDevice:
    def __init__(self, capabilities: DeviceCapabilities):
        self.capabilities = capabilities
        self.assigned_layers: List[int] = []
        self.intermediate_cache: Dict = {}
        self.performance_history: List[Dict] = []
        self.lock = Lock()
        
        if capabilities.device_type == 'gpu' and torch.cuda.is_available():
            self.torch_device = torch.device(f'cuda:{capabilities.device_id}')
        else:
            self.torch_device = torch.device('cpu')
    
    def update_energy(self, energy_consumption: float):
        with self.lock:
            self.capabilities.current_energy = max(0.0, 
                self.capabilities.current_energy - energy_consumption)
            self.capabilities.last_update = time.time()
    
    def get_energy_ratio(self) -> float:
        return self.capabilities.current_energy / self.capabilities.max_energy
    
    def is_energy_sufficient(self, threshold: float = 0.1) -> bool:
        return self.get_energy_ratio() > threshold
    
    def calculate_layer_cost(self, layer_complexity: Dict) -> float:
        compute_cost = layer_complexity.get('flops', 0) / self.capabilities.compute_capacity
        memory_cost = layer_complexity.get('memory', 0) / self.capabilities.memory_capacity
        return max(compute_cost, memory_cost)
    
    def can_handle_layer(self, layer_complexity: Dict) -> bool:
        required_compute = layer_complexity.get('flops', 0)
        required_memory = layer_complexity.get('memory', 0)
        
        return (required_compute <= self.capabilities.compute_capacity and 
                required_memory <= self.capabilities.memory_capacity and
                self.is_energy_sufficient())
    
    def assign_layer(self, layer_id: int):
        with self.lock:
            if layer_id not in self.assigned_layers:
                self.assigned_layers.append(layer_id)
    
    def remove_layer(self, layer_id: int):
        with self.lock:
            if layer_id in self.assigned_layers:
                self.assigned_layers.remove(layer_id)
    
    def cache_intermediate_result(self, layer_id: int, result: torch.Tensor):
        with self.lock:
            self.intermediate_cache[layer_id] = result.detach().cpu()
    
    def get_cached_result(self, layer_id: int) -> Optional[torch.Tensor]:
        return self.intermediate_cache.get(layer_id)
    
    def clear_cache(self):
        with self.lock:
            self.intermediate_cache.clear()
    
    def update_performance_metrics(self, metrics: Dict):
        with self.lock:
            metrics['timestamp'] = time.time()
            self.performance_history.append(metrics)
            # Keep only recent history
            if len(self.performance_history) > 100:
                self.performance_history = self.performance_history[-100:]
    
    def get_average_performance(self, window: int = 10) -> Dict:
        recent_metrics = self.performance_history[-window:]
        if not recent_metrics:
            return {}
        
        avg_metrics = {}
        for key in recent_metrics[0].keys():
            if key != 'timestamp':
                avg_metrics[key] = np.mean([m[key] for m in recent_metrics])
        
        return avg_metrics

class DeviceMonitor:
    def __init__(self):
        self.monitoring_active = False
        if NVML_AVAILABLE:
            try:
                pynvml.nvmlInit()
                self.nvml_initialized = True
            except:
                self.nvml_initialized = False
        else:
            self.nvml_initialized = False
    
    def get_gpu_metrics(self, device_id: int) -> Dict:
        metrics = {}
        if not self.nvml_initialized:
            return metrics
        
        try:
            handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
            
            # Memory info
            mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            metrics['memory_used'] = mem_info.used / (1024**3)  # GB
            metrics['memory_total'] = mem_info.total / (1024**3)  # GB
            metrics['memory_utilization'] = mem_info.used / mem_info.total
            
            # Utilization
            util = pynvml.nvmlDeviceGetUtilizationRates(handle)
            metrics['gpu_utilization'] = util.gpu / 100.0
            
            # Temperature
            temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
            metrics['temperature'] = temp
            
            # Power
            power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0  # Watts
            metrics['power_usage'] = power
            
        except Exception as e:
            print(f"Error getting GPU metrics for device {device_id}: {e}")
        
        return metrics
    
    def get_cpu_metrics(self) -> Dict:
        return {
            'cpu_utilization': psutil.cpu_percent(interval=None) / 100.0,
            'memory_utilization': psutil.virtual_memory().percent / 100.0,
            'memory_used': psutil.virtual_memory().used / (1024**3),  # GB
            'memory_total': psutil.virtual_memory().total / (1024**3)  # GB
        }

class DeviceManager:
    def __init__(self, config=None):
        self.devices: List[EdgeDevice] = []
        self.monitor = DeviceMonitor()
        self.config = config
    
    def create_device_network(self, num_devices: int = 4) -> List[EdgeDevice]:
        self.devices = []
        
        for i in range(num_devices):
            # Simulate diverse device capabilities
            capabilities = self._generate_device_capabilities(i)
            device = EdgeDevice(capabilities)
            self.devices.append(device)
        
        return self.devices
    
    def _generate_device_capabilities(self, device_id: int) -> DeviceCapabilities:
        if self.config:
            energy_range = self.config.device.energy_range
            compute_range = self.config.device.compute_range
            memory_range = self.config.device.memory_range
            bandwidth_range = self.config.device.bandwidth_range
        else:
            energy_range = (0.5, 1.0)
            compute_range = (1.0, 8.0)
            memory_range = (2.0, 16.0)
            bandwidth_range = (10.0, 100.0)
        
        return DeviceCapabilities(
            device_id=device_id,
            compute_capacity=np.random.uniform(*compute_range),
            memory_capacity=np.random.uniform(*memory_range),
            current_energy=np.random.uniform(*energy_range),
            bandwidth=np.random.uniform(*bandwidth_range),
            device_type='gpu' if torch.cuda.is_available() and device_id < torch.cuda.device_count() else 'cpu'
        )
    
    def get_device_by_id(self, device_id: int) -> Optional[EdgeDevice]:
        for device in self.devices:
            if device.capabilities.device_id == device_id:
                return device
        return None
    
    def get_available_devices(self, energy_threshold: float = 0.1) -> List[EdgeDevice]:
        return [device for device in self.devices 
                if device.capabilities.is_available and 
                device.is_energy_sufficient(energy_threshold)]
    
    def update_device_states(self):
        for device in self.devices:
            if device.capabilities.device_type == 'gpu':
                metrics = self.monitor.get_gpu_metrics(device.capabilities.device_id)
            else:
                metrics = self.monitor.get_cpu_metrics()
            
            if metrics:
                device.capabilities.utilization = metrics.get('gpu_utilization', 
                                                            metrics.get('cpu_utilization', 0))
                device.capabilities.temperature = metrics.get('temperature', 0)
                device.update_performance_metrics(metrics)
    
    def redistribute_failed_device_layers(self, failed_device_id: int) -> Dict[int, int]:
        failed_device = self.get_device_by_id(failed_device_id)
        if not failed_device:
            return {}
        
        available_devices = self.get_available_devices()
        if not available_devices:
            return {}
        
        redistribution_map = {}
        failed_layers = failed_device.assigned_layers.copy()
        
        for layer_id in failed_layers:
            # Find best device for this layer
            best_device = min(available_devices, 
                            key=lambda d: len(d.assigned_layers))
            
            failed_device.remove_layer(layer_id)
            best_device.assign_layer(layer_id)
            redistribution_map[layer_id] = best_device.capabilities.device_id
        
        return redistribution_map
    
    def get_network_statistics(self) -> Dict:
        total_devices = len(self.devices)
        available_devices = len(self.get_available_devices())
        
        total_compute = sum(d.capabilities.compute_capacity for d in self.devices)
        total_memory = sum(d.capabilities.memory_capacity for d in self.devices)
        total_energy = sum(d.capabilities.current_energy for d in self.devices)
        
        avg_utilization = np.mean([d.capabilities.utilization for d in self.devices])
        
        return {
            'total_devices': total_devices,
            'available_devices': available_devices,
            'total_compute_capacity': total_compute,
            'total_memory_capacity': total_memory,
            'total_remaining_energy': total_energy,
            'average_utilization': avg_utilization,
            'network_efficiency': available_devices / total_devices if total_devices > 0 else 0
        }