"""quality_checks.py: Contains all functions to identify data quality insufficiencies. """

__author__      = "Niklas Biermann"
__copyright__   = "Copyright 2023, Niklas Biermann"
import csv
import logging
import re
import statistics
from typing import List

import omni.ext
import omni.usd
from pxr import Usd, UsdGeom

logger = logging.getLogger(__name__)

# Paths to export CSV files to
outputpath_psc = "c:/omniverse/position_check.csv"
outputpath_dmc = "c:/omniverse/2D_check.csv"
outputpath_pfc = "c:/omniverse/performance_check.csv"
outputpath_efc = "c:/omniverse/empty_file_check.csv"
outputpath_ncc = "c:/omniverse/naming_check.csv"
outputpath_scc = "c:/omniverse/scaling_check.csv"


class DQCheck:
    """
    Main class of the extension including all functions to perform automatic identification of data quality insufficiencies"
    """

    def get_selection():
        """Get the list of currently selected prims"""
        return omni.usd.get_context().get_selection().get_selected_prim_paths()

    def get_vertice_count(prim):
        """ "Get the vertice count for a selected prim within a stage"""
        vertices = UsdGeom.Mesh(prim).GetFaceVertexCountsAttr()
        vertex = vertices.Get()
        return len(vertex)

    def compute_path_bbox(prim_path):
        """
        Compute Bounding Box using omni.usd.UsdContext.compute_path_world_bounding_box
        See https://docs.omniverse.nvidia.com/kit/docs/omni.usd/latest/omni.usd/omni.usd.UsdContext.html#omni.usd.UsdContext.compute_path_world_bounding_box

        Args:
            prim_path: A prim path to compute the bounding box.
        Returns:
            A range (i.e. bounding box) as a minimum point and maximum point.
        """
        return omni.usd.get_context().compute_path_world_bounding_box(prim_path)

    def position_check(stage):
        """
        Compute the coordinates of every prim and calculate a 99.9% confidence interval for the perim density in x-, y- and z-direction.
        Check for every prim if it is located within the area of acceptance of these intervals and export the results as part of a CSV file.
        """
        # Empty list to store prims and their respective coordinates from which CSV file is exported.
        position_data = []

        # All prims in the stage are scanned and their paths and coordinates are determined.
        prims = stage.TraverseAll()
        for prim in prims:
            transform = UsdGeom.Xformable(prim).GetLocalTransformation()
            x = int(round(float(transform[3][0]), 2))
            y = int(round(float(transform[3][1]), 2))
            z = int(round(float(transform[3][2]), 2))
            path = prim.GetPath()

            if x != 0 and y != 0 and z != 0:
                # All prims with valid coordinates are added to the list.
                position_data.append({"Prim": path, "X": x, "Y": y, "Z": z})

        """ 
        In the CSV, extreme data points, the positional data quality insufficiencies, are to be identified. 
        For that purpose, the sample mean and standard deviation of the coordinates on each axis are calculated.
        """
        x_values = [x["X"] for x in position_data]
        x_mean = statistics.mean(x_values)
        x_stddev = statistics.stdev(x_values)

        y_values = [x["Y"] for x in position_data]
        y_mean = statistics.mean(y_values)
        y_stddev = statistics.stdev(y_values)

        z_values = [x["Z"] for x in position_data]
        z_mean = statistics.mean(z_values)
        z_stddev = statistics.stdev(z_values)

        # Building on the calculated mean and standard deviation, a 99.9% confidence interval is determined.
        x_lower = x_mean - 3.89 * x_stddev
        x_upper = x_mean + 3.89 * x_stddev

        y_lower = y_mean - 3.89 * y_stddev
        y_upper = y_mean + 3.89 * y_stddev

        z_lower = z_mean - 3.89 * z_stddev
        z_upper = z_mean + 3.89 * z_stddev

        """ 
        The x-, y- and z-coordinate of every prim is compared to the upper and lower bound of the confidence intervals. 
        A corresponding message is created depending on whether prim is in- or outside of the confidence interval.
        """
        for prim in position_data:
            if prim["X"] < x_lower or prim["X"] > x_upper:
                prim["X-Axis Confidence Interval"] = False
            else:
                prim["X-Axis Confidence Interval"] = True

        for prim in position_data:
            if prim["Y"] < y_lower or prim["Y"] > y_upper:
                prim["Y-Axis Confidence Interval"] = False
            else:
                prim["Y-Axis Confidence Interval"] = True

        for prim in position_data:
            if prim["Z"] < z_lower or prim["Z"] > z_upper:
                prim["Z-Axis Confidence Interval"] = False
            else:
                prim["Z-Axis Confidence Interval"] = True

        """ 
        Finally, CSV file is created and exported: The CSV file contains seven columns with information on the prims' 
        paths, their coordinates and the message indicating whether or not each prim is located within the confidence 
        interval of acceptance for each axis.
        """
        with open(outputpath_psc, "w", newline="") as csv_file:
            fieldnames = [
                "Prim",
                "X",
                "Y",
                "Z",
                "X-Axis Confidence Interval",
                "Y-Axis Confidence Interval",
                "Z-Axis Confidence Interval",
            ]
            position_writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=";")
            position_writer.writeheader()
            for row in position_data:
                position_writer.writerow(row)

    def scaling_check(stage):
        """
        Computes the vertices of each prim and the volume of its bounding box. Subsequently, an average volume for all
        prims within the USD stage is determined. An error message is thrown for those prims whose volume exceeds the average
        by the factor 5 or higher. These prims are believed to be falsely scaled.The results are exported as part of a CSV file.
        """
        # List to include the prim paths, vertices and volumes from which CSV file is exported
        scaling_data = []

        """
        Construct a new BBoxCache for a specific time and set of includedPurposes. For details, see:
        https://openusd.org/dev/api/class_usd_geom_b_box_cache.html#ad4dbc60e68f738ff79c437fec435322d
        """
        boxcache = UsdGeom.BBoxCache(
            Usd.TimeCode.Default(), includedPurposes=[UsdGeom.Tokens.default_], useExtentsHint=True
        )
        # Compute vertice count of meshes and volume of respective bounding box
        prims = stage.TraverseAll()
        for prim in prims:
            if UsdGeom.Mesh(prim):
                # key = get_vertice_count(prim)
                name = prim.GetPath()
                box = boxcache.ComputeWorldBound(prim)
                b_range = box.GetRange()
                # Compute extents of bounding box
                max_extents = [*b_range.GetMin(), *b_range.GetMax()]
                # Compute z extents (Assumption: x,y,z coordinate system, NOT x,z,y)
                z_length = int(max_extents[5] - max_extents[2])
                volume = int(box.GetVolume())
                scaling_data.append({"Prim": name, "Volume": volume, "Z-Extents": z_length})

        # Compute average volume of all meshes and one-sided 99.99% confidence interval
        volume_values = [item["Volume"] for item in scaling_data]
        z_extent_values = [item["Z-Extents"] for item in scaling_data]

        std_volume = statistics.stdev(volume_values)
        std_z_extents = statistics.stdev(z_extent_values)

        mean_volume = statistics.mean(volume_values)
        mean_z_extents = statistics.mean(z_extent_values)

        v_upper = mean_volume + 3.5 * std_volume
        z_ext_upper = mean_z_extents + 3.5 * std_z_extents

        # Point out what prims exceed average volume or z-extents by factor 3.5 (or more) which indicates scaling error
        for item in scaling_data:
            if item["Volume"] >= v_upper or item["Z-Extents"] >= z_ext_upper:
                item["Scaling"] = False
            else:
                item["Scaling"] = True

        # Export a CSV file of all mesh paths, vertice counts and volumes as well as a false scaling indication
        with open(outputpath_scc, "w", newline="") as csv_file:
            fieldnames = ["Prim", "Volume", "Z-Extents", "Scaling"]
            scaling_writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=";")
            scaling_writer.writeheader()
            for row in scaling_data:
                scaling_writer.writerow(row)

    def dimension_check(stage):
        """
        Compute each prim's z-extent based on the extents of its bounding box. A prim's bounding box is a cuboidal block
        that fully contains a prim. The extents of the bounding box equal the maximum extents of the included prim.
        Export the results as part of a CSV file to identify those prims whose z-height equals 0.0 (and that are therefore considered to be 2D).
        """

        # List from which CSV file is exported
        extents_data = []

        """
        Construct a new BBoxCache for a specific time and set of includedPurposes. For details, see:
        https://openusd.org/dev/api/class_usd_geom_b_box_cache.html#ad4dbc60e68f738ff79c437fec435322d
        """
        boxcache = UsdGeom.BBoxCache(
            Usd.TimeCode.Default(), includedPurposes=[UsdGeom.Tokens.default_], useExtentsHint=True
        )
        # Create a bounding box instance
        for prim in stage.TraverseAll():
            if UsdGeom.Mesh(prim):
                path = prim.GetPath()
                # Compute bounding box of every prim
                box = boxcache.ComputeWorldBound(prim)
                b_range = box.GetRange()
                # Compute extents of bounding box
                max_extents = [*b_range.GetMin(), *b_range.GetMax()]
                # Compute z extents (Assumption: x,y,z coordinate system, NOT x,z,y)
                height = round(float(max_extents[5] - max_extents[2]), 1)
                extents_data.append({"Prim": path, "Z-Extents": height})

        # A CSV file is created and exported: The CSV file contains two columns with the prims' paths + z extents of the prim
        with open(outputpath_dmc, "w", newline="") as csv_file:
            fieldname = ["Prim", "Z-Extents"]
            dimension_writer = csv.DictWriter(csv_file, fieldnames=fieldname, delimiter=";")
            dimension_writer.writeheader()
            for raw in extents_data:
                dimension_writer.writerow(raw)

    def performance_check(stage):
        """
        Compute the number of vertices of each prims as well as the total number of meshes within a USD stage.
        Export the results as part of a CSV file.
        """
        # List from which CSV file is exported
        performance_list = []
        # Create a mesh counter for the number of meshes within the USD stage
        meshes = 0
        # Get vertice count of all meshes and add up total number using mesh counter
        prims = stage.TraverseAll()
        for prim in prims:
            if UsdGeom.Mesh(prim):
                key = DQCheck.get_vertice_count(prim)
                name = prim.GetPath()
                meshes += 1
                # Add number and vertice count of valid meshes to list
                if key != 0:
                    # All prims with an assigned z value are added to a list containing the name and respective z value of the prims.
                    performance_list.append({"Prim": name, "Vertices": key, "Number": meshes})

        # CSV file is created and exported: The CSV file contains the path, vertice count and assigned number of each mesh-typed prim
        with open(outputpath_pfc, "w", newline="") as csv_file:
            fieldnames = ["Prim", "Vertices", "Number"]
            performance_writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=";")
            performance_writer.writeheader()
            for row in performance_list:
                performance_writer.writerow(row)

    def empty_file_check(stage):
        """
        Identify empty prim files within a USD stage.
        Export the results as part of a CSV file.
        """
        # List to include empty prim files from which CSV file is exported
        empty_files_data = []
        """
        The converter assigns a certain name to empty files from source systems converted to USD. These files appear as empty
        prim files in the Nvidia Omniverse applications and can be identified by their name which is to be defined. 
        """
        empty_file_ident = "EMPTY_FILE"
        invalidCharactersRegex = re.compile(empty_file_ident)

        # Check all prim names for the defined characters and add empty prim files to the list.
        prims = stage.TraverseAll()
        for prim in prims:
            name = prim.GetName()
            path = prim.GetPath()
            if re.search(invalidCharactersRegex, name):
                empty_files_data.append({"Prim": path})

        # Export a CSV file including the paths of the empty prim files
        with open(outputpath_efc, "w", newline="") as csv_file:
            fieldnames = ["Prim"]
            empty_file_writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=";")
            empty_file_writer.writeheader()
            for row in empty_files_data:
                empty_file_writer.writerow(row)

    def naming_check(stage):
        """
        Check the conformity of a prim's name with its overarching folder structure. In case there are deviations from the
        defined convention the paths of these particular prims are exported as part of a CSV file.
        """
        # List including falsely named prims from which CSV file is exported
        naming_data = []
        prims = stage.TraverseAll()
        for prim in prims:
            # Split name of payload prims and check if their name is in accordance with the naming convention
            if not prim.IsLoaded():
                name = prim.GetName()
                path = prim.GetPath().pathString
                path_parts = path.split("/")
                path_parts = [part for part in path_parts if part != ""]
                name_parts = name.split("_")

                # Check for naming convention violation
                if name_parts[0] != path_parts[4] and name_parts[5] != path_parts[4]:
                    violation_check = False
                else:
                    violation_check = True
                naming_data.append({"Prim": path, "Naming Convention": violation_check})

        # Export a CSV file including the paths of the prims that violate the naming convention
        with open(outputpath_ncc, "w", newline="") as csv_file:
            fieldnames = ["Prim", "Naming Convention"]
            naming_writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=";")
            naming_writer.writeheader()
            for row in naming_data:
                naming_writer.writerow(row)
