% This script aggregates the confusion matrix based on the class groups used to classify the data, and the new class groups
% Note that this means that two Grouped_Classes*.xls files should be present

clear all
close all
clc

%%% Retrieve the grouped classes %%%
    grouped_classes_list = dir('Grouped_Classes*.xls');
    
    % The original as well as the aggregated class groups are retrieved
    number_classes_list = zeros(1, 2);
    grouped_class_values_cell = cell(1, 2);
    class_names_cell = cell(1, 2);
    class_values_cell = cell(1, 2);
    
    for g = 1 : 2
        grouped_classes_file = grouped_classes_list(g).name;

        % Read the file
        grouped_classes_data = readtable(grouped_classes_file);

        % Retrieve the grouped classes
        [number_classes, ~] = size(grouped_classes_data);

        grouped_class_values = cell(1, number_classes);
        class_names = cell(1, number_classes);
        class_values = zeros(1, number_classes);

        for c = 1:number_classes
            % The name of this group of classes
            group_name = grouped_classes_data{c, 1};
            class_names{c} = group_name{1};

            % The values of the classes within this group
            group_class_values = grouped_classes_data{c, 2 : end};
            group_class_values(isnan(group_class_values)) = [];

            grouped_class_values{c} = group_class_values;

            % The first value, used to denote this group in the future
            class_values(c) = group_class_values(1);
        end
        
        % The missing class is removed
        missing_ind = find(class_values == 999);
        class_values(missing_ind) = [];
        class_names(missing_ind) = [];
        grouped_class_values(missing_ind) = [];
        number_classes = number_classes - 1;
        
        % The variables are appended
        number_classes_list(g) = number_classes;
        grouped_class_values_cell{g} = grouped_class_values;
        class_names_cell{g} = class_names;
        class_values_cell{g} = class_values;
    end
    
%%% The aggregation preparation %%%
    % To determine which groups were used for classification, and which for aggregation, the number of classes are compared
    [~, order] = sort(number_classes_list);
    
    aggregation_ind = order(1);     % The lowest number of classes is the aggregation file
    classification_ind = order(2);
    
    % The variables belonging to classification/aggregation
    number_classes_classification = number_classes_list(classification_ind);
    number_classes_aggregation = number_classes_list(aggregation_ind);
    
    class_values_classification = class_values_cell{classification_ind};
    grouped_classes_aggregation = grouped_class_values_cell{aggregation_ind};
    
    class_names_classification = class_names_cell{classification_ind};
    class_names_aggregation = class_names_cell{aggregation_ind};
    
    % It is determined which class each of the classes will now be aggregated into
    class_indices = zeros(1, number_classes_classification);
    
    for c = 1:number_classes_classification
        class_value = class_values_classification(c);
        
        for g = 1 : number_classes_aggregation
            classes_group = grouped_classes_aggregation{g};
            
            if ismember(class_value, classes_group)
                class_indices(c) = g;
                
                continue
            end
        end
    end
    
%%% Aggregate the confusion matrices %%%
    CM_file_list = dir('CM*.mat');
    number_confusion_matrices = length(CM_file_list);
    
    for m = 1:number_confusion_matrices
        % Load the confusion matrix data
        CM_file_name = CM_file_list(m).name;
        CM_file = load(CM_file_name);
        CM_matrix_data = CM_file.confusion_matrix_total;

        CM_matrix = CM_matrix_data.NormalizedValues;
        CM_labels = CM_matrix_data.ClassLabels;
        
        % Not all classes necessarily exist in the CM
        number_CM_classes = length(CM_labels);
        
        % The actual class indices are determined
        CM_classes_indices = zeros(1, number_CM_classes);
        
        for cm = 1:number_CM_classes
            CM_class = CM_labels{cm};           
            
            for c = 1:number_classes_classification
                class = class_names_classification{c};
                
                if strcmp(CM_class, class)
                    CM_classes_indices(cm) = class_indices(c);
                    
                    continue
                end
            end
        end
        
        % The confusion matrix is created
        CM_matrix_aggregation = zeros(number_classes_aggregation);
        
        for i = 1:number_classes_aggregation
            rows = find(CM_classes_indices == i);
            
            for j = 1:number_classes_aggregation
                columns = find(CM_classes_indices == j);
                
                % Sum the CM values
                CM_value = 0;
                
                for r = rows
                    for c = columns
                        CM_value = CM_value + CM_matrix(r, c);
                    end
                end
                
                % Append the value
                CM_matrix_aggregation(i, j) = CM_value;
            end
        end
        
        % Create the new confusion matrix
        
        figure(1)
        % Set the size and white background color
        set(gcf, 'Units', 'Normalized', 'Position', [0 0 1 1]);
        set(gcf, 'color', [1, 1, 1]);
        
        confusion_matrix_aggregate = confusionchart(CM_matrix_aggregation, class_names_aggregation);
        sortClasses(confusion_matrix_aggregate, class_names_aggregation);
        confusion_matrix_aggregate.RowSummary = 'row-normalized';
        confusion_matrix_aggregate.ColumnSummary = 'column-normalized'; 
        confusion_matrix_aggregate.Title = ''; 

        set(gca, 'FontSize', 15);
        
        % Save the merged confusion matrix
        CM_name = erase(CM_file_name, '.mat');
        CM_name = [CM_name, '_Aggregate'];
        
        try
            export_fig([CM_name, '.png'])
        catch
            frame = getframe(1);
            im = frame2im(frame);
            [imind, cm] = rgb2ind(im, 256);
            imwrite(imind, cm, [CM_name, '.png']);
        end
        
        save([CM_name, '.mat'], 'confusion_matrix_aggregate');
        
        close(1)
    end




    