%% Pre-initialisation
% Temporarily store old workspace
if exist("lab_old", "var"); lab = lab_old; end  % In case initialisation errorred
clearvars("lab_old");
if exist("lab", "var"); lab_old = lab; end  % Used later on
clearvars("-except", "lab_old");

% Read parameters from SLURM
part_count = str2double(getenv("SLURM_ARRAY_TASK_MAX"));
if isnan(part_count); part_count = 1; end
part_idx = str2double(getenv("SLURM_ARRAY_TASK_ID"));
if isnan(part_idx); part_idx = 0; end

% Create logger
logger = Logger();
logger.println("Initialising StretchSim.");


%% Configure
% For documentation, see classes [LaboratoryConfig], [ExperimentConfig], and [PlotConfig].
lab = Laboratory( ...
    LaboratoryConfig(struct( ...
        seed = 17, ...
        parallel = true, ...
        parallel_max_workers = Inf, ...
        ...
        part_count = part_count, ...
        part_idx = part_idx, ...
        ...
        load_enabled = true, ...
        load_exp_behavior = "load", ...
        load_validate = false, ...
        save_enabled = true, ...
        ...
        repeat_count = 500, ...
        ...
        plot_show = false, ...
        plot_save = true, ...
        plot_dir = "figs/", ...
        ...
        metrics = struct( ...
            stretched_eigenratio = @(it) Graphs.eigenratio(it.graph_stretched), ...
            stretched_algebraic_connectivity = @(it) Graphs.algebraic_connectivity(it.graph_stretched), ...
            stretched_closeness_centrality = @(it) Graphs.closeness_centrality(it.graph_stretched), ...
            stretched_efficiency = @(it) Graphs.efficiency(it.graph_stretched), ...
            ...
            leaves_minimised_eigenratio = @(it) Graphs.eigenratio(it.graph_leaves_minimised), ...
            leaves_minimised_algebraic_connectivity = @(it) Graphs.algebraic_connectivity(it.graph_leaves_minimised), ...
            leaves_minimised_closeness_centrality = @(it) Graphs.closeness_centrality(it.graph_leaves_minimised), ...
            leaves_minimised_efficiency = @(it) Graphs.efficiency(it.graph_leaves_minimised), ...
            ...
            optimised_eigenratio = @(it) Graphs.eigenratio(it.graph_optimised), ...
            optimised_algebraic_connectivity = @(it) Graphs.algebraic_connectivity(it.graph_optimised), ...
            optimised_closeness_centrality = @(it) Graphs.closeness_centrality(it.graph_optimised), ...
            optimised_efficiency = @(it) Graphs.efficiency(it.graph_optimised) ...
        ), ...
        ...
        exp_confs = cellfun( ...
            @(it) ExperimentConfig(it), ...
            Config.combinations( ...
                ... Base config
                { ...
                    struct( ...
                        node_count = [25, 100], ...
                        network_max_attempts = 1000, ...
                        ...
                        da_repetitions = 10, ...
                        da_coordination = "asynchronous", ...
                        da_neighbor_method = "single", ...
                        da_update_direction = "push_pull", ...
                        da_convergence_method = "error_norm", ...
                        da_convergence_threshold = 0.01, ...
                        da_max_rounds = -1 ...
                    ) ...
                }, ...
                ... Graph families
                horzcat( ...
                    {struct(network_layout = "erdos_renyi", ...
                            network_erdos_renyi_p = @(n) randf([log(n) / n, 1]))}, ...
                    {struct(network_layout = "watts_strogatz", ...
                            network_watts_strogatz_k = @(n) randi([1, floor(n / 2) - 1]), ...
                            network_watts_strogatz_p = @(~) randf([0, 1]))}, ...
                    {struct(network_layout = "barabasi_albert", ...
                            network_barabasi_albert_m = @(n) randi([1, n - 1]))}, ...
                    {struct(network_layout = "geometric_random", ...
                            network_geometric_random_d = @(~) 2, ...
                            network_geometric_random_r = @(n, ~) randf([1.1 * sqrt(log(n) / (n * pi)), 1]))} ...
                ), ...
                ... Graph post-processing
                arrayfun(@(it) {struct(network_stretch_girth = it)}, 3:10), ...
                arrayfun(@(it) {struct(network_stretch_method = it)}, ...
                         ["random", "least_cycles_steps", "most_cycles_steps"]), ...
                         ... "random"), ... ...
                arrayfun(@(it) {struct(network_minimise_leaves_method = it)}, ...
                         ["none", "random", "closest", "furthest"]), ...
                         ... "none"), ...
                cellfun(@(metric) {struct(network_optimise_metric = metric(1), ...
                                          network_optimise_direction = metric(2))}, ...
                        {["none", "maximum"], ...
                         ["eigenratio", "maximum"], ...
                         ["algebraic_connectivity", "maximum"], ...
                         ["closeness_centrality", "maximum"], ...
                         ["efficiency", "maximum"]}) ...
                        ... {["none", "maximum"]}) ...
            ) ...
        ), ...
        ...
        plot_confs = cellfun( ...
            @(it) PlotConfig(it), ...
            Config.combinations( ...
                ... Base config
                {struct( ...
                    save_formats = ["bin.tikz", "png"] ...
                )}, ...
                horzcat( ...
                    ... Stretching
                    {struct( ...
                        title = "Proportion of Edges Removed During Stretching", ...
                        ax_conf = struct( ...
                            XLabel = struct(String = "Stretched girth"), ...
                            YLabel = struct(String = "Proportion"), ...
                            YLim = [0, 1] ...
                        ), ...
                        filter_exp = @(it) it.conf.network_minimise_leaves_method == "none" && it.conf.network_optimise_metric == "none", ...
                        exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                        exp_to_data = @(it) [it.conf.network_stretch_girth, (numedges(it.graph_raw) - numedges(it.graph_stretched)) / numedges(it.graph_raw)], ...
                        line_to_style = @(it) style_conf(it.conf, ["layout", "colour"], ["stretch_method", "marker"]), ...
                        confidence_conf = struct(FaceAlpha = 0.1), ...
                        legend_conf = struct(visible = "off") ...
                    )}, ...
                    {struct( ...
                        title = "Number of Leaves After Stretching", ...
                        ax_conf = struct( ...
                            XLabel = struct(String = "Stretched girth"), ...
                            YLabel = struct(String = "Leaves"), ...
                            YTick = 0:10:50 ...
                        ), ...
                        filter_exp = @(it) it.conf.network_minimise_leaves_method == "none" && it.conf.network_optimise_metric == "none", ...
                        exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                        exp_to_data = @(it) [it.conf.network_stretch_girth, Graphs.leaf_count(it.graph_stretched)], ...
                        line_to_style = @(it) style_conf(it.conf, ["layout", "colour"], ["stretch_method", "marker"]), ...
                        legend_conf = struct(visible = "off") ...
                    )}, ...
                    {struct( ...
                        type = "tiles", ...
                        title = "Metrics After Stretching", ...
                        tiles_confs = { ...
                            reshape(arrayfun( ...
                                @(metric) {PlotConfig(struct( ...
                                    save_formats = strings(0), ...
                                    title = human("optimise_metric", metric), ...
                                    ax_conf = struct( ...
                                        XLabel = struct(String = "Stretched girth"), ...
                                        YLabel = struct(String = sentence_case(human("optimise_metric", metric))), ...
                                        XLim = [3, nan], ...
                                        YLim = struct(eigenratio = [3e-3, 6e-1], algebraic_connectivity = [2e-2, 3e1], closeness_centrality = [1.6e-1, 10^-0.1], efficiency = [2.2e-1, 10^-.1]).(metric), ...
                                        YTick = struct(eigenratio = 10.^(-2:0), algebraic_connectivity = 10.^(-1:2), closeness_centrality = 10.^[-.5, -.1], efficiency = 10.^[-.4, -.1]).(metric), ...
                                        YScale = "log", ...
                                        YMinorGrid = false ...
                                    ), ...
                                    filter_exp = @(it) it.conf.network_minimise_leaves_method == "none" && it.conf.network_optimise_metric == "none", ...
                                    exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                                    exp_to_data = @(it) [it.conf.network_stretch_girth, it.metrics.(sprintf("stretched_%s", metric))], ...
                                    line_to_style = @(it) style_conf(it.conf, ["layout", "colour"], ["stretch_method", "marker"]), ...
                                    legend_conf = struct(visible = "off") ...
                                ))}, ...
                                category_classes("optimise_metric", "none") ...
                            ), 2, 2)' ...
                        } ...
                    )}, ...
                    {struct( ...
                        title = "Convergence Time After Stretching", ...
                        ax_conf = struct( ...
                            XLabel = struct(String = "Stretched girth"), ...
                            YLabel = struct(String = "Convergence time"), ...
                            YScale = "log", ...
                            YLim = [4e2, 3.5e4], ...
                            YMinorGrid = true ...
                        ), ...
                        filter_exp = @(it) it.conf.network_minimise_leaves_method == "none" && it.conf.network_optimise_metric == "none", ...
                        exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                        exp_to_data = @(it) [it.conf.network_stretch_girth, mean(it.round_idx)], ...
                        line_to_style = @(it) style_conf(it.conf, ["layout", "colour"], ["stretch_method", "marker"]), ...
                        confidence_conf = struct(FaceAlpha = 0.1), ...
                        legend_conf = struct(visible = "off") ...
                    )}, ...
                    ...
                    ... Leaf minimisation
                    {struct( ...
                        type = "tiles", ...
                        title = "Number of Leaves After Minimisation", ...
                        tiles_confs = { ...
                            reshape(arrayfun( ...
                                @(layout) {PlotConfig(struct( ...
                                    save_formats = strings(0), ...
                                    title = human("layout", layout), ...
                                    ax_conf = struct( ...
                                        XLabel = struct(String = "Stretched girth"), ...
                                        YLabel = struct(String = "Leaves"), ...
                                        XLim = [3, nan], ...
                                        YLim = [0, 50], ...
                                        YTick = 0:10:50 ...
                                    ), ...
                                    filter_exp = @(it) it.conf.network_layout == layout && it.conf.network_optimise_metric == "none", ...
                                    exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                                    exp_to_data = @(it) [it.conf.network_stretch_girth, Graphs.leaf_count(it.graph_leaves_minimised)], ...
                                    line_to_style = @(it) style_conf(it.conf, ["stretch_method", "marker"], ["minimise_leaves_method", "colour"]), ...
                                    legend_conf = struct(visible = "off") ...
                                ))}, ...
                                category_classes("layout", "none") ...
                            ), 2, 2)' ...
                        } ...
                    )}, ...
                    {struct( ...
                        type = "tiles", ...
                        title = "Edges Added During Leaf Minimisation", ...
                        tiles_confs = { ...
                            reshape(arrayfun( ...
                                @(layout) {PlotConfig(struct( ...
                                    save_formats = strings(0), ...
                                    title = human("layout", layout), ...
                                    ax_conf = struct( ...
                                        XLabel = struct(String = "Stretched girth"), ...
                                        YLabel = struct(String = "Edges added"), ...
                                        XLim = [3, nan], ...
                                        YLim = [0, 25], ...
                                        YTick = 0:5:25 ...
                                    ), ...
                                    filter_exp = @(it) it.conf.network_layout == layout && it.conf.network_optimise_metric == "none", ...
                                    exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                                    exp_to_data = @(it) [it.conf.network_stretch_girth, numedges(it.graph_leaves_minimised) - numedges(it.graph_stretched)], ...
                                    line_to_style = @(it) style_conf(it.conf, ["stretch_method", "marker"], ["minimise_leaves_method", "colour"]), ...
                                    legend_conf = struct(visible = "off") ...
                                ))}, ...
                                category_classes("layout", "none") ...
                            ), 2, 2)' ...
                        } ...
                    )}, ...
                    {struct( ...
                        type = "tiles", ...
                        title = "Convergence Time After Leaf Minimisation", ...
                        tiles_confs = { ...
                            reshape(arrayfun( ...
                                @(layout) {PlotConfig(struct( ...
                                    save_formats = strings(0), ...
                                    title = human("layout", layout), ...
                                    ax_conf = struct( ...
                                        XLabel = struct(String = "Stretched girth"), ...
                                        YLabel = struct(String = "Convergence time"), ...
                                        XLim = [3, nan], ...
                                        YLim = [4e2, 3.5e4], ...
                                        YScale = "log", ...
                                        YMinorGrid = false ...
                                    ), ...
                                    filter_exp = @(it) it.conf.network_layout == layout && it.conf.network_optimise_metric == "none", ...
                                    exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                                    exp_to_data = @(it) [it.conf.network_stretch_girth, mean(it.round_idx)], ...
                                    line_to_style = @(it) style_conf(it.conf, ["stretch_method", "marker"], ["minimise_leaves_method", "colour"]), ...
                                    confidence_conf = struct(FaceAlpha = ternary(layout == "erdos_renyi", nan, 0.1)), ...
                                    legend_conf = struct(visible = "off") ...
                                ))}, ...
                                category_classes("layout", "none") ...
                            ), 2, 2)' ...
                        } ...
                    )}, ...
                    ...
                    ... Optimisation
                    {struct( ...
                        type = "tiles", ...
                        title = "Edges Changed During Optimisation", ...
                        tiles_confs = { ...
                            reshape(arrayfun( ...
                                @(layout) {PlotConfig(struct( ...
                                    save_formats = strings(0), ...
                                    title = human("layout", layout), ...
                                    ax_conf = struct( ...
                                        XLabel = struct(String = "Stretched girth"), ...
                                        YLabel = struct(String = "Edges changed"), ...
                                        XLim = [3, nan], ...
                                        YLim = [4, 3.5e3], ...
                                        YScale = "log", ...
                                        YMinorGrid = false ...
                                    ), ...
                                    filter_exp = @(it) it.conf.network_layout == layout && it.conf.network_minimise_leaves_method == "none" && it.conf.network_optimise_metric ~= "none", ...
                                    exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                                    exp_to_data = @(it) [it.conf.network_stretch_girth, full(sum(adjacency(it.graph_optimised) ~= adjacency(it.graph_leaves_minimised), "all"))], ...
                                    line_to_style = @(it) style_conf(it.conf, ["optimise_metric", "colour", "none"], ["stretch_method", "marker"]), ...
                                    legend_conf = struct(visible = "off") ...
                                ))}, ...
                                category_classes("layout", "none") ...
                            ), 2, 2)' ...
                        } ...
                    )}, ...
                    {struct( ...
                        type = "tiles", ...
                        title = "Convergence Time After Optimising", ...
                        tiles_confs = { ...
                            arrayfun( ...
                                @(layout, metric) {PlotConfig(struct( ...
                                    save_formats = strings(0), ...
                                    ax_conf = struct( ...
                                        XLabel = struct(String = ternary(metric == "efficiency", "Stretched girth", "")), ...
                                        YLabel = struct(String = ternary(layout == "erdos_renyi", "Convergence time", "")), ...
                                        XLim = [3, nan], ...
                                        YLim = [4e2, ternary(layout == "barabasi_albert", 14e3, 6e3)], ...
                                        YScale = "log", ...
                                        YMinorGrid = false ...
                                    ), ...
                                    filter_exp = @(it) it.conf.network_layout == layout && it.conf.network_optimise_metric == metric, ...
                                    exp_to_line = @(it) copy(it.conf).set(struct(network_stretch_girth = 4)).cache_id(), ...
                                    exp_to_data = @(it) [it.conf.network_stretch_girth, mean(it.round_idx)], ...
                                    line_to_style = @(it) style_conf(it.conf, ["stretch_method", "marker"], ["minimise_leaves_method", "colour"]), ...
                                    confidence_conf = struct(FaceAlpha = ternary(layout == "barabasi_albert", 0.1, nan)), ...
                                    legend_conf = struct(visible = "off") ...
                                ))}, ...
                                repmat(category_classes("layout"), numel(category_classes("optimise_metric", "none")), 1), ...
                                repmat(category_classes("optimise_metric", "none"), numel(category_classes("layout")), 1)' ...
                            ) ...
                        }, ...
                        tiles_row_titles = arrayfun(@(it) title_case(human("optimise_metric", it)), category_classes("optimise_metric", "none")), ...
                        tiles_col_titles = arrayfun(@(it) title_case(human("layout", it)), category_classes("layout")) ...
                    )} ...
                ) ...
            ) ...
        ) ...
    )) ...
);


