import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Tuple, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
from device import EdgeDevice, DeviceCapabilities

@dataclass
class LayerProfile:
    layer_id: int
    layer_type: str
    input_shape: Tuple[int, ...]
    output_shape: Tuple[int, ...]
    parameters: int
    flops: float
    memory_usage: float
    attention_heads: int = 0
    embedding_dim: int = 0
    is_attention: bool = False

@dataclass
class PartitionProfile:
    strategy: str
    device_assignments: Dict[int, int]  # layer_id -> device_id
    energy_consumption: float
    latency: float
    accuracy: float
    communication_cost: float
    score: float = 0.0

class ModelProfiler:
    def __init__(self):
        self.layer_profiles: List[LayerProfile] = []
    
    def profile_vit_model(self, model: nn.Module, input_shape: Tuple[int, ...]) -> List[LayerProfile]:
        profiles = []
        layer_id = 0
        
        # Profile each module in the model
        for name, module in model.named_modules():
            if self._is_computation_layer(module):
                profile = self._profile_layer(module, layer_id, name, input_shape)
                profiles.append(profile)
                layer_id += 1
        
        self.layer_profiles = profiles
        return profiles
    
    def _is_computation_layer(self, module: nn.Module) -> bool:
        computation_layers = (
            nn.Linear, nn.Conv2d, nn.MultiheadAttention,
            nn.LayerNorm, nn.GELU, nn.ReLU
        )
        return isinstance(module, computation_layers)
    
    def _profile_layer(self, module: nn.Module, layer_id: int, name: str, 
                      input_shape: Tuple[int, ...]) -> LayerProfile:
        # Calculate parameters
        parameters = sum(p.numel() for p in module.parameters())
        
        # Estimate FLOPs based on layer type
        flops = self._estimate_flops(module, input_shape)
        
        # Estimate memory usage (parameters + activations)
        memory_usage = parameters * 4 / (1024**3)  # 4 bytes per float32, convert to GB
        
        # Check if it's an attention layer
        is_attention = isinstance(module, nn.MultiheadAttention) or 'attn' in name.lower()
        attention_heads = getattr(module, 'num_heads', 0) if is_attention else 0
        embedding_dim = getattr(module, 'embed_dim', 0) if is_attention else 0
        
        return LayerProfile(
            layer_id=layer_id,
            layer_type=type(module).__name__,
            input_shape=input_shape,
            output_shape=input_shape,  # Simplified assumption
            parameters=parameters,
            flops=flops,
            memory_usage=memory_usage,
            attention_heads=attention_heads,
            embedding_dim=embedding_dim,
            is_attention=is_attention
        )
    
    def _estimate_flops(self, module: nn.Module, input_shape: Tuple[int, ...]) -> float:
        if isinstance(module, nn.Linear):
            return np.prod(input_shape) * module.out_features
        elif isinstance(module, nn.Conv2d):
            h, w = input_shape[-2:]
            return (h * w * module.in_channels * module.out_channels * 
                   module.kernel_size[0] * module.kernel_size[1])
        elif isinstance(module, nn.MultiheadAttention):
            seq_len = input_shape[1] if len(input_shape) > 1 else 1
            embed_dim = getattr(module, 'embed_dim', input_shape[-1])
            return seq_len * seq_len * embed_dim * 3  # Q, K, V operations
        else:
            return np.prod(input_shape) * 2  # Default estimate

class PartitioningStrategy(ABC):
    def __init__(self, devices: List[EdgeDevice], config=None):
        self.devices = devices
        self.config = config
        self.profiler = ModelProfiler()
    
    @abstractmethod
    def partition(self, layer_profiles: List[LayerProfile]) -> PartitionProfile:
        pass
    
    def calculate_communication_cost(self, assignments: Dict[int, int], 
                                   layer_profiles: List[LayerProfile]) -> float:
        total_cost = 0.0
        
        for i in range(len(layer_profiles) - 1):
            current_device = assignments.get(i, 0)
            next_device = assignments.get(i + 1, 0)
            
            if current_device != next_device:
                # Estimate data transfer cost
                output_size = np.prod(layer_profiles[i].output_shape) * 4 / (1024**2)  # MB
                bandwidth = min(self.devices[current_device].capabilities.bandwidth,
                              self.devices[next_device].capabilities.bandwidth)
                transfer_time = output_size / bandwidth  # seconds
                total_cost += transfer_time
        
        return total_cost

