classdef Laboratory < handle
    % LABORATORY  Runs multiple [Experiment]s, optionally repeating them, and
    % parsing their outputs.
    %
    % See also: LaboratoryConfig, PlotConfig

    properties% (SetAccess = private)
        % CONF  The configuration for this [Laboratory].
        conf (1, 1) LaboratoryConfig;
        % LOGGER  The [Logger] to log with during execution.
        logger (1, 1) Logger;

        % EXPERIMENTS  The [Experiment]s that were executed.
        %
        % Each row corresponds to a different experimental setup, and each column
        % corresponds to a repeated execution of that setup.
        exps (:, :) Experiment;

        % RUN_DONE  Set to `true` if and only if all experiments have completed.
        run_done (1, 1) logical = false;
    end


    methods
        function obj = Laboratory(conf)
            % LABORATORY  Constructs a new [Laboratory] from the [LaboratoryConfig].

            arguments% (Input)
                conf (1, 1) LaboratoryConfig;
            end
            % arguments (Output)
            %     obj (1, 1) Laboratory;
            % end

            obj.conf = conf;
            obj.logger = Logger();
        end


        function run(obj)
            % RUN  Runs all registered [Experiment]s.

            if obj.run_done; return; end

            lab_path = get_path(obj.conf.cache_dir, ...
                                sprintf(obj.conf.save_lab_target, obj.conf.save_version, obj.conf.cache_id()));

            % Validate configuration
            assert(obj.conf.part_idx == 0 || obj.conf.save_enabled, "Experiment output is never saved.");
            assert(numel(unique(arrayfun(@cache_id, obj.conf.exp_confs(:, 1)))) == height(obj.conf.exp_confs), ...
                   "Non-unique experiments detected, hash collisions expected.");

            % Load cache
            obj.logger.print("Checking for cached laboratory.\n");
            if obj.conf.part_idx == 0 && obj.conf.load_enabled && is_loadable(lab_path, obj.conf.load_validate)
                obj.logger.header("Loading cached laboratory.\n"); time = tic;
                obj.exps = do_load(lab_path);
                obj.logger.footer("Loaded cached laboratory in %.3f second(s).\n", toc(time));
                return;
            end

            % Seed
            if obj.conf.seed >= 0
                base_seed = obj.conf.seed;
                obj.logger.print("Using configured seed %d.\n", base_seed);
            else
                base_seed = randi(2^31);
                obj.logger.print("Using random seed %d.\n", base_seed);
            end

            % Select experiments
            assert(obj.conf.part_idx <= obj.conf.part_count, "'part_idx' must be less than or equal to 'part_count'.");
            if obj.conf.part_idx == 0
                col_idxs = 1:obj.conf.repeat_count;
            elseif obj.conf.part_count > obj.conf.repeat_count
                if obj.conf.part_idx > obj.conf.repeat_count; return; end

                col_idxs = obj.conf.part_idx;
            else
                space = linspace(1, obj.conf.repeat_count, obj.conf.part_count + 1);
                col_idxs = ceil(space(obj.conf.part_idx)):floor(space(obj.conf.part_idx + 1));
            end

            % Instantiate experiments
            obj.logger.print("Instantiating experiment(s).\n");
            [base_idxs, rep_idxs] = ndgrid(1:numel(obj.conf.exp_confs), col_idxs);
            exps_cell = arrayfun(@(base_idx, rep_idx) Experiment([base_idx, rep_idx], obj.conf.exp_confs(base_idx)), ...
                                 base_idxs, ...
                                 rep_idxs, ...
                                 UniformOutput = false);
            obj.exps = reshape(vertcat(exps_cell{:}), [], numel(col_idxs));  % Storing as cell and then converting is faster

            % Create progress bar
            if exist("ProgressBar", "class") > 0
                ppm = ProgressBar(numel(obj.exps), taskname = "StretchSim", no_log = true);
                ppm_inc = @() count(ppm);
                ppm_cleanup = onCleanup(@(~) delete(ppm));
            else
                ppm_inc = @() 0;
            end

            % Create local variables to reduce communication
            obj_exps = obj.exps;
            obj_conf = obj.conf;
            do_write_back = obj.conf.will_save_lab() || obj.conf.will_plot();

            log_queue = parallel.pool.DataQueue; 
            afterEach(log_queue, @(data) obj.logger.dispatch(data(1), data(2)));

            % Run
            obj.logger.header("Running %d experiment(s).\n", numel(obj.exps)); exps_time = tic;

            % for exp_idx = 1:numel(obj.exps)
            parfor (exp_idx = 1:numel(obj.exps), obj.parpool())
                exp = obj_exps(exp_idx);
                exp_cache_id = exp.conf.cache_id();
                exp_path = get_path( ...
                    obj_conf.cache_dir + "/" + extractBetween(exp_cache_id, 1, 1) + "/" + extractBetween(exp_cache_id, 2, 2), ...
                    sprintf(obj_conf.save_exp_target, ...
                            obj_conf.save_version, ...
                            exp_cache_id, ...
                            exp.idx(2)) ...
                );  %#ok<PFBNS> `obj_conf` isn't that big

                % Load or run experiments
                save_me = false;
                if obj_conf.load_enabled && ...
                        obj_conf.load_exp_behavior ~= "run" && ...
                        is_loadable(exp_path, obj_conf.load_validate)
                    if obj_conf.load_exp_behavior == "skip"; continue; end

                    exp = do_load(exp_path);
                    if obj_conf.load_validate
                        assert(isequal(exp.conf.cache_cull(), obj_exps(exp_idx).conf.cache_cull()), ...
                               "Loaded colliding experiment.");
                    end
                    assert(exp.started);
                else
                    h = exp.conf.cache_id();
                    s = base_seed + hex2dec(extractAfter(h, strlength(h) - 12)) + exp.idx(2);
                    rng(mod(s, 2^32), "twister");

                    send(log_queue, ...
                         ["header", sprintf("Running experiment #%d-%d.\n", exp.idx(1), exp.idx(2))]); time = tic;
                    exp.run();
                    send(log_queue, ...
                         ["footer", sprintf("Ran experiment #%d-%d in %.3f second(s).\n", ...
                                            exp.idx(1), exp.idx(2), toc(time))]);

                    save_me = true;
                end

                % (Re)calculate metrics
                ks = fieldnames(obj_conf.metrics);
                for i = 1:numel(ks)
                    k = ks{i};
                    v = obj_conf.metrics.(k);

                    if isfield(exp.metrics, k); continue; end

                    exp.metrics(1).(k) = v(exp);
                    save_me = true;
                end

                % Save and write back to `obj_exps`
                if obj_conf.save_enabled && save_me; do_save(exp_path, exp, "-v7"); end
                if do_write_back; obj_exps(exp_idx) = exp; end

                % Progress
                ppm_inc();  %#ok<PFBNS> Required
            end
            obj.logger.footer("Ran experiment(s) in %.3f second(s).\n", toc(exps_time));

            obj.logger.header("Writing back experiment(s) to main thread.\n"); write_time = tic;
            if do_write_back; obj.exps = obj_exps; end
            obj.logger.footer("Wrote back experiment(s) in %.3f second(s).\n", toc(write_time));
            obj.run_done = true;

            % Cache
            if obj.conf.will_save_lab()
                obj.logger.header("Saving laboratory to cache.\n"); time = tic;
                do_save(lab_path, obj.exps, "-v7.3");
                obj.logger.footer("Saved laboratory to cache in %.3f second(s).\n", toc(time));
            end
        end

        function plot(obj)
            % PLOT  Plots the obtained results.

            if ~obj.conf.will_plot(); return; end

            obj.logger.header("Creating plots.\n"); time = tic;

            for plot_idx = 1:numel(obj.conf.plot_confs)
                plot_conf = obj.conf.plot_confs(plot_idx);
                obj.logger.print("Creating plot #%d: %s.\n", plot_idx, plot_conf.title);

                f = MyFigure.create();
                obj.plot_single(plot_conf, f);
            end

            obj.logger.footer("Done creating plots in %.3f seconds.\n", toc(time));
        end
    end

    methods (Access = private)
        function pool_desc = parpool(obj, create)
            % PARPOOL  Returns a parallel pool descriptor for this [Laboratory].
            %
            % A pool descriptor is any second argument accepted by the function [parpool].
            % This function returns either a pool instance or the integer 0.
            %
            % If the parallel processing toolbox is not installed or `obj.conf.parallel` is
            % `false`, then the integer 0 is always returned. Otherwise, if a parallel pool
            % with the desired [obj.conf.parallel_max_workers] already exists, a descriptor
            % for that pool is returned; if no such pool exists, a new parallel pool is
            % created if and only if [create] is `true`.

            arguments% (Input)
                obj (1, 1) Laboratory;
                create (1, 1) logical = true;
            end
            % arguments (Output)
            %     pool_desc (1, 1);  % Either a pool or an integer
            % end

            if isempty(ver("parallel")) || ~obj.conf.parallel
                pool_desc = 0;
            else
                cluster = parcluster("local");
                if ~isinf(obj.conf.parallel_max_workers); cluster.NumWorkers = obj.conf.parallel_max_workers; end

                old_pool = gcp("nocreate");
                if isempty(old_pool) && ~create
                    pool_desc = 0;
                elseif ~isempty(old_pool) && old_pool.NumWorkers == cluster.NumWorkers
                    pool_desc = old_pool;
                else
                    delete(old_pool);
                    pool_desc = cluster.parpool(cluster.NumWorkers);
                end
            end
        end

        function plot_single(obj, plot_conf, f)
            % PLOT_SINGLE  Plots a single plot specified by `plot_conf` into the figure `f`.

            set(0, "currentfigure", f);

            % Pre-process data
            names = strings(0);
            if plot_conf.type ~= "tiles"
                filtered_exps = reshape(obj.exps(arrayfun(plot_conf.filter_exp, obj.exps)), [], width(obj.exps));

                line_by_exp = arrayfun(plot_conf.exp_to_line, filtered_exps);
                lines = unique(line_by_exp);
                [~, lines_sorting] = sort(arrayfun(@(it) find(lines(it) == line_by_exp, 1), 1:numel(lines)));
                lines = lines(lines_sorting);

                names = strings(height(lines), 1);
                styles = repmat({{}}, height(lines), 1);
                raw = cell(height(lines), 1);
                points = cell(height(lines), 1);
                errors = cell(height(lines), 1);
                for i = 1:height(lines)
                    line = lines(i);
                    line_exps = filtered_exps(line_by_exp == line);
                    if isempty(line_exps); continue; end

                    line_data_raw = arrayfun(plot_conf.exp_to_data, line_exps, UniformOutput = false);
                    line_data_xs = cellfun(@(it) it(1), line_data_raw);
                    line_data_avg = arrayfun(@(x) mean(cell2mat(line_data_raw(line_data_xs == x)), 1), ...
                                             unique(line_data_xs), ...
                                             UniformOutput = false);
                    line_data_err = arrayfun(@(x) ci(cell2mat(line_data_raw(line_data_xs == x))), ...
                                             unique(line_data_xs), ...
                                             UniformOutput = false);

                    if isa(plot_conf.line_to_label, "function_handle")
                        names(i) = plot_conf.line_to_label(line_exps(1));
                    end
                    if isa(plot_conf.line_to_style, "function_handle")
                        styles(i) = {plot_conf.line_to_style(line_exps(1))};
                    end
                    raw{i} = vertcat(line_data_raw{:});
                    points{i} = vertcat(line_data_avg{:});
                    errors{i} = vertcat(line_data_err{:});

                    assert(width(raw{i}) == 2, "'exp_to_data' does not output 2 columns for line %d.", i);
                end

                nonempty_mask = ~cellfun(@isempty, points);
                if sum(nonempty_mask, "all") == 0
                    obj.logger.print("Skipping plot: No data to display.\n");
                    return;
                end

                names = names(nonempty_mask);
                styles = styles(nonempty_mask);
                raw = raw(nonempty_mask);
                points = points(nonempty_mask);
                errors = errors(nonempty_mask);
            end

            % Plot
            ax = gca;
            if plot_conf.type ~= "tiles"
                % Regular plot
                if plot_conf.type == "line" || plot_conf.type == "confidence" || plot_conf.type == "scatter"
                    grid on;
                    hold on;
                end
                for i = 1:height(points)
                    if plot_conf.type == "line"
                        line = points{i};
                        plot(line(:, 1), line(:, 2), styles{i}{:}, DisplayName = names(i));
                    elseif plot_conf.type == "confidence"
                        line = points{i};
                        conf_x = [line(:, 1); flipud(line(:, 1))];
                        conf_y = [line(:, 2) - errors{i}(:, 2); flipud(line(:, 2) + errors{i}(:, 2))];

                        plot_args = cell2struct(styles{i}(2:2:end)', [styles{i}{1:2:end}]);
                        plot_line = plot(line(:, 1), line(:, 2), styles{i}{:}, DisplayName = names(i));
                        plot_color = get(plot_line, "Color");
                        if isfield(plot_args, "Color")
                            plot_color = plot_args.Color;
                        end

                        shade = fill(conf_x, conf_y, 1, FaceColor = plot_color, EdgeColor = "none", FaceAlpha = 0.2);
                        copy_into(plot_conf.confidence_conf, shade);
                    elseif plot_conf.type == "boxplot"
                        subplot(height(points), 1, i);
                        data = raw{i};
                        boxplot(data(:, 2), data(:, 1));
                    elseif plot_conf.type == "scatter"
                        data = raw{i};
                        scatter(data(:, 1), data(:, 2), styles{i}{:}, DisplayName = names(i));
                    end
                end
                if plot_conf.type == "confidence"
                    % Place intervals behind lines
                    if ~isa(ax.Children(1), "matlab.graphics.primitive.Patch") || ...
                        ~isa(ax.Children(2), "matlab.graphics.chart.primitive.Line")
                        warning("Functions 'plot' and 'fill' are being called in the wrong order.");
                    end

                    ax.Children = ax.Children([2:2:end 1:2:end]);
                end
            else
                % Tiles plot
                d = size(plot_conf.tiles_confs);
                t = tiledlayout(f, d(1), d(2));

                for row = 1:d(1)
                    for col = 1:d(2)
                        idx = sub2ind(flip(d), col, row);

                        tile_conf = plot_conf.tiles_confs{row, col};
                        tile_f = MyFigure.create();
                        obj.plot_single(tile_conf, tile_f);
                        tile_ax = gca;
                        tile_ax.Parent = t;
                        tile_ax.Layout.Tile = idx;

                        if row == 1 && ~isempty(plot_conf.tiles_col_titles) && plot_conf.tiles_col_titles(col) ~= ""
                            % Display col titles as title above figures on first row
                            title(tile_ax, plot_conf.tiles_col_titles(col));
                        end
                        if col == 1 && ~isempty(plot_conf.tiles_row_titles) && plot_conf.tiles_row_titles(row) ~= ""
                            % Display row titles as extra label next to figures in first column
                            bold_title = ...
                                "\textbf{" + ...
                                join(split(plot_conf.tiles_row_titles(row), " "), "}\\\textbf{") + ...
                                "}";

                            ylabel( ...
                                tile_ax, ...
                                {bold_title, tile_ax.YLabel.String}, ...
                                Interpreter = "latex" ...
                            );
                        end
                        if tile_conf.title ~= ""
                            % Display tile title as extra label below figure
                            xlabel( ...
                                tile_ax, ...
                                {tile_ax.XLabel.String, "{\normalsize{(" + char('a' + idx - 1) + ") " + tile_conf.title + "}}"}, ...
                                Interpreter = "latex" ...
                            );
                        end
                    end
                end
            end

            % Legend
            if plot_conf.type == "line" || plot_conf.type == "confidence" || plot_conf.type == "scatter" || plot_conf.type == "tiles"
                legend_entries = plot_conf.legend_entries();
                if any(names ~= "", "all") || ~isempty(legend_entries)
                    if plot_conf.type ~= "tiles"; target_ax = ax; else; target_ax = tile_ax; end

                    if isempty(legend_entries)
                        leg = legend(target_ax);
                    else
                        figure(Visible = false);
                        leg_ax = gca;
                        hold on;

                        leg_p = zeros([height(legend_entries), 1]);
                        for i = 1:height(legend_entries)
                            leg_p(i) = plot(leg_ax, ...
                                            nan, nan, ...
                                            legend_entries{i, 2}{:}, ...
                                            DisplayName = legend_entries{i, 1});
                        end
                        leg = legend(target_ax, leg_p);
                    end
                    copy_into(plot_conf.legend_conf, leg);
                end
            end

            % Apply custom configuration
            copy_into(plot_conf.fig_conf, f);
            copy_into(plot_conf.ax_conf, ax);

            if ~isnan(plot_conf.font_size)
                % TODO[Deprecate, R2022a]: Use function "fontsize"
                set(findall(f, "-property", "FontSize"), "FontSize", plot_conf.font_size);
            end

            % Save plot
            if obj.conf.plot_save
                if plot_conf.legend_only
                    set(leg, "visible", "on");

                    ax.Position = [-0.1, -0.1, 0.01, 0.01];
                    f.Position(3:4) = leg.Position(3:4) .* f.Position(3:4) .* 1.1;
                    leg.Position(1:2) = 0.05;
                end

                for format = plot_conf.save_formats
                    filename = strrep(plot_conf.title, " ", "-");

                    if format == "bin.tikz"
                        MyFigure.export( ...
                            f, ...
                            get_path(obj.conf.plot_dir, filename + ".bin.tikz"), ...
                            ... Disable grid line at x=3, as otherwise it will be above the axis line
                            extra_axis_options = {"extra tick style={grid=none}"}, ...
                            extra_x_ticks = 3 ...
                        );  %#ok<STRSCALR> False positive
                    else
                        % `saveas` is preferred over `exportgraphics` because the latter has ugly fonts
                        saveas(f, get_path(obj.conf.plot_dir, filename), format);
                    end
                end
            end

            % Show plot
            if obj.conf.plot_show && ~plot_conf.legend_only
                if plot_conf.type == "line" || plot_conf.type == "confidence" || plot_conf.type == "scatter"
                    title(plot_conf.title);
                elseif plot_conf.type == "boxplot"
                    sgtitle(plot_conf.title);
                end
                set(f, "visible", "on");
            end
        end
    end
end


%% Helper functions
function do_save(file, obj, format)
    % DO_SAVE  Saves [obj] into [file] using [format].
    %
    % Using format "-v7" is good for small files, while format "-v7.3" is good for
    % big files.
    %
    % This function is necessary because you cannot perform `save` in a `parfor`
    % loop directly.

    arguments% (Input)
        file (1, 1) {mustBeText};
        obj;
        format (1, 1) {mustBeMember(format, ["-v7", "-v7.3"])};
    end

    save(file, "obj", format);
end

function loadable = is_loadable(file, validate)
    % IS_LOADABLE  Checks if [file] exists, and, if [validate] is `true`, whether
    % [file] contains a loadable object.
    %
    % This function is not strictly necessary, but is a nice counterpart to
    % [do_save] that transparently treats the name under which data is stored.

    arguments% (Input)
        file (1, 1) {mustBeText};
        validate (1, 1) logical = true;
    end
    % arguments (Output)
    %     loadable (1, 1) logical;
    % end

    try
        loadable = exist(file, "file") && (~validate || ~isempty(whos("obj", "-file", file)));
    catch
        fprintf("File %s is corrupted.\n", file);
        loadable = false;
    end
end

function obj = do_load(file)
    % DO_LOAD  Loads [obj] from file.
    %
    % This function is not strictly necessary, but is a nice counterpart to
    % [do_save] that transparently treats the name under which data is stored.

    arguments% (Input)
        file (1, 1) {mustBeText};
    end
    % arguments (Output)
    %     obj;
    % end

    obj = load(file).obj;
end

function path = get_path(dir, filename)
    % GET_PATH  Returns the path `[dir]/[filename]`, creating [dir] if necessary.

    arguments% (Input)
        dir (1, 1) {mustBeText};
        filename (1, 1) {mustBeText};
    end
    % arguments (Output)
    %     path (1, 1) {mustBeText};
    % end

    if ~exist(dir, "dir"); mkdir(dir); end

    path = sprintf("%s/%s", dir, filename);
end

function margin = ci(population)
    % Calculates the margin of the 95% confidence interval of the given
    % population of samples.

    margin = 1.960 * (std(population) / sqrt(numel(population)));
end
