import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import logging
from typing import Dict, List, Tuple, Optional
import time
import os
from torch.utils.tensorboard import SummaryWriter
import json

# Import our custom modules
from enhanced_vi_transformer import EnhancedVariationalVisionTransformer, mixed_distribution_loss, adaptive_training_step
from unified_dataloader import create_camera_lidar_dataloaders, generate_sample_data_config

class MixedPrecisionVariationalTrainer:

    def __init__(self, 
                 model: EnhancedVariationalVisionTransformer,
                 train_loader,
                 val_loader,
                 test_loader,
                 device: str = 'cuda',
                 log_dir: str = 'logs',
                 checkpoint_dir: str = 'checkpoints'):
        
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        
        # Setup logging
        self.log_dir = log_dir
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(log_dir, exist_ok=True)
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        self.writer = SummaryWriter(log_dir)
        self.setup_logging()
        
        # Training history
        self.history = {
            'train_loss': [],
            'train_nll': [],
            'train_kl': [],
            'train_sparsity': [],
            'val_loss': [],
            'val_accuracy': [],
            'distribution_switches': [],
            'bit_assignments': []
        }
    
    def setup_logging(self):
        """Setup logging configuration"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(os.path.join(self.log_dir, 'training.log')),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def train(self, 
              epochs: int = 100,
              learning_rate: float = 1e-4,
              lambda_var: float = 0.01,
              lambda_sparse: float = 0.001,
              switch_every: int = 10,
              validate_every: int = 5,
              save_every: int = 20,
              early_stopping_patience: int = 15):

        # Initialize optimizer
        optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        
        best_val_loss = float('inf')
        patience_counter = 0
        
        self.logger.info(f"Starting training for {epochs} epochs")
        self.logger.info(f"Device: {self.device}")
        self.logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        for epoch in range(1, epochs + 1):
            epoch_start_time = time.time()
            
            # Training phase
            train_metrics = self._train_epoch(epoch, optimizer, lambda_var, lambda_sparse, switch_every)
            
            # Validation phase
            if epoch % validate_every == 0:
                val_metrics = self._validate_epoch(epoch, lambda_var)
                
                # Early stopping check
                if val_metrics['val_loss'] < best_val_loss:
                    best_val_loss = val_metrics['val_loss']
                    patience_counter = 0
                    self._save_checkpoint(epoch, optimizer, 'best_model.pth')
                else:
                    patience_counter += 1
                
                if patience_counter >= early_stopping_patience:
                    self.logger.info(f"Early stopping triggered at epoch {epoch}")
                    break
            
            # Learning rate scheduling
            scheduler.step()
            
            # Save checkpoint
            if epoch % save_every == 0:
                self._save_checkpoint(epoch, optimizer, f'checkpoint_epoch_{epoch}.pth')
            
            # Log epoch time
            epoch_time = time.time() - epoch_start_time
            self.logger.info(f"Epoch {epoch} completed in {epoch_time:.2f}s")
            
            # Update history
            self._update_history(train_metrics, val_metrics if epoch % validate_every == 0 else None)
        
        self.logger.info("Training completed!")
        
        # Final testing
        test_metrics = self.test()
        
        # Save final model and results
        self._save_checkpoint(epoch, optimizer, 'final_model.pth')
        self._save_training_results(test_metrics)
        
        return self.history
    
    def _train_epoch(self, epoch: int, optimizer, lambda_var: float, lambda_sparse: float, switch_every: int) -> Dict:
        """Train for one epoch"""
        self.model.train()
        
        total_loss = 0.0
        total_nll = 0.0
        total_kl = 0.0
        total_sparsity = 0.0
        total_samples = 0
        
        # Track current distribution
        current_distribution = None
        
        for batch_idx, (camera_data, lidar_data, targets) in enumerate(self.train_loader):
            camera_data = camera_data.to(self.device)
            lidar_data = lidar_data.to(self.device)
            targets = targets.to(self.device)
            
            # Adaptive training step with distribution switching
            metrics = adaptive_training_step(
                self.model, camera_data, lidar_data, targets, 
                optimizer, lambda_var, epoch, switch_every
            )
            
            # Accumulate metrics
            batch_size = camera_data.size(0)
            total_loss += metrics['total_loss'] * batch_size
            total_nll += metrics['nll'] * batch_size
            total_kl += metrics['kl_loss'] * batch_size
            total_sparsity += metrics['sparsity_loss'] * batch_size
            total_samples += batch_size
            
            current_distribution = metrics['distribution']
            
            # Log batch progress
            if batch_idx % 50 == 0:
                self.logger.info(
                    f"Epoch {epoch}, Batch {batch_idx}/{len(self.train_loader)}: "
                    f"Loss={metrics['total_loss']:.4f}, "
                    f"NLL={metrics['nll']:.4f}, "
                    f"KL={metrics['kl_loss']:.4f}, "
                    f"Distribution={current_distribution}"
                )
        
        # Calculate average metrics
        avg_metrics = {
            'train_loss': total_loss / total_samples,
            'train_nll': total_nll / total_samples,
            'train_kl': total_kl / total_samples,
            'train_sparsity': total_sparsity / total_samples,
            'distribution': current_distribution
        }
        
        # Log to tensorboard
        self.writer.add_scalar('Train/Loss', avg_metrics['train_loss'], epoch)
        self.writer.add_scalar('Train/NLL', avg_metrics['train_nll'], epoch)
        self.writer.add_scalar('Train/KL', avg_metrics['train_kl'], epoch)
        self.writer.add_scalar('Train/Sparsity', avg_metrics['train_sparsity'], epoch)
        
        self.logger.info(
            f"Epoch {epoch} Training: "
            f"Loss={avg_metrics['train_loss']:.4f}, "
            f"NLL={avg_metrics['train_nll']:.4f}, "
            f"KL={avg_metrics['train_kl']:.4f}, "
            f"Distribution={current_distribution}"
        )
        
        return avg_metrics
    
    def _validate_epoch(self, epoch: int, lambda_var: float) -> Dict:
        """Validate for one epoch"""
        self.model.eval()
        
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for camera_data, lidar_data, targets in self.val_loader:
                camera_data = camera_data.to(self.device)
                lidar_data = lidar_data.to(self.device)
                targets = targets.to(self.device)
                
                # Forward pass
                outputs = self.model(camera_data, lidar_data)
                
                # Compute loss
                loss, nll, kl_loss, sparsity_loss = mixed_distribution_loss(
                    self.model, outputs, targets, lambda_var
                )
                
                # Accumulate metrics
                batch_size = camera_data.size(0)
                total_loss += loss.item() * batch_size
                total_samples += batch_size
                
                # Calculate accuracy
                pred = outputs.argmax(dim=1)
                total_correct += pred.eq(targets).sum().item()
        
        # Calculate average metrics
        avg_loss = total_loss / total_samples
        accuracy = 100.0 * total_correct / total_samples
        
        val_metrics = {
            'val_loss': avg_loss,
            'val_accuracy': accuracy
        }
        
        # Log to tensorboard
        self.writer.add_scalar('Val/Loss', avg_loss, epoch)
        self.writer.add_scalar('Val/Accuracy', accuracy, epoch)
        
        self.logger.info(
            f"Epoch {epoch} Validation: "
            f"Loss={avg_loss:.4f}, "
            f"Accuracy={accuracy:.2f}%"
        )
        
        return val_metrics
    
    def test(self) -> Dict:
        """Test the model on test set"""
        self.logger.info("Starting testing...")
        
        self.model.eval()
        
        total_correct = 0
        total_samples = 0
        class_correct = {}
        class_total = {}
        
        with torch.no_grad():
            for camera_data, lidar_data, targets in self.test_loader:
                camera_data = camera_data.to(self.device)
                lidar_data = lidar_data.to(self.device)
                targets = targets.to(self.device)
                
                # Forward pass
                outputs = self.model(camera_data, lidar_data)
                pred = outputs.argmax(dim=1)
                
                # Overall accuracy
                correct = pred.eq(targets)
                total_correct += correct.sum().item()
                total_samples += targets.size(0)
                
                # Per-class accuracy
                for i in range(targets.size(0)):
                    label = targets[i].item()
                    class_correct[label] = class_correct.get(label, 0) + correct[i].item()
                    class_total[label] = class_total.get(label, 0) + 1
        
        overall_accuracy = 100.0 * total_correct / total_samples
        
        # Calculate per-class accuracy
        class_accuracies = {}
        for label in class_total:
            class_accuracies[label] = 100.0 * class_correct[label] / class_total[label]
        
        test_metrics = {
            'test_accuracy': overall_accuracy,
            'class_accuracies': class_accuracies,
            'total_samples': total_samples
        }
        
        self.logger.info(f"Test Accuracy: {overall_accuracy:.2f}%")
        for label, acc in class_accuracies.items():
            self.logger.info(f"Class {label} Accuracy: {acc:.2f}%")
        
        return test_metrics
    
    def _save_checkpoint(self, epoch: int, optimizer, filename: str):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': self.history,
            'model_config': {
                'vi_distribution': self.model.vi_distribution,
                'laplace_b': self.model.laplace_b,
                'mixed_precision_config': self.model.mixed_precision_config
            }
        }
        
        path = os.path.join(self.checkpoint_dir, filename)
        torch.save(checkpoint, path)
        self.logger.info(f"Checkpoint saved: {path}")
    
    def _update_history(self, train_metrics: Dict, val_metrics: Optional[Dict]):
        """Update training history"""
        self.history['train_loss'].append(train_metrics['train_loss'])
        self.history['train_nll'].append(train_metrics['train_nll'])
        self.history['train_kl'].append(train_metrics['train_kl'])
        self.history['train_sparsity'].append(train_metrics['train_sparsity'])
        self.history['distribution_switches'].append(train_metrics['distribution'])
        
        if val_metrics:
            self.history['val_loss'].append(val_metrics['val_loss'])
            self.history['val_accuracy'].append(val_metrics['val_accuracy'])
    
    def _save_training_results(self, test_metrics: Dict):
        """Save final training results"""
        results = {
            'training_history': self.history,
            'test_metrics': test_metrics,
            'model_config': {
                'vi_distribution': self.model.vi_distribution,
                'laplace_b': self.model.laplace_b,
                'mixed_precision_config': self.model.mixed_precision_config
            }
        }
        
        results_path = os.path.join(self.log_dir, 'training_results.json')
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        
        self.logger.info(f"Training results saved: {results_path}")
    
    def plot_training_curves(self):
        """Plot training curves"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Loss curves
        axes[0, 0].plot(self.history['train_loss'], label='Training Loss')
        if self.history['val_loss']:
            axes[0, 0].plot(self.history['val_loss'], label='Validation Loss')
        axes[0, 0].set_title('Loss Curves')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Accuracy curve
        if self.history['val_accuracy']:
            axes[0, 1].plot(self.history['val_accuracy'], label='Validation Accuracy')
        axes[0, 1].set_title('Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy (%)')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # KL and NLL
        axes[1, 0].plot(self.history['train_nll'], label='NLL')
        axes[1, 0].plot(self.history['train_kl'], label='KL Divergence')
        axes[1, 0].set_title('NLL and KL Divergence')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Sparsity
        axes[1, 1].plot(self.history['train_sparsity'], label='Sparsity Loss')
        axes[1, 1].set_title('Sparsity Regularization')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'training_curves.png'), dpi=300, bbox_inches='tight')
        plt.show()