%% Run
% Load old workspace if compatible
if exist("lab_old", "var") && lab.conf.cache_id() == lab_old.conf.cache_id()
    logger.println("Reusing already-loaded experiments.");
    lab_old.conf = lab.conf;
    lab = lab_old;
    lab.logger.level = 0;
end
clearvars("lab_old");

% Actually run
lab.run();
lab.plot();

logger.println("StretchSim has completed.");


%% Helper functions for plots
function out = category_classes(name, exclude)
    % CATEGORY_CLASSES  Returns the list of classes in category [name], except the
    % class [exclude].

    % TODO[Deprecate, R2022b]: Enable output argument validation everywhere
    arguments% (Input)
        name (1, 1) {mustBeText};
        exclude (1, 1) {mustBeText} = "";
    end
    % arguments (Output)
    %     out (1, :) {mustBeText};
    % end

    switch name
        case "layout"
            out = ["erdos_renyi", "watts_strogatz", "barabasi_albert", "geometric_random"];
        case "stretch_method"
            out = ["random", "least_cycles_steps", "most_cycles_steps"];
        case "minimise_leaves_method"
            out = ["none", "random", "closest", "furthest"];
        case "optimise_metric"
            out = ["none", "eigenratio", "algebraic_connectivity", "closeness_centrality", "efficiency"];
        otherwise
            error("Unknown category with name '%s'.", name);
    end

    out = out(out ~= exclude);
