import os
import json
import numpy as np
import cv2
from typing import Dict, List, Tuple, Optional
import argparse
from collections import Counter
import logging

class NuScenesPreprocessor:
    def __init__(self, nuscenes_root: str, output_dir: str):
        self.nuscenes_root = nuscenes_root
        self.output_dir = output_dir
        self.target_classes = [
            "human.pedestrian.adult",
            "vehicle.bicycle", 
            "vehicle.motorcycle"
        ]
        self.setup_logging()
    
    def setup_logging(self):
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def extract_annotations(self, version: str = "v1.0-trainval") -> List[Dict]:
        # Load nuScenes data
        from nuscenes.nuscenes import NuScenes
        
        nusc = NuScenes(version=version, dataroot=self.nuscenes_root, verbose=True)
        
        annotations = []
        scene_count = 0
        
        for scene in nusc.scene:
            if scene_count % 10 == 0:
                self.logger.info(f"Processing scene {scene_count}/{len(nusc.scene)}")
            
            sample_token = scene['first_sample_token']
            
            while sample_token != '':
                sample = nusc.get('sample', sample_token)
                
                # Get camera data
                for camera_channel in ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT']:
                    if camera_channel in sample['data']:
                        camera_token = sample['data'][camera_channel]
                        camera_data = nusc.get('sample_data', camera_token)
                        
                        # Get image path
                        image_path = os.path.join(self.nuscenes_root, camera_data['filename'])
                        
                        # Get annotations for this sample
                        sample_annotations = self._extract_sample_annotations(
                            nusc, sample, camera_token, camera_channel
                        )
                        
                        for ann in sample_annotations:
                            ann['image_path'] = camera_data['filename']
                            ann['scene_token'] = scene['token']
                            ann['sample_token'] = sample['token']
                            ann['camera_channel'] = camera_channel
                            annotations.append(ann)
                
                sample_token = sample['next']
            
            scene_count += 1
        
        return annotations
    
    def _extract_sample_annotations(self, nusc, sample: Dict, camera_token: str, 
                                  camera_channel: str) -> List[Dict]:
        sample_annotations = []
        
        for ann_token in sample['anns']:
            ann = nusc.get('sample_annotation', ann_token)
            
            # Filter for target classes
            if ann['category_name'] not in self.target_classes:
                continue
            
            # Get 2D bounding box in image coordinates
            bbox_2d = self._get_2d_bbox(nusc, ann, camera_token)
            
            if bbox_2d is not None:
                annotation = {
                    'category_name': ann['category_name'],
                    'bbox': bbox_2d,
                    'visibility': ann['visibility_token'],
                    'instance_token': ann['instance_token'],
                    'annotation_token': ann['token']
                }
                sample_annotations.append(annotation)
        
        return sample_annotations
    
    def _get_2d_bbox(self, nusc, annotation: Dict, camera_token: str) -> Optional[List[float]]:
        try:
            from nuscenes.utils.geometry_utils import view_points
            from pyquaternion import Quaternion
            
            # Get camera calibration
            camera_data = nusc.get('sample_data', camera_token)
            camera_calibration = nusc.get('calibrated_sensor', camera_data['calibrated_sensor_token'])
            
            # Get 3D bounding box corners
            box = nusc.get_box(annotation['token'])
            
            # Transform to camera coordinates
            box.translate(-np.array(camera_calibration['translation']))
            box.rotate(Quaternion(camera_calibration['rotation']).inverse)
            
            # Project to image plane
            corners_3d = box.corners()
            corners_2d = view_points(corners_3d, 
                                   np.array(camera_calibration['camera_intrinsic']), 
                                   normalize=True)[:2, :]
            
            # Get bounding box
            x_min, x_max = corners_2d[0, :].min(), corners_2d[0, :].max()
            y_min, y_max = corners_2d[1, :].min(), corners_2d[1, :].max()
            
            # Check if bbox is valid
            if x_min >= 0 and y_min >= 0 and x_max > x_min and y_max > y_min:
                return [x_min, y_min, x_max, y_max]
            
        except Exception as e:
            self.logger.warning(f"Failed to extract 2D bbox: {e}")
        
        return None
    
    def analyze_class_distribution(self, annotations: List[Dict]) -> Dict[str, int]:
        class_counts = Counter()
        
        for ann in annotations:
            class_counts[ann['category_name']] += 1
        
        total_annotations = len(annotations)
        
        self.logger.info("Class Distribution Analysis:")
        self.logger.info(f"Total annotations: {total_annotations}")
        
        for class_name in self.target_classes:
            count = class_counts[class_name]
            percentage = (count / total_annotations) * 100
            self.logger.info(f"{class_name}: {count} ({percentage:.2f}%)")
        
        return dict(class_counts)
    
    def create_balanced_splits(self, annotations: List[Dict], 
                             train_ratio: float = 0.8) -> Tuple[List[Dict], List[Dict]]:
        # Group annotations by class
        class_annotations = {class_name: [] for class_name in self.target_classes}
        
        for ann in annotations:
            if ann['category_name'] in class_annotations:
                class_annotations[ann['category_name']].append(ann)
        
        train_annotations = []
        val_annotations = []
        
        # Split each class separately to maintain distribution
        for class_name, class_anns in class_annotations.items():
            np.random.shuffle(class_anns)
            split_idx = int(len(class_anns) * train_ratio)
            
            train_annotations.extend(class_anns[:split_idx])
            val_annotations.extend(class_anns[split_idx:])
        
        # Shuffle the final splits
        np.random.shuffle(train_annotations)
        np.random.shuffle(val_annotations)
        
        self.logger.info(f"Train split: {len(train_annotations)} annotations")
        self.logger.info(f"Validation split: {len(val_annotations)} annotations")
        
        return train_annotations, val_annotations
    
    def save_annotations(self, annotations: List[Dict], filename: str):
        os.makedirs(self.output_dir, exist_ok=True)
        filepath = os.path.join(self.output_dir, filename)
        
        with open(filepath, 'w') as f:
            json.dump(annotations, f, indent=2)
        
        self.logger.info(f"Saved {len(annotations)} annotations to {filepath}")
    
    def preprocess_images(self, annotations: List[Dict], target_size: Tuple[int, int] = (224, 224)):
        processed_dir = os.path.join(self.output_dir, 'processed_images')
        os.makedirs(processed_dir, exist_ok=True)
        
        unique_images = set(ann['image_path'] for ann in annotations)
        
        for i, image_path in enumerate(unique_images):
            if i % 100 == 0:
                self.logger.info(f"Processing image {i}/{len(unique_images)}")
            
            # Load image
            full_image_path = os.path.join(self.nuscenes_root, image_path)
            if not os.path.exists(full_image_path):
                continue
            
            image = cv2.imread(full_image_path)
            if image is None:
                continue
            
            # Resize image
            resized_image = cv2.resize(image, target_size)
            
            # Save processed image
            output_path = os.path.join(processed_dir, os.path.basename(image_path))
            cv2.imwrite(output_path, resized_image)
    
    def generate_statistics_report(self, annotations: List[Dict]) -> Dict:
        stats = {
            'total_annotations': len(annotations),
            'class_distribution': self.analyze_class_distribution(annotations),
            'unique_images': len(set(ann['image_path'] for ann in annotations)),
            'unique_scenes': len(set(ann['scene_token'] for ann in annotations)),
        }
        
        # Analyze bbox sizes
        bbox_areas = []
        bbox_aspect_ratios = []
        
        for ann in annotations:
            bbox = ann['bbox']
            width = bbox[2] - bbox[0]
            height = bbox[3] - bbox[1]
            area = width * height
            aspect_ratio = width / height if height > 0 else 0
            
            bbox_areas.append(area)
            bbox_aspect_ratios.append(aspect_ratio)
        
        stats['bbox_statistics'] = {
            'mean_area': np.mean(bbox_areas),
            'std_area': np.std(bbox_areas),
            'mean_aspect_ratio': np.mean(bbox_aspect_ratios),
            'std_aspect_ratio': np.std(bbox_aspect_ratios)
        }
        
        return stats