class MetaPartitioning(PartitioningStrategy):
    def partition(self, layer_profiles: List[LayerProfile]) -> PartitionProfile:
        assignments = {}
        total_energy = 0.0
        
        # Assign each layer independently to the best available device
        for profile in layer_profiles:
            best_device_id = self._find_best_device_for_layer(profile)
            assignments[profile.layer_id] = best_device_id
            
            # Update device energy
            device = self.devices[best_device_id]
            energy_cost = self._calculate_layer_energy_cost(profile, device)
            device.update_energy(energy_cost)
            total_energy += energy_cost
        
        communication_cost = self.calculate_communication_cost(assignments, layer_profiles)
        
        return PartitionProfile(
            strategy='meta',
            device_assignments=assignments,
            energy_consumption=total_energy,
            latency=max(communication_cost, self._estimate_computation_latency(assignments, layer_profiles)),
            accuracy=0.0,  # To be updated after evaluation
            communication_cost=communication_cost
        )
    
    def _find_best_device_for_layer(self, profile: LayerProfile) -> int:
        best_device_id = 0
        best_score = float('inf')
        
        for device in self.devices:
            if not device.is_energy_sufficient():
                continue
            
            energy_ratio = device.get_energy_ratio()
            compute_ratio = profile.flops / device.capabilities.compute_capacity
            memory_ratio = profile.memory_usage / device.capabilities.memory_capacity
            
            # Score based on energy efficiency and resource utilization
            score = compute_ratio + memory_ratio + (1.0 - energy_ratio)
            
            if score < best_score and device.can_handle_layer({
                'flops': profile.flops,
                'memory': profile.memory_usage
            }):
                best_score = score
                best_device_id = device.capabilities.device_id
        
        return best_device_id
    
    def _calculate_layer_energy_cost(self, profile: LayerProfile, device: EdgeDevice) -> float:
        # Simplified energy model based on computation
        base_energy = profile.flops / device.capabilities.compute_capacity * 0.01
        return min(base_energy, device.capabilities.current_energy * 0.1)
    
    def _estimate_computation_latency(self, assignments: Dict[int, int], 
                                    layer_profiles: List[LayerProfile]) -> float:
        device_workloads = {}
        
        for layer_id, device_id in assignments.items():
            if device_id not in device_workloads:
                device_workloads[device_id] = 0.0
            
            layer_profile = layer_profiles[layer_id]
            computation_time = layer_profile.flops / self.devices[device_id].capabilities.compute_capacity
            device_workloads[device_id] += computation_time
        
        return max(device_workloads.values()) if device_workloads else 0.0

