% This script creates a figure of the confusion matrix, which might otherwise not be able to be shown in full

clear all
close all
clc

%%% Retrieve the confusion matrix data %%%
    CM_file_list = dir('CM*.mat');
    number_CM = length(CM_file_list);
    
    CM_data_cell = cell(1, number_CM);
    CM_labels_cell = cell(1, number_CM);
    CM_years_cell = cell(1, number_CM);
    
    accuracy_cell = cell(1, number_CM);
    precision_cell = cell(1, number_CM);
    
    for m = 1 : number_CM
        % Confusion matrix
        CM_file_name = CM_file_list(m).name;
        CM_file = load(CM_file_name);
        
        try
            confusion_matrix = CM_file.confusion_matrix_total;
        catch
            confusion_matrix = CM_file.confusion_matrix_aggregate;
        end
        
        % Year
        year_ind = isstrprop(CM_file_name, 'digit');
        year = CM_file_name(year_ind);
        CM_years_cell{m} = year;
        
        % Data and classess
        CM_data = confusion_matrix.NormalizedValues;
        CM_labels = confusion_matrix.ClassLabels;
        
        % Append the data
        CM_data_cell{m} = CM_data;
        CM_labels_cell{m} = CM_labels;
        
        % Determine the precision and accuracy of each class
        number_labels = length(CM_labels);
        
        accuracy_list = zeros(1, number_labels);
        precision_list = zeros(1, number_labels);
        
        for c = 1 : number_labels
            % This class' data            
            class_pixels = CM_data(c, c);
            class_row = CM_data(c, :);
            class_column = CM_data(:, c);
            
            % Accuracy and precision
            class_accuracy = class_pixels / sum(class_row) * 100;
            class_precision = class_pixels / sum(class_column) * 100;
            
            % In case a row/column is full of 0s, a precision/accuracy of 0 is used
            if isnan(class_accuracy)    
                class_accuracy = 0;
            end
            if isnan(class_precision)
                class_precision = 0;
            end
            
            accuracy_list(c) = class_accuracy;
            precision_list(c) = class_precision;
        end
        
        % Append the performance
        accuracy_cell{m} = accuracy_list;
        precision_cell{m} = precision_list;
    end
    
%%% Retrieve classes information %%%
    grouped_classes_list = dir('Grouped_Classes*.xls');
    grouped_classes_file = grouped_classes_list(1).name;

    % Read the file
    grouped_classes_data = readtable(grouped_classes_file);

    % Retrieve the grouped classes
    [number_classes, ~] = size(grouped_classes_data);

    class_names = cell(1, number_classes);
    class_values = zeros(1, number_classes);

    for c = 1:number_classes
        % The name of this group of classes
        group_name = grouped_classes_data{c, 1};
        class_names{c} = group_name{1};

        % The values of the classes within this group
        group_class_values = grouped_classes_data{c, 2 : end};
        group_class_values(isnan(group_class_values)) = [];

        % The first value, used to denote this group in the future
        class_values(c) = group_class_values(1);
    end
        
%%% Confusion matrix figure %%%
    for m = 1 : number_CM
        % This confusion matrix' data
        CM_data = CM_data_cell{m};
        CM_labels = CM_labels_cell{m};
        year = CM_years_cell{m};
        
        accuracy_list = accuracy_cell{m};
        precision_list = precision_cell{m};
                
        % Determine the class values corresponding to these class labels
        number_labels = length(CM_labels);
        
        CM_values = zeros(1, number_labels);
        CM_ind = zeros(1, number_labels);
        
        for l = 1 : number_labels
            CM_label = CM_labels{l};
            
            for c = 1 : number_classes
                if strcmp(class_names{c}, CM_label)
                    ind = c;
                    CM_values(l) = class_values(ind);
                    CM_ind(l) = ind;
                end
            end
        end
        
        [CM_values, order] = sort(CM_values);
        CM_ind = CM_ind(order);
        
        class_numbers = cellstr(num2str(CM_values'));
        
        class_numbers_names = cell(length(CM_values), 1);
        
        for c = 1 : length(CM_values)
            ind = CM_ind(c);
            class_entry = sprintf('%s - %s', class_numbers{c}, class_names{ind});
            class_numbers_names{c} = class_entry;
        end

        % Include accuracy/precision
        CM_data = [CM_data, accuracy_list'];
        CM_data = [CM_data; [precision_list, NaN]];
        
        % Change the format of the data
        CM_data = sprintfc('%3.2g', CM_data);
        CM_data{end} = '';
        
        x_labels = [class_numbers; 'Acc. [%]'];
        y_labels = [class_numbers_names; 'Prec. [%]'];

        % Create the table        
        T = array2table(CM_data, 'VariableNames', x_labels, 'RowNames', y_labels);
        
        % Column width
        CW = 33.5 * ones(1, number_labels + 1);
        CW(end) = 50;
        CW = num2cell(CW);
        
        f = figure(1);
                
        % Set the size and white background color
        set(gcf, 'Units', 'Normalized', 'Position', [0 0 1 1])
        set(gcf, 'color', [1, 1, 1])
        
        % Set the axes for the true/predicted labels
        a = axes('Parent', f, 'Units', 'Normalized', 'Position', [0, 0, 1, 1], 'Visible', 'off', 'XLim', [0, 1], 'YLim', [0, 1], 'NextPlot', 'add');
        
        % Table
        uit = uitable('Data',T{:,:}, 'ColumnName', T.Properties.VariableNames, 'RowName',T.Properties.RowNames, 'Units', 'Normalized', 'Position', [0.03, 0, 0, 0], 'FontSize', 5.5, 'ColumnWidth', CW);
        uit.Position(3:4) = uit.Extent(3:4);

        % Resize the row header
        uit_j = findjobj(uit);
        row_header_viewport = uit_j.getComponent(4);
        row_header = row_header_viewport.getComponent(0);
        height = row_header.getSize;
        row_header.setSize(50, 100);
        
        row_header_width = 200;
        row_header_height = 50;
        
        try
            row_header_viewport.setPreferredSize(java.awt.Dimension(row_header_width, 0));
            row_header.setSize(row_header_height, row_header_width);
        catch
            row_header.setSize(row_header_height, row_header_width);
        end
        
        % True/predicted class
        t = text(0, 0, 'True class', 'Rotation', 90, 'FontSize', 20, 'FontName', 'Ariel');
        t.Position(1) = uit.Position(1) / 2;
        t.Position(2) = 0.5 - t.Extent(4) / 2;
        
        p = text(0, 0, 'Predicted class', 'FontSize', 20, 'FontName', 'Ariel');
        p.Position(1) = 0.5 - p.Extent(3) / 2;
        p.Position(2) = uit.Position(4) + uit.Position(1) / 2 / 9 * 16;
        
        % Save the figure
        export_fig('CM_Table.png');
        export_fig('CM_Table.fig');
        
        close(1);
    end
    