# -*- coding: utf-8 -*-
"""
Created on Fri Aug 11 13:31:54 2023

@author: Abberior_Admin
"""
import numpy as np
import sys
import os
import scipy
import skimage.io
import pprint
import raster_geometry as rg
import diplib as dip
import time
import easygui
import matplotlib.pyplot as plt
from scipy.ndimage import correlate
from skimage.feature import match_template

def GetParams(im, ParOv, ParROIconf, ParROI):

    msr = im.active_measurement()           # Use existing active measurement
    msr.activate(msr.configuration("Overview"))      # Activate the overview configuration
    
    ParOv["xinit"] = im.value_at('ExpControl/scan/range/coarse_x/g_off').get()
    ParOv["yinit"] = im.value_at('ExpControl/scan/range/coarse_y/g_off').get()
    ParOv["zinit"] = im.value_at('ExpControl/scan/range/coarse_z/g_off').get()
    
    ParOv["xoff_g"] = im.value_at('ExpControl/scan/range/coarse_x/g_off').get()     # Get the global x position
    ParOv["yoff_g"] = im.value_at('ExpControl/scan/range/coarse_y/g_off').get()
    ParOv["zoff_g"] = im.value_at('ExpControl/scan/range/coarse_z/g_off').get()
    
    ParOv["px"] =       im.value_at('ExpControl/scan/range/x/psz').get()            # Get overview pixel size
    ParOv["Len"] =      im.value_at('ExpControl/scan/range/x/len').get()             # Get overview length
    ParOv["Depth"] =    im.value_at('ExpControl/scan/range/z/len').get()           # Get overview depth
    ParOv["scanmode"] = im.value_at('ExpControl/scan/range/scanmode').get()
    ParOv["Npx"] = np.array([ParOv["Depth"],ParOv["Len"],ParOv["Len"]])/ParOv["px"]
    ParOv["Ndims"] = len(ParOv["scanmode"])
    
    scanmode = list(ParOv["scanmode"])
    scanmode[scanmode.index("x")] = 2; 
    scanmode[scanmode.index("y")] = 1; 
    scanmode[scanmode.index("z")] = 0;
    ParOv["order"] = scanmode
    
    msr.activate(msr.configuration("ROIconf"))      # Activate the overview configuration
    
    ParROIconf["px"] =       im.value_at('ExpControl/scan/range/x/psz').get()            # Get ROI pixel size
    ParROIconf["Len"] =      im.value_at('ExpControl/scan/range/x/len').get()             # Get ROI length
    ParROIconf["Depth"] =    im.value_at('ExpControl/scan/range/z/len').get()           # Get ROI depth
    ParROIconf["Npx"] =     np.array([ParROIconf["Depth"],ParROIconf["Len"],ParROIconf["Len"]])/ParROIconf["px"]
    ParROIconf["scanmode"] = im.value_at('ExpControl/scan/range/scanmode').get()
    ParROIconf["Ndims"] = len(ParROIconf["scanmode"])
    
    msr.activate(msr.configuration("ROI"))      # Activate the overview configuration
    
    ParROI["xinit"] = im.value_at('ExpControl/scan/range/coarse_x/g_off').get()
    ParROI["yinit"] = im.value_at('ExpControl/scan/range/coarse_y/g_off').get()
    ParROI["zinit"] = im.value_at('ExpControl/scan/range/coarse_z/g_off').get()
    
    ParROI["xoff_g"] = im.value_at('ExpControl/scan/range/coarse_x/g_off').get()     # Get the global x position
    ParROI["yoff_g"] = im.value_at('ExpControl/scan/range/coarse_y/g_off').get()
    ParROI["zoff_g"] = im.value_at('ExpControl/scan/range/coarse_z/g_off').get()
    
    ParROI["px"] =       im.value_at('ExpControl/scan/range/x/psz').get()            # Get ROI pixel size
    ParROI["Len"] =      im.value_at('ExpControl/scan/range/x/len').get()             # Get ROI length
    ParROI["Depth"] =    im.value_at('ExpControl/scan/range/z/len').get()           # Get ROI depth
    ParROI["scanmode"] = im.value_at('ExpControl/scan/range/scanmode').get()
    
    ParROI["Ndims"] = len(ParROI["scanmode"])
        
    scanmode = list(ParROI["scanmode"])
    scanmode[scanmode.index("x")] = 2 - (3-ParROI["Ndims"]); 
    scanmode[scanmode.index("y")] = 1 - (3-ParROI["Ndims"]); 
    if len(scanmode) == 3:
        scanmode[scanmode.index("z")] = 0;
    ParROI["order"] = scanmode
    
    
    return ParOv, ParROIconf, ParROI

