function [datasets_train, datasets_test] = dataset_split(dataset, ...
                                                         classes, ...
                                                         node_count, ...
                                                         target_sample_count_per_node, ...
                                                         split_ratio, ...
                                                         split_by_type, ...
                                                         is_iid, ...
                                                         dirichlet_alpha)
    % DATASET_SPLIT  Splits a non-empty [dataset] into [node_count] different
    % datasets.
    %
    % * The [dataset] must have a column `"inputs"` of type `cell` and a column
    %   `"labels"` of type `categorical`.
    % * If [split_by_type] is `true`, [dataset] must also have a column `"types"` of
    %   type `categorical`.
    % * If [target_sample_count_per_node] is not negative, this number has a
    %   slightly different meaning depending on [is_iid]. If [is_iid] is `true`,
    %   this is the number of samples that will be assigned to each node, exact up
    %   to some rounding errors. Otherwise, if [is_iid] is `false`, this is the mean
    %   number of samples; assigning exactly the same number of samples is not
    %   possible due to inherent properties of i.i.d.-ness.
    % * If [target_sample_count_per_node] is negative, each node is assigned a
    %   number of samples proportional to the number of rows in [dataset].
    % * The number [target_sample_count_per_node] includes both training and test
    %   samples.
    % * The training-test split is [split_ratio], where the first number is the
    %   numerator and the second number the denominator. For example, a ratio of
    %   `[3, 4]` means that 75% of samples is for training.
    % * For each node, the training set and test set are i.i.d. with respect to the
    %   `"labels"` column.
    % * If [split_by_type] is `true`, each node is assigned data only from one type.
    % * Options [split_by_type] and [is_iid] must not both be `true`.
    % * If [is_iid] is `true`, each node-class pair has the exact same number of
    %   samples.
    % * If [is_iid] is `false`, for each class, a distribution over the nodes is
    %   determined based on a Dirichlet distribution with `k = [node_count]` and
    %   `alpha = [dirichlet_alpha]`. A low alpha (<0.1) creates an extremely
    %   imbalanced distribution, whereas a high alpha (>100) approaches an i.i.d.
    %   distribution.

    arguments% (Input)
        dataset (:, :) table {mustHaveColumns(dataset, ["inputs", "labels"]), mustBeNonempty};
        classes (:, :) categorical;
        node_count (1, 1) {mustBeInteger, mustBePositive};
        target_sample_count_per_node (1, 1) {mustBeInteger};
        split_ratio (1, 2) {mustBeInteger, mustBePositive};
        split_by_type (1, 1) logical;
        is_iid (1, 1) logical;
        dirichlet_alpha (1, 1) {mustBeFloat};
    end
    % arguments (Output)
    %     datasets_train (:, 1) cell;
    %     datasets_test (:, 1) cell;
    % end

    assert(iscell(dataset.inputs), "Dataset inputs must be of type 'cell'.");
    assert(iscategorical(dataset.labels), "Dataset labels must be of type 'categorical'.");
    assert(~(split_by_type && is_iid), "Cannot enforce i.i.d. data when splitting by type.");

    % Shuffle
    dataset = dataset(randperm(height(dataset)), :);

    % Determine which samples are used by each (node, class) pair
    if ~split_by_type
        dataset_by_class = arrayfun(@(it) dataset(dataset.labels == it, :), classes, UniformOutput = false);
        sample_count_by_class = arrayfun(@(it) height(dataset_by_class{it}), 1:width(dataset_by_class));

        if is_iid
            sample_count_by_node_by_class = ...
                repmat(floor(sample_count_by_class / node_count), node_count, 1);

            sample_count_by_node = sum(sample_count_by_node_by_class, 2);

            if any(sample_count_by_node < target_sample_count_per_node)
                error("%d samples per node were requested, but cannot assign more than %d samples per node.", ...
                      min(sample_count_by_node), ...
                      target_sample_count_per_node);
            elseif target_sample_count_per_node > 0
                scaling_factors = sample_count_by_node ./ target_sample_count_per_node;
                sample_count_by_node_by_class = floor(sample_count_by_node_by_class ./ scaling_factors);
            end
        else
            sample_count_by_node_by_class = ...
                floor(dirichlet(node_count, dirichlet_alpha, numel(classes))' .* sample_count_by_class);

            sample_count_by_node = sum(sample_count_by_node_by_class, 2);

            if target_sample_count_per_node > 0
                scaling_factors = mean(sample_count_by_node) ./ target_sample_count_per_node;
                sample_count_by_node_by_class = floor(sample_count_by_node_by_class ./ scaling_factors);
            end
        end

        if any(sum(sample_count_by_node_by_class, 1) > sample_count_by_class)
            error("Attempted to assign more samples than exist for a class. This is a bug.");
        end

        % Split indices
        train_sample_count_by_node_by_class = floor(sample_count_by_node_by_class * split_ratio(1) / split_ratio(2));
        test_sample_count_by_node_by_class = sample_count_by_node_by_class - train_sample_count_by_node_by_class;

        % Populate datasets
        variable_types = ["cell", "categorical"];
        variable_names = ["inputs", "labels"];
        datasets_train = arrayfun(@(it) table(Size = [sum(train_sample_count_by_node_by_class(it, :)), 2], ...
                                              VariableTypes = variable_types, ...
                                              VariableNames = variable_names), ...
                                  1:node_count, ...
                                  UniformOutput = false)';
        datasets_test = arrayfun(@(it) table(Size = [sum(test_sample_count_by_node_by_class(it, :)), 2], ...
                                             VariableTypes = variable_types, ...
                                             VariableNames = variable_names), ...
                                 1:node_count, ...
                                 UniformOutput = false)';

        for node = 1:node_count
            for class_idx = 1:numel(classes)
                train_set_offset = sum(train_sample_count_by_node_by_class(node, 1:(class_idx - 1)));
                train_get_offset = sum(train_sample_count_by_node_by_class(1:(node - 1), class_idx));
                train_sample_size = train_sample_count_by_node_by_class(node, class_idx);
                datasets_train{node}(train_set_offset + (1:train_sample_size), ["inputs", "labels"]) = ...
                    dataset_by_class{class_idx}(train_get_offset + (1:train_sample_size), ["inputs", "labels"]);

                test_set_offset = sum(test_sample_count_by_node_by_class(node, 1:(class_idx - 1)));
                test_get_offset = train_get_offset + train_sample_size;
                test_sample_size = test_sample_count_by_node_by_class(node, class_idx);
                datasets_test{node}(test_set_offset + (1:test_sample_size), ["inputs", "labels"]) = ...
                    dataset_by_class{class_idx}(test_get_offset + (1:test_sample_size), ["inputs", "labels"]);
            end

            datasets_train{node} = datasets_train{node}(randperm(height(datasets_train{node})), :);
            datasets_test{node} = datasets_test{node}(randperm(height(datasets_test{node})), :);

            datasets_train{node}.labels = categorical(datasets_train{node}.labels, classes);
            datasets_test{node}.labels = categorical(datasets_test{node}.labels, classes);
        end
    else
        % Validate
        assert(ismember("types", dataset.Properties.VariableNames), "Dataset misses required column 'types'.");
        assert(iscategorical(dataset.types), "Datases types must be of type 'categorical'.");

        % Select type for each node
        types = categories(dataset.types(1));
        dataset_by_type = arrayfun(@(it) dataset(dataset.types == it, :), types, UniformOutput = false);

        if target_sample_count_per_node >= 0
            sufficiently_large_types = ...
                find(cellfun(@(it) height(it), dataset_by_type) > target_sample_count_per_node);
            assert(~isempty(sufficiently_large_types), ...
                   "Dataset does not have %d types with at least %d samples.", ...
                   node_count, ...
                   target_sample_count_per_node);

            dataset_idx_by_node = randsample(sufficiently_large_types, node_count);
        else
            dataset_idx_by_node = randsample(height(dataset_by_type), node_count);
        end

        datasets = dataset_by_type(dataset_idx_by_node);

        % Split into train and test datasets
        datasets_train = cell([node_count, 1]);
        datasets_test = cell([node_count, 1]);
        for node = 1:node_count
            node_dataset = datasets{node};

            node_dataset_by_class = arrayfun(@(it) node_dataset(node_dataset.labels == it, :), classes, ...
                                             UniformOutput = false);
            node_sample_count_by_class = cellfun(@(it) height(it), node_dataset_by_class);
            node_train_sample_count_by_class = floor(node_sample_count_by_class * split_ratio(1) / split_ratio(2));

            node_train_dataset_by_class = ...
                arrayfun(@(it) node_dataset_by_class{it}(1:node_train_sample_count_by_class(it), ...
                                                         ["inputs", "labels"]), ...
                         1:numel(classes), ...
                         UniformOutput = false);
            datasets_train{node} = vertcat(node_train_dataset_by_class{:});

            node_test_dataset_by_class = ...
                arrayfun(@(it) node_dataset_by_class{it}((node_train_sample_count_by_class(it) + 1):end, ...
                                                         ["inputs", "labels"]), ...
                         1:numel(classes), ...
                         UniformOutput = false);
            datasets_test{node} = vertcat(node_test_dataset_by_class{:});
        end
    end
end