def main_training_pipeline():
    """Main training pipeline following the algorithm"""
    
    # Configuration
    config = {
        'data_root': 'data',
        'batch_size': 16,
        'num_workers': 4,
        'epochs': 100,
        'learning_rate': 1e-4,
        'lambda_var': 0.01,
        'lambda_sparse': 0.001,
        'switch_every': 10,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'img_size': 224,
        'num_classes': 10
    }
    
    print("Setting up data...")
    
    # Generate sample data configuration (replace with your actual data)
    data_config = generate_sample_data_config(config['data_root'], num_samples=1000)
    
    # Create dataloaders
    train_loader, val_loader, test_loader = create_camera_lidar_dataloaders(
        data_config,
        batch_size=config['batch_size'],
        num_workers=config['num_workers']
    )
    
    print("Initializing model...")
    
    # Initialize model
    model = EnhancedVariationalVisionTransformer(
        img_size=config['img_size'],
        patch_size=16,
        in_channels=4,  # RGB + LiDAR
        num_classes=config['num_classes'],
        vi_distribution='gaussian',  # Start with Gaussian
        laplace_b=1.0
    )
    
    print("Starting training...")
    
    # Initialize trainer
    trainer = MixedPrecisionVariationalTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config['device']
    )
    
    # Train the model
    history = trainer.train(
        epochs=config['epochs'],
        learning_rate=config['learning_rate'],
        lambda_var=config['lambda_var'],
        lambda_sparse=config['lambda_sparse'],
        switch_every=config['switch_every']
    )
    
    # Plot results
    trainer.plot_training_curves()
    
    print("Training completed successfully!")
    
    return trainer, history


if __name__ == "__main__":
    trainer, history = main_training_pipeline()