def PositionOverview(im, ParOv, oo):

    ParOv["xoff_g"] = ParOv["xinit"] + np.remainder(oo,np.ceil(np.sqrt(ParOv["Nov"]))) * ParOv["Len"]       # x offset overview image
    ParOv["yoff_g"] = ParOv["yinit"] + np.floor(oo/np.ceil(np.sqrt(ParOv["Nov"]))) * ParOv["Len"]      
    ParOv["zoff_g"] = ParOv["zinit"]    
    
    if abs(ParOv["xoff_g"]-ParOv["xinit"]) > 3e-3 or abs(ParOv["yoff_g"]-ParOv["yinit"]) > 5e-3:
        print("Stop it you're moving too far")
        sys.exit()
    else:
            
        im.value_at('ExpControl/scan/range/coarse_x/g_off').set(ParOv["xoff_g"])            # Set overview x position
        im.value_at('ExpControl/scan/range/coarse_y/g_off').set(ParOv["yoff_g"])            # Set overview y position
        im.value_at('ExpControl/scan/range/coarse_z/g_off').set(ParOv["zinit"])
        
    return ParOv

def makeKernel(ParIm, Par, OVimage):
    
    Par["sph"] = np.empty(np.size(ParIm["Rpsp"]), dtype=object)
    Par["kernelxy"] = np.empty(np.size(ParIm["Rpsp"]), dtype=object)
    Par["kernelz"] = np.empty(np.size(ParIm["Rpsp"]), dtype=object)
    
    for rr in enumerate(ParIm["Rpsp"]):
    
        radxy = np.ceil(rr[1]/Par['px']).astype('int')
        radz = np.ceil(rr[1]/Par['px']).astype('int')
        
        # kernelxy = np.ceil(radxy+ParIm["d_nb"]*2/Par["px"]).astype('int')
        # kernelxy += (1-kernelxy%2)
        # kernelz = np.ceil(radz+ParIm["d_nb"]*2/Par["px"]).astype('int')
        # kernelz += (1-kernelz%2)
        
        kernelxy = np.ceil((rr[1]+ParIm["d_nb"])*2/Par["px"]).astype('int')
        kernelxy += (1-kernelxy%2)
        kernelz = np.ceil((rr[1]+ParIm["d_nb"])*2/Par["px"]).astype('int')
        kernelz += (1-kernelz%2)
        
        sph = rg.ellipsoid([kernelz.item(),kernelxy.item(),kernelxy.item()],
                            [radz.item(),radxy.item(),radxy.item()]).astype('int')
        
        sigmaxy = ParIm["PSFconf_xy"]/Par["px"]
        sigmaz = ParIm["PSFconf_z"]/Par["px"]
        
        sph = scipy.ndimage.gaussian_filter(1000*sph, [sigmaz, sigmaxy, sigmaxy])/1000
        
        Par["sph"][rr[0]] = sph
        Par["kernelxy"][rr[0]] = kernelxy
        Par["kernelz"][rr[0]] = kernelz
    
    return Par

def getCentres(Image, ParIm, Par, thr):
    
    centres = []
    
    Ncc = np.zeros(np.shape(Image))
    
    for rr in enumerate(ParIm["Rpsp"]):
        ncc = match_template(Image, Par["sph"][rr[0]],pad_input=True)
        Ncc = np.maximum(Ncc, ncc)
        

    idx = np.column_stack(np.where(Ncc>thr))
    # idx = np.column_stack(np.where(Ncc>Par["thr"]))
    
    nc=1
    
    for ii in enumerate(idx):
        ind = (ii[1][0],ii[1][1],ii[1][2])
        
        if all((np.asarray(ind)+1)< (Par["Npx"] - ParIm["Nnb"])) and all((np.asarray(ind)+1)> -ParIm["Nnb"]):
            
            peak = Ncc[ind]

            offsets = np.arange(-ParIm["Nnb"], ParIm["Nnb"] + 1)
            inb = np.array(np.meshgrid(*[ind[i] + offsets for i in range(len(ind))])).T.reshape(-1, len(ind))

            # print(inb)

            nb = Ncc[inb[:,0],inb[:,1],inb[:,2]]

            if np.sum(nb>peak) == 0:
                # ind = [ind[0],ind[1],ind[2]] + np.round(np.divide([Par["kernelz"], Par["kernelxy"], Par["kernelxy"]],2))
                if nc==1:
                    centres = np.reshape(ind,(1,len(ind)))
                else:
                    centres = np.append(centres, np.reshape(ind,(1,len(ind))),axis=0)
                tozero = np.all(np.isin(idx, inb),1)
                nc +=1
    
    return centres