def main():
    parser = argparse.ArgumentParser(description='Preprocess nuScenes dataset for bias analysis')
    parser.add_argument('--nuscenes_root', type=str, required=True,
                       help='Path to nuScenes dataset root directory')
    parser.add_argument('--output_dir', type=str, default='./preprocessed_data',
                       help='Output directory for processed data')
    parser.add_argument('--version', type=str, default='v1.0-trainval',
                       help='nuScenes dataset version')
    parser.add_argument('--train_ratio', type=float, default=0.8,
                       help='Ratio of data for training split')
    parser.add_argument('--preprocess_images', action='store_true',
                       help='Preprocess and resize images')
    parser.add_argument('--target_size', type=int, nargs=2, default=[224, 224],
                       help='Target image size for preprocessing')
    
    args = parser.parse_args()
    
    # Initialize preprocessor
    preprocessor = NuScenesPreprocessor(args.nuscenes_root, args.output_dir)
    
    # Extract annotations
    print("Extracting annotations from nuScenes dataset...")
    annotations = preprocessor.extract_annotations(args.version)
    
    # Analyze class distribution
    class_distribution = preprocessor.analyze_class_distribution(annotations)
    
    # Create balanced splits
    train_annotations, val_annotations = preprocessor.create_balanced_splits(
        annotations, args.train_ratio
    )
    
    # Save annotations
    preprocessor.save_annotations(annotations, 'all_annotations.json')
    preprocessor.save_annotations(train_annotations, 'train_annotations.json')
    preprocessor.save_annotations(val_annotations, 'val_annotations.json')
    
    # Generate statistics report
    stats = preprocessor.generate_statistics_report(annotations)
    preprocessor.save_annotations(stats, 'dataset_statistics.json')
    
    # Preprocess images if requested
    if args.preprocess_images:
        print("Preprocessing images...")
        preprocessor.preprocess_images(annotations, tuple(args.target_size))
    
    print(f"Preprocessing completed. Results saved to {args.output_dir}")

if __name__ == '__main__':
    main()