import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import timm
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import time
import threading
from collections import defaultdict

from device import EdgeDevice, DeviceManager
from partitioning import PartitionSelector, ModelProfiler, LayerProfile, PartitionProfile

class AttentionSummarizer(nn.Module):
    def __init__(self, input_dim: int, compression_ratio: float = 0.3):
        super().__init__()
        self.compression_ratio = compression_ratio
        self.compressed_dim = int(input_dim * compression_ratio)
        
        self.compressor = nn.Sequential(
            nn.Linear(input_dim, self.compressed_dim),
            nn.ReLU(),
            nn.Linear(self.compressed_dim, input_dim)
        )
        
        self.attention_weights = nn.Parameter(torch.ones(input_dim))
    
    def forward(self, x: torch.Tensor, attention_matrix: torch.Tensor) -> torch.Tensor:
        # Apply attention-based compression
        attention_scores = torch.softmax(self.attention_weights, dim=0)
        weighted_features = x * attention_scores
        
        # Compress features
        compressed = self.compressor(weighted_features)
        
        # Apply attention matrix for feature selection
        if attention_matrix is not None:
            importance_scores = torch.mean(attention_matrix, dim=(0, 1))  # Average across heads and sequence
            top_k = int(x.size(-1) * self.compression_ratio)
            _, top_indices = torch.topk(importance_scores, top_k)
            
            compressed = compressed[..., top_indices]
        
        return compressed

class ElasticViTLayer(nn.Module):
    def __init__(self, original_layer: nn.Module, max_heads: int, max_dim: int):
        super().__init__()
        self.original_layer = original_layer
        self.max_heads = max_heads
        self.max_dim = max_dim
        self.current_heads = max_heads
        self.current_dim = max_dim
        
    def adjust_complexity(self, energy_ratio: float):
        # Adjust number of attention heads and dimension based on energy
        self.current_heads = max(1, int(self.max_heads * energy_ratio))
        self.current_dim = max(64, int(self.max_dim * energy_ratio))
        
        # Modify the attention mechanism if it exists
        if hasattr(self.original_layer, 'num_heads'):
            self.original_layer.num_heads = self.current_heads
        
        if hasattr(self.original_layer, 'head_dim'):
            self.original_layer.head_dim = self.current_dim // self.current_heads
    
    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        # Adjust input dimension if necessary
        if x.size(-1) != self.current_dim and hasattr(self.original_layer, 'qkv'):
            # Truncate or pad features to match current dimension
            if x.size(-1) > self.current_dim:
                x = x[..., :self.current_dim]
            else:
                padding = self.current_dim - x.size(-1)
                x = F.pad(x, (0, padding))
        
        return self.original_layer(x, **kwargs)

class SubNetworkWrapper(nn.Module):
    def __init__(self, layers: List[nn.Module], device_id: int, config):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        self.device_id = device_id
        self.config = config
        self.summarizer = None
        
        # Initialize attention summarizer if needed
        if layers and hasattr(layers[0], 'embed_dim'):
            embed_dim = getattr(layers[0], 'embed_dim', 768)
            self.summarizer = AttentionSummarizer(
                embed_dim, config.partition.compression_ratio
            )
    
    def forward(self, x: torch.Tensor, cache_intermediate: bool = True) -> Dict[str, torch.Tensor]:
        outputs = {}
        attention_matrices = []
        
        for i, layer in enumerate(self.layers):
            try:
                if isinstance(layer, ElasticViTLayer):
                    x = layer(x)
                elif hasattr(layer, 'attention') or 'attn' in str(type(layer)).lower():
                    # Handle attention layers specially
                    if hasattr(layer, 'forward_with_attention'):
                        x, attn = layer.forward_with_attention(x)
                        attention_matrices.append(attn)
                    else:
                        x = layer(x)
                else:
                    x = layer(x)
                
                if cache_intermediate:
                    outputs[f'layer_{i}'] = x.detach().cpu()
                    
            except Exception as e:
                print(f"Error in layer {i} on device {self.device_id}: {e}")
                break
        
        # Summarize features if attention summarizer is available
        if self.summarizer is not None and attention_matrices:
            avg_attention = torch.mean(torch.stack(attention_matrices), dim=0)
            x = self.summarizer(x, avg_attention)
        
        outputs['final'] = x
        return outputs

class KnowledgeTransfer:
    def __init__(self, config):
        self.config = config
        self.transfer_matrices = {}
    
    def create_transfer_matrix(self, from_dim: int, to_dim: int, device: torch.device) -> torch.Tensor:
        key = f"{from_dim}_{to_dim}"
        if key not in self.transfer_matrices:
            # Initialize transfer matrix
            if from_dim == to_dim:
                matrix = torch.eye(from_dim, device=device)
            else:
                matrix = torch.randn(to_dim, from_dim, device=device)
                matrix = F.normalize(matrix, dim=1)
            self.transfer_matrices[key] = matrix
        
        return self.transfer_matrices[key]
    
    def transfer_features(self, features: torch.Tensor, from_device: int, 
                         to_device: int, target_dim: int) -> torch.Tensor:
        current_dim = features.size(-1)
        
        if current_dim != target_dim:
            device = torch.device(f'cuda:{to_device}' if torch.cuda.is_available() else 'cpu')
            transfer_matrix = self.create_transfer_matrix(current_dim, target_dim, device)
            
            features = features.to(device)
            transformed = torch.matmul(features, transfer_matrix.T)
            return transformed
        
        return features.to(f'cuda:{to_device}' if torch.cuda.is_available() else 'cpu')

