% This script is useful for analysing large confusion matrices

clear all
close all
clc

%%% Inputs %%%
    confusion_limit = 5;   % The number of classes with which the classifier confused the most

%%% Retrieve the confusion matrix data %%%
    CM_file_list = dir('CM*.mat');
    number_CM = length(CM_file_list);
    
    CM_data_cell = cell(1, number_CM);
    CM_labels_cell = cell(1, number_CM);
    CM_years_cell = cell(1, number_CM);
    
    for m = 1 : number_CM
        % Confusion matrix
        CM_file_name = CM_file_list(m).name;
        CM_file = load(CM_file_name);
        
        try
            confusion_matrix = CM_file.confusion_matrix_total;
        catch
            try
                confusion_matrix = CM_file.confusion_matrix_aggregate;
            catch
                confusion_matrix = CM_file.confusion_matrix;
            end
        end
        
        % Year
        year_ind = isstrprop(CM_file_name, 'digit');
        year = CM_file_name(year_ind);
        CM_years_cell{m} = year;
        
        % Data and classess
        CM_data = confusion_matrix.NormalizedValues;
        CM_labels = confusion_matrix.ClassLabels;
        
        % Append the data
        CM_data_cell{m} = CM_data;
        CM_labels_cell{m} = CM_labels;
    end
    
