% This script compares two confusion matrices

clear all
close all
clc

%%% Inputs %%%
    % Give part of the name of 'primary' confusion matrix
    % The delta is computed w.r.t. the primary
    primary_CM_name = 'CM_2012';
    
    primary_name = 'Normal';
    secondary_name = 'Outdated';
    
%%% Retrieve the confusion matrix data %%%
    CM_file_list = dir('CM*.mat');
    
    CM_data_cell = cell(1, 2);
    CM_labels_cell = cell(1, 2);
    CM_years_cell = cell(1, 2);
    
    for m = 1 : 2
        % Confusion matrix
        CM_file_name = CM_file_list(m).name;
        
        if contains(CM_file_name, primary_CM_name)
            primary_CM = m;
        end
        
        CM_file = load(CM_file_name);
        
        try
            confusion_matrix = CM_file.confusion_matrix_total;
        catch
            confusion_matrix = CM_file.confusion_matrix_aggregate;
        end
        
        % Year
        year_ind = isstrprop(CM_file_name, 'digit');
        year = CM_file_name(year_ind);
        CM_years_cell{m} = year;
        
        % Data and classess
        CM_data = confusion_matrix.NormalizedValues;
        CM_labels = confusion_matrix.ClassLabels;
        
        % Append the data
        CM_data_cell{m} = CM_data;
        CM_labels_cell{m} = CM_labels;
    end
    
%%% Create 'complete' matrices, in case classes are missing %%%
    % All classes
    CM_labels_total = horzcat(CM_labels_cell{:});
    CM_labels_total = unique(CM_labels_total);
    
    number_classes = length(CM_labels_total);
    
    CM_matrix_cell = cell(1, 2);
    
    for m = 1 : 2
        % This confusion matrix' data
        CM_data = CM_data_cell{m};
        CM_labels = CM_labels_cell{m};
        
        number_CM_classes = length(CM_labels);

        % Find the corresponding class indices
        class_indices = zeros(1, number_CM_classes);
        
        for cm = 1 : number_CM_classes
            CM_class = CM_labels{cm};
            
            for c = 1 : number_classes
                if strcmp(CM_class, CM_labels_total{c})
                    class_indices(cm) = c;
                end
            end
        end
        
        % Create the complete confusion matrices
        CM_data_complete = zeros(number_classes);

        for c1 = 1 : number_CM_classes
            class_ind_1 = class_indices(c1);
            
            for c2 = 1 : number_CM_classes
                class_ind_2 = class_indices(c2);
                
                CM_value = CM_data(c1, c2);
                
                CM_data_complete(class_ind_1, class_ind_2) = CM_value;
            end
        end
        
        % Append the confusion matric
        CM_matrix_cell{m} = CM_data_complete;
    end
    
%%% Determine the accuracy and precision of each matrix %%%
    accuracy_cell = cell(1, 2);
    precision_cell = cell(1, 2);
    
    for m = 1 : 2
        CM_matrix = CM_matrix_cell{m};
        
        accuracy_list = zeros(1, number_classes);
        precision_list = zeros(1, number_classes);
        
        for c = 1 : number_classes
            class_pixels = CM_matrix(c, c);
            class_row = CM_matrix(c, :);
            class_column = CM_matrix(:, c);
                        
            if max(class_row) == 0 & max(class_column) == 0     % If this class wasn't in the domain, continuing is pointless
                continue
            end
            
            % Accuracy and precision
            class_accuracy = class_pixels / sum(class_row) * 100;
            class_precision = class_pixels / sum(class_column) * 100;
            
            % In case a row/column is full of 0s, a precision/accuracy of 0 is used
            if isnan(class_accuracy)    
                class_accuracy = 0;
            end
            if isnan(class_precision)
                class_precision = 0;
            end
            
            accuracy_list(c) = class_accuracy;
            precision_list(c) = class_precision;
        end
        
        % Append the results
        accuracy_cell{m} = accuracy_list;
        precision_cell{m} = precision_list;
    end
    
%%% Difference between the two matrices %%%
    secondary_CM = setdiff([1, 2], primary_CM);
    
    accuracy_list_s = accuracy_cell{secondary_CM};
    accuracy_list_p = accuracy_cell{primary_CM};
    precision_list_s = precision_cell{secondary_CM};
    precision_list_p = precision_cell{primary_CM};
    
    accuracy_difference =  accuracy_list_s - accuracy_list_p;
    precision_difference = precision_list_s - precision_list_p;

