%{
Code to perform polarization analysis on rotationally and translationally
aligned 3D images of spherical to ellipsoidal particles

Requires DIPImage (https://diplib.org/), Parallel Computing Toolbox, Signal Processing Toolbox

Copyright Enya Berrevoets, TU Delft, 2024
Licensed under the Apache License, Version 2.0

%}
%%
tic

clear
addpath(genpath(pwd))
[message] = PSP_CodeCheck();
if ~isempty(message)
    return
end

%% User input

locDir ='Demo';      % Directory to store rotated images

headDir = 'Demo_data';
sellocs = 1;
rerun = 1;
save_data = 1;
savevid = 0;
saveStack = 0;
saveims = 0;

plot_ims = 0;
coeffplt = 0;
coeffisoplt = 1;        % WARNING: Setting this to 1 will give 3D renderings of ALL images, consuming significant working memory

getims = any([rerun, plot_ims, coeffisoplt, saveStack]);


cols = ["magenta","green"];
cut = 'under';      % Look at particles 'under' or 'above' elongation threshold (scaling_cutoff)

% Image params
pcaCol = 'Sum';        % Perform PCA on 'Orange'/'Red'/'Both'
px = 30;            % pixel size [nm]
Colours = ["Orange", "Red"];
shiftCol = "Sum";
Nits = 2;
Nsegs = 10;
init = 'Sum';   % Initial registration ["S1", "Sum", "Empty"]
dslice = 100/px;
binmet = 'AbCov';
sortmet = "Length";
res = 0;      % Resolution [nm]
OnlyIn = 1;

% Spherical Harmonics params
sz =    [37,37,37];
Rmax =  floor(sz/2)+0.5*(1-rem(sz,2));
Nsecs =         ceil((2*pi*360/100)^2);
Apsp =          4*pi;               % Surface area of PSP
Nrads =         6;      % Number of radial sections for spherical coordinate grid for volume to surface projection
proj_met =      'cones';     % Method to project image volume on spherical surface ('cones'= cone method used in paper; 'no_overlap'=alternative method to sample volume without gaps/overlaps)                          ')

LL = ceil(sqrt(Nsecs));
poldeg = 1;
% scaling_cutoff = 1.3;
scaling_cutoff = 100;

all_theta = acos(1 - 2*(1:Nsecs)./Nsecs);   % Compute angles for each Fibonacci number
all_phi = rem(pi * (1 + sqrt(5)) * (1:Nsecs), 2*pi);

%% Plot axes
marcol = parula(Nsegs);
ax = [1,sz(1)]; ay = [1,sz(2)]; az = [1,sz(3)];

xtcks = (ax(1):12:ax(2)); ytcks = (ay(1):12:ay(2)); ztcks = (az(1):12:az(2));

ax_leg= cellfun(@num2str, num2cell((xtcks-median(xtcks))*px), 'UniformOutput', false);
ay_leg= cellfun(@num2str, num2cell((ytcks-median(ytcks))*px), 'UniformOutput', false);
az_leg= cellfun(@num2str, num2cell((ztcks-median(ztcks))*px), 'UniformOutput', false);

%% Load variables

cellVarA = load([locDir  '\' headDir '\PCA_' pcaCol '\Orange_cellvars.mat']).cellVarA;
cellIntA = load([locDir  '\' headDir '\PCA_' pcaCol '\Orange_cellints.mat']).cellIntA;

cellVarB = load([locDir  '\' headDir '\PCA_' pcaCol '\Red_cellvars.mat']).cellVarB;
cellIntB = load([locDir  '\' headDir '\PCA_' pcaCol '\Red_cellints.mat']).cellIntB;

cellVarSum = load([locDir  '\' headDir '\PCA_' pcaCol '\Sum_cellvars.mat']).cellVarSum;

cellRvol = load([locDir  '\' headDir '\PCA_' pcaCol '\cellRvol.mat']').cellRvol;
cellVol = load([locDir  '\' headDir '\PCA_' pcaCol '\cellVol.mat']').cellVol;

SegVol = load([locDir  '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet '_SegVol.mat']).SegVol;
DoubleOutraw = load([locDir  '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'DoubleOutraw.mat']).DoubleOutraw;


%% Load loc variables

[locs, Nlocs] = PSP_getLocs([locDir  '\' headDir]);
ll = sellocs;

outraw = DoubleOutraw{2}{ll};

Allrawrad = load([locDir  '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'Allrawrad.mat']).Allrawrad{ll};
Allfitrad = squeeze(load([locDir  '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\Sort' char(sortmet) '_' binmet 'Allfitrad.mat']).Allfitrad);
if Nlocs==1
    Allrawrad = Allrawrad(:,:,1);
    Allfitrad = squeeze(Allfitrad(:,1,:));
else
    Allrawrad = Allrawrad(:,:,1);
    Allfitrad = squeeze(Allfitrad(ll,:,1,:));
end

if ~rerun
    savedir = [locDir  '\' headDir '\ParticlePolarisation\pca' pcaCol '\' num2str(Nsegs) 'segs\' locs(ll).name];
    savename = [char(sortmet)  '_' binmet];
    CoeffsAll = load([savedir '\' savename '_CoeffsAll.mat']).CoeffsAll;
    ExpanAll = load([savedir '\' savename '_ExpanAll.mat']).ExpanAll;
else
    CoeffsAll = cell(Nsegs,1);
    ExpanAll = cell(Nsegs,1);
end

%%
if save_data
    if ~exist([locDir  '\' headDir '\ParticlePolarisation'],'dir')      % Make folder to save reference segmentations in
        mkdir([locDir  '\' headDir '\ParticlePolarisation']);
    end
end
%% Initialise segs

[varssort, varsegs, segs, idx_vars] = PSP_VarSort(sortmet, Nsegs, [],[],[],cellRvol{ll},[]);

szStack = load([locDir  '\' headDir '\' locs(ll).name  '\Stacksize.mat']).sz;

if strcmp(binmet,'FixN')
    segs = linspace(0,szStack(end),Nsegs+1);
end

sidx = 0;

vecsall = [];
dcomall = [];
allvnorm = [];

noSeg = ones(Nsegs,1);

if rerun
    SH_all = SphericalHarmonics(LL, all_theta, all_phi);
end

msg = fprintf('Probe %d/%d, %.2f %%', ll, Nlocs, 0);
for ss = 1:Nsegs

    if ~getims

        fprintf(repmat('\b',1,msg))
        msg = fprintf('Probe %d/%d, %.2f %%', ll, Nlocs, round(ss/Nsegs*100,2));
    end

    seg = floor(segs(ss)+1):floor(segs(ss+1));
    Nims = length(seg);

    if Nims==0
        noSeg(ss) = 0;
        continue
    end

    sidx = sidx(end)+1:sidx(end)+Nims;


    if rerun || plot_ims || coeffisoplt || saveStack || getims
        Regs = zeros([sz,2]);
        CStackSave = cell(2,1);
        for icol = 1:2
            col = char(Colours(icol));
            savedir = [locDir  '\' headDir '\ParticleRegistration\pca' pcaCol '\' num2str(Nsegs) 'segs\' locs(ll).name '\' col];
            savename = [binmet 'Init' char(init) '_Im' num2str(seg(1)) 'to' num2str(seg(end)) '_' num2str(Nits) 'its'];
            CStackSave{icol} = load([savedir '\' col '_RawStack_' savename '.mat']).StackSave;
            Regs(:,:,:,icol) = sum(CStackSave{icol},4);

        end

        segrads = Allrawrad(seg,:);
        AllRegs(:,:,:,ss) = sum(Regs,4);

        Allscaling(seg,:) = segrads./segrads(:,3);

        for ii = 1:Nims

            fprintf(repmat('\b',1,msg))
            msg = fprintf('Probe %d/%d, %.2f %%', ll, Nlocs, 100*seg(ii)/(segs(end)-1));

            scaling = segrads(ii,:)./segrads(ii,3);

            Im1 = CStackSave{1}(:,:,:,ii);
            Im2 = CStackSave{2}(:,:,:,ii);

            Stack{1}(:,:,:,sidx(ii)) = Im1;
            Stack{2}(:,:,:,sidx(ii)) = Im2;

            Amax = mean(maxk(Im1(:),round(prod(sz)/100)));
            Bmax = mean(maxk(Im2(:),round(prod(sz)/100)));

            ImCombi = Im1/Amax - Im2/Bmax;

            [cone_sum, dTh, dPh] = sph2surf(ImCombi, proj_met, Nsecs, Apsp, Nrads, scaling);

            cone_sum(:,1) = cone_sum(:,1) - mean(cone_sum(:,1));
            cone_sum(:,1) = cone_sum(:,1)/std(cone_sum(:,1));

            if rerun

                coeff = zeros(sum((0:(LL))*2+1),3);
                expan = zeros(size(cone_sum,1),1);

                for lsph = 0:LL
                    for msph = -lsph:lsph

                        count = sum((0:(lsph-1))*2+1) + msph + lsph + 1;

                        SH = SH_all(:,count);
                        coeff(count,:) = [sum(conj(SH(:,1)).*cone_sum(:,1).*sin(cone_sum(:,2)).*dTh'.*dPh'),...
                            lsph,msph];

                        expan = expan + coeff(count,1).*SH;

                    end
                end

                CoeffsAll{ss}(ii,:,:) = squeeze(coeff);
                ExpanAll{ss}(ii,:) = expan;
            else
                coeff = squeeze(CoeffsAll{ss}(ii,:,:));
                expan = squeeze(ExpanAll{ss}(ii,:));
            end

            if plot_ims

                ImagesSph(:,:,:,1) = ellip2sph(Im1, 0, 'SameVol', sz/2, scaling);
                ImagesSph(:,:,:,2) = ellip2sph(Im2, 0, 'SameVol', sz/2, scaling);

                Images(:,:,:,1) = Im1;
                Images(:,:,:,2) = Im2;

                plot_Pol(cone_sum, expan, ii, Images, ImagesSph, coeff, px);

            end

        end

        SCoeffs(sidx,:,:) = CoeffsAll{ss};

    end


    %% Polarisation degree

    degs = CoeffsAll{ss}(1,:,2);

    aps = (abs(CoeffsAll{ss}(:,:,1))).^2;                      % Angular power spectrum
    aps1 = (abs(CoeffsAll{ss}(:,degs==poldeg,1))).^2;

    polrel(seg) = sum(aps1,2)./sum(aps,2);                           % Normalised contribution of first degree to angular power spectrum

    if coeffplt
        coeffhist = figure();
        histogram(polrel(seg),0:0.01:0.5)
        ylim([0,8])
        ylabel('Degree of polarization')
        switch binmet
            case 'FixN'
                title(['Bin ' num2str(ss) '/' num2str(Nsegs)])
            case 'AbCov'
                title(sprintf(['Bin ' num2str(ss) '/' num2str(Nsegs)  ...
                    '\nCovariance ' num2str(round(varsegs(ss),1)) ':' num2str(round(varsegs(ss+1),1))]))
        end

        frames(ss) = getframe(gcf);
        close(coeffhist)
    end

    [polrelsort,degims] = sort(polrel(seg));

    %% Polarisation vector

    coa = CoeffsAll{ss}(:,2,1);
    cob = CoeffsAll{ss}(:,3,1);
    coc = CoeffsAll{ss}(:,4,1);

    vecs = [coa-coc, -1i*(coa+coc), sqrt(2)*cob];
    allvnorm(end+1:end+Nims,:) = vecnorm(vecs,2,2);
    vs = car2sph(vecs);


    %%

end
fprintf('\n');

%% Save data

if save_data

    savedir = [locDir  '\' headDir '\ParticlePolarisation\pca' pcaCol '\' num2str(Nsegs) 'segs\' locs(ll).name];
    savename = [char(sortmet)  '_' binmet];
    if ~exist(savedir,'dir')      % Make folder to save reference segmentations in
        mkdir(savedir);
    end
    save([savedir '\' savename '_CoeffsAll','.mat',''], 'CoeffsAll');
    save([savedir '\' savename '_ExpanAll','.mat',''], 'ExpanAll');
end


%% Get spherical or elongated particles

SCoeffs = cell2mat(CoeffsAll);

if strcmp(cut, 'under')
    idxSph = Allrawrad(:,1)./Allrawrad(:,3)<=scaling_cutoff;        % Get indices of all images with ellipticity below cut-off (spherical)
elseif strcmp(cut, 'above')
    idxSph = Allrawrad(:,1)./Allrawrad(:,3)>scaling_cutoff;        % Get indices of all images with ellipticity above cut-off
end

scalingSph = Allrawrad(idxSph,1)./Allrawrad(idxSph,3);

CoeffsSph = SCoeffs(idxSph,:,:);                                % Coefficients of spherical images

degs = SCoeffs(1,:,2);                                          % Expansion degrees

aps = (abs(SCoeffs(idxSph,:,1))).^2;                      % Angular power spectrum
aps1 = (abs(SCoeffs(idxSph,degs==poldeg,1))).^2;

polrel = sum(aps1,2)./sum(aps,2);                           % Normalised contribution of first degree to angular power spectrum

[polrel,polsort] = sort(polrel);
CoeffsSph = CoeffsSph(polsort,:,:);

scalingSph = scalingSph(polsort);


%% Polarisation vector

coa = CoeffsSph(:,2,1);
cob = CoeffsSph(:,3,1);
coc = CoeffsSph(:,4,1);

vecs = [coa-coc, -1i*(coa+coc), sqrt(2)*cob];
vs = car2sph(vecs);

vecseb = vecs(:,[3,1,2]);                   % Minor, middle, major
vecseb = abs(vecseb./vecnorm(vecseb,2,2));

polangs = [acos(vecseb*[1;0;0]), acos(vecseb*[0;1;0]), acos(vecseb*[0;0;1])]/pi*180; % % Minor, middle, major

%% Make isoplots
if coeffisoplt

    StackSph{1} = Stack{1}(:,:,:,idxSph);                           % Orange spherical images
    StackSph{2} = Stack{2}(:,:,:,idxSph);                           % Red spherical images
    
    StackSort{1} = StackSph{1}(:,:,:,polsort);
    StackSort{2} = StackSph{2}(:,:,:,polsort);
    
    cc=0;
    figure();
    for ii = [1:10, length(polrel)-9:length(polrel)]

        vec = vecs(ii,:)./vecnorm(vecs(ii,:));
        vpre = [0,1,0];
        vcross = cross(vpre,vec);
        sintheta = norm(vcross);
        costheta = dot(vpre, vec);

        vx = [0, -vcross(3), vcross(2);...
            vcross(3), 0, -vcross(1);...
            -vcross(2), vcross(1), 0];

        rotmat = eye(3) + vx + vx^2 * ((1 - costheta) / (sintheta^2));

        cc = cc+1;

        for dd = 1:2
            ImCurr = StackSort{dd}(:,:,:,ii);
            ImCurr = im2mat(rotation3d(ImCurr,'direct',rotmat, ceil(sz/2)-1));
            ImCurr = ImCurr./max(ImCurr,[],'all');

            subplot(4,5, cc)

            [~,isort] = sort(ImCurr(:));

            isovals = linspace(0.97, 0.999,5);


            for iv = isovals
                isoval = ImCurr(isort(ceil(iv*end)));

                isoplot(ImCurr, isoval, cols(dd), 1/numel(isovals))

            end
            axis equal
            set(gca, 'GridColor', 'white', 'GridAlpha',0.4)
            hold on

            xlim(ax)
            ylim(ay)
            zlim(az)
            hold on

            set(gca,'ZTick',ztcks)
            set(gca,'ZTickLabel',az_leg)

            set(gca,'YTick',ytcks)
            set(gca,'YTickLabel',ay_leg)

            set(gca,'XTick',xtcks)
            set(gca,'XTickLabel',ax_leg)

            xlabel('[nm]')
            ylabel('[nm]')
            zlabel('[nm]')

        end
        title([num2str(round(polrel(ii)*100)) '%'])

    end
end

%% Save output

if save_data

    savedir = [locDir  '\' headDir '\ParticlePolarisation\pca' pcaCol '\' num2str(Nsegs) 'segs\' locs(ll).name];
    if ~exist(savedir,'dir')      % Make folder to save reference segmentations in
        mkdir(savedir);
    end

    save([savedir '\AllCoeffs_' cut num2str(scaling_cutoff*10),'.mat',''], 'CoeffsSph');
    save([savedir '\Polarisation_' cut num2str(scaling_cutoff*10),'.mat',''], 'polrel');
    save([savedir '\PolSorting_' cut num2str(scaling_cutoff*10),'.mat',''], 'polsort');
    save([savedir '\PolAngle_' cut num2str(scaling_cutoff*10),'.mat',''], 'polangs');

    if exist('Stack','var')

        StackSph{1} = Stack{1}(:,:,:,idxSph);                           % Orange spherical images
        StackSph{2} = Stack{2}(:,:,:,idxSph);                           % Red spherical images

        StackSort{1} = StackSph{1}(:,:,:,polsort);
        StackSort{2} = StackSph{2}(:,:,:,polsort);

        if saveStack
            save([savedir '\StackSort_' cut num2str(scaling_cutoff*10),'.mat',''], 'StackSort');
        end
    end

end

toc