classdef Laboratory < handle
    % LABORATORY  Runs multiple [Experiment]s, optionally repeating them, and
    % parsing their outputs.
    %
    % See also: LaboratoryConfig, PlotConfig

    properties
        % CONF  The configuration for this [Laboratory].
        conf (1, 1) LaboratoryConfig;
        % EXPERIMENTS  The [Experiment]s that were executed.
        %
        % Each row corresponds to a different experimental setup, and each column
        % corresponds to a repeated execution of that setup.
        exps (:, :) cell;  % cell<Experiment>
    end


    methods
        function obj = Laboratory(conf)
            % LABORATORY  Constructs a new [Laboratory] from the [LaboratoryConfig].

            arguments% (Input)
                conf (1, 1) LaboratoryConfig;
            end
            % arguments (Output)
            %     obj (1, 1) Laboratory;
            % end

            obj.conf = conf;
            obj.exps = cell([numel(obj.conf.exp_confs), obj.conf.repeat_count]);
        end


        function run(obj)
            % RUN  Runs all registered [Experiment]s.

            arguments% (Input)
                obj (1, 1) Laboratory;
            end

            logger = obj.conf.create_logger();
            lab_path = get_path(obj.conf.cache_dir, ...
                                sprintf(obj.conf.save_lab_target, obj.conf.save_version, obj.conf.cache_id()));

            % Load cache
            if obj.conf.part_idx == 0 && obj.conf.load_enabled && is_loadable(lab_path, obj.conf.load_validate)
                logger.print("Loading cached laboratory.\n");
                obj.exps = do_load(lab_path);
                return
            end

            % Seed
            if obj.conf.seed >= 0
                base_seed = obj.conf.seed;
                logger.print("Using configured seed %d.\n", base_seed);
            else
                base_seed = randi(2^31);
                logger.print("Using random seed %d.\n", base_seed);
            end

            % Select experiments
            assert(obj.conf.part_idx <= obj.conf.part_count, "'part_idx' must be less than or equal to 'part_count'.");
            if obj.conf.part_idx == 0
                exp_idxs = 1:numel(obj.exps);
            elseif obj.conf.part_count > numel(obj.exps)
                if obj.conf.part_idx > numel(obj.exps); return; end

                exp_idxs = obj.conf.part_idx;
            else
                space = linspace(1, numel(obj.exps), obj.conf.part_count + 1);
                exp_idxs = ceil(space(obj.conf.part_idx)):floor(space(obj.conf.part_idx + 1));
            end

            % Calculate number of workers
            workers = obj.calc_workers();
            if workers > 0 && exist("ParforProgressbar", "class") > 0
                ppm = ParforProgressbar(numel(exp_idxs));
                ppm_cleanup = onCleanup(@(it) delete(ppm));
            else
                ppm = struct();
                ppm.increment = @() 0;
            end

            % Local variables to reduce communication to workers
            obj_exps = cell(size(obj.exps));
            obj_conf = obj.conf;

            for exp_idx = exp_idxs
                [base_idx, rep_idx] = ind2sub(size(obj.exps), exp_idx);
                obj_exps{exp_idx} = Experiment([base_idx, rep_idx], ...
                                               obj.conf.create_logger(), ...
                                               obj.conf.exp_confs{base_idx});
            end

            % Run
            logger.print("Running %d experiment(s).\n", numel(exp_idxs)); total_tic = tic;

            % for exp_idx = exp_idxs
            parfor (exp_idx = exp_idxs, workers)
                exp = obj_exps{exp_idx};
                exp_path = get_path( ...
                    obj_conf.cache_dir, ...
                    sprintf(obj_conf.save_exp_target, ...
                            obj_conf.save_version, ...
                            exp.conf.cache_id(), ...
                            exp.idx(2)) ...
                );  %#ok<PFBNS> `local_conf` isn't that big

                % Load or run experiment
                if obj_conf.load_exp_behavior ~= "run" && obj_conf.load_enabled
                    if is_loadable(exp_path, obj_conf.load_validate)
                        if obj_conf.load_exp_behavior == "skip"; continue; end

                        exp = load(exp_path).obj;
                    else
                        rng(base_seed + exp_idx, "twister");
                        exp = exp.run();
                        exp.models = cell(1, 1);
                    end
                end

                % Write back to main worker
                obj_exps{exp_idx} = exp;

                % Cache
                if obj_conf.save_enabled
                    do_save(exp_path, exp);
                end

                % Progress
                ppm.increment();  %#ok<PFBNS> Required
            end

            obj.exps = obj_exps;  % Write back to `obj`
            logger.print("Completed all experiments in %.3f seconds.\n", toc(total_tic));

            % Cache
            if obj.conf.part_idx == 0 && obj.conf.save_enabled
                logger.print("Saving laboratory to cache.\n");
                do_save(lab_path, obj.exps);
            end
        end

        function plot(obj)
            % PLOT  Plots the obtained results.

            arguments% (Input)
                obj (1, 1) Laboratory;
            end

            if obj.conf.part_idx ~= 0 || ...
                   ~obj.conf.plot_show && ~obj.conf.plot_save || ...
                   obj.conf.load_exp_behavior == "skip"
                return;
            end

            logger = obj.conf.create_logger();

            for plot_idx = 1:numel(obj.conf.plot_confs)
                plot_conf = obj.conf.plot_confs{plot_idx};

                % Pre-process data
                filtered_exps = reshape(obj.exps(cellfun(plot_conf.filter_exp, obj.exps)), [], width(obj.exps));

                line_by_exp = cellfun(plot_conf.exp_to_line, filtered_exps);
                lines = unique(line_by_exp);
                [~, lines_sorting] = sort(arrayfun(@(it) find(lines(it) == line_by_exp, 1), 1:numel(lines)));
                lines = lines(lines_sorting);

                names = strings(height(lines), 1);
                points = cell(height(lines), 1);
                for i = 1:height(lines)
                    line = lines(i);
                    line_exps = reshape(filtered_exps(line_by_exp == line), [], width(line_by_exp));
                    if isempty(line_exps); continue; end

                    line_data_raw = cellfun(plot_conf.exp_to_data, line_exps, UniformOutput = false);
                    line_data_heights = num2cell(repmat(max(cellfun(@height, line_data_raw), [], 2), ...
                                                        [1, width(line_data_raw)]));
                    line_data_normal = cellfun(@(data, data_height) resize(data, data_height, ...
                                                                           Dimension = 1, ...
                                                                           FillValue = missing), ...
                                               line_data_raw, ...
                                               line_data_heights, ...
                                               UniformOutput = false);
                    line_data_avg = arrayfun(@(it) mean(cat(3, line_data_normal{it, :}), 3), ...
                                             1:height(line_data_normal), ...
                                             UniformOutput = false)';

                    if isa(plot_conf.line_to_label, "function_handle")
                        names(i) = plot_conf.line_to_label(line_exps{1});
                    end
                    points{i} = vertcat(line_data_avg{:});
                end

                nonempty_mask = ~cellfun(@isempty, points);
                names = names(nonempty_mask);
                points = points(nonempty_mask);

                if isempty(points)
                    logger.print("Skipping plot %d: No data to display.\n", plot_idx);
                    continue;
                end

                % Plot
                figure(Visible = obj.conf.plot_show);

                xlabel(plot_conf.x_label);
                ylabel(plot_conf.y_label);
                ax = gca;
                if ~plot_conf.y_sci_ticks; ax.YAxis.Exponent = 0; end
                if plot_conf.y_log; ax.YScale = "log"; end

                hold on;
                for i = 1:height(points)
                    line = points{i};
                    X = line(:, 1);
                    Y = line(:, 2);

                    plot(X, Y, DisplayName = names(i));
                end
                if any(names ~= "", "all"); legend("Location", "NorthWest"); end
                hold off;

                if obj.conf.plot_save
                    for format = obj.conf.plot_save_formats'
                        saveas(gcf, get_path(obj.conf.plot_dir, plot_conf.title), format);
                    end
                end

                title(plot_conf.title);  % Not included in saved figure
            end
        end
    end

    methods (Access = private)
        function workers = calc_workers(obj)
            % CALC_WORKERS  Calculates the number of parallel workers needed for this
            % [Laboratory].

            arguments% (Input)
                obj (1, 1) Laboratory;
            end
            % arguments (Output)
            %     workers (1, 1) {mustBeNumeric};
            % end

            if obj.conf.parallel && numel(obj.exps) > 1
                if obj.conf.parallel_max_workers >= 0
                    workers = obj.conf.parallel_max_workers;
                else
                    workers = Inf;
                end
            else
                workers = 0;
            end
        end
    end
