classdef Metrics < handle
    % METRICS  Tracks and displays metrics.

    properties (Access = private)
        show_text (1, 1) logical;
        show_gui (1, 1) logical;
        show_gui_global (1, 1) logical;
        show_gui_by_node (1, 1) logical;

        node_count (1, 1) {mustBeInteger, mustBePositive} = 1;
        round_count (1, 1) {mustBeInteger};

        track_convergence (1, 1) logical;

        track_training_loss (1, 1) logical;
        training_loss_by_node (:, 1) {mustBeFloat};

        track_validation_accuracy (1, 1) logical;
        validation_accuracy_by_node (:, 1) {mustBeFloat};
    end

    properties
        monitor (1, 1);
        history (:, :) table;
    end


    methods
        function obj = Metrics(show_text, show_gui, show_gui_by_node, node_count, round_count, track_convergence, ...
                               track_training_loss, track_validation_accuracy)
            % METRICS  Constructs a new metrics tracker.
            %
            % Metrics are tracked separately for all [node_count] nodes over a period of up
            % to [round_count] rounds. [show_text] and [show_gui] determine where metrics
            % are displayed. If both [show_gui] and [show_gui_by_node] are `true`, metrics
            % will be displayed in the GUI separately for each node. The `track_*` arguments
            % determine which metrics are gathered.
            %
            % The method [start] should be invoked before any metrics are recorded.

            arguments% (Input)
                show_text (1, 1) logical = false;
                show_gui (1, 1) logical = false;
                show_gui_by_node (1, 1) logical = false;
                node_count (1, 1) {mustBeInteger, mustBePositive} = 1;
                round_count (1, 1) {mustBeInteger} = -1;
                track_convergence (1, 1) logical = false;
                track_training_loss (1, 1) logical = false;
                track_validation_accuracy (1, 1) logical = false;
            end

            obj.show_text = show_text;
            obj.show_gui = show_gui;
            obj.show_gui_global = node_count > 1 || ~show_gui_by_node;
            obj.show_gui_by_node = show_gui_by_node;
            obj.node_count = node_count;
            obj.round_count = round_count;

            obj.track_convergence = track_convergence;

            obj.track_training_loss = track_training_loss;
            obj.training_loss_by_node = inf([node_count, 1]);

            obj.track_validation_accuracy = track_validation_accuracy;
            obj.validation_accuracy_by_node = -inf([node_count, 1]);

            obj.history = table();
        end


        function obj = start(obj)
            % START  Displays the GUI if [start_gui] is `true`.

            if obj.show_gui
                obj.monitor = obj.create_monitor();
                obj.monitor.Visible = 1;
            end
        end

        function record_round(obj, round_idx, selected_node)
            % RECORD_ROUND  Records a new round.
            %
            % Records the fact that round number [round_idx] has started and that
            % [selected_node] is the node that has been selected to run the round. This
            % method must be called before calling any of the other `record_*` rounds.

            arguments% (Input)
                obj (1, 1) Metrics;
                round_idx (1, 1) {mustBeInteger, mustBePositive};
                selected_node (1, 1) {mustBeInteger, mustBePositive};
            end

            assert(round_idx >= 1, "Round index must be non-negative.");

            if round_idx > 1; obj.history{round_idx, :} = missing; end
            obj.history.Round(round_idx) = round_idx;
            obj.history.SelectedNode(round_idx) = selected_node;
        end

        function record_convergence(obj, round_idx, convergence)
            % RECORD_CONVERGENCE  Records the fact that the convergence in [round_idx] was
            % [convergence].
            %
            % Note that [record_round] must be called on [round_idx] before calling this
            % method. The display is not changed until [update_display] is called.

            arguments% (Input)
                obj (1, 1) Metrics;
                round_idx (1, 1) {mustBeInteger, mustBePositive};
                convergence (1, 1) {mustBeFloat};
            end

            assert(obj.track_convergence, "Tracking convergence has been disabled.");
            assert(round_idx <= height(obj.history), "Call `record_round` before calling other `record_*` methods.");

            obj.history.Convergence(round_idx) = convergence;
        end

        function record_training_loss(obj, round_idx, training_loss)
            % RECORD_TRAINING_LOSS  Records the fact that the selected node in [round_idx]
            % obtained [training_loss] in that round.
            %
            % Note that [record_round] must be called on [round_idx] before calling this
            % method. The display is not changed until [update_display] is called.

            arguments% (Input)
                obj (1, 1) Metrics;
                round_idx (1, 1) {mustBeInteger, mustBePositive};
                training_loss (1, 1) {mustBeFloat, mustBeNonnegative};
            end

            assert(obj.track_training_loss, "Tracking training loss has been disabled.");
            assert(round_idx <= height(obj.history), "Call `record_round` before calling other `record_*` methods.");

            obj.training_loss_by_node(obj.history.SelectedNode(round_idx)) = training_loss;

            finite_values = obj.training_loss_by_node(isfinite(obj.training_loss_by_node));
            obj.history.TrainingLoss(round_idx) = training_loss;
            obj.history.MinTrainingLoss(round_idx) = min(finite_values);
            obj.history.MeanTrainingLoss(round_idx) = mean(finite_values);
            obj.history.MaxTrainingLoss(round_idx) = max(finite_values);
        end

        function record_validation_accuracy(obj, round_idx, validation_accuracy)
            % RECORD_VALIDATION_ACCURACY  Records the fact that the selected node in
            % [round_idx] obtained [validation_accuracy] in that round.
            %
            % Note that [record_round] must be called on [round_idx] before calling this
            % method. The display is not changed until [update_display] is called.

            arguments% (Input)
                obj (1, 1) Metrics;
                round_idx (1, 1) {mustBeInteger, mustBePositive};
                validation_accuracy (1, 1) {mustBeFloat, mustBeNonnegative};
            end

            assert(obj.track_validation_accuracy, "Tracking validation accuracy has been disabled.");
            assert(round_idx <= height(obj.history), "Call `record_round` before calling other `record_*` methods.");

            obj.validation_accuracy_by_node(obj.history.SelectedNode(round_idx)) = validation_accuracy;

            finite_values = obj.validation_accuracy_by_node(isfinite(obj.validation_accuracy_by_node));
            obj.history.ValidationAccuracy(round_idx) = validation_accuracy;
            obj.history.MinValidationAccuracy(round_idx) = min(finite_values);
            obj.history.MeanValidationAccuracy(round_idx) = mean(finite_values);
            obj.history.MaxValidationAccuracy(round_idx) = max(finite_values);
        end

        function update_display(obj, logger, round_idx)
            % UPDATE_DISPLAY   Updates the display with the information of [round_idx].
            %
            % If [Metrics#show_text] is `true`, metrics are displayed textually using
            % [logger]. If [Metrics#show_gui] is `true`, metrics are displayed in the
            % [Metrics#monitor].

            arguments% (Input)
                obj (1, 1) Metrics;
                logger (1, 1) Logger;
                round_idx (1, 1) {mustBeInteger, mustBePositive};
            end

            if obj.show_text
                logger.header();
                if obj.track_convergence
                    logger.print("Convergence: %.8f.\n", obj.history.Convergence(round_idx));
                end
                if obj.track_training_loss
                    if obj.show_gui_global
                        logger.print("Loss: %.3f  global(min: %.3f, mean: %.3f, max: %.3f).\n", ...
                                     obj.history.TrainingLoss(round_idx), ...
                                     obj.history.MinTrainingLoss(round_idx), ...
                                     obj.history.MeanTrainingLoss(round_idx), ...
                                     obj.history.MaxTrainingLoss(round_idx));
                    else
                        logger.print("Loss: %.3f.\n", obj.history.TrainingLoss(round_idx));
                    end
                end
                if obj.track_validation_accuracy
                    if obj.show_gui_global
                        logger.print("Accuracy: %.3f  global(min: %.3f, mean: %.3f, max: %.3f).\n", ...
                                     obj.history.ValidationAccuracy(round_idx), ...
                                     obj.history.MinValidationAccuracy(round_idx), ...
                                     obj.history.MeanValidationAccuracy(round_idx), ...
                                     obj.history.MaxValidationAccuracy(round_idx));
                    else
                        logger.print("Accuracy: %.3f.\n", obj.history.ValidationAccuracy(round_idx));
                    end
                end
                logger.footer();
            end
            if obj.show_gui
                logger.print("Updating monitor... "); tic;

                selected_node = obj.history.SelectedNode(round_idx);

                round_metrics = struct();
                if obj.track_convergence
                    round_metrics.Convergence = obj.history.Convergence(round_idx);
                end
                if obj.track_training_loss
                    if obj.show_gui_by_node
                        round_metrics.(sprintf("TrainingLoss%d", selected_node)) = ...
                            obj.history.TrainingLoss(round_idx);
                    end

                    if obj.show_gui_global
                        round_metrics.MinTrainingLoss = obj.history.MinTrainingLoss(round_idx);
                        round_metrics.MeanTrainingLoss = obj.history.MeanTrainingLoss(round_idx);
                        round_metrics.MaxTrainingLoss = obj.history.MaxTrainingLoss(round_idx);
                    end
                end
                if obj.track_validation_accuracy
                    if obj.show_gui_by_node
                        round_metrics.(sprintf("ValidationAccuracy%d", selected_node)) = ...
                            obj.history.ValidationAccuracy(round_idx);
                    end

                    if obj.show_gui_global
                        round_metrics.MinValidationAccuracy = obj.history.MinValidationAccuracy(round_idx);
                        round_metrics.MeanValidationAccuracy = obj.history.MeanValidationAccuracy(round_idx);
                        round_metrics.MaxValidationAccuracy = obj.history.MaxValidationAccuracy(round_idx);
                    end
                end
                recordMetrics(obj.monitor, round_idx, round_metrics);

                updateInfo(obj.monitor, ...
                           CurrentNode = selected_node, ...
                           CurrentNodeRound = obj.get_rounds_recorded_for(selected_node), ...
                           UnselectedNodeCount = numel(obj.get_unrecorded_nodes()));

                if obj.round_count > 0
                    obj.monitor.Progress = 100 * round_idx / obj.round_count;
                end

                logger.append("done in %.3f seconds.\n", toc);
            end
        end


        function gather(obj)
            % GATHER  Moves all data in [obj] from GPU to CPU.

            arguments
                obj (1, 1) Metrics;
            end

            for metric = string(obj.history.Properties.VariableNames)
                metric_data = obj.history.(metric);
                if isgpuarray(metric_data); metric_data = gather(metric_data); end
                if isdlarray(metric_data); metric_data = extractdata(metric_data); end
                obj.history.(metric) = metric_data;
            end
        end

        function force_redraw_gui(obj)
            % FORCE_REDRAW_GUI   Closes the [Metrics#monitor] if it is open, and redraws it
            % completely.
            %
            % Use this method to reconstruct visual output after loading data from a `.mat`
            % file, for example from yesterday or obtained from a server.

            arguments
                obj (1, 1) Metrics;
            end

            % Close old GUI
            if obj.monitor ~= 0
                obj.monitor.Visible = 0;
            end

            % Create new GUI
            obj.show_gui = true;
            obj.monitor = obj.create_monitor();

            % Update GUI
            old_show_text = obj.show_text;
            obj.show_text = false;
            for round_idx = 1:height(obj.history)
                obj.update_display(VoidLogger(), round_idx);
            end
            obj.show_text = old_show_text;
        end


        function record = get_record(obj, name, round_idx)
            % GET_RECORD  Returns the [record] with [name] of [round_idx], or `missing` if
            % no such record exists.

            arguments% (Input)
                obj (1, 1) Metrics;
                name (1, 1) {mustBeText};
                round_idx (1, 1) {mustBeInteger};
            end
            % arguments (Output)
            %     record (1, 1);
            % end

            if round_idx <= 0 || round_idx > height(obj.history)
                record = missing;
            elseif ~ismember(name, obj.history.Properties.VariableNames)
                record = missing;
            else
                record = obj.history{round_idx, name};
            end
        end

        function round_count = get_rounds_recorded_for(obj, node)
            % GET_ROUNDS_RECORDED_FOR  Returns the number of recorded rounds that [node] has
            % participated in.

            arguments% (Input)
                obj (1, 1) Metrics;
                node (1, 1) {mustBeInteger, mustBePositive};
            end

            round_count = sum(obj.history.SelectedNode == node);
        end

        function unrecorded_nodes = get_unrecorded_nodes(obj)
            % GET_UNRECORDED_NODES  Returns all nodes that have not participated in any
            % recorded rounds.

            arguments% (Input)
                obj (1, 1) Metrics;
            end

            recorded_nodes = unique(obj.history.SelectedNode(~isnan(obj.history.SelectedNode)));
            unrecorded_nodes = setdiff(1:obj.node_count, recorded_nodes);
        end
    end

    methods (Access = private)
        function monitor = create_monitor(obj)
            % CREATE_MONITOR  Creates and displays a new [monitor] (aka GUI).

            arguments% (Input)
                obj (1, 1) Metrics;
            end
            % arguments (Output)
            %     monitor (1, 1);
            % end

            monitor = trainingProgressMonitor(XLabel = "Round", ...
                                              Metrics = [], ...
                                              Info = ["CurrentNode", "CurrentNodeRound", "UnselectedNodeCount"]);

            if obj.track_convergence
                monitor.Metrics = [monitor.Metrics, "Convergence"];
            end

            if obj.track_training_loss
                if obj.show_gui_by_node
                    node_metrics = splitlines(strtrim(sprintf("TrainingLoss%d\n", 1:obj.node_count)))';
                else
                    node_metrics = [];
                end

                if obj.show_gui_global
                    global_metrics = ["MinTrainingLoss", "MeanTrainingLoss", "MaxTrainingLoss"];
                else
                    global_metrics = [];
                end

                monitor.Metrics = [monitor.Metrics, global_metrics, node_metrics];

                if obj.show_gui_global; groupSubPlot(monitor, "GlobalTrainingLoss", global_metrics); end
                if obj.show_gui_by_node; groupSubPlot(monitor, "TrainingLoss", node_metrics); end
            end

            if obj.track_validation_accuracy
                if obj.show_gui_by_node
                    node_metrics = splitlines(strtrim(sprintf("ValidationAccuracy%d\n", 1:obj.node_count)))';
                else
                    node_metrics = [];
                end

                if obj.show_gui_global
                    global_metrics = ["MinValidationAccuracy", "MeanValidationAccuracy", "MaxValidationAccuracy"];
                else
                    global_metrics = [];
                end

                monitor.Metrics = [monitor.Metrics, global_metrics, node_metrics];

                if obj.show_gui_global; groupSubPlot(monitor, "GlobalValidationAccuracy", global_metrics); end
                if obj.show_gui_by_node; groupSubPlot(monitor, "ValidationAccuracy", node_metrics); end
            end
        end
    end
end
