%{
Code to simulate spherical two-component particles and perform polarization
analysis

Requires DIPImage (https://diplib.org/)

Copyright Enya Berrevoets, TU Delft, 2025
Licensed under the Apache License, Version 2.0

%}
%%

regen = 1;
reblur = 1;
reSH = 1;

plotims = 0;        % Careful, this plots ALL Nims and WILL crash matlab if Nims is large
plotiso = 1;
save_pngs = 0;


% Image parameters
Nims = 1000;		    % Number of images
sz = [37, 37, 37];      % Image size [px]
px = 30;        	    % Pixel size[nm]
Np = 30;		        % Number of Neat1_2 per PSP
Np_var = 5;		        % Variance number of Neat1_2 per PSP
rmin = 180;  		    % PSP radius [px]
Colours = ["Orange", "Red"];
sig = 1;
res = 110;   		    % Standard deviation Gaussian filter [nm]
d53 = 86;               % Distance between 5' and 3' end [nm]

shape = 'random';       % Neat1_2 shape ['V', 'random']

L_range = [1.3,2];       % Elongation range ([1,1] = all particles spherical)

% Polarisation analysis parameters
Nsecs =         512;		% Number of surface sections PSP
Apsp =          4*pi;               % Surface area of PSP
Ndegs = ceil(sqrt(Nsecs));	% Number of degrees spherical harmonics expansion
NSH = sum((0:Ndegs)*2+1);	% Number of spherical harmonics in expansion
poldeg = 1;			        % Degree used to assess degree of polarisation
gr = (1+sqrt(5))/2; 		% Golden ratio

plotcols = ["magenta","green"];

PolQ = 'Power spectrum';    % Method used to calculate degree of polarisation ['Power spectrum', 'Angular power spectrum'];
% PolQ = 'Angular power spectrum';

%% Plot axes
ax = [1,sz(1)]; ay = [1,sz(2)]; az = [1,sz(3)];

xtcks = (ax(1):12:ax(2)); ytcks = (ay(1):12:ay(2)); ztcks = (az(1):12:az(2));

ax_leg= cellfun(@num2str, num2cell((xtcks-median(xtcks))*px), 'UniformOutput', false);
ay_leg= cellfun(@num2str, num2cell((ytcks-median(ytcks))*px), 'UniformOutput', false);
az_leg= cellfun(@num2str, num2cell((ztcks-median(ztcks))*px), 'UniformOutput', false);
%% Image generation

if any([regen, reblur])

    [xx,yy,zz] = ndgrid(-floor(sz(1)/2):floor(sz(1)/2),-floor(sz(2)/2):floor(sz(2)/2),-floor(sz(3)/2):floor(sz(3)/2));

    Rlong = (L_range(2)-L_range(1)) * rand + L_range(1);
    coords = sqrt((xx/Rlong).^2+yy.^2+zz.^2);
    idxRing = find(exp(-((coords - (rmin/px))/sig).^4)>0.95);
    [xxRing, yyRing, zzRing] = ind2sub(sz, idxRing);

    CStack = cell(2,1);


    CStack{1} = zeros([sz, Nims]);
    CStack{2} = zeros([sz, Nims]);

    for ii = 1:Nims

        if regen
            Image5 = zeros(sz);
            idxNeat5 = idxRing(ceil(length(idxRing)*rand(max(1,round(Np_var*randn+Np)),1)));
            Image5(idxNeat5) = 1;

            CStackRaw{1}(:,:,:,ii) = Image5;

            Image3 = zeros(sz);

            switch shape
                case 'V'

                    [x1, y1, z1] = ind2sub(sz, idxNeat5);
                    for pp = 1:size(idxNeat5,1)

                        dd = sqrt((x1(pp) - xxRing).^2 + (y1(pp) - yyRing).^2 + (z1(pp) - zzRing).^2);
                        % idxNeat3 = idxRing(dd<=(d53/px*1.5));
                        idxNeat3 = idxRing(logical((dd<=(d53/px*1.5)).*(dd>(d53/px*0.5))));
                        Image3(idxNeat3(randi(size(idxNeat3,1)))) = 1;

                    end

                    CStackRaw{2}(:,:,:,ii) = Image3;
                case 'random'
                    idxNeat3 = idxRing(ceil(length(idxRing)*rand(max(1,round(Np_var*randn+Np)),1)));
                    Image3(idxNeat3) = 1;

                    CStackRaw{2}(:,:,:,ii) = Image3;
            end
        end

        for icol = 1:2
            Image = CStackRaw{icol}(:,:,:,ii);
            % Image = im2mat(gaussf(Image, res/px/sqrt(2)));
            Image = imgaussfilt3(Image, res/px/sqrt(2));
            CStack{icol}(:,:,:,ii) = Image./max(Image,[],'all');
        end

    end
end
%% SH expansion
if reSH
    tic
    all_theta = acos(1 - 2*(1:Nsecs)./Nsecs);   % Compute angles for each Fibonacci number
    all_phi = rem(2*pi/gr * (1:Nsecs), 2*pi);

    dTh = diff(all_theta);
    dTh = [dTh(end) dTh];
    dPh = diff(all_phi)+(diff(all_phi)<0)*2*pi;
    dPh = [dPh(1) dPh];

    SH_all = SphericalHarmonics(Ndegs, all_theta, all_phi);

    %%
    CoeffsAll = zeros(Nims, NSH, 3);
    ExpanAll = zeros(Nims,Nsecs);

    msg = fprintf('Image %d of %d', 0, Nims);  % Initial message
    for ii = 1:Nims

        fprintf(repmat('\b', 1, msg));  % Delete the previous message by backspacing
        msg = fprintf('Image %d of %d', ii, Nims);  % Print the updated message

        Im1 = CStack{1}(:,:,:,ii);
        Im2 = CStack{2}(:,:,:,ii);

        Amax = mean(maxk(Im1(:),round(prod(sz)/100)));
        Bmax = mean(maxk(Im2(:),round(prod(sz)/100)));

        ImCombi = Im1/Amax - Im2/Bmax;

        [cone_sum, voxels, dTh, dPh, cs_old] = sph2surf(ImCombi, Nsecs, Apsp);

        cone_sum(:,1) = cone_sum(:,1) - mean(cone_sum(:,1));
        cone_sum(:,1) = cone_sum(:,1)/std(cone_sum(:,1));

        cone_sum(:,1) = cone_sum(:,1)./sum(sqrt(cone_sum(:,1).^2));

        coeff = zeros(sum(1:Ndegs+1),3);
        count = 1;
        expan = zeros(size(cone_sum,1),1);

        for lsph = 0:Ndegs
            for msph = -lsph:lsph

                count = sum((0:(lsph-1))*2+1) + msph + lsph + 1;

                SH = SH_all(:,count);
                coeff(count,:) = [sum(conj(SH).*cone_sum(:,1).*sin(cone_sum(:,2)).*dTh'.*dPh'),...
                    lsph,msph];
                expan = expan + coeff(count,1).*SH;

            end
        end

        CoeffsAll(ii,:,:) = squeeze(coeff);
        ExpanAll(ii,:) = expan;

        if plotims
            ImagesSph(:,:,:,1) = Im1;
            ImagesSph(:,:,:,2) = Im2;

            Images(:,:,:,1) = Im1;
            Images(:,:,:,2) = Im2;

            plot_Pol(cone_sum, expan, ii, Images, ImagesSph, coeff, px);
        end
    end
    fprintf('\n');
    toc

end
%%
SHsq = SH_all.*conj(SH_all).*repmat(sin(all_theta)',1,NSH).*repmat(dTh',1,NSH).*repmat(dPh',1,NSH)*gr;
ONcheck = sum(SHsq,1);
%% Polarisation analysis

degnorm = 1./(2*CoeffsAll(:,:,2)+1);

degs = CoeffsAll(1,:,2);                                          % Expansion degrees

if strcmp(PolQ,'Power spectrum')

    modsq = (abs(CoeffsAll(:,:,1))).^2;                          % Modulus squared of all coefficients for spherical images
    modsq1 = (abs(CoeffsAll(:,degs==poldeg,1))).^2;              % Modulus squared of coefficients of selected degree (poldeg) for spherical images

    polrel = sum(modsq1,2)./sum(modsq,2);                           % Relative weight of selected degree (poldeg) for spherical images
elseif strcmp(PolQ,'Angular power spectrum')

    aps = degnorm.*(abs(CoeffsAll(:,:,1))).^2;                      % Angular power spectrum
    aps1 = degnorm(:,degs==poldeg).*(abs(CoeffsAll(:,degs==poldeg,1))).^2;

    polrel = sum(aps1,2)./sum(aps,2);
end


[polrel,polsort] = sort(polrel);
CoeffsSort = CoeffsAll(polsort,:,:);

coa = CoeffsSort(:,2,1);
cob = CoeffsSort(:,3,1);
coc = CoeffsSort(:,4,1);

vecs = [coa-coc, -1i*(coa+coc), sqrt(2)*cob];
vs = car2sph(vecs);
vecsph = (vs(:,2:3) - [pi/2,pi])/pi*180;

%% Plotting

figure();

histogram(polrel*100, linspace(0,100,20))
grid minor
title(sprintf(['Polarisation \nResolution ' num2str(res) 'nm, ' num2str(Nsecs) ' cones \n' PolQ]))
ylabel('Number of images')
xlabel('Polarisation [%]')

%%
if plotiso
    StackSort{1} = CStack{1}(:,:,:,polsort);
    StackSort{2} = CStack{2}(:,:,:,polsort);

    cc=0;
    figure();
    % for ii = [1:10, length(polrel)-9:length(polrel)]
    for ii = [1:5, ceil((Nims/2-2):(Nims/2+2)), Nims-4:Nims]
        % for ii = 1:min(20,Nims)
        % for ii = [length(polrel)-19:length(polrel)]

        vec = vecs(ii,:)./vecnorm(vecs(ii,:));
        vpre = [0,1,0];
        vcross = cross(vpre,vec);
        sintheta = norm(vcross);
        costheta = dot(vpre, vec);

        vx = [0, -vcross(3), vcross(2);...
            vcross(3), 0, -vcross(1);...
            -vcross(2), vcross(1), 0];

        rotmat = eye(3) + vx + vx^2 * ((1 - costheta) / (sintheta^2));

        cc = cc+1;

        for dd = 1:2

            ImCurr = StackSort{dd}(:,:,:,ii);

            ImCurr = im2mat(rotation3d(ImCurr,'direct',rotmat, ceil(sz/2)-1));
            ImCurr = ImCurr./max(ImCurr,[],'all');

            sp_idx = (rem(cc-1,5))*3+floor((cc-1)/5)+1;
            subplot(5,3, sp_idx)
            [~,isort] = sort(ImCurr(:));

            isovals = linspace(0.97, 0.999,5);


            for iv = isovals
                isoval = ImCurr(isort(ceil(iv*end)));

                isoplot(ImCurr, isoval, plotcols(dd), 1/numel(isovals))

            end
            axis equal
            set(gca, 'GridColor', 'white', 'GridAlpha',0.4)
            hold on

            xlim(ax)
            ylim(ay)
            zlim(az)
            hold on

            set(gca,'ZTick',ztcks)
            set(gca,'ZTickLabel',az_leg)

            set(gca,'YTick',ytcks)
            set(gca,'YTickLabel',ay_leg)

            set(gca,'XTick',xtcks)
            set(gca,'XTickLabel',ax_leg)

            % xlabel('[nm]')
            % ylabel('[nm]')
            % zlabel('[nm]')

            % set(gca, 'ZTick',[])
            % set(gca, 'XTick',[])
            % set(gca, 'YTick',[])

        end
        title([num2str(round(polrel(ii)*100)) '%'])

    end

end
%%
if save_pngs

    savedir = 'Sim_PNGs';
    if ~exist(savedir,'dir')      % Make folder to save reference segmentations in
        mkdir(savedir);
    end


    for ii = 1:length(polrel)

        fig = figure();
        for dd = 1:2

            vec = vecs(ii,:)./vecnorm(vecs(ii,:));
            vpre = [0,1,0];
            vcross = cross(vpre,vec);
            sintheta = norm(vcross);
            costheta = dot(vpre, vec);

            vx = [0, -vcross(3), vcross(2);...
                vcross(3), 0, -vcross(1);...
                -vcross(2), vcross(1), 0];

            rotmat = eye(3) + vx + vx^2 * ((1 - costheta) / (sintheta^2));

            Im = StackSort{dd}(:,:,:,ii);
            Im = im2mat(rotation3d(Im,'direct',rotmat, ceil(sz/2)-1));

            [~,isort] = sort(Im(:));
            isovals = linspace(0.960, 0.999,5);
            for iv = isovals
                isoval = Im(isort(ceil(iv*end)));
                isoplot(Im, isoval, plotcols(dd), 1/numel(isovals))
            end
            axis equal
            set(gca, 'GridColor', 'white', 'GridAlpha',0.4)
            hold on

            xlim(ax)
            ylim(ay)
            zlim(az)

            set(gca,'XTick',xtcks)
            set(gca,'XTickLabel',ax_leg)

            set(gca,'ZTick',ztcks)
            set(gca,'ZTickLabel',az_leg)

            set(gca,'YTick',ytcks)
            set(gca,'YTickLabel',ay_leg)

            xlabel('[nm]')
            ylabel('[nm]')
            zlabel('[nm]')

        end

        title([num2str(round(polrel(ii)*100)) '%'])

        set(gcf, 'InvertHardcopy', 'off')

        saveas(gcf,[savedir '\Sim_Aligned' num2str(ii) '_' num2str(round(polrel(ii)*100)) '.png'])

        close(fig)
    end

    save([savedir '\polrel','.',''], 'polrel');
end