class EnsembledPartitioning(PartitioningStrategy):
    def partition(self, layer_profiles: List[LayerProfile]) -> PartitionProfile:
        # Group layers with similar dimensions
        layer_groups = self._group_similar_layers(layer_profiles)
        assignments = {}
        total_energy = 0.0
        
        for group_layers in layer_groups:
            best_device_id = self._find_best_device_for_group(group_layers)
            
            for profile in group_layers:
                assignments[profile.layer_id] = best_device_id
            
            # Update device energy for the entire group
            device = self.devices[best_device_id]
            group_energy = sum(self._calculate_layer_energy_cost(p, device) for p in group_layers)
            device.update_energy(group_energy)
            total_energy += group_energy
        
        communication_cost = self.calculate_communication_cost(assignments, layer_profiles)
        
        return PartitionProfile(
            strategy='ensembled',
            device_assignments=assignments,
            energy_consumption=total_energy,
            latency=max(communication_cost, self._estimate_computation_latency(assignments, layer_profiles)),
            accuracy=0.0,
            communication_cost=communication_cost
        )
    
    def _group_similar_layers(self, layer_profiles: List[LayerProfile]) -> List[List[LayerProfile]]:
        groups = []
        current_group = []
        
        for i, profile in enumerate(layer_profiles):
            if not current_group:
                current_group.append(profile)
            else:
                # Check if layer is similar to current group
                if self._is_similar_layer(profile, current_group[0]):
                    current_group.append(profile)
                else:
                    groups.append(current_group)
                    current_group = [profile]
        
        if current_group:
            groups.append(current_group)
        
        return groups
    
    def _is_similar_layer(self, profile1: LayerProfile, profile2: LayerProfile) -> bool:
        # Layers are similar if they have similar computational requirements
        flops_ratio = min(profile1.flops, profile2.flops) / max(profile1.flops, profile2.flops)
        memory_ratio = min(profile1.memory_usage, profile2.memory_usage) / max(profile1.memory_usage, profile2.memory_usage)
        
        return flops_ratio > 0.7 and memory_ratio > 0.7 and profile1.layer_type == profile2.layer_type
    
    def _find_best_device_for_group(self, group_layers: List[LayerProfile]) -> int:
        total_flops = sum(p.flops for p in group_layers)
        total_memory = sum(p.memory_usage for p in group_layers)
        
        best_device_id = 0
        best_score = float('inf')
        
        for device in self.devices:
            if not device.can_handle_layer({'flops': total_flops, 'memory': total_memory}):
                continue
            
            energy_ratio = device.get_energy_ratio()
            compute_ratio = total_flops / device.capabilities.compute_capacity
            memory_ratio = total_memory / device.capabilities.memory_capacity
            
            score = compute_ratio + memory_ratio + (1.0 - energy_ratio)
            
            if score < best_score:
                best_score = score
                best_device_id = device.capabilities.device_id
        
        return best_device_id
    
    def _calculate_layer_energy_cost(self, profile: LayerProfile, device: EdgeDevice) -> float:
        base_energy = profile.flops / device.capabilities.compute_capacity * 0.01
        return min(base_energy, device.capabilities.current_energy * 0.1)
    
    def _estimate_computation_latency(self, assignments: Dict[int, int], 
                                    layer_profiles: List[LayerProfile]) -> float:
        device_workloads = {}
        
        for layer_id, device_id in assignments.items():
            if device_id not in device_workloads:
                device_workloads[device_id] = 0.0
            
            layer_profile = layer_profiles[layer_id]
            computation_time = layer_profile.flops / self.devices[device_id].capabilities.compute_capacity
            device_workloads[device_id] += computation_time
        
        return max(device_workloads.values()) if device_workloads else 0.0