end


function do_save(file, obj)
    % DO_SAVE  Saves [obj] into [file].
    %
    % This function is necessary because you cannot perform `save` in a `parfor`
    % loop directly.

    arguments% (Input)
        file (1, 1) {mustBeText};
        obj;
    end

    save(file, "obj", "-v7.3");
end

function loadable = is_loadable(file, validate)
    % IS_LOADABLE  Checks if [file] exists, and, if [validate] is `true`, whether
    % [file] contains a loadable object.
    %
    % This function is not strictly necessary, but is a nice counterpart to
    % [do_save] that transparently treats the name under which data is stored.

    arguments% (Input)
        file (1, 1) {mustBeText};
        validate (1, 1) logical = true;
    end
    % arguments (Output)
    %     loadable (1, 1) logical;
    % end

    loadable = exist(file, "file") && (~validate || ~isempty(whos("obj", "-file", file)));
end

function obj = do_load(file)
    % DO_LOAD  Loads [obj] from file.
    %
    % This function is not strictly necessary, but is a nice counterpart to
    % [do_save] that transparently treats the name under which data is stored.

    arguments% (Input)
        file (1, 1) {mustBeText};
    end
    % arguments (Output)
    %     obj;
    % end

    obj = load(file).obj;
end

function path = get_path(dir, filename)
    % GET_PATH  Returns the path `[dir]/[filename]`, creating [dir] if necessary.

    arguments% (Input)
        dir (1, 1) {mustBeText};
        filename (1, 1) {mustBeText};
    end
    % arguments (Output)
    %     path (1, 1) {mustBeText};
    % end

    if ~exist(dir, "dir"); mkdir(dir); end

    path = sprintf("%s/%s", dir, filename);
end