end

function out = human(category, class)
    % HUMAN  Returns a human-readable string for the [class] from [category]
    % returned by [category_classes].

    arguments% (Input)
        category (1, 1) {mustBeText};
        class (1, 1) {mustBeText};
    end
    % arguments(Output)
    %     out (1, 1) {mustBeText};
    % end

    switch category
        case "layout"
            out = struct( ...
                erdos_renyi = "Erdős--Rényi", ...
                watts_strogatz = "Watts--Strogatz", ...
                barabasi_albert = "Barabási--Albert", ...
                geometric_random = "Geometric" ...
            ).(class);
        case "stretch_method"
            out = struct( ...
                random = "Random stretching", ...
                least_cycles_steps = "Least-cycles stretching", ...
                most_cycles_steps = "Most-cycles stretching" ...
            ).(class);
        case "minimise_leaves_method"
            out = struct( ...
                none = "No minimisation", ...
                random = "Random minimisation", ...
                closest = "Closest minimisation", ...
                furthest = "Furthest minimisation" ...
            ).(class);
        case "optimise_metric"
            out = struct( ...
                none = "No optimisation", ...
                eigenratio = "Eigenratio", ...
                algebraic_connectivity = "Algebraic connectivity", ...
                closeness_centrality = "Closeness centrality", ...
                efficiency = "Efficiency" ...
            ).(class);
        otherwise
            error("Unknown category with name '%s'.", name);
    end
