import torch
import numpy as np
import time
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional, Any, Tuple
from collections import defaultdict, deque
from dataclasses import dataclass, asdict
import threading
import psutil

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

@dataclass
class MetricSnapshot:
    timestamp: float
    accuracy: float
    loss: float
    energy_consumption: float
    latency: float
    communication_overhead: float
    device_utilization: Dict[int, float]

class MetricsTracker:
    def __init__(self, config=None):
        self.config = config
        self.metrics_history: List[MetricSnapshot] = []
        self.current_metrics = defaultdict(list)
        self.lock = threading.Lock()
        
        # Wandb integration
        self.use_wandb = config.monitoring.use_wandb if config else False
        if self.use_wandb and WANDB_AVAILABLE:
            wandb.init(project=config.monitoring.wandb_project)
    
    def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
        with self.lock:
            timestamp = time.time()
            
            # Store in history
            for key, value in metrics.items():
                self.current_metrics[key].append((timestamp, value))
            
            # Log to wandb if available
            if self.use_wandb and WANDB_AVAILABLE:
                wandb.log(metrics, step=step)
    
    def log_snapshot(self, snapshot: MetricSnapshot):
        with self.lock:
            self.metrics_history.append(snapshot)
    
    def get_metric_statistics(self, metric_name: str, window_size: int = 100) -> Dict[str, float]:
        with self.lock:
            if metric_name not in self.current_metrics:
                return {}
            
            recent_values = [v for _, v in self.current_metrics[metric_name][-window_size:]]
            
            if not recent_values:
                return {}
            
            return {
                'mean': np.mean(recent_values),
                'std': np.std(recent_values),
                'min': np.min(recent_values),
                'max': np.max(recent_values),
                'median': np.median(recent_values),
                'count': len(recent_values)
            }
    
    def plot_metrics(self, metrics: List[str], save_path: Optional[str] = None):
        fig, axes = plt.subplots(len(metrics), 1, figsize=(12, 4 * len(metrics)))
        if len(metrics) == 1:
            axes = [axes]
        
        for i, metric_name in enumerate(metrics):
            if metric_name in self.current_metrics:
                timestamps, values = zip(*self.current_metrics[metric_name])
                timestamps = [(t - timestamps[0]) / 3600 for t in timestamps]  # Convert to hours
                
                axes[i].plot(timestamps, values, label=metric_name)
                axes[i].set_xlabel('Time (hours)')
                axes[i].set_ylabel(metric_name.replace('_', ' ').title())
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        else:
            plt.show()
    
    def export_metrics(self, filepath: str):
        with self.lock:
            export_data = {
                'metrics_history': [asdict(snapshot) for snapshot in self.metrics_history],
                'current_metrics': dict(self.current_metrics),
                'export_timestamp': time.time()
            }
            
            with open(filepath, 'w') as f:
                json.dump(export_data, f, indent=2, default=str)
    
    def clear_metrics(self):
        with self.lock:
            self.current_metrics.clear()
            self.metrics_history.clear()

