import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Dict, List, Tuple, Optional, Any, Union
from collections import defaultdict, deque
import threading
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass

from lap_dtr import LAPDTRFramework, SubNetworkWrapper
from utils import MetricsTracker, EnergyMonitor, CommunicationTracker

@dataclass
class InferenceResult:
    predictions: torch.Tensor
    confidence_scores: torch.Tensor
    latency: float
    energy_consumed: float
    communication_overhead: float
    device_utilization: Dict[int, float]
    intermediate_results: Dict[int, torch.Tensor] = None

class InferenceCache:
    def __init__(self, max_size: int = 1000):
        self.cache = {}
        self.access_times = {}
        self.max_size = max_size
        self.lock = threading.Lock()
    
    def get(self, key: str) -> Optional[torch.Tensor]:
        with self.lock:
            if key in self.cache:
                self.access_times[key] = time.time()
                return self.cache[key].clone()
            return None
    
    def put(self, key: str, value: torch.Tensor):
        with self.lock:
            if len(self.cache) >= self.max_size:
                # Remove least recently used item
                lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
                del self.cache[lru_key]
                del self.access_times[lru_key]
            
            self.cache[key] = value.clone()
            self.access_times[key] = time.time()
    
    def clear(self):
        with self.lock:
            self.cache.clear()
            self.access_times.clear()

class FaultToleranceManager:
    def __init__(self, framework: LAPDTRFramework):
        self.framework = framework
        self.failed_devices = set()
        self.backup_routes = {}
        self.recovery_attempts = defaultdict(int)
        self.max_recovery_attempts = 3
    
    def mark_device_failed(self, device_id: int):
        self.failed_devices.add(device_id)
        self.recovery_attempts[device_id] += 1
        
        # Redistribute layers if possible
        if self.recovery_attempts[device_id] <= self.max_recovery_attempts:
            self._redistribute_failed_device(device_id)
    
    def _redistribute_failed_device(self, failed_device_id: int):
        # Find backup devices
        available_devices = [
            d for d in self.framework.devices 
            if d.capabilities.device_id not in self.failed_devices and 
            d.is_energy_sufficient()
        ]
        
        if not available_devices:
            print(f"No available devices for redistribution from device {failed_device_id}")
            return
        
        # Redistribute layers
        redistribution_map = self.framework.device_manager.redistribute_failed_device_layers(failed_device_id)
        
        if redistribution_map:
            print(f"Redistributed {len(redistribution_map)} layers from failed device {failed_device_id}")
            self.backup_routes[failed_device_id] = redistribution_map
    
    def is_device_available(self, device_id: int) -> bool:
        return device_id not in self.failed_devices
    
    def reset_device(self, device_id: int):
        if device_id in self.failed_devices:
            self.failed_devices.remove(device_id)
            self.recovery_attempts[device_id] = 0

