%{
Automated translational alignment and axis measurement of ellipsoid objects
Requires DIPImage (https://diplib.org/), Parallel Computing Toolbox, Signal Processing Toolbox


Copyright Enya Berrevoets, TU Delft, 2024
Licensed under the Apache License, Version 2.0
%}

%%
% clearvars -except Stack_save Stack
% clear
% close all

dipstart

addpath(genpath(pwd))
[message] = PSP_CodeCheck();
if ~isempty(message)
    return
end

%% User input

projDir = 'Demo';        % Directory with the raw images
locDir =  'Demo';        % Directory to store rotated images

headDir = 'Demo_data';
pcaCol = 'Sum';        % Perform PCA on 'Orange'/'Red'/'Both'/'Sum'
Colours = ["Sum","Orange", "Red"];  % Channel colours to align. Put pcaCol first.
Nsegs = 10;                     % Number of bins
binmet = 'AbCov';       % Use fixed bin size or fixed bin length interval['FixN'/'AbCov']
sellocs = [1];          % Probe positions
sortmet = "Length";     % Method to sort over bins ['Var'/'Length'/'Vol']
fitcol = ["Sum"]; % Channel colours to measure PSP size over

sigma_psf = 0;
px = 30;                % pixel size [nm]
shiftCol = pcaCol;      % Images to calculate translation from ['Orange'/'Red'/'Both'/'Sum']

savedata = 1;
saveims = 1;
rerun = 1;
refit = 1;
fitraws = 1;
loadStack = 1;

sliceplt = 0;

loopLS = any([rerun, refit, sliceplt]);

Nits = 2;       % Number of iterations for translational alignment
init = 'Sum';   % Initial registration ["S1", "Sum", "Empty"]
dxPeak = 1;

Npar = 2;
if Npar==1
    parSegs = [1,Nsegs+1];
else
    parSegs = [round(linspace(1,Nsegs+1,Npar+1))];
end

RF_method = "50/50";      % Method to fit radius ["dropoff","50/50", "meanR"]
check2 = 1;               % Check for double paraspeckles
scaling_cutoff = 1.5;     % Ellipticity cut-off for plotting
Vthr = 0.5;               % Threshold to include in volume estimate, normalised to max image intensity

%% Load data

loaddir = [locDir  '\' headDir '\PCA_' pcaCol];

fileNames = {'Orange_cellvars', 'Red_cellvars', 'Sum_cellvars', 'Orange_cellints', ...
    'Red_cellints', 'cellRvol', 'cellOutbox', 'cellVol'};
cellNames = {'cellVarA', 'cellVarB', 'cellVarSum', 'cellIntA', ...
    'cellIntB', 'cellRvol', 'cellOutbox', 'cellVol'};

for i = 1:numel(fileNames)
    filePath = fullfile(loaddir, [fileNames{i} '.mat' ]);
    if exist(filePath, 'file')
        loadedData = load(filePath);
        eval([cellNames{i} ' = loadedData.' cellNames{i} ';']);
    end
end

[locs, Nlocs, leg] = PSP_getLocs([projDir '\' headDir]);

varpath = [locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_'  binmet];

if exist([varpath 'Allfitrad'],'file')
    Allfitrad = load([varpath 'Allfitrad.mat']).Allfitrad;
    Allrawrad = load([varpath 'Allrawrad.mat']).Allrawrad;
    DoubleOut = load([varpath 'DoubleOut.mat']).DoubleOut;
    DoubleOutraw = load([varpath 'DoubleOutraw.mat']).DoubleOutraw;
else
    Allfitrad = zeros([Nlocs, Nsegs,length(fitcol),3]);
    Allrawrad = cell(10,1);
    DoubleOut = cell(2,1);
    DoubleOutraw = cell(2,1);
end

dPSP = DoubleOut{1};
outbox = DoubleOut{2};
dPSPraw = DoubleOutraw{1};
outraw = DoubleOutraw{2};

%%

StackSort = [];

if loopLS

    for ll = sellocs

        if exist('cellVarSum','var')
            [varssort, varsegs, segs, idx_vars] = PSP_VarSort(sortmet, Nsegs, ...
                cellVarA{ll}, cellVarB{ll}, cellVarSum{ll}, cellRvol{ll}, cellVol{ll});
        else
            [varssort, varsegs, segs, idx_vars] = PSP_VarSort(sortmet, Nsegs, ...
                cellVarA{ll}, cellVarB{ll}, [], cellRvol{ll}, cellVol{ll});
        end

        IntA = cellIntA{ll};

        if any(strcmp(Colours, "Red"))
            IntB = cellIntB{ll};
        end

        shiftSave = cell(Npar,1);

        %%

        for icol = 1:length(Colours)

            col = char(Colours(icol));

            if strcmp(Colours(icol), "Orange")
                AllInt = IntA(idx_vars);
            elseif strcmp(Colours(icol), "Red")
                AllInt = IntB(idx_vars);
            else
                AllInt = IntA(idx_vars);
            end

            AllVol = cellVol{ll}(idx_vars);

            if rerun
                if loadStack
                    Stack = struct2cell(load([projDir '\' headDir '\' locs(ll).name '\PCA_' pcaCol '\' col '_aligned.mat']));
                    Stack = Stack{1}(:,:,:,idx_vars);
                elseif exist('Stack_save','var')
                    Stack = Stack_save{1}(:,:,:,idx_vars);
                end

                sz = size(Stack);

                Nims = sz(end);

                if ~exist([locDir '\' headDir '\' locs(ll).name],'dir')      % Make folder to save reference segmentations in
                    mkdir([locDir '\' headDir '\' locs(ll).name]);
                end
                save([locDir '\' headDir '\' locs(ll).name  '\Stacksize.mat',''], 'sz');
            else
                sz = load([locDir '\' headDir '\' locs(ll).name '\Stacksize.mat']).sz;
            end

            segs = linspace(0,sz(end),Nsegs+1);

            maxL = 0;
            for ps = 1:Npar


                LS{ps} = parSegs(ps):(parSegs(ps+1)-1);
                Lps = 0;
                for ss = LS{ps}

                    if strcmp(binmet,'FixN')
                        segP{ps}{ss - min(LS{ps})+1} = floor(segs(ss)+1):floor(segs(ss+1));
                    elseif strcmp(binmet, 'AbCov')
                        segP{ps}{ss - min(LS{ps})+1}= find(varssort>=varsegs(ss)&varssort<varsegs(ss+1))';
                    end

                    Lps = Lps+length(segP{ps}{ss - min(LS{ps})+1});

                    if rerun
                        StackSeg{ps}{ss - min(LS{ps})+1} = Stack(:,:,:,segP{ps}{ss - min(LS{ps})+1});

                    end

                end
                if Lps>maxL
                    maxL = Lps;
                    maxps = ps;
                end

            end

            idss = 0;

            %% Particle registration

            if rerun
                RegAP = cell(Npar,1);
                RawP = cell(Npar,1);
                RegInitP = cell(Npar,1);

                if Npar>1

                    parfor ps = 1:Npar

                        pp = 'C:\Program Files\DIPimage 2.9\common\dipimage';
                        addpath(pp);
                        dip_initialise;

                        [RegInitP{ps}, RegAP{ps}, RawP{ps}, shiftSave{ps}, StackSaveP{ps}] = fRegPar(sz, ll, Nlocs, col,  Nsegs, ps, StackSeg{ps}, parSegs, init, Nits, shiftCol, dxPeak, shiftSave{ps}, maxps);

                    end
                else
                    ps = 1;

                    [RegInitP{ps}, RegAP{ps}, RawP{ps}, shiftSave{ps}, StackSaveP{ps}] = fRegPar(sz, ll, Nlocs, col,  Nsegs, ps, StackSeg{ps}, parSegs, init, Nits, shiftCol, dxPeak, shiftSave{ps}, maxps);

                end
            end

            msg = fprintf('Probe %d/%d, %s, Segment %d/%d', ll, Nlocs, col, 0, Nsegs);

            for ps = 1:Npar

                for ss = LS{ps}

                    idss = idss+1;

                    istart =  min(LS{ps})-1;
                    seg = segP{ps}{ss - istart};
                    fprintf(repmat('\b',1,msg))
                    msg = fprintf('Probe %d/%d, %s, Segment %d/%d', ll, Nlocs, col, ss, Nsegs);

                    if isempty(seg)
                        continue
                    end

                    savedir = [locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\' locs(ll).name '\' col];
                    if ~exist(savedir,'dir')      % Make folder to save reference segmentations in
                        mkdir(savedir);
                    end

                    savename = [ binmet 'Init' char(init) '_Im' num2str(seg(1)) 'to' num2str(seg(end)) '_' num2str(Nits) 'its.mat'];

                    Nims = length(seg);

                    if Nims==0
                        AllSegs(ll,ss,:) = [NaN, NaN, 0, NaN];
                        continue
                    end

                    if rerun
                        RegA = RegAP{ps}(:,:,:,ss-istart);
                        RegInit = RegInitP{ps}(:,:,:,ss-istart);
                        StackSave = StackSaveP{ps}{ss-istart};
                        Raw = RawP{ps}{ss-istart};

                        if savedata

                            save([savedir '\' col '_RegInit_' savename,''], 'RegInit','-v7');
                            save([savedir '\' col '_Reg_' savename,''], 'RegA','-v7');
                            save([savedir '\' col '_RawParameters_' savename,''], 'Raw','-v7');
                            save([savedir '\' col '_RawStack_' savename,''], 'StackSave','-v7');

                        end

                    else
                        savedir = [locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\' locs(ll).name '\' col];
                        savename = [binmet 'Init' char(init) '_Im' num2str(seg(1)) 'to' num2str(seg(end)) '_' num2str(Nits) 'its.mat'];
                        RegInit = load([savedir '\' col '_RegInit_' savename]).RegInit;
                        RegA = load([savedir '\' col '_Reg_' savename]).RegA;
                        Raw = load([savedir '\' col '_RawParameters_' savename]).Raw;
                    end


                    SegInts(ll,ss,icol) = mean(AllInt(seg));
                    IntErr(ll,ss,icol) = std(AllInt(seg));

                    SegVol(ll,ss,icol,1) = nnz(RegA>=max(RegA,[],'all')*Vthr)*px^3;

                    vertices = isosurface(RegA,Vthr).vertices;
                    vertices = vertices - sz(1:3)/2;


                    [az,pol,RR] = cart2sph(vertices(:,1), vertices(:,2), vertices(:,3));
                    RRdecon = RR - (sigma_psf*sqrt(2*log(2))/px);
                    [vertices(:,1), vertices(:,2), vertices(:,3)] = sph2cart(az, pol, RRdecon);

                    [ch,V_ch] = convhull(vertices);
                    SegVol(ll,ss,icol,2) = V_ch*px^3;

                    if refit && any(strcmp(fitcol, col))

                        RegA = CentreIm(RegA);
                        RegA = RegA./max(RegA,[],'all');

                        [rotmat, CoM, eigvals] = PSP_Align(RegA);

                        if col=="Orange"
                            VarAReg(ll,ss,:) = sqrt(eigvals);
                        elseif col=="Red"
                            VarBReg(ll,ss,:) = sqrt(eigvals);
                        end

                        %% Fit averaged radii

                        if strcmp(col, pcaCol)
                            [fitrad, scalingReg(ss,:), raddist, dPSP{ll}(ss), Lvol(ss,:), outbox{ll}(ss)] = FitRad(RegA, "SameMinor", px, 1, sigma_psf, RF_method,check2, 1, 0);
                        else
                            [fitrad, ~, raddist, dPSP{ll}(ss), ~, outbox{ll}(ss)] = FitRad(RegA, "SameMinor", px, 1, sigma_psf, RF_method,0,dPSP{ll}(ss), outbox{ll}(ss), scalingReg(ss,:));
                        end

                        if dPSP{ll}(ss)<1
                            dPSP{ll}(ss)=1;
                        elseif dPSP{ll}(ss)>2
                            dPSP{ll}(ss)=2;
                        end
                        Allfitrad(ll,ss,icol,:) = fitrad;

                        %% Fit individual images
                        if fitraws
                            StackSave = load([savedir '\' col '_RawStack_' savename]).StackSave;

                            StackSort(:,:,:,end+1:end+size(StackSave,4)) = StackSave;

                            for cc = 1:size(StackSave,4)

                                if strcmp(col, pcaCol)
                                    [rawfitrad, rawscaling{ss}(cc,:), ~, ~, ~, outraw{ll}(min(seg)+cc-1)] = ...
                                        FitRad(StackSave(:,:,:,cc), "SameMinor", px, 1, sigma_psf, RF_method, 0, 0);
                                else
                                    [rawfitrad, ~, ~, ~] = ...
                                        FitRad(StackSave(:,:,:,cc), "SameMinor", px, 1, sigma_psf, RF_method, 0, dPSP{ll}(ss), outraw{ll}(min(seg)+cc-1), rawscaling{ss}(cc,:), 0.95);
                                end

                                % if size(Raw.FitRad,2) == 1
                                %     Raw = rmfield(Raw,'FitRad');
                                % end
                                if iscell(Raw)
                                    Raw{ss}.FitRad(cc,:) = rawfitrad;
                                else
                                    Raw.FitRad(cc,:) = rawfitrad;
                                end

                                Allrawrad{ll}(min(seg)+cc-1,:,icol) = rawfitrad;
                                dPSPraw{ll}(min(seg)+cc-1) = dPSP{ll}(ss);

                                if strcmp(col, pcaCol)
                                    Im = im2mat(erosion(dilation(StackSave(:,:,:,cc),10),10));
                                    ImVol{ll}(min(seg)+cc-1) = nnz(Im>=max(Im,[],'all')*Vthr)*px^3;
                                end

                                theta = 0:0.01:2*pi;

                            end
                        end

                        if savedata && fitraws
                            save([savedir '\' col '_RawParameters_' savename,'.mat',''], 'Raw','-v7');
                        end

                        legS{ss} = [num2str(seg(1)) ':' num2str(seg(end))];

                    else
                        AFR = squeeze(Allfitrad(ll,:,1,:));
                        scalingReg(ss,:) = AFR(ss,:)./AFR(ss,3);
                    end

                    %% Plot averaged images

                    if sliceplt

                        % Show registrations

                        theta = 0:0.01:2*pi;


                        x_1 = Allfitrad(ll,ss,icol,3)/px*sin(theta)+ceil(sz(1)/2);
                        y_1 = Allfitrad(ll,ss,icol,2)/px*cos(theta)+ceil(sz(1)/2);

                        x_2 = Allfitrad(ll,ss,icol,3)/px*sin(theta)+ceil(sz(1)/2);
                        y_2 = Allfitrad(ll,ss,icol,1)/px*cos(theta)+ceil(sz(1)/2);

                        x_3 = Allfitrad(ll,ss,icol,2)/px*cos(theta)+ceil(sz(1)/2);
                        y_3 = Allfitrad(ll,ss,icol,1)/px*sin(theta)+ceil(sz(1)/2);

                        figure()

                        % PC 2 and 3
                        subplot(1,3,1)
                        imsl = squeeze(sum(RegA(ceil(end/2)-1:ceil(end/2)+1,:,:),1));
                        imsl = imsl./max(imsl,[],'all');
                        imshow(imsl)
                        title([num2str(seg(1)) ':' num2str(seg(end)) ', ZX'])
                        if strcmp(binmet, 'AbCov')
                            tt = get(gca, 'Title').String;
                            title(sprintf([num2str(tt) '\n \\sigma=' num2str(round(varsegs(ss),1)) ':' num2str(round(varsegs(ss+1),1)) '%g\t']))
                        end
                        hold on

                        % PC 3 and 1
                        subplot(1,3,2)
                        imsl = squeeze(sum(RegA(:,ceil(end/2)-1:ceil(end/2)+1,:),2));
                        imsl = imsl./max(imsl,[],'all');
                        imshow(imsl)
                        title([num2str(seg(1)) ':' num2str(seg(end)) ', ZY'])
                        hold on

                        % PC 2 and 1
                        subplot(1,3,3)
                        imsl = squeeze(sum(RegA(:,:,ceil(end/2)-1:ceil(end/2)+1),3));
                        imsl = imsl./max(imsl,[],'all');
                        imshow(imsl)
                        title([num2str(seg(1)) ':' num2str(seg(end)) ', XY'])
                        hold on

                        if exist("outbox",'var') && outbox{ll}(ss)
                            sgtitle(['PSP outside box'])
                        end

                        if saveims
                            saveas(gcf,[savedir '\' binmet '_sort' char(sortmet) '_' col '_' num2str(Nsegs) 'segs_Seg' num2str(idss)])
                        end

                        subplot(1,3,1)
                        plot(x_1, y_1, 'red')

                        subplot(1,3,2)
                        plot(x_2, y_2, 'red')

                        subplot(1,3,3)
                        plot(x_3, y_3, 'red')

                        if saveims
                            saveas(gcf,[savedir '\' binmet '_sort' char(sortmet) '_' col '_' num2str(Nsegs) 'segs_Seg' num2str(idss) '_Fit'])
                        end

                        if saveims
                            close(gcf)
                        end

                    end

                    RegStack(:,:,:,ss) = RegA;

                end
                fprintf('\n');

            end

            if refit&&fitraws&&any(strcmp(fitcol, col))
                Alloutraw{ll} = cell2mat(outraw);
            end

            idx = 1:size(SegInts,2);

        end

    end

end
%% Save data
if savedata
    if refit && ~isempty(intersect(fitcol, Colours))
        if exist("VarAReg",'var')
            save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'VarAReg.mat',''], 'VarAReg','-v7');
        end
        if exist("VarBReg",'var')
            save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'VarBReg.mat',''], 'VarBReg','-v7');
        end
        save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'Allfitrad.mat',''], 'Allfitrad','-v7');

        DoubleOut = {dPSP,outbox};
        save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'DoubleOut.mat',''], 'DoubleOut','-v7');

        if fitraws
            save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'Allrawrad.mat',''], 'Allrawrad','-v7');
            save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'ImVol.mat',''], 'ImVol','-v7');
            DoubleOutraw = {dPSPraw;outraw};
            save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'DoubleOutraw.mat',''], 'DoubleOutraw','-v7');

        end
    end

    save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet '_SegInts','.mat',''], 'SegInts','-v7');

    SegVol = squeeze(SegVol(:,:,1,:));
    FR = squeeze(Allfitrad(:,idx,1,:));

    if length(size(FR))<3

        SegVol(:,3) = 4/3*pi*prod(FR,2);
    else
        SegVol(:,:,3) = 4/3*pi*prod(FR,3);
    end
    save([locDir '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet '_SegVol','.mat',''], 'SegVol','-v7');

end