%%% Create the histogram %%%
    accuracy_list_p_1 = accuracy_list_p;
    accuracy_list_p_1(accuracy_list_p < accuracy_list_s) = 0;
    accuracy_list_p_2 = accuracy_list_p;
    accuracy_list_p_2(accuracy_list_p > accuracy_list_s) = 0;

    precision_list_p_1 = precision_list_p;
    precision_list_p_1(precision_list_p < precision_list_s) = 0;
    precision_list_p_2 = precision_list_p;
    precision_list_p_2(precision_list_p > precision_list_s) = 0;

    % Colours
    cmap = cbrewer('qual', 'Set1', 3);

    % Accuracy
    figure(1)     

    % Set the size and white background color
    set(gcf, 'Units', 'Normalized', 'Position', [0 0 1.0 1.0]);
    set(gcf, 'color', [1, 1, 1]);

    hold on
    b_p_1 = barh(1 : number_classes, fliplr(accuracy_list_p_1), 'EdgeColor', 'k', 'LineWidth', 0.1, 'DisplayName', primary_name);
    b_s = barh(1 : number_classes, fliplr(accuracy_list_s), 'EdgeColor', 'k', 'LineWidth', 0.1, 'DisplayName', secondary_name);
    b_p_2 = barh(1 : number_classes, fliplr(accuracy_list_p_2), 'EdgeColor', 'k', 'LineWidth', 0.1, 'HandleVisibility', 'Off');
    b_d = barh(1 : number_classes, fliplr(accuracy_difference), 'EdgeColor', 'k', 'LineWidth', 0.1, 'DisplayName', '\Delta');

    legend('show')
    
    b_p_1.FaceColor = cmap(1, :);
    b_p_2.FaceColor = cmap(1, :);
    b_s.FaceColor = cmap(2, :);
    b_d.FaceColor = cmap(3, :);

    grid on
    x_lb = 10 * floor(min(accuracy_difference) / 10);
    
    xlim([x_lb, 100]);
    xticks(x_lb : 5 : 100);

    yticklabels(fliplr(CM_labels));
    yticks(1:number_classes);

    xlabel('Accuracy [%]');
    ylabel('');

    set(gca, 'FontSize', 15);
    set(gca, 'LineWidth', 2);

    export_fig('Accuracy_Delta_Histogram.png');
    export_fig('Accuracy_Delta_Histogram.fig');

    close(1)
    
    % Precision
    figure(2)     

    % Set the size and white background color
    set(gcf, 'Units', 'Normalized', 'Position', [0 0 1.0 1.0]);
    set(gcf, 'color', [1, 1, 1]);

    hold on
    b_p_1 = barh(1 : number_classes, fliplr(precision_list_p_1), 'EdgeColor', 'k', 'LineWidth', 0.1, 'DisplayName', primary_name);
    b_s = barh(1 : number_classes, fliplr(precision_list_s), 'EdgeColor', 'k', 'LineWidth', 0.1, 'DisplayName', secondary_name);
    b_p_2 = barh(1 : number_classes, fliplr(precision_list_p_2), 'EdgeColor', 'k', 'LineWidth', 0.1, 'HandleVisibility', 'Off');
    b_d = barh(1 : number_classes, fliplr(precision_difference), 'EdgeColor', 'k', 'LineWidth', 0.1, 'DisplayName', '\Delta');

    legend('show')
    
    b_p_1.FaceColor = cmap(1, :);
    b_p_2.FaceColor = cmap(1, :);
    b_s.FaceColor = cmap(2, :);
    b_d.FaceColor = cmap(3, :);

    grid on
    x_lb = 10 * floor(min(precision_difference) / 10);
    
    xlim([x_lb, 100]);
    xticks(x_lb : 5 : 100);

    yticklabels(fliplr(CM_labels));
    yticks(1:number_classes);

    xlabel('Precision [%]');
    ylabel('');

    set(gca, 'FontSize', 15);
    set(gca, 'LineWidth', 2);

    export_fig('Precision_Delta_Histogram.png');
    export_fig('Precision_Delta_Histogram.fig');

    close(2)