%%% Analyse the confusion matrices %%%
    % Delete the excel files, if they already exist
    delete('Accuracy_Precision_Classes/*.xls');

    for m = 1 : number_CM
        % This confusion matrix' data
        CM_data = CM_data_cell{m};
        CM_labels = CM_labels_cell{m};
        year = CM_years_cell{m};
        
        number_classes = length(CM_labels);
        
        confusion_limit_CM = min(confusion_limit, number_classes - 1);      % In case there are fewer classes than the limit
        
        accuracy_list = zeros(1, number_classes);
        precision_list = zeros(1, number_classes);
        number_pixels_list = zeros(1, number_classes);
        
        for c = 1 : number_classes
            % This class' data
            class = CM_labels{c};
            
            class_pixels = CM_data(c, c);
            class_row = CM_data(c, :);
            class_column = CM_data(:, c);
            
            number_pixels_list(c) = class_pixels;
            
            if max(class_row) == 0 & max(class_column) == 0     % If this class wasn't in the domain, continuing is pointless
                continue
            end
            
            % Accuracy and precision
            class_accuracy = class_pixels / sum(class_row) * 100;
            class_precision = class_pixels / sum(class_column) * 100;
            
            % In case a row/column is full of 0s, a precision/accuracy of 0 is used
            if isnan(class_accuracy)    
                class_accuracy = 0;
            end
            if isnan(class_precision)
                class_precision = 0;
            end
            
            accuracy_list(c) = class_accuracy;
            precision_list(c) = class_precision;
            
            %--% The most commonly confused classes %--%
            class_row(c) = [];
            class_column(c) = [];
            CM_labels_c = CM_labels;
            CM_labels_c(c) = [];
            
            % Inaccuracy / false positives
            [false_positives, ind_FP] = sort(class_row, 'descend');      
            false_positives = false_positives(1 : confusion_limit_CM);
            false_positives_relative = false_positives / (sum(class_row) + class_pixels) * 100;
            ind_FP = ind_FP(1 : confusion_limit_CM);
            classes_FP = CM_labels_c(ind_FP);
            
             % Imprecision / false negatives
            [false_negatives, ind_FN] = sort(class_column, 'descend');   
            false_negatives = false_negatives(1 : confusion_limit_CM);
            false_negatives_relative = false_negatives / (sum(class_column) + class_pixels) * 100;
            ind_FN = ind_FN(1 : confusion_limit_CM);
            classes_FN = CM_labels_c(ind_FN);
            
            %--% Display the results %--%
            format bank
            table_name = sprintf('%s %s.xls', class, year);
            table_name = strrep(table_name, '/', '-');
            
            xlswrite(table_name, {'Inaccuracy'}, 'sheet1', 'A1');
            xlswrite(table_name, {'Imprecision'}, 'sheet1', 'A6');

            % Inaccuracy / false positives
            variable_names = [upper(class), classes_FP];
            row_names = {'Pixels [-]'; '(In)accuracy [%]'};
            
            pixels = [class_pixels, false_positives];
            relative_pixels = [class_accuracy, false_positives_relative];
            table_data = [pixels; relative_pixels];
            
            inaccuracy_table = table(row_names, 'VariableNames', {'Class [-]'});
            
            for c2 = 1 : confusion_limit_CM
                table_column = table(table_data(:, c2), 'VariableNames', variable_names(c2));
                inaccuracy_table = [inaccuracy_table, table_column];
            end
            
            disp('Inaccuracy')
            disp(inaccuracy_table)
            
            writetable(inaccuracy_table, table_name, 'Filetype', 'spreadsheet', 'Range', 'A2');
            
            % Imprecision / false negatives
            variable_names = [upper(class), classes_FP];
            row_names = {'Pixels [-]'; '(Im)precision [%]'};
            
            pixels = [class_pixels, false_negatives'];
            relative_pixels = [class_precision, false_negatives_relative'];
            table_data = [pixels; relative_pixels];
            
            imprecision_table = table(row_names, 'VariableNames', {'Class [-]'});
            
            for c2 = 1 : confusion_limit_CM
                table_column = table(table_data(:, c2), 'VariableNames', variable_names(c2));
                imprecision_table = [imprecision_table, table_column];
            end
            
            disp('Imprecision')
            disp(imprecision_table)
            
            class = strrep(class, ' ', '_');
            class = strrep(class, '/', '-');
            writetable(imprecision_table, table_name, 'FileType', 'spreadsheet', 'Range', 'A7');
        end
        
        % Compute the total accuracy and precision
        total_accuracy = sum(diag(CM_data)) / sum(sum(CM_data)) * 100;
        
        average_precision = sum(precision_list .* number_pixels_list) / sum(number_pixels_list);    % Precision weighted by the number of pixels
        
        % Excluding the water bodies class
        water_class_indices = [];
        
        for c = 1 : number_classes
            class = CM_labels{c};
            
            if strcmp(class, 'Water bodies')
                water_class_indices = [water_class_indices, c];
            end
        end
        
        CM_data_l = CM_data;
        CM_data_l(:, water_class_indices) = [];
        CM_data_l(water_class_indices, :) = [];

        total_accuracy_land = sum(diag(CM_data_l)) / sum(sum(CM_data_l)) * 100;

        number_pixels_list_l = number_pixels_list;
        number_pixels_list_l(water_class_indices) = [];
        precision_list_l = precision_list;
        precision_list_l(water_class_indices) = [];

        average_precision_land = sum(precision_list_l .* number_pixels_list_l) / sum(number_pixels_list_l);
        
        % Accuracy and precision table        
        performance_data = [number_pixels_list; accuracy_list; precision_list];
        
        performance_table = table({'Nr. Pixels [-]'; 'Accuracy [%]'; 'Precision [%]'}, 'VariableNames', {' '});
        
        for c = 1 : number_classes
            table_column = table(performance_data(:, c), 'VariableNames', CM_labels(c));
            performance_table = [performance_table, table_column];
        end
        
        performance_table_name = sprintf('%s_Accuracy_Precision.xls', year);
        writetable(performance_table, performance_table_name);
        
        xlswrite(performance_table_name, {'Over-all accuracy'}, 'sheet1', 'A6');
        xlswrite(performance_table_name, total_accuracy, 'sheet1', 'B6');
        xlswrite(performance_table_name, {'Average precision'}, 'sheet1', 'A7');
        xlswrite(performance_table_name, average_precision, 'sheet1', 'B7');       
        
        xlswrite(performance_table_name, {'Over-all land accuracy'}, 'sheet1', 'A9');
        xlswrite(performance_table_name, total_accuracy_land, 'sheet1', 'B9');
        xlswrite(performance_table_name, {'Average land precision'}, 'sheet1', 'A10');
        xlswrite(performance_table_name, average_precision_land, 'sheet1', 'B10');     
        
        % Move the .xls files
        if ~exist('Accuracy_Precision_Classes', 'dir')
            mkdir('Accuracy_Precision_Classes');
        end
        
        movefile('./*.xls', './Accuracy_Precision_Classes');
        
        % Accuracy and precision histogram
        hist_matrix = [[total_accuracy, fliplr(accuracy_list)]; [average_precision, fliplr(precision_list)]];
        
        % Colours
        cmap = cbrewer('qual', 'Set1', 3);
        
        figure(1)     
        
        % Set the size and white background color
        set(gcf, 'Units', 'Normalized', 'Position', [0 0 1.0 1.0]);
        set(gcf, 'color', [1, 1, 1]);
        
        b = barh(1 : number_classes + 1, hist_matrix, 'EdgeColor', 'k', 'LineWidth', 0.1);
        b(1).FaceColor = cmap(1, :);
        b(2).FaceColor = cmap(2, :);
        legend({'Accuracy', 'Precision'});
        
        b(1).BarWidth = 1.0;
        b(2).BarWidth = 1.0;

        grid on
        xlim([0, 100]);
        xticks(0 : 5 : 100);
            
        yticklabels(['\bf Total', fliplr(CM_labels)]);
        yticks(1:number_classes + 1);

        xlabel('%');
        ylabel('');

        set(gca, 'FontSize', 15);
        set(gca, 'LineWidth', 2);
        
        export_fig('Performance_Histogram.png');
        export_fig('Performance_Histogram.fig');
        
        close(1)
    end

    
    
    
    