class EnergyMonitor:
    def __init__(self, devices, update_interval: float = 1.0):
        self.devices = devices
        self.update_interval = update_interval
        
        # Energy tracking
        self.device_energy_history = defaultdict(list)
        self.total_energy_consumed = 0.0
        self.epoch_energy_counters = defaultdict(float)
        self.inference_energy_counters = defaultdict(float)
        
        # System monitoring
        self.system_power_baseline = self._get_system_power_baseline()
        self.monitoring_active = False
        self.lock = threading.Lock()
    
    def _get_system_power_baseline(self) -> float:
        try:
            # Estimate baseline power consumption
            cpu_percent = psutil.cpu_percent(interval=1)
            memory_percent = psutil.virtual_memory().percent
            
            # Simple power model: base power + CPU factor + memory factor
            base_power = 50.0  # watts
            cpu_power = cpu_percent * 0.5  # 0.5W per percent CPU usage
            memory_power = memory_percent * 0.2  # 0.2W per percent memory usage
            
            return base_power + cpu_power + memory_power
        except:
            return 75.0  # Default baseline
    
    def start_monitoring(self):
        self.monitoring_active = True
        monitor_thread = threading.Thread(target=self._monitoring_loop, daemon=True)
        monitor_thread.start()
    
    def stop_monitoring(self):
        self.monitoring_active = False
    
    def _monitoring_loop(self):
        while self.monitoring_active:
            try:
                timestamp = time.time()
                
                for device in self.devices:
                    device_id = device.capabilities.device_id
                    
                    # Estimate energy consumption based on device utilization
                    utilization = device.capabilities.utilization
                    compute_capacity = device.capabilities.compute_capacity
                    
                    # Energy model: base consumption + utilization-based consumption
                    base_energy = 0.001  # 1mWh base consumption per interval
                    utilization_energy = utilization * compute_capacity * 0.0001  # Scale factor
                    
                    total_energy = base_energy + utilization_energy
                    
                    with self.lock:
                        self.device_energy_history[device_id].append((timestamp, total_energy))
                        self.total_energy_consumed += total_energy
                
                time.sleep(self.update_interval)
            
            except Exception as e:
                print(f"Error in energy monitoring: {e}")
                time.sleep(self.update_interval * 2)
    
    def update_device_energy(self, device_id: int, energy_consumed: float):
        with self.lock:
            timestamp = time.time()
            self.device_energy_history[device_id].append((timestamp, energy_consumed))
            self.total_energy_consumed += energy_consumed
            self.epoch_energy_counters[device_id] += energy_consumed
            self.inference_energy_counters[device_id] += energy_consumed
    
    def update_batch_energy(self, batch_time: float):
        # Estimate energy consumption for a training batch
        estimated_energy = batch_time * self.system_power_baseline / 3600.0  # Convert to Wh
        
        with self.lock:
            self.total_energy_consumed += estimated_energy
            
            # Distribute energy among active devices
            active_devices = [d for d in self.devices if d.is_energy_sufficient()]
            if active_devices:
                energy_per_device = estimated_energy / len(active_devices)
                for device in active_devices:
                    device_id = device.capabilities.device_id
                    self.epoch_energy_counters[device_id] += energy_per_device
    
    def reset_epoch_counters(self):
        with self.lock:
            self.epoch_energy_counters.clear()
    
    def reset_inference_counters(self):
        with self.lock:
            self.inference_energy_counters.clear()
    
    def get_epoch_statistics(self) -> Dict[str, float]:
        with self.lock:
            total_epoch_energy = sum(self.epoch_energy_counters.values())
            
            stats = {
                'total_energy': total_epoch_energy,
                'avg_device_energy': np.mean(list(self.epoch_energy_counters.values())) if self.epoch_energy_counters else 0.0,
                'max_device_energy': max(self.epoch_energy_counters.values()) if self.epoch_energy_counters else 0.0,
                'energy_efficiency': total_epoch_energy / len(self.devices) if self.devices else 0.0
            }
            
            return stats
    
    def get_inference_energy(self) -> float:
        with self.lock:
            return sum(self.inference_energy_counters.values())
    
    def get_device_energy_history(self, device_id: int, window_hours: float = 1.0) -> List[Tuple[float, float]]:
        with self.lock:
            if device_id not in self.device_energy_history:
                return []
            
            current_time = time.time()
            cutoff_time = current_time - (window_hours * 3600)
            
            return [(t, e) for t, e in self.device_energy_history[device_id] if t >= cutoff_time]
    
    def calculate_energy_efficiency(self, accuracy: float) -> float:
        # Energy efficiency metric: accuracy per unit energy
        total_energy = self.get_inference_energy()
        if total_energy > 0:
            return accuracy / total_energy
        return 0.0