class HybridPartitioning(PartitioningStrategy):
    def partition(self, layer_profiles: List[LayerProfile]) -> PartitionProfile:
        # Combine meta and ensembled approaches based on device capabilities
        meta_strategy = MetaPartitioning(self.devices, self.config)
        ensembled_strategy = EnsembledPartitioning(self.devices, self.config)
        
        # Decide which layers to use meta vs ensembled partitioning
        assignments = {}
        total_energy = 0.0
        
        # Group layers by complexity
        high_complexity_layers = [p for p in layer_profiles if p.flops > np.median([l.flops for l in layer_profiles])]
        low_complexity_layers = [p for p in layer_profiles if p not in high_complexity_layers]
        
        # Use ensembled for high complexity layers (better resource utilization)
        if high_complexity_layers:
            ensembled_profile = ensembled_strategy.partition(high_complexity_layers)
            assignments.update(ensembled_profile.device_assignments)
            total_energy += ensembled_profile.energy_consumption
        
        # Use meta for low complexity layers (better parallelization)
        if low_complexity_layers:
            meta_profile = meta_strategy.partition(low_complexity_layers)
            assignments.update(meta_profile.device_assignments)
            total_energy += meta_profile.energy_consumption
        
        # Optimize assignments based on current device states
        assignments = self._optimize_assignments(assignments, layer_profiles)
        
        communication_cost = self.calculate_communication_cost(assignments, layer_profiles)
        
        return PartitionProfile(
            strategy='hybrid',
            device_assignments=assignments,
            energy_consumption=total_energy,
            latency=max(communication_cost, self._estimate_computation_latency(assignments, layer_profiles)),
            accuracy=0.0,
            communication_cost=communication_cost
        )
    
    def _optimize_assignments(self, assignments: Dict[int, int], 
                            layer_profiles: List[LayerProfile]) -> Dict[int, int]:
        optimized = assignments.copy()
        
        # Check for overloaded devices and redistribute
        device_loads = {}
        for layer_id, device_id in assignments.items():
            if device_id not in device_loads:
                device_loads[device_id] = []
            device_loads[device_id].append(layer_id)
        
        # Redistribute if any device is overloaded
        for device_id, layer_ids in device_loads.items():
            device = self.devices[device_id]
            total_load = sum(layer_profiles[lid].flops for lid in layer_ids)
            
            if total_load > device.capabilities.compute_capacity * 0.8:  # 80% threshold
                # Move some layers to less loaded devices
                available_devices = [d for d in self.devices if d.is_energy_sufficient()]
                available_devices.sort(key=lambda d: len(device_loads.get(d.capabilities.device_id, [])))
                
                # Move half of the layers to the least loaded device
                layers_to_move = layer_ids[len(layer_ids)//2:]
                target_device = available_devices[0] if available_devices else device
                
                for lid in layers_to_move:
                    optimized[lid] = target_device.capabilities.device_id
        
        return optimized
    
    def _estimate_computation_latency(self, assignments: Dict[int, int], 
                                    layer_profiles: List[LayerProfile]) -> float:
        device_workloads = {}
        
        for layer_id, device_id in assignments.items():
            if device_id not in device_workloads:
                device_workloads[device_id] = 0.0
            
            layer_profile = layer_profiles[layer_id]
            computation_time = layer_profile.flops / self.devices[device_id].capabilities.compute_capacity
            device_workloads[device_id] += computation_time
        
        return max(device_workloads.values()) if device_workloads else 0.0

class PartitionSelector:
    def __init__(self, devices: List[EdgeDevice], config):
        self.devices = devices
        self.config = config
        self.strategies = {
            'meta': MetaPartitioning(devices, config),
            'ensembled': EnsembledPartitioning(devices, config),
            'hybrid': HybridPartitioning(devices, config)
        }
    
    def select_best_partition(self, layer_profiles: List[LayerProfile], 
                            requirements: Dict) -> PartitionProfile:
        candidates = []
        
        # Generate partition profiles for each strategy
        for strategy_name, strategy in self.strategies.items():
            if strategy_name == self.config.partition.strategy or self.config.partition.strategy == 'auto':
                profile = strategy.partition(layer_profiles)
                candidates.append(profile)
        
        if not candidates:
            # Fallback to hybrid strategy
            profile = self.strategies['hybrid'].partition(layer_profiles)
            candidates.append(profile)
        
        # Filter candidates that meet requirements
        valid_candidates = self._filter_candidates(candidates, requirements)
        
        if not valid_candidates:
            # Return best by accuracy if no candidates meet requirements
            return max(candidates, key=lambda x: x.accuracy) if candidates else candidates[0]
        
        # Score and select best candidate
        best_candidate = self._score_candidates(valid_candidates, requirements)
        return best_candidate
    
    def _filter_candidates(self, candidates: List[PartitionProfile], 
                          requirements: Dict) -> List[PartitionProfile]:
        valid = []
        
        for candidate in candidates:
            if (candidate.accuracy >= requirements.get('accuracy', 0.0) and
                candidate.energy_consumption <= requirements.get('energy', float('inf')) and
                candidate.latency <= requirements.get('latency', float('inf'))):
                valid.append(candidate)
        
        return valid
    
    def _score_candidates(self, candidates: List[PartitionProfile], 
                         requirements: Dict) -> PartitionProfile:
        weights = self.config.get_partition_weights()
        best_score = -float('inf')
        best_candidate = candidates[0]
        
        # Normalize metrics for scoring
        max_energy = max(c.energy_consumption for c in candidates)
        max_latency = max(c.latency for c in candidates)
        
        for candidate in candidates:
            accuracy_score = weights['accuracy'] * candidate.accuracy
            energy_score = weights['energy'] * (1.0 - candidate.energy_consumption / max_energy) if max_energy > 0 else 0
            latency_score = weights['latency'] * (1.0 - candidate.latency / max_latency) if max_latency > 0 else 0
            
            total_score = accuracy_score + energy_score + latency_score
            candidate.score = total_score
            
            if total_score > best_score:
                best_score = total_score
                best_candidate = candidate
        
        return best_candidate