def FindROI(im, ParOv, ParIm, oo, ParROI):
    
    msr = im.active_measurement()
    msr.activate(msr.configuration("Overview"))      # Activate the overview configuration
    config = msr.active_configuration()     # Get the experimental configuration
    
    OVimage = np.squeeze(config.stack(0).data()[0])
    
    if ParIm["saveims"]: 
        skimage.io.imsave(ParIm["Folder"] + "/Overview/Overview"+str(oo)+".tiff", OVimage, check_contrast=False)
        with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_ExpVals.txt", "wt") as out: pprint.pprint(im.value_at('ExpControl').get(), stream=out); out.close()

    ParOv = makeKernel(ParIm, ParOv, OVimage)
    
    Ncc = np.zeros(np.shape(OVimage))
    
    for rr in enumerate(ParIm["Rpsp"]):
        ncc = match_template(OVimage, ParOv["sph"][rr[0]],pad_input=True)
        Ncc = np.maximum(Ncc, ncc)
    
    # idx = np.column_stack(np.where(Ncc>ParIm["thr"]))
    idx = np.column_stack(np.where(Ncc>ParOv["thr"]))

    
    if (np.shape(idx)[0])>0:

        centres = getCentres(OVimage, ParIm, ParOv, ParOv["thr"])
    
        Npsp = np.shape(centres)[0]      
    else:
        centres = np.array([])
        Npsp=0
        

    return centres, Npsp
    
        
def PositionROI(im, ParIm, ParROIconf, ParROI, ParOv, ii, oo):   

    msr = im.active_measurement()
    msr.activate(msr.configuration("ROIconf"))      # Activate the overview configuration
    config = msr.active_configuration()     # Get the experimental configuration
    order = ParOv["order"]
    
    ParROIconf["xoff"] = ii[1][order[0]]*ParOv["px"]-ParOv["Len"]/2
    ParROIconf["yoff"] = ii[1][order[1]]*ParOv["px"]-ParOv["Len"]/2
    ParROIconf["zoff"] = ii[1][order[2]]*ParOv["px"]-ParOv["Depth"]/2    


    im.value_at('ExpControl/scan/range/x/off').set(ParROIconf["xoff"])            
    im.value_at('ExpControl/scan/range/y/off').set(ParROIconf["yoff"])       
    
    empty = 0
    
    if np.abs(ParROIconf["zoff"]) > ParOv["Depth"]:
        zerror = 1 
        printtext = ("This ROI is too far from the overview plane")
        print(printtext)
        if ParIm["saveims"]: 
            with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()

    else:
        zerror=0
        im.value_at('ExpControl/scan/range/coarse_z/g_off').set(ParOv["zinit"]+ParROIconf["zoff"])  
        msr.activate(msr.configuration("ROIconf"))      # Activate the overview configuration
        config = msr.active_configuration()     # Get the experimental configuration    
        
        im.run(msr)
        ROIconf = np.squeeze(config.stack(0).data()[0])
        
        if ParIm["saveims"]: 
            skimage.io.imsave(ParIm["Folder"] + "/ROIconf/Overview"+"%03d"%(oo)+"ROI"+"%03d"%(ii[0])+".tiff", ROIconf, check_contrast=False)

        ParROIconf = makeKernel(ParIm, ParROIconf, ROIconf)

        CoM = np.squeeze(getCentres(ROIconf, ParIm, ParROIconf, ParOv["thr"]))     # Gives z, y, x

    
        if CoM.size<ParROIconf["Ndims"]:
            empty = 1
            printtext = ("            Empty ROI, not imaged")
            print(printtext)
            if ParIm["saveims"]: 
                with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()

        elif CoM.size>ParROIconf["Ndims"]:
            empty = 1
            printtext = ("            Multiple PSPs in ROI, not imaged")
            print(printtext)
            if ParIm["saveims"]: 
                with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()



        else:
            
            CoM = np.squeeze(getCentres(ROIconf, ParIm, ParROIconf, ParROIconf["thr"]))     # Gives z, y, x
            
            if CoM.size<ParROIconf["Ndims"]:
                empty = 1
                printtext = ("            Poor fit with template, not imaged")
                print(printtext)
                if ParIm["saveims"]: 
                    with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()
    
            elif CoM.size>ParROIconf["Ndims"]:
                empty = 1
                printtext = ("            Multiple PSPs in ROI, not imaged")
                print(printtext)
                if ParIm["saveims"]: 
                    with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()
            
            else:
                nb = int(np.min(ParIm["Rpsp"])/ParROIconf["px"])
                offsets = np.arange(-nb, nb + 1)
                
                inb = np.array(np.meshgrid(*[CoM[i] + offsets for i in range(len(CoM))])).T.reshape(-1, len(CoM))
                peak = np.mean(ROIconf[inb[:,0],inb[:,1],inb[:,2]])
                
                # if peak < 10:
                    
                #     empty = 1
                #     printtext = ("            Low intensity PSP, not imaged")
                #     print(printtext)
                #     if ParIm["saveims"]: 
                #         with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()
                    
                # else:
                    
                
                CoM = np.array([CoM[2], CoM[1], CoM[0]])
    
                printtext = ("            Mean PSP intensity: "+str(peak))
                print(printtext)
                if ParIm["saveims"]: 
                    with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()
    
                
                ParROI["xoff"] = ParROIconf["xoff"] + (CoM[0]*ParROIconf["px"] - ParROIconf["Len"]/2)
                ParROI["yoff"] = ParROIconf["yoff"] + (CoM[1]*ParROIconf["px"] - ParROIconf["Len"]/2)
                ParROI["zoff"] = ParROIconf["zoff"] + (CoM[2]*ParROIconf["px"] - ParROIconf["Depth"]/2)
                
                msr.activate(msr.configuration("ROI"))      # Activate the overview configuration
                config = msr.active_configuration()     # Get the experimental configuration
                
                if np.abs(ParROI["zoff"]) > ParOv["Depth"]:
                    zerror = 1
                    printtext = "            This ROI is too far from the overview plane"
                    print(printtext)
                    if ParIm["saveims"]: 
                        with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()
    
                else:
                    im.value_at('ExpControl/scan/range/x/off').set(ParROI["xoff"])            
                    im.value_at('ExpControl/scan/range/y/off').set(ParROI["yoff"]) 
                    im.value_at('ExpControl/scan/range/coarse_z/g_off').set(ParOv["zinit"]+ParROI["zoff"])
    
    # if im.value_at('ExpControl/scan/range/coarse_z/g_off').get() != ParOv["zinit"]:
    #     im.value_at('ExpControl/scan/range/coarse_z/g_off').set(ParOv["zinit"])
        
    return zerror, empty