class CommunicationTracker:
    def __init__(self):
        self.transfer_history: List[Dict] = []
        self.device_bandwidth_usage = defaultdict(list)
        self.total_data_transferred = 0.0
        self.epoch_data_counters = defaultdict(float)
        self.inference_data_counters = defaultdict(float)
        self.lock = threading.Lock()
    
    def track_transfer(self, from_device: int, to_device: int, data_size: float):
        timestamp = time.time()
        
        with self.lock:
            transfer_record = {
                'timestamp': timestamp,
                'from_device': from_device,
                'to_device': to_device,
                'data_size': data_size,  # in bytes
                'data_size_mb': data_size / (1024 * 1024)  # in MB
            }
            
            self.transfer_history.append(transfer_record)
            self.total_data_transferred += data_size
            self.epoch_data_counters[from_device] += data_size
            self.inference_data_counters[from_device] += data_size
    
    def add_transfer_time(self, transfer_time: float, data_size: float):
        if transfer_time > 0:
            bandwidth = data_size / transfer_time  # bytes per second
            
            with self.lock:
                self.device_bandwidth_usage[time.time()].append(bandwidth)
    
    def reset_epoch_counters(self):
        with self.lock:
            self.epoch_data_counters.clear()
    
    def reset_inference_counters(self):
        with self.lock:
            self.inference_data_counters.clear()
    
    def get_epoch_statistics(self) -> Dict[str, float]:
        with self.lock:
            total_data = sum(self.epoch_data_counters.values())
            
            stats = {
                'total_data_transferred': total_data / (1024 * 1024),  # MB
                'avg_transfer_size': np.mean(list(self.epoch_data_counters.values())) / (1024 * 1024) if self.epoch_data_counters else 0.0,
                'num_transfers': len(self.epoch_data_counters),
                'communication_efficiency': self._calculate_communication_efficiency()
            }
            
            return stats
    
    def get_inference_data_transferred(self) -> float:
        with self.lock:
            return sum(self.inference_data_counters.values()) / (1024 * 1024)  # MB
    
    def _calculate_communication_efficiency(self) -> float:
        # Calculate efficiency as useful data vs overhead
        if not self.transfer_history:
            return 0.0
        
        recent_transfers = self.transfer_history[-100:]  # Last 100 transfers
        total_size = sum(t['data_size'] for t in recent_transfers)
        num_transfers = len(recent_transfers)
        
        if num_transfers == 0:
            return 0.0
        
        avg_transfer_size = total_size / num_transfers
        overhead_ratio = num_transfers / total_size if total_size > 0 else 0
        
        return 1.0 / (1.0 + overhead_ratio)  # Higher is better
    
    def get_bandwidth_statistics(self, window_minutes: float = 5.0) -> Dict[str, float]:
        current_time = time.time()
        cutoff_time = current_time - (window_minutes * 60)
        
        recent_bandwidths = []
        with self.lock:
            for timestamp, bandwidths in self.device_bandwidth_usage.items():
                if timestamp >= cutoff_time:
                    recent_bandwidths.extend(bandwidths)
        
        if not recent_bandwidths:
            return {}
        
        return {
            'avg_bandwidth_mbps': np.mean(recent_bandwidths) / (1024 * 1024),
            'max_bandwidth_mbps': np.max(recent_bandwidths) / (1024 * 1024),
            'min_bandwidth_mbps': np.min(recent_bandwidths) / (1024 * 1024),
            'std_bandwidth_mbps': np.std(recent_bandwidths) / (1024 * 1024)
        }

class PerformanceProfiler:
    def __init__(self):
        self.execution_times = defaultdict(list)
        self.memory_usage = defaultdict(list)
        self.gpu_utilization = defaultdict(list)
        self.bottlenecks = []
        self.lock = threading.Lock()
    
    def profile_function(self, func_name: str):
        def decorator(func):
            def wrapper(*args, **kwargs):
                start_time = time.time()
                start_memory = self._get_memory_usage()
                
                try:
                    result = func(*args, **kwargs)
                    
                    end_time = time.time()
                    end_memory = self._get_memory_usage()
                    
                    execution_time = end_time - start_time
                    memory_delta = end_memory - start_memory
                    
                    with self.lock:
                        self.execution_times[func_name].append(execution_time)
                        self.memory_usage[func_name].append(memory_delta)
                        
                        # Detect potential bottlenecks
                        if execution_time > 1.0:  # Functions taking more than 1 second
                            self.bottlenecks.append({
                                'function': func_name,
                                'execution_time': execution_time,
                                'timestamp': start_time
                            })
                    
                    return result
                
                except Exception as e:
                    print(f"Error in profiled function {func_name}: {e}")
                    raise
            
            return wrapper
        return decorator
    
    def _get_memory_usage(self) -> float:
        try:
            if torch.cuda.is_available():
                return torch.cuda.memory_allocated() / (1024 * 1024)  # MB
            else:
                return psutil.virtual_memory().used / (1024 * 1024)  # MB
        except:
            return 0.0
    
    def get_performance_summary(self) -> Dict[str, Any]:
        with self.lock:
            summary = {}
            
            for func_name, times in self.execution_times.items():
                if times:
                    summary[func_name] = {
                        'avg_time': np.mean(times),
                        'std_time': np.std(times),
                        'min_time': np.min(times),
                        'max_time': np.max(times),
                        'total_calls': len(times),
                        'total_time': np.sum(times)
                    }
            
            # Bottleneck analysis
            summary['bottlenecks'] = self.bottlenecks[-10:]  # Last 10 bottlenecks
            
            return summary
    
    def reset_profiling_data(self):
        with self.lock:
            self.execution_times.clear()
            self.memory_usage.clear()
            self.gpu_utilization.clear()
            self.bottlenecks.clear()