end

function out = style_conf(conf, varargin, opts)
    % STYLE_CONF  Builds style config arguments to pass to a function such as [plot]
    % or [scatter].
    %
    % [varargin] is a set of style descriptors. Each descriptor describes how a
    % category of classes is visually distinguishable. The format for each
    % descriptor is `[category, visualisation, exclude]`, where `exclude` is
    % optional. The [visualisation] chooses the type of distinction, and [category]
    % and [exclude] determine the set of classes using [category_classes].
    %
    % [visualisation] must be one of: "colour", "marker", "line".
    %
    % Finally, [opts] may contain some parameter affecting the display style in
    % a custom way.

    arguments% (Input)
        conf (1, 1) ExperimentConfig;
    end
    arguments (Repeating)% (Input, Repeating)
        varargin (1, :) {mustBeText};
    end
    arguments% (Input)
        opts.line_width (1, 1) {mustBeNumeric} = 0.5;
        opts.marker_size (1, 1) {mustBeNumeric} = 4;
    end
    % arguments (Output)
    %     out (:, :) cell;
    % end

    colours = rgb2hex(colororder);
    markers = ["+", "*", "^"];
    lines = ["-", "--", ":", "-."];

    out = cell([numel(varargin), 1]);
    for idx = 1:numel(varargin)
        descriptor = varargin{idx};
        category = descriptor(1);
        visualisation = descriptor(2);
        if numel(descriptor) > 2; exclude = descriptor(3); else; exclude = ""; end

        classes = category_classes(category, exclude);
        class = conf.(sprintf("network_%s", category));
        if islogical(class); class = mat2str(class); end
        class_idx = find(classes == class);

        switch visualisation
            case "colour"
                out(idx) = {{"Color", colours(class_idx)}};
            case "marker"
                out(idx) = {{"Marker", markers(class_idx)}};
            case "line"
                out(idx) = {{"LineStyle", lines(class_idx)}};
            otherwise
                error("Unknown visualisation type '%s'.", visualisation);
        end
    end

    out = horzcat({"LineWidth", opts.line_width, "MarkerSize", opts.marker_size}, out{:});
