#!/usr/bin/env python3

import argparse
import os
import sys
import time
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from typing import Dict, List, Optional, Any

# Import our modules
from config import Config
from device import DeviceManager, EdgeDevice
from lap_dtr import LAPDTRFramework
from training import DistributedTrainer
from inference import DistributedInference
from utils import (
    MetricsTracker, EnergyMonitor, CommunicationTracker, 
    ExperimentLogger, set_random_seeds, calculate_model_complexity,
    create_visualization_dashboard
)

class LAPDTRExperiment:
    def __init__(self, config: Config, args):
        self.config = config
        self.args = args
        
        # Set random seeds for reproducibility
        set_random_seeds(config.seed)
        
        # Initialize components
        self.device_manager = DeviceManager(config)
        self.devices = self.device_manager.create_device_network(config.device.num_devices)
        self.framework = LAPDTRFramework(config, self.devices)
        
        # Logging and monitoring
        experiment_name = f"{args.strategy}_{int(time.time())}" if args.strategy else f"lap_dtr_{int(time.time())}"
        self.logger = ExperimentLogger(config, experiment_name)
        
        # Training and inference engines
        self.trainer = None
        self.inference_engine = None
        
        print(f"Initialized LAP-DTR Experiment: {experiment_name}")
        print(f"Configuration: {config}")
        print(f"Devices: {len(self.devices)} devices initialized")
    
    def setup_data_loaders(self) -> tuple:
        """Setup training and validation data loaders."""
        print("Setting up data loaders...")
        
        # Data transformations
        transform_train = transforms.Compose([
            transforms.Resize((self.config.model.img_size, self.config.model.img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        transform_val = transforms.Compose([
            transforms.Resize((self.config.model.img_size, self.config.model.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Load dataset
        if self.config.dataset.lower() == 'imagenet':
            # For demonstration, use a subset or ImageNet
            try:
                train_dataset = datasets.ImageNet(
                    root=self.config.data_root,
                    split='train',
                    transform=transform_train
                )
                val_dataset = datasets.ImageNet(
                    root=self.config.data_root,
                    split='val',
                    transform=transform_val
                )
            except:
                print("ImageNet not available, using CIFAR-10 as fallback")
                train_dataset = datasets.CIFAR10(
                    root=self.config.data_root,
                    train=True,
                    download=True,
                    transform=transform_train
                )
                val_dataset = datasets.CIFAR10(
                    root=self.config.data_root,
                    train=False,
                    download=True,
                    transform=transform_val
                )
                self.config.model.num_classes = 10
        
        elif self.config.dataset.lower() == 'cifar10':
            train_dataset = datasets.CIFAR10(
                root=self.config.data_root,
                train=True,
                download=True,
                transform=transform_train
            )
            val_dataset = datasets.CIFAR10(
                root=self.config.data_root,
                train=False,
                download=True,
                transform=transform_val
            )
            self.config.model.num_classes = 10
        
        else:
            raise ValueError(f"Unsupported dataset: {self.config.dataset}")
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
        return train_loader, val_loader
    
    def run_training(self):
        """Run distributed training experiment."""
        print("Starting distributed training...")
        
        # Setup data loaders
        train_loader, val_loader = self.setup_data_loaders()
        
        # Load and partition model
        model = self.framework.load_vit_model(self.config.model.model_name)
        print(f"Loaded model: {self.config.model.model_name}")
        
        # Calculate model complexity
        input_shape = (1, 3, self.config.model.img_size, self.config.model.img_size)
        complexity = calculate_model_complexity(model, input_shape)
        print(f"Model complexity: {complexity}")
        
        # Partition model
        partitioned_models = self.framework.partition_model(model, self.args.strategy)
        print(f"Partitioned model using {self.framework.current_partition.strategy} strategy")
        print(f"Partition summary: {self.framework.get_partition_summary()}")
        
        # Initialize trainer
        self.trainer = DistributedTrainer(self.framework, self.config)
        
        # Run training
        start_time = time.time()
        training_stats = self.trainer.train(train_loader, val_loader)
        training_time = time.time() - start_time
        
        print(f"Training completed in {training_time:.2f} seconds")
        print(f"Best validation accuracy: {self.trainer.best_accuracy:.2f}%")
        
        # Log results
        self.logger.log_experiment_result({
            'mode': 'training',
            'training_stats': training_stats,
            'training_time': training_time,
            'best_accuracy': self.trainer.best_accuracy,
            'model_complexity': complexity,
            'partition_summary': self.framework.get_partition_summary()
        })
        
        return training_stats
    
    def run_inference(self, model_path: Optional[str] = None):
        """Run distributed inference experiment."""
        print("Starting distributed inference...")
        
        if model_path and os.path.exists(model_path):
            # Load trained model
            checkpoint = torch.load(model_path)
            print(f"Loaded model from {model_path}")
        else:
            # Use fresh model for demonstration
            model = self.framework.load_vit_model(self.config.model.model_name)
            partitioned_models = self.framework.partition_model(model, self.args.strategy)
            print("Using fresh model for inference")
        
        # Setup test data
        _, test_loader = self.setup_data_loaders()
        test_data = []
        for i, (data, _) in enumerate(test_loader):
            test_data.extend([data[j] for j in range(data.size(0))])
            if len(test_data) >= 100:  # Limit test samples for demo
                break
        
        print(f"Prepared {len(test_data)} test samples for inference")
        
        # Initialize inference engine
        self.inference_engine = DistributedInference(self.framework, self.config)
        
        # Run inference benchmark
        start_time = time.time()
        inference_stats = self.inference_engine.benchmark_inference(test_data, num_runs=min(100, len(test_data)))
        inference_time = time.time() - start_time
        
        print(f"Inference benchmark completed in {inference_time:.2f} seconds")
        print(f"Average latency: {inference_stats['latency']['mean']:.4f} seconds")
        print(f"Throughput: {inference_stats['throughput']['samples_per_second']:.2f} samples/second")
        
        # Test single inference
        sample_input = test_data[0]
        result = self.inference_engine.run_inference(sample_input)
        print(f"Sample inference - Latency: {result.latency:.4f}s, Energy: {result.energy_consumed:.6f}Wh")
        
        # Log results
        self.logger.log_experiment_result({
            'mode': 'inference',
            'inference_stats': inference_stats,
            'inference_time': inference_time,
            'sample_result': {
                'latency': result.latency,
                'energy_consumed': result.energy_consumed,
                'communication_overhead': result.communication_overhead
            },
            'partition_summary': self.framework.get_partition_summary()
        })
        
        return inference_stats
    
    def run_evaluation(self):
        """Run comprehensive evaluation comparing different strategies."""
        print("Starting comprehensive evaluation...")
        
        strategies = ['meta', 'ensembled', 'hybrid']
        results = {}
        
        # Load model once
        model = self.framework.load_vit_model(self.config.model.model_name)
        
        # Setup test data
        _, test_loader = self.setup_data_loaders()
        test_data = []
        for i, (data, _) in enumerate(test_loader):
            test_data.extend([data[j] for j in range(data.size(0))])
            if len(test_data) >= 50:  # Smaller set for comparison
                break
        
        for strategy in strategies:
            print(f"\nEvaluating {strategy} strategy...")
            
            # Reset devices
            self.devices = self.device_manager.create_device_network(self.config.device.num_devices)
            self.framework.devices = self.devices
            
            # Partition model with current strategy
            self.config.partition.strategy = strategy
            partitioned_models = self.framework.partition_model(model, strategy)
            
            # Run inference evaluation
            inference_engine = DistributedInference(self.framework, self.config)
            stats = inference_engine.benchmark_inference(test_data, num_runs=len(test_data))
            
            results[strategy] = {
                'inference_stats': stats,
                'partition_summary': self.framework.get_partition_summary(),
                'device_utilization': stats['device_utilization']
            }
            
            print(f"{strategy} - Avg Latency: {stats['latency']['mean']:.4f}s, "
                  f"Throughput: {stats['throughput']['samples_per_second']:.2f} samples/s")
        
        # Compare results
        self._compare_strategies(results)
        
        # Log comprehensive results
        self.logger.log_experiment_result({
            'mode': 'evaluation',
            'strategy_comparison': results,
            'best_strategy': self._find_best_strategy(results)
        })
        
        return results
    
    def _compare_strategies(self, results: Dict[str, Any]):
        """Compare and display strategy comparison results."""
        print("\n" + "="*60)
        print("STRATEGY COMPARISON RESULTS")
        print("="*60)
        
        print(f"{'Strategy':<12} {'Latency (ms)':<12} {'Throughput':<12} {'Energy (mWh)':<12} {'Devices':<8}")
        print("-" * 60)
        
        for strategy, result in results.items():
            stats = result['inference_stats']
            latency_ms = stats['latency']['mean'] * 1000
            throughput = stats['throughput']['samples_per_second']
            energy_mwh = stats['energy_consumption']['mean'] * 1000
            active_devices = result['device_utilization']['active_devices']
            
            print(f"{strategy:<12} {latency_ms:<12.2f} {throughput:<12.2f} {energy_mwh:<12.3f} {active_devices:<8}")
        
        print("="*60)
    
    def _find_best_strategy(self, results: Dict[str, Any]) -> str:
        """Find the best strategy based on weighted scoring."""
        weights = self.config.get_partition_weights()
        best_strategy = None
        best_score = -float('inf')
        
        for strategy, result in results.items():
            stats = result['inference_stats']
            
            # Normalize metrics (lower is better for latency and energy, higher is better for throughput)
            latency_score = 1.0 / max(stats['latency']['mean'], 0.001)
            throughput_score = stats['throughput']['samples_per_second']
            energy_score = 1.0 / max(stats['energy_consumption']['mean'], 0.001)
            
            # Weighted combination
            total_score = (weights['latency'] * latency_score + 
                          weights['accuracy'] * throughput_score +  # Use throughput as proxy for accuracy
                          weights['energy'] * energy_score)
            
            if total_score > best_score:
                best_score = total_score
                best_strategy = strategy
        
        return best_strategy
    
    def create_report(self):
        """Create comprehensive experiment report."""
        print("Generating experiment report...")
        
        # Create visualizations
        dashboard_path = create_visualization_dashboard(
            self.logger.metrics_tracker,
            self.logger.energy_monitor,
            self.logger.comm_tracker,
            self.logger.log_dir
        )
        
        # Generate text report
        training_stats = {}
        validation_stats = {}
        inference_stats = {}
        
        report_path = self.logger.create_experiment_report(
            training_stats, validation_stats, inference_stats
        )
        
        print(f"Report generated: {report_path}")
        print(f"Dashboard saved: {dashboard_path}")
        
        return report_path, dashboard_path

def parse_arguments():
    parser = argparse.ArgumentParser(description='LAP-DTR: Layer-Adaptive Partitioning with Dynamic Task Redistribution')
    
    parser.add_argument('--mode', type=str, choices=['train', 'inference', 'evaluate'], 
                       default='train', help='Execution mode')
    parser.add_argument('--strategy', type=str, choices=['meta', 'ensembled', 'hybrid', 'auto'],
                       default='hybrid', help='Partitioning strategy')
    parser.add_argument('--config', type=str, help='Path to configuration file')
    parser.add_argument('--model_path', type=str, help='Path to trained model')
    parser.add_argument('--devices', type=int, default=4, help='Number of edge devices')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--dataset', type=str, choices=['imagenet', 'cifar10'], 
                       default='cifar10', help='Dataset to use')
    parser.add_argument('--data_root', type=str, default='./data', help='Data root directory')
    parser.add_argument('--output_dir', type=str, default='./output', help='Output directory')
    parser.add_argument('--verbose', action='store_true', help='Verbose output')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--compare_strategies', action='store_true', 
                       help='Compare all partitioning strategies')
    
    return parser.parse_args()

def main():
    # Parse command line arguments
    args = parse_arguments()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize configuration
    config = Config()
    
    # Override config with command line arguments
    if args.config and os.path.exists(args.config):
        with open(args.config, 'r') as f:
            config_dict = json.load(f)
            config.update_from_dict(config_dict)
    
    # Apply command line overrides
    config.device.num_devices = args.devices
    config.training.batch_size = args.batch_size
    config.training.num_epochs = args.epochs
    config.training.learning_rate = args.learning_rate
    config.dataset = args.dataset
    config.data_root = args.data_root
    config.seed = args.seed
    config.partition.strategy = args.strategy
    
    # Validate configuration
    config.validate_config()
    
    print("LAP-DTR Framework - Energy-aware Model Partitioning for Vision Transformers")
    print("="*80)
    print(f"Mode: {args.mode}")
    print(f"Strategy: {args.strategy}")
    print(f"Devices: {args.devices}")
    print(f"Dataset: {args.dataset}")
    print("="*80)
    
    try:
        # Initialize experiment
        experiment = LAPDTRExperiment(config, args)
        
        # Run based on mode
        if args.mode == 'train':
            training_stats = experiment.run_training()
            print("\nTraining completed successfully!")
            
        elif args.mode == 'inference':
            inference_stats = experiment.run_inference(args.model_path)
            print("\nInference completed successfully!")
            
        elif args.mode == 'evaluate' or args.compare_strategies:
            evaluation_results = experiment.run_evaluation()
            print("\nEvaluation completed successfully!")
        
        # Generate report
        experiment.create_report()
        
        print(f"\nExperiment completed. Results saved to: {experiment.logger.log_dir}")
    
    except KeyboardInterrupt:
        print("\nExperiment interrupted by user.")
        sys.exit(1)
    
    except Exception as e:
        print(f"\nError during experiment: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

if __name__ == '__main__':
    main()