class ExperimentLogger:
    def __init__(self, config, experiment_name: str = None):
        self.config = config
        self.experiment_name = experiment_name or f"lap_dtr_experiment_{int(time.time())}"
        
        # Create experiment directory
        self.log_dir = os.path.join(config.monitoring.tensorboard_dir, self.experiment_name)
        os.makedirs(self.log_dir, exist_ok=True)
        
        # Initialize loggers
        self.metrics_tracker = MetricsTracker(config)
        self.energy_monitor = EnergyMonitor([])
        self.comm_tracker = CommunicationTracker()
        self.profiler = PerformanceProfiler()
        
        # Experiment metadata
        self.experiment_metadata = {
            'experiment_name': self.experiment_name,
            'start_time': time.time(),
            'config': asdict(config) if hasattr(config, '__dict__') else str(config),
            'git_commit': self._get_git_commit(),
            'python_version': self._get_python_version(),
            'torch_version': torch.__version__
        }
        
        self._save_metadata()
    
    def _get_git_commit(self) -> str:
        try:
            import subprocess
            result = subprocess.run(['git', 'rev-parse', 'HEAD'], 
                                  capture_output=True, text=True, timeout=5)
            return result.stdout.strip() if result.returncode == 0 else 'unknown'
        except:
            return 'unknown'
    
    def _get_python_version(self) -> str:
        import sys
        return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
    
    def _save_metadata(self):
        metadata_path = os.path.join(self.log_dir, 'experiment_metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(self.experiment_metadata, f, indent=2, default=str)
    
    def log_experiment_result(self, results: Dict[str, Any]):
        results_path = os.path.join(self.log_dir, 'experiment_results.json')
        
        # Add timestamp and experiment metadata
        results['experiment_metadata'] = self.experiment_metadata
        results['end_time'] = time.time()
        results['total_duration'] = results['end_time'] - self.experiment_metadata['start_time']
        
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2, default=str)
    
    def create_experiment_report(self, training_stats: Dict, validation_stats: Dict,
                               inference_stats: Dict) -> str:
        report = f"""
# LAP-DTR Experiment Report: {self.experiment_name}

## Experiment Configuration
- Start Time: {time.ctime(self.experiment_metadata['start_time'])}
- Partitioning Strategy: {self.config.partition.strategy}
- Number of Devices: {self.config.device.num_devices}
- Model: {self.config.model.model_name}

## Training Results
- Final Training Accuracy: {training_stats.get('train_accuracies', [0])[-1]:.2f}%
- Final Validation Accuracy: {validation_stats.get('val_accuracies', [0])[-1]:.2f}%
- Total Energy Consumption: {sum(training_stats.get('energy_consumption', [])):.4f} Wh
- Average Communication Overhead: {np.mean(training_stats.get('communication_overhead', [0])):.2f} MB

## Inference Performance
- Average Latency: {inference_stats.get('latency', {}).get('mean', 0):.4f} seconds
- Throughput: {inference_stats.get('throughput', {}).get('samples_per_second', 0):.2f} samples/second
- Energy Efficiency: {inference_stats.get('energy_consumption', {}).get('mean', 0):.6f} Wh per sample

## Device Utilization
- Average Device Utilization: {inference_stats.get('device_utilization', {}).get('avg_utilization', 0):.2f}
- Active Devices: {inference_stats.get('device_utilization', {}).get('active_devices', 0)}/{inference_stats.get('device_utilization', {}).get('total_devices', 0)}

## Performance Bottlenecks
{self._format_bottlenecks()}

---
Generated at: {time.ctime()}
"""
        
        report_path = os.path.join(self.log_dir, 'experiment_report.md')
        with open(report_path, 'w') as f:
            f.write(report)
        
        return report_path
    
    def _format_bottlenecks(self) -> str:
        bottlenecks = self.profiler.get_performance_summary().get('bottlenecks', [])
        
        if not bottlenecks:
            return "No significant bottlenecks detected."
        
        formatted = []
        for bottleneck in bottlenecks[-5:]:  # Last 5 bottlenecks
            formatted.append(f"- {bottleneck['function']}: {bottleneck['execution_time']:.3f}s")
        
        return '\n'.join(formatted)