class DistributedInference:
    def __init__(self, framework: LAPDTRFramework, config):
        self.framework = framework
        self.config = config
        self.devices = framework.devices
        
        # Inference components
        self.cache = InferenceCache()
        self.fault_manager = FaultToleranceManager(framework)
        self.metrics_tracker = MetricsTracker()
        self.energy_monitor = EnergyMonitor(self.devices)
        self.comm_tracker = CommunicationTracker()
        
        # Performance optimization
        self.batch_queue = deque()
        self.batch_processor = None
        self.processing_active = False
        
        # Thread pool for parallel execution
        self.executor = ThreadPoolExecutor(max_workers=len(self.devices))
        
        self._setup_inference_pipeline()
    
    def _setup_inference_pipeline(self):
        # Set all models to evaluation mode
        for sub_network in self.framework.partitioned_models.values():
            sub_network.eval()
        
        # Enable inference optimizations
        torch.backends.cudnn.benchmark = True
        if hasattr(torch.backends.cudnn, 'allow_tf32'):
            torch.backends.cudnn.allow_tf32 = True
    
    def run_inference(self, input_data: Union[torch.Tensor, List[torch.Tensor]], 
                     batch_size: Optional[int] = None) -> Union[InferenceResult, List[InferenceResult]]:
        if isinstance(input_data, list):
            return self._batch_inference(input_data, batch_size)
        else:
            return self._single_inference(input_data)
    
    def _single_inference(self, input_tensor: torch.Tensor) -> InferenceResult:
        start_time = time.time()
        
        # Reset monitoring for this inference
        self.energy_monitor.reset_inference_counters()
        self.comm_tracker.reset_inference_counters()
        
        try:
            # Check cache first
            cache_key = self._generate_cache_key(input_tensor)
            cached_result = self.cache.get(cache_key)
            
            if cached_result is not None and self.config.inference.cache_intermediate:
                latency = time.time() - start_time
                return InferenceResult(
                    predictions=cached_result,
                    confidence_scores=torch.softmax(cached_result, dim=-1),
                    latency=latency,
                    energy_consumed=0.0,  # Cached result
                    communication_overhead=0.0,
                    device_utilization={}
                )
            
            # Run distributed inference
            predictions, intermediate_results = self._distributed_inference_forward(input_tensor)
            
            # Calculate confidence scores
            confidence_scores = torch.softmax(predictions, dim=-1) if predictions is not None else torch.tensor([])
            
            # Cache result if enabled
            if self.config.inference.cache_intermediate and predictions is not None:
                self.cache.put(cache_key, predictions)
            
            # Calculate metrics
            latency = time.time() - start_time
            energy_consumed = self.energy_monitor.get_inference_energy()
            comm_overhead = self.comm_tracker.get_inference_data_transferred()
            device_util = self._calculate_device_utilization()
            
            return InferenceResult(
                predictions=predictions,
                confidence_scores=confidence_scores,
                latency=latency,
                energy_consumed=energy_consumed,
                communication_overhead=comm_overhead,
                device_utilization=device_util,
                intermediate_results=intermediate_results if self.config.inference.cache_intermediate else None
            )
        
        except Exception as e:
            print(f"Error during inference: {e}")
            return self._handle_inference_failure(input_tensor, start_time)
    
    def _distributed_inference_forward(self, input_tensor: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        intermediate_results = {}
        current_input = input_tensor
        final_output = None
        
        # Get inference execution plan
        execution_plan = self._create_execution_plan()
        
        for stage in execution_plan:
            device_id = stage['device_id']
            sub_network = stage['sub_network']
            device = self.framework.devices[device_id]
            
            # Check device availability
            if not self.fault_manager.is_device_available(device_id):
                continue
            
            if not device.is_energy_sufficient():
                # Adjust layer complexity for low energy
                energy_ratio = device.get_energy_ratio()
                self.framework.adjust_layer_complexity(device_id, energy_ratio)
            
            try:
                # Transfer input to device
                current_input = self._transfer_to_device(current_input, device.torch_device)
                
                # Monitor energy before computation
                pre_computation_energy = device.capabilities.current_energy
                
                # Forward pass with timeout
                with torch.no_grad():
                    if self.config.mixed_precision:
                        with torch.cuda.amp.autocast():
                            sub_outputs = sub_network(current_input, cache_intermediate=True)
                    else:
                        sub_outputs = sub_network(current_input, cache_intermediate=True)
                
                # Store intermediate results
                intermediate_results[device_id] = sub_outputs
                
                # Update current input for next stage
                if 'final' in sub_outputs:
                    current_input = sub_outputs['final']
                    final_output = current_input
                
                # Track energy consumption
                post_computation_energy = device.capabilities.current_energy
                energy_consumed = pre_computation_energy - post_computation_energy
                self.energy_monitor.update_device_energy(device_id, energy_consumed)
                
                # Track communication if not last stage
                if stage != execution_plan[-1]:
                    next_device_id = execution_plan[execution_plan.index(stage) + 1]['device_id']
                    data_size = current_input.numel() * current_input.element_size()
                    self.comm_tracker.track_transfer(device_id, next_device_id, data_size)
                    
                    # Knowledge transfer
                    current_input = self.framework.knowledge_transfer.transfer_features(
                        current_input, device_id, next_device_id, current_input.size(-1)
                    )
            
            except Exception as e:
                print(f"Error on device {device_id}: {e}")
                
                if self.config.inference.fault_tolerance:
                    self.fault_manager.mark_device_failed(device_id)
                    # Try to continue with cached intermediate results
                    cached_result = device.get_cached_result(stage.get('layer_id', -1))
                    if cached_result is not None:
                        current_input = cached_result.to(current_input.device)
                        print(f"Recovered using cached result from device {device_id}")
                else:
                    raise e
        
        # Ensure final output has correct shape
        if final_output is not None and final_output.dim() > 2:
            final_output = torch.mean(final_output, dim=1)  # Global average pooling
        
        return final_output, intermediate_results
    
    def _create_execution_plan(self) -> List[Dict]:
        plan = []
        
        # Sort devices by layer assignment order
        if self.framework.current_partition:
            device_order = sorted(self.framework.partitioned_models.keys())
            
            for device_id in device_order:
                if device_id in self.framework.partitioned_models:
                    plan.append({
                        'device_id': device_id,
                        'sub_network': self.framework.partitioned_models[device_id],
                        'priority': len(self.framework.devices[device_id].assigned_layers)
                    })
        
        return plan
    
    def _transfer_to_device(self, tensor: torch.Tensor, target_device: torch.device) -> torch.Tensor:
        if tensor.device != target_device:
            start_time = time.time()
            tensor = tensor.to(target_device, non_blocking=True)
            transfer_time = time.time() - start_time
            
            # Track transfer time
            data_size = tensor.numel() * tensor.element_size()
            self.comm_tracker.add_transfer_time(transfer_time, data_size)
        
        return tensor
    
    def _generate_cache_key(self, tensor: torch.Tensor) -> str:
        # Generate a hash-based key for caching
        tensor_hash = hash(tensor.detach().cpu().numpy().tobytes())
        shape_str = '_'.join(map(str, tensor.shape))
        return f"{tensor_hash}_{shape_str}"
    
    def _batch_inference(self, input_list: List[torch.Tensor], 
                        batch_size: Optional[int] = None) -> List[InferenceResult]:
        if batch_size is None:
            batch_size = self.config.inference.batch_size
        
        results = []
        
        # Process in batches
        for i in range(0, len(input_list), batch_size):
            batch = input_list[i:i + batch_size]
            
            # Stack batch tensors
            try:
                batch_tensor = torch.stack(batch)
            except:
                # Handle different sized tensors
                batch_tensor = torch.cat([t.unsqueeze(0) for t in batch], dim=0)
            
            # Run inference on batch
            batch_result = self._single_inference(batch_tensor)
            
            # Split results back to individual samples
            if batch_result.predictions is not None:
                for j in range(len(batch)):
                    individual_result = InferenceResult(
                        predictions=batch_result.predictions[j:j+1],
                        confidence_scores=batch_result.confidence_scores[j:j+1],
                        latency=batch_result.latency / len(batch),
                        energy_consumed=batch_result.energy_consumed / len(batch),
                        communication_overhead=batch_result.communication_overhead / len(batch),
                        device_utilization=batch_result.device_utilization
                    )
                    results.append(individual_result)
            else:
                # Handle failed inference
                for j in range(len(batch)):
                    results.append(self._create_failure_result())
        
        return results
    
    def _handle_inference_failure(self, input_tensor: torch.Tensor, start_time: float) -> InferenceResult:
        latency = time.time() - start_time
        
        # Try fallback strategies
        if self.config.inference.fault_tolerance:
            # Try using cached intermediate results
            for device in self.framework.devices:
                if device.intermediate_cache:
                    # Use the most recent cached result
                    cached_keys = sorted(device.intermediate_cache.keys())
                    if cached_keys:
                        cached_result = device.intermediate_cache[cached_keys[-1]]
                        
                        # Create a basic prediction from cached result
                        if cached_result.dim() > 2:
                            cached_result = torch.mean(cached_result, dim=1)
                        
                        if cached_result.size(-1) != self.config.model.num_classes:
                            # Add a simple classifier
                            classifier = nn.Linear(cached_result.size(-1), self.config.model.num_classes)
                            cached_result = classifier(cached_result)
                        
                        return InferenceResult(
                            predictions=cached_result,
                            confidence_scores=torch.softmax(cached_result, dim=-1),
                            latency=latency,
                            energy_consumed=0.0,
                            communication_overhead=0.0,
                            device_utilization={}
                        )
        
        return self._create_failure_result(latency)
    
    def _create_failure_result(self, latency: float = 0.0) -> InferenceResult:
        # Return empty result for failed inference
        return InferenceResult(
            predictions=torch.tensor([]),
            confidence_scores=torch.tensor([]),
            latency=latency,
            energy_consumed=0.0,
            communication_overhead=0.0,
            device_utilization={}
        )
    
    def _calculate_device_utilization(self) -> Dict[int, float]:
        utilization = {}
        
        for device in self.framework.devices:
            device_id = device.capabilities.device_id
            utilization[device_id] = {
                'compute_utilization': device.capabilities.utilization,
                'memory_utilization': len(device.assigned_layers) / max(1, len(self.framework.current_partition.device_assignments)) if self.framework.current_partition else 0,
                'energy_ratio': device.get_energy_ratio()
            }
        
        return utilization
    
    def benchmark_inference(self, test_data: List[torch.Tensor], 
                           num_runs: int = 100) -> Dict[str, Any]:
        latencies = []
        energy_consumptions = []
        accuracies = []
        communication_overheads = []
        
        print(f"Running inference benchmark with {num_runs} iterations...")
        
        for i in range(num_runs):
            # Select random test sample
            test_sample = test_data[i % len(test_data)]
            
            # Run inference
            result = self._single_inference(test_sample)
            
            latencies.append(result.latency)
            energy_consumptions.append(result.energy_consumed)
            communication_overheads.append(result.communication_overhead)
            
            # Clear cache periodically to simulate realistic conditions
            if i % 20 == 0:
                self.cache.clear()
        
        # Calculate statistics
        benchmark_results = {
            'latency': {
                'mean': np.mean(latencies),
                'std': np.std(latencies),
                'min': np.min(latencies),
                'max': np.max(latencies),
                'p95': np.percentile(latencies, 95),
                'p99': np.percentile(latencies, 99)
            },
            'energy_consumption': {
                'mean': np.mean(energy_consumptions),
                'std': np.std(energy_consumptions),
                'total': np.sum(energy_consumptions)
            },
            'communication_overhead': {
                'mean': np.mean(communication_overheads),
                'std': np.std(communication_overheads),
                'total': np.sum(communication_overheads)
            },
            'throughput': {
                'samples_per_second': num_runs / np.sum(latencies)
            },
            'device_utilization': self._get_average_device_utilization()
        }
        
        return benchmark_results
    
    def _get_average_device_utilization(self) -> Dict[str, float]:
        total_util = 0.0
        total_energy = 0.0
        active_devices = 0
        
        for device in self.framework.devices:
            if self.fault_manager.is_device_available(device.capabilities.device_id):
                total_util += device.capabilities.utilization
                total_energy += device.get_energy_ratio()
                active_devices += 1
        
        if active_devices == 0:
            return {'avg_utilization': 0.0, 'avg_energy_ratio': 0.0}
        
        return {
            'avg_utilization': total_util / active_devices,
            'avg_energy_ratio': total_energy / active_devices,
            'active_devices': active_devices,
            'total_devices': len(self.framework.devices)
        }
    
    def start_continuous_inference(self, input_stream):
        self.processing_active = True
        
        def process_stream():
            while self.processing_active:
                try:
                    if hasattr(input_stream, '__next__'):
                        input_data = next(input_stream)
                        result = self._single_inference(input_data)
                        yield result
                    else:
                        time.sleep(0.01)  # Small delay to prevent busy waiting
                except StopIteration:
                    break
                except Exception as e:
                    print(f"Error in continuous inference: {e}")
                    time.sleep(0.1)
        
        return process_stream()
    
    def stop_continuous_inference(self):
        self.processing_active = False
    
    def cleanup(self):
        self.stop_continuous_inference()
        self.executor.shutdown(wait=True)
        self.cache.clear()