class LAPDTRFramework:
    def __init__(self, config, devices: List[EdgeDevice]):
        self.config = config
        self.devices = devices
        self.device_manager = DeviceManager(config)
        self.device_manager.devices = devices
        
        self.partition_selector = PartitionSelector(devices, config)
        self.model_profiler = ModelProfiler()
        self.knowledge_transfer = KnowledgeTransfer(config)
        
        self.partitioned_models: Dict[int, SubNetworkWrapper] = {}
        self.current_partition: Optional[PartitionProfile] = None
        self.monitoring_active = False
        self.performance_metrics = defaultdict(list)
        
    def load_vit_model(self, model_name: str = None) -> nn.Module:
        if model_name is None:
            model_name = self.config.model.model_name
        
        try:
            # Load pre-trained ViT model using timm
            model = timm.create_model(
                model_name,
                pretrained=True,
                num_classes=self.config.model.num_classes,
                img_size=self.config.model.img_size
            )
            
            # Make layers elastic
            model = self._make_model_elastic(model)
            
            return model
        except Exception as e:
            print(f"Error loading model {model_name}: {e}")
            # Fallback to basic ViT
            return self._create_basic_vit()
    
    def _make_model_elastic(self, model: nn.Module) -> nn.Module:
        for name, module in model.named_modules():
            if hasattr(module, 'num_heads') or 'attn' in name.lower():
                # Replace with elastic version
                max_heads = getattr(module, 'num_heads', self.config.model.max_attention_heads)
                max_dim = getattr(module, 'embed_dim', self.config.model.max_hidden_dim)
                
                elastic_layer = ElasticViTLayer(module, max_heads, max_dim)
                
                # Replace in model
                parent = model
                names = name.split('.')
                for n in names[:-1]:
                    parent = getattr(parent, n)
                setattr(parent, names[-1], elastic_layer)
        
        return model
    
    def _create_basic_vit(self) -> nn.Module:
        # Create a basic ViT as fallback
        from timm.models.vision_transformer import VisionTransformer
        
        model = VisionTransformer(
            img_size=self.config.model.img_size,
            patch_size=self.config.model.patch_size,
            embed_dim=self.config.model.embed_dim,
            depth=self.config.model.depth,
            num_heads=self.config.model.num_heads,
            num_classes=self.config.model.num_classes
        )
        
        return self._make_model_elastic(model)
    
    def partition_model(self, model: nn.Module, strategy: str = None) -> Dict[int, SubNetworkWrapper]:
        if strategy:
            self.config.partition.strategy = strategy
        
        # Profile the model
        input_shape = (1, 3, self.config.model.img_size, self.config.model.img_size)
        layer_profiles = self.model_profiler.profile_vit_model(model, input_shape)
        
        # Define requirements
        requirements = {
            'accuracy': self.config.required_accuracy,
            'energy': self.config.required_energy,
            'latency': self.config.required_latency
        }
        
        # Select best partition
        self.current_partition = self.partition_selector.select_best_partition(
            layer_profiles, requirements
        )
        
        # Create sub-networks for each device
        self.partitioned_models = self._create_sub_networks(model, self.current_partition)
        
        return self.partitioned_models
    
    def _create_sub_networks(self, model: nn.Module, partition: PartitionProfile) -> Dict[int, SubNetworkWrapper]:
        # Group layers by device assignment
        device_layers = defaultdict(list)
        
        model_layers = list(model.modules())[1:]  # Skip the model itself
        
        for layer_id, device_id in partition.device_assignments.items():
            if layer_id < len(model_layers):
                device_layers[device_id].append(model_layers[layer_id])
        
        # Create sub-network wrappers
        sub_networks = {}
        for device_id, layers in device_layers.items():
            if layers:  # Only create if there are layers assigned
                sub_network = SubNetworkWrapper(layers, device_id, self.config)
                
                # Move to appropriate device
                device = self.devices[device_id].torch_device
                sub_network = sub_network.to(device)
                
                sub_networks[device_id] = sub_network
        
        return sub_networks
    
    def dynamic_redistribute(self, performance_metrics: Dict) -> bool:
        if not self.current_partition:
            return False
        
        # Check if redistribution is needed
        need_redistribution = False
        
        for device in self.devices:
            device_id = device.capabilities.device_id
            
            # Check energy threshold
            if not device.is_energy_sufficient(self.config.device.energy_threshold):
                need_redistribution = True
                break
            
            # Check performance degradation
            if device_id in performance_metrics:
                metrics = performance_metrics[device_id]
                if (metrics.get('latency', 0) > self.config.required_latency or
                    metrics.get('utilization', 0) > 0.9):
                    need_redistribution = True
                    break
        
        if need_redistribution:
            return self._perform_redistribution()
        
        return False
    
    def _perform_redistribution(self) -> bool:
        try:
            # Get current layer assignments
            current_assignments = self.current_partition.device_assignments.copy()
            
            # Find overloaded or low-energy devices
            problem_devices = []
            available_devices = []
            
            for device in self.devices:
                device_id = device.capabilities.device_id
                
                if (not device.is_energy_sufficient(self.config.device.energy_threshold) or
                    device.capabilities.utilization > 0.9):
                    problem_devices.append(device_id)
                elif device.is_energy_sufficient(0.3):  # Has reasonable energy
                    available_devices.append(device_id)
            
            # Redistribute layers from problem devices
            redistributed = False
            for problem_device_id in problem_devices:
                # Find layers assigned to this device
                layers_to_move = [layer_id for layer_id, device_id in current_assignments.items()
                                if device_id == problem_device_id]
                
                # Move layers to available devices
                for i, layer_id in enumerate(layers_to_move):
                    if available_devices:
                        target_device = available_devices[i % len(available_devices)]
                        current_assignments[layer_id] = target_device
                        redistributed = True
            
            if redistributed:
                # Update partition profile
                self.current_partition.device_assignments = current_assignments
                
                # Update device assignments
                for device_id in problem_devices:
                    if device_id in self.partitioned_models:
                        device = self.devices[device_id]
                        device.assigned_layers.clear()
                
                for layer_id, device_id in current_assignments.items():
                    device = self.devices[device_id]
                    device.assign_layer(layer_id)
                
                return True
        
        except Exception as e:
            print(f"Error during redistribution: {e}")
        
        return False
    
    def adjust_layer_complexity(self, device_id: int, energy_ratio: float):
        if device_id in self.partitioned_models:
            sub_network = self.partitioned_models[device_id]
            
            for layer in sub_network.layers:
                if isinstance(layer, ElasticViTLayer):
                    layer.adjust_complexity(energy_ratio)
    
    def start_monitoring(self):
        self.monitoring_active = True
        monitor_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
        monitor_thread.start()
    
    def stop_monitoring(self):
        self.monitoring_active = False
    
    def _monitoring_loop(self):
        while self.monitoring_active:
            try:
                # Update device states
                self.device_manager.update_device_states()
                
                # Collect performance metrics
                current_metrics = {}
                for device in self.devices:
                    device_id = device.capabilities.device_id
                    metrics = {
                        'energy_ratio': device.get_energy_ratio(),
                        'utilization': device.capabilities.utilization,
                        'temperature': device.capabilities.temperature,
                        'assigned_layers': len(device.assigned_layers)
                    }
                    current_metrics[device_id] = metrics
                    
                    # Adjust layer complexity based on energy
                    energy_ratio = device.get_energy_ratio()
                    if energy_ratio < 0.5:  # Low energy threshold
                        self.adjust_layer_complexity(device_id, energy_ratio)
                
                # Check if redistribution is needed
                self.dynamic_redistribute(current_metrics)
                
                # Store metrics
                self.performance_metrics['timestamp'].append(time.time())
                for device_id, metrics in current_metrics.items():
                    for metric_name, value in metrics.items():
                        self.performance_metrics[f'device_{device_id}_{metric_name}'].append(value)
                
                time.sleep(1)  # Monitor every second
                
            except Exception as e:
                print(f"Error in monitoring loop: {e}")
                time.sleep(5)  # Wait longer on error
    
    def get_partition_summary(self) -> Dict:
        if not self.current_partition:
            return {}
        
        summary = {
            'strategy': self.current_partition.strategy,
            'total_devices': len(self.devices),
            'active_devices': len([d for d in self.devices if d.is_energy_sufficient()]),
            'energy_consumption': self.current_partition.energy_consumption,
            'estimated_latency': self.current_partition.latency,
            'communication_cost': self.current_partition.communication_cost,
            'device_assignments': {}
        }
        
        # Device assignment details
        for device_id, layers in self.current_partition.device_assignments.items():
            device = self.devices[device_id]
            summary['device_assignments'][device_id] = {
                'layers': layers if isinstance(layers, list) else [layers],
                'energy_ratio': device.get_energy_ratio(),
                'compute_capacity': device.capabilities.compute_capacity,
                'memory_capacity': device.capabilities.memory_capacity
            }
        
        return summary
    
    def save_partition_state(self, filepath: str):
        state = {
            'partition_profile': self.current_partition.__dict__ if self.current_partition else None,
            'device_states': [device.capabilities.__dict__ for device in self.devices],
            'performance_metrics': dict(self.performance_metrics),
            'config': self.config.__dict__
        }
        
        torch.save(state, filepath)
    
    def load_partition_state(self, filepath: str):
        state = torch.load(filepath)
        
        # Restore device states
        for i, device_state in enumerate(state['device_states']):
            if i < len(self.devices):
                for key, value in device_state.items():
                    if hasattr(self.devices[i].capabilities, key):
                        setattr(self.devices[i].capabilities, key, value)
        
        # Restore performance metrics
        self.performance_metrics = defaultdict(list, state['performance_metrics'])