def set_random_seeds(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def calculate_model_complexity(model: torch.nn.Module, input_shape: Tuple[int, ...]) -> Dict[str, Any]:
    """Calculate model complexity metrics."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Estimate FLOPs (simplified)
    dummy_input = torch.randn(1, *input_shape[1:])
    model.eval()
    
    with torch.no_grad():
        flops = 0
        for module in model.modules():
            if isinstance(module, torch.nn.Linear):
                flops += module.in_features * module.out_features
            elif isinstance(module, torch.nn.Conv2d):
                output_dims = dummy_input.shape[2:]
                kernel_flops = module.kernel_size[0] * module.kernel_size[1] * module.in_channels
                output_elements = output_dims[0] * output_dims[1] * module.out_channels
                flops += kernel_flops * output_elements
    
    return {
        'total_parameters': total_params,
        'trainable_parameters': trainable_params,
        'estimated_flops': flops,
        'model_size_mb': total_params * 4 / (1024 * 1024),  # Assuming float32
    }

def create_visualization_dashboard(metrics_tracker: MetricsTracker, 
                                 energy_monitor: EnergyMonitor,
                                 comm_tracker: CommunicationTracker,
                                 save_dir: str):
    """Create a comprehensive visualization dashboard."""
    os.makedirs(save_dir, exist_ok=True)
    
    # Set style
    plt.style.use('seaborn-v0_8')
    sns.set_palette("husl")
    
    # Create subplots
    fig = plt.figure(figsize=(20, 15))
    
    # Training metrics
    ax1 = plt.subplot(2, 3, 1)
    metrics_tracker.plot_metrics(['accuracy', 'loss'])
    ax1.set_title('Training Progress')
    
    # Energy consumption
    ax2 = plt.subplot(2, 3, 2)
    energy_stats = energy_monitor.get_epoch_statistics()
    if 'total_energy' in energy_stats:
        plt.bar(['Total Energy'], [energy_stats['total_energy']])
        ax2.set_title('Energy Consumption')
        ax2.set_ylabel('Energy (Wh)')
    
    # Communication overhead
    ax3 = plt.subplot(2, 3, 3)
    comm_stats = comm_tracker.get_epoch_statistics()
    if 'total_data_transferred' in comm_stats:
        plt.bar(['Total Data'], [comm_stats['total_data_transferred']])
        ax3.set_title('Communication Overhead')
        ax3.set_ylabel('Data Transferred (MB)')
    
    # Device utilization heatmap (placeholder)
    ax4 = plt.subplot(2, 3, 4)
    dummy_utilization = np.random.rand(4, 10)  # 4 devices, 10 time points
    sns.heatmap(dummy_utilization, annot=True, fmt='.2f', cmap='viridis')
    ax4.set_title('Device Utilization Over Time')
    ax4.set_xlabel('Time Steps')
    ax4.set_ylabel('Device ID')
    
    # Latency distribution
    ax5 = plt.subplot(2, 3, 5)
    latency_stats = metrics_tracker.get_metric_statistics('latency')
    if latency_stats:
        plt.hist(np.random.normal(latency_stats.get('mean', 1), 
                                latency_stats.get('std', 0.1), 100), 
                bins=20, alpha=0.7)
        ax5.set_title('Latency Distribution')
        ax5.set_xlabel('Latency (seconds)')
        ax5.set_ylabel('Frequency')
    
    # Energy efficiency
    ax6 = plt.subplot(2, 3, 6)
    accuracy_stats = metrics_tracker.get_metric_statistics('accuracy')
    energy_stats = energy_monitor.get_epoch_statistics()
    
    if accuracy_stats and 'total_energy' in energy_stats:
        efficiency = accuracy_stats.get('mean', 0) / max(energy_stats['total_energy'], 0.001)
        plt.bar(['Energy Efficiency'], [efficiency])
        ax6.set_title('Energy Efficiency')
        ax6.set_ylabel('Accuracy per Wh')
    
    plt.tight_layout()
    dashboard_path = os.path.join(save_dir, 'dashboard.png')
    plt.savefig(dashboard_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return dashboard_path