classdef Experiment < handle
    % EXPERIMENT  Runs a distributed averaging experiment.
    %
    % An experiment consists of distributing a dataset amongst several nodes, and
    % then, in each round, selecting a node which learns from their neighbours,
    % continuing until the nodes' models have converged.
    %
    % See also: ExperimentConfig, StretchSim

    properties (SetAccess = private)
        % IDX  This [Experiment]'s unique identifier, containing both its identity and
        % its repetition number.
        idx (1, 2) {mustBeInteger, mustBePositive} = [1, 1];
        % CONF  The configuration on how to run the [Experiment].
        conf (1, 1) ExperimentConfig;

        % ATTEMPT_IDX  The current graph generating attempt.
        attempt_idx (1, 1) {mustBeInteger, mustBeNonnegative} = 0;
        % GRAPH_RAW  The created graph before any additional processing takes place.
        graph_raw (1, 1) graph;
        % GRAPH_STRETCHED  The graph after performing stretching.
        graph_stretched (1, 1) graph;
        % GRAPH_LEAVES_MINIMISED  The graph after minimising the number of leaves.
        graph_leaves_minimised (1, 1) graph;
        % GRAPH_OPTIMISED  The graph after performing optimisation.
        graph_optimised (1, 1) graph;
        % GRAPH  The final graph used for the distributed algorithm.
        graph (1, 1) graph;

        % STARTED  `true` if and only if this experiment has started running.
        started (1, 1) logical = false;
        % ROUND_IDX  The current distributed averaging round for each repetition.
        round_idx (1, :) {mustBeInteger, mustBeNonnegative};
    end
    properties
        % METRICS  A struct for caching arbitrary meta-data on the experiment.
        %
        % This field can be used as follows:
        % * Assign a new metric using `obj.metrics(1).foo = "bar"`.
        % * Update an existing metric using `obj.metrics.foo = "baz"`.
        % * Read an existing metric using `obj.metrics.foo`.
        %
        % Calculated metrics are not stored in the cache.
        metrics struct = struct([]);
    end


    methods
        function obj = Experiment(idx, conf)
            % EXPERIMENT  Constructs experiment number [idx] using [config].

            arguments% (Input)
                idx (1, 2) {mustBeInteger, mustBePositive};
                conf (1, 1) ExperimentConfig;
            end
            % arguments (Output)
            %     obj (1, 1) Experiment;
            % end

            obj.idx = idx;
            obj.conf = conf;
            obj.round_idx = zeros([1, obj.conf.da_repetitions]);
        end


        function run(obj)
            % RUN  Runs this experiment.

            assert(~obj.started, "Cannot re-start experiment.");
            obj.started = true;


            %% Create network graph
            % Generate
            n = randi(obj.conf.node_count);
            switch obj.conf.network_layout
                case "erdos_renyi"
                    p = obj.conf.network_erdos_renyi_p(n);
                case "watts_strogatz"
                    k = obj.conf.network_watts_strogatz_k(n);
                    p = obj.conf.network_watts_strogatz_p(n);
                case "barabasi_albert"
                    m = obj.conf.network_barabasi_albert_m(n);
                case "geometric_random"
                    d = obj.conf.network_geometric_random_d(n);
                    r = obj.conf.network_geometric_random_r(n);
            end

            obj.attempt_idx = 0;
            while true
                obj.attempt_idx = obj.attempt_idx + 1;
                assert(obj.attempt_idx <= obj.conf.network_max_attempts, ...
                       "Failed to generate network in maximum number of attempts.");

                switch obj.conf.network_layout 
                    case "erdos_renyi"
                        obj.graph_raw = Graphs.generate_erdos_renyi(n, p);
                    case "watts_strogatz"
                        obj.graph_raw = Graphs.generate_watts_strogatz(n, k, p);
                    case "barabasi_albert"
                        obj.graph_raw = Graphs.generate_barabasi_albert(n, m);
                    case "geometric_random"
                        obj.graph_raw = Graphs.generate_geometric_random(d, n, r);
                    case "complete"
                        obj.graph_raw = Graphs.generate_complete(n);
                    case "empty"
                        obj.graph_raw = Graphs.generate_empty(n);
                end

                if ~obj.conf.network_require_connected || Graphs.is_connected(obj.graph_raw); break; end
            end

            % Stretch
            obj.graph_stretched = ...
                Graphs.stretch( ...
                    obj.graph_raw, ...
                    girth = obj.conf.network_stretch_girth, ...
                    method = obj.conf.network_stretch_method ...
                );

            % Minimise leaves
            if obj.conf.network_minimise_leaves_method ~= "none"
                obj.graph_leaves_minimised = ...
                    Graphs.minimise_leaves( ...
                        obj.graph_stretched, ...
                        method = obj.conf.network_minimise_leaves_method, ...
                        girth = obj.conf.network_stretch_girth ...
                    );
            else
                obj.graph_leaves_minimised = obj.graph_stretched;
            end

            % Optimise
            if obj.conf.network_optimise_metric ~= "none"
                obj.graph_optimised = ...
                    Graphs.optimise( ...
                        obj.graph_leaves_minimised, ...
                        metric = obj.conf.network_optimise_metric, ...
                        direction = obj.conf.network_optimise_direction, ...
                        girth = obj.conf.network_stretch_girth, ...
                        leaves_minimised = obj.conf.network_minimise_leaves_method ~= "none" ...
                    );
            else
                obj.graph_optimised = obj.graph_leaves_minimised;
            end

            % Validate
            obj.graph = obj.graph_optimised;

            assert(Graphs.is_connected(obj.graph), "Created disconnected graph.");
            assert(Graphs.girth(obj.graph) >= obj.conf.network_stretch_girth, "Created low-girth graph.");
            assert(isisomorphic(simplify(obj.graph), obj.graph), "Created non-simple graph.");


            %% Run distributed averaging
            for repetition = 1:obj.conf.da_repetitions
                % Initialize models
                models = 50 * rand([n, 1]);
                true_mean = mean(models);
                true_norm = norm(models);
    
                % Run rounds
                while true
                    obj.round_idx(repetition) = obj.round_idx(repetition) + 1;
    
                    % Average
                    if obj.conf.da_coordination == "synchronous"
                        % Update
                        models = (models + adjacency(obj.graph) * models) ./ (1 + degree(obj.graph));
                    elseif obj.conf.da_coordination == "asynchronous"
                        % Select
                        sel_node = randelem(1:n);
                        sel_node_neighs = neighbors(obj.graph, sel_node);
                        sel_node_deg = numel(sel_node_neighs);
    
                        if obj.conf.da_neighbor_method == "single"
                            sel_neighs = randelem(sel_node_neighs);
                        elseif obj.conf.da_neighbor_method == "multiple"
                            while true
                                sel_neighs = sel_node_neighs(rand([sel_node_deg, 1]) < (1 / sel_node_deg));
                                if ~isempty(sel_neighs); break; end
                            end
                        elseif obj.conf.da_neighbor_method == "all"
                            sel_neighs = sel_node_neighs;
                        end
    
                        % Update
                        updated_model = mean([models(sel_node); models(sel_neighs)]);
                        if obj.conf.da_update_direction == "pull" || obj.conf.da_update_direction == "push_pull"
                            models(sel_node) = updated_model;
                        end
                        if obj.conf.da_update_direction == "push" || obj.conf.da_update_direction == "push_pull"
                            models(sel_neighs) = updated_model;
                        end
                    end
    
                    % Analyse
                    if obj.conf.da_convergence_threshold >= 0
                        if obj.conf.da_convergence_method == "mutual_distance"
                            convergence = max(models) - min(models);
                        elseif obj.conf.da_convergence_method == "max_error"
                            convergence = max(models - true_mean);
                        elseif obj.conf.da_convergence_method == "error_norm"
                            convergence = norm(models - true_mean) / true_norm;
                        end
                    end
    
                    % Break
                    if obj.conf.da_max_rounds >= 0 && obj.round_idx >= obj.conf.da_max_rounds; break; end
                    if obj.conf.da_convergence_threshold >= 0 && ...
                            convergence <= obj.conf.da_convergence_threshold; break; end
                end
            end
        end


        function fig = plot_stretching(obj)
            % PLOT_STRETCHING  Visualizes how this experiment's graph got stretched.

            fig = figure;
            sgtitle(sprintf("Stretching to girth %d", obj.conf.network_stretch_girth));

            P1 = subplot(1, 3, 3);
            H = plot(obj.graph, Layout = "force3");
            title(sprintf("Edges compensated"));

            P2 = subplot(1, 3, 2);
            plot(Graphs.stretched(obj.graph_raw, ...
                                  girth = obj.conf.network_stretch_girth, ...
                                  remove_filter = "all_cycles", ...
                                  remove_select = "random", ...
                                  place = "none"), ...
                 XData = H.XData, YData = H.YData, ZData = H.ZData);
            title("Edges removed");

            P3 = subplot(1, 3, 1);
            plot(obj.graph_raw, XData = H.XData, YData = H.YData, ZData = H.ZData);
            title("Initial graph");

            setappdata(fig, MyCameraLink = linkprop([P1, P2, P3], ...
                                                    {"CameraUpVector", "CameraPosition", "CameraTarget", ...
                                                     "XLim", "YLim", "ZLim"}));
        end
    end
end