def ImageROI(im, ParOv, ParROIconf, ParROI, ParIm, oo):
    
    msr = im.active_measurement()
    config = msr.active_configuration()
    OVimage = np.squeeze(config.stack(0).data()[0])
    order = ParOv["order"]
    
    centres, Npsp = FindROI(im, ParOv, ParIm, oo, ParROI)
    
    printtext = str(Npsp) + ' ROIs found at coordinates [z y x]:'
    
    print(printtext)
    print(centres)
    
    if ParIm["saveims"]: 
        with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "wt") as out: print(printtext, file=out); out.close()
        
        with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: np.savetxt(out, centres, delimiter=',', fmt='%d'); out.close()
    

    if Npsp > 0:
        
        for ii in enumerate(centres[:Npsp,:]):
            
            printtext = ('      ROI '+str(ii[0]))
            print(printtext)
            if ParIm["saveims"]: 
                with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()

            
            zerror, empty = PositionROI(im, ParIm, ParROIconf, ParROI, ParOv, ii, oo)
            tROI = time.time()
            
            if not zerror and not empty:
                
                msr.activate(msr.configuration("ROI"))      # Activate the overview configuration
                config = msr.active_configuration()     # Get the experimental configuration
    
                im.run(msr)
                
                elapROI = time.time() - tROI
                
                printtext = ('            Imaged in '+
                      str(int(np.floor(elapROI/60))) +' minutes, ' + str(round(elapROI%60)) + ' seconds')
                print(printtext)
                if ParIm["saveims"]: 
                    with open(ParIm["Folder"] + "/Overview/Overview"+"%03d"%(oo)+"_output.txt", "at") as out: print(printtext, file=out); out.close()

                
                # imred = np.squeeze(config.stack(0).data()[0])
                # plt.figure()
                # plt.hist(np.reshape(imred,[np.size(imred),]))
                # plt.yscale('log')
                # plt.title("ROI pixel intensity values")
                # plt.show()
    
                if ParIm["saveims"]: 
                    with open(ParIm["Folder"] + "/ExpVals/Overview"+"%03d"%(oo)+"ROI""%03d"%(ii[0])+".txt", "wt") as out: pprint.pprint(im.value_at('ExpControl').get(), stream=out); out.close()
                    ROIimage = np.squeeze(config.stack(0).data()[0])
                    skimage.io.imsave(ParIm["Folder"] + "/Red/Overview"+"%03d"%(oo)+"ROI"+"%03d"%(ii[0])+".tiff", ROIimage, check_contrast=False)
                    ROIimage = np.squeeze(config.stack(1).data()[0])
                    skimage.io.imsave(ParIm["Folder"] + "/Orange/Overview""%03d"%(oo)+"ROI"+"%03d"%(ii[0])+".tiff", ROIimage, check_contrast=False)

        
    