end

function out = legend_conf(varargin)
    % LEGEND_CONF  Specifies styles for the custom lines to be drawn in the legend;
    % see also [PlotConfig#legend_entries].
    %
    % Each [varargin] is a style descriptor, exactly as in [style_conf], except that
    % this method also allows a special style descriptor `["pad"]` (using only a
    % single parameter) which results in one entry being skipped in the legend.

    arguments (Repeating)% (Input, Repeating)
        varargin (:, 1) {mustBeText};
    end
    % arguments (Output)
    %     out (:, :) cell;
    % end

    colours = rgb2hex(colororder);
    markers = ["+", "*", "^"];
    lines = ["-", "--", ":", "-."];

    out = cell([numel(varargin), 1]);
    for idx = 1:numel(varargin)
        descriptor = varargin{idx};
        if descriptor(1) == "pad"; out(idx) = {{"", {"LineStyle", "none"}}}; continue; end  %#ok<CLARRSTR>

        category = descriptor(1);
        visualisation = descriptor(2);
        if numel(descriptor) > 2; exclude = descriptor(3); else; exclude = ""; end

        classes = category_classes(category, exclude);
        switch visualisation
            case "colour"
                format = @(idx) {"Color", colours(idx)};
            case "marker"
                format = @(idx) {"Marker", markers(idx), "LineStyle", "none", "Color", "black"};
            case "line"
                format = @(idx) {"LineStyle", lines(idx), "Color", "black"};
            otherwise
                error("Unknown visualisation type '%s'.", visualisation);
        end

        out_i = arrayfun(@(idx) {human(category, classes(idx)), format(idx)}, 1:numel(classes), UniformOutput = false);
        out_i = vertcat(out_i{:});
        out(idx) = {out_i};
    end

    out = vertcat(out{:});
end

function out = title_case(str)
    % TITLE_CASE  Changes the first letter of each word in [str] to uppercase, and
    % all other letters to lowercase.

    [words, delim] = split(str, [" ", "-", "–"]);
    for i = 1:numel(words)
        if strlength(words(i)) ~= 0
            words(i) = upper(extract(words(i), 1)) + lower(extractAfter(words(i), 1));
        end
    end
    out = join(words, delim);
end

function out = sentence_case(str)
    % SENTENCE_CASE  Changes the first letter of [str] to uppercase, and all other
    % letters to lowercase.

    [words, delim] = split(str, [" ", "-"]);
    words(1) = title_case(words(1));
    words(2:end) = lower(words(2:end));
    out = join(words, delim);
end

function out = ternary(condition, outTrue, outFalse)
    % TERNARY  Returns `outTrue` if `condition` is `true`, and returns `outFalse` otherwise.
    if condition
        out = outTrue;
    else
        out = outFalse;
    end
end
