﻿import numpy as np
import time
import matplotlib.pyplot as plt
from ALP4 import ALP4
from pylablib.devices import IMAQdx
from lens_generator import LensPatternGenerator
from data_analysis import gaussian_1d, plot_max_crop, merit_Q, fit_gaussian
from scipy.optimize import curve_fit

def _gaussian2d(coords, A, x0, y0, sx, sy, offset):
    x, y = coords
    return (A * np.exp(-(((x - x0) ** 2) / (2 * sx ** 2) + ((y - y0) ** 2) / (2 * sy ** 2))) + offset).ravel()

def fit_gaussian2d_and_fwhm(img, half):
    """Crop around the brightest pixel, fit axis-aligned 2D Gaussian, return peak & FWHM in px."""
    r0, c0 = np.unravel_index(np.argmax(img), img.shape)
    r1, r2 = max(r0 - half, 0), min(r0 + half + 1, img.shape[0])
    c1, c2 = max(c0 - half, 0), min(c0 + half + 1, img.shape[1])
    patch = img[r1:r2, c1:c2].astype(float)

    # mesh + inits
    yy, xx = np.mgrid[0:patch.shape[0], 0:patch.shape[1]]
    A0 = patch.max() - patch.min()
    idx_max = int(np.argmax(patch))
    x0 = float(idx_max % patch.shape[1])
    y0 = float(idx_max // patch.shape[1])
    sx0 = max(patch.shape[1] / 6.0, 1.0)
    sy0 = max(patch.shape[0] / 6.0, 1.0)
    off0 = patch.min()
    p0 = [A0, x0, y0, sx0, sy0, off0]

    try:
        popt, _ = curve_fit(_gaussian2d, (xx, yy), patch.ravel(), p0=p0, maxfev=20000)
        A, x0, y0, sx, sy, off = popt
        peak = float(A + off)
        fwhm_x = float(FWHM_K * abs(sx))
        fwhm_y = float(FWHM_K * abs(sy))
        ok = True
    except Exception:
        peak = float(patch.max())
        fwhm_x = float(FWHM_K * sx0)
        fwhm_y = float(FWHM_K * sy0)
        ok = False

    return {"peak": peak, "fwhm_x": fwhm_x, "fwhm_y": fwhm_y, "ok": ok}


def gaussian_peak_fit(x, y):
    x = np.asarray(x); y = np.asarray(y)
    med, std = np.median(y), np.std(y)
    mask = np.abs(y - med) < 3 * (std if std > 1e-12 else 1.0)
    x_fit, y_fit = x[mask], y[mask]
    if x_fit.size < 4:  # fallback guard
        i = int(np.argmax(y))
        return float(x[i]), (x.max()-x.min())/4, None, False

    A0 = y_fit.max() - y_fit.min()
    mu0 = x_fit[np.argmax(y_fit)]
    sigma0 = max((x_fit.max() - x_fit.min())/4, 1e-3)
    off0 = y_fit.min()
    p0 = [A0, mu0, sigma0, off0]

    try:
        popt, _ = curve_fit(gaussian_1d, x_fit, y_fit, p0=p0, maxfev=10000)
        A, mu, sig, off = popt
        return float(mu), float(abs(sig)), popt, True
    except Exception:
        return float(mu0), float(sigma0), None, False

def measure_amplitude_for_coeff(cur_coeffs, mode, c):
    cur_map = dict(cur_coeffs)
    cur_map[mode] = float(c)
    pattern = gen.generate_correction_pattern(cur_map, xcenter=xc, ycenter=yc)

    dmd.Halt()
    dmd.SeqPut(pattern[np.newaxis, :])
    dmd.Run()
    time.sleep(settle_cycles / fDMD)

    img = camera.snap()
    amp = float(fit_gaussian(img, half=half_crop))
    return amp
# Initialize pattern generator

fDMD = 250
gen = LensPatternGenerator()



# ---------------
# Camera Setup
# ---------------

ystep = 25
xstep = 25
   
yc_list = np.arange(75, 101, ystep)
xc_list = np.arange(-125, 201, xstep)

pk_array_wo_opt = np.zeros((len(yc_list), len(xc_list)))
pk_array_aft_opt = np.zeros((len(yc_list), len(xc_list)))
fwhmx_array_wo_opt = np.zeros((len(yc_list), len(xc_list)))
fwhmx_array_aft_opt = np.zeros((len(yc_list), len(xc_list)))
fwhmy_array_wo_opt = np.zeros((len(yc_list), len(xc_list)))
fwhmy_array_aft_opt = np.zeros((len(yc_list), len(xc_list)))

for i,yc in enumerate(yc_list):
    for j, xc in enumerate(xc_list):
        camera = IMAQdx.IMAQdxCamera('cam2')
        camera.setup_acquisition("snap")
        aberrations = [(2, 0), (2, 2), (2, -2), 
                       (3, 1), (3, -1), (3, 3), (3, -3), 
                       (4, 0), (4, 2), (4, -2),
                       (5, 1), (5, -1), (5, 3), (5, -3)]
        dmd = ALP4(version='4.3')
        dmd.Initialize()
        dmd.SeqAlloc(nbImg=1, bitDepth=8)
        picture_time = int(1 / fDMD * 1e6)
        dmd.SetTiming(pictureTime=picture_time)
        
        passes = 1              # <-- do two passes (first wide, second narrow). Set to >2 for more.
        initial_half_range = 2  # sweep c in [center - 2, center + 2] on pass 1
        points = 21             # samples per sweep
        shrink = 0.5            # halve the sweep half-range on each later pass
        min_half_range = 0.10   # don't shrink below this
        settle_cycles = 2       # wait this many DMD cycles before snapping
        tol = 0.02              # early-stop if max delta < tol after a pass
        half_crop = 10          # single source of truth for your crop
        
        aberration_coeff = {}         # running best values carried forward
        sigma_estimate = {}           # keep sigma per mode to set the next sweep width
        results_by_pass = []          # optional bookkeeping
        
        FWHM_K = 2.0 * np.sqrt(2.0 * np.log(2.0))  # ≈2.35482
        
        last_half_range = {mode: initial_half_range for mode in aberrations}
        
        for p in range(passes):
            this_pass = {"pass": p+1, "modes": {}}
            max_delta = 0.0
        
            for mode in aberrations:
                # sweep center and half-range
                center = aberration_coeff.get(mode, 0.0)
                if p == 0:
                    half_range = initial_half_range
                else:
                    prev_sigma = sigma_estimate.get(mode, initial_half_range/3)
                    prev_sigma = float(np.clip(prev_sigma, 1e-3, initial_half_range))
                    
                    candidate = shrink * 2.0 * prev_sigma
                    half_range = max(min_half_range, min(initial_half_range, last_half_range[mode], candidate))
                    
                    last_half_range[mode] = half_range
        
                sweep = np.linspace(center - half_range, center + half_range, points)
        
                amps = [measure_amplitude_for_coeff(aberration_coeff, mode, c) for c in sweep]
                mu_fit, sig_fit, popt, ok = gaussian_peak_fit(sweep, amps)
        
                # Clamp new mu into the sweep bounds (be conservative)
                # mu_fit = float(np.clip(mu_fit, sweep.min(), sweep.max()))
                old = aberration_coeff.get(mode, 0.0)
                aberration_coeff[mode] = mu_fit
                sigma_estimate[mode] = sig_fit
        
                max_delta = max(max_delta, abs(mu_fit - old))
        
                # (optional) plot quick diagnostic for this mode
                xx = np.linspace(sweep.min(), sweep.max(), 200)
                try:
                    yy = gaussian_1d(xx, *popt)
                except Exception:
                    # cubic fallback if fit failed
                    yy = np.polyval(np.polyfit(sweep, amps, 3), xx)
        
                plt.figure(figsize=(6,4))
                plt.plot(sweep, amps, 'o', alpha=0.4, label='samples')
                plt.plot(xx, yy, '-', label=f'fit μ={mu_fit:.3f}, σ≈{sig_fit:.3f}')
                plt.axvline(mu_fit, linestyle='--', alpha=0.6)
                plt.xlabel('Coefficient')
                plt.ylabel('Amplitude (crop)')
                n, m = mode
                plt.title(f'Pass {p+1} — mode (n={n}, m={m})')
                plt.legend(); plt.grid(True); plt.tight_layout(); plt.show()
        
                # --- Snapshot after optimizing this aberration ---
                dmd.Halt()
                pattern_now = gen.generate_correction_pattern(aberration_coeff, xcenter=xc, ycenter=yc)
                dmd.SeqPut(pattern_now[np.newaxis, :])
                dmd.Run()
                time.sleep(1.0 / fDMD)   # or a bit longer if needed
                
                img_now = camera.snap()
                met = fit_gaussian2d_and_fwhm(img_now, half=half_crop)   # uses your 'half' crop
             
                this_pass["modes"][mode] = {
                    "center_used": float(center),
                    "half_range": float(half_range),
                    "coeffs": sweep.tolist(),
                    "amp": amps,
                    "mu_fit": float(mu_fit),
                    "sigma_fit": float(sig_fit),
                    "peak_opt": met["peak"],
                    "fwhm_x_opt": met["fwhm_x"],
                    "fwhm_y_opt": met["fwhm_y"],
                }
                
                # print(f"After {mode}: peak={met['peak']:.1f}, FWHM_x={met['fwhm_x']:.2f}px, FWHM_y={met['fwhm_y']:.2f}px")    
        
            results_by_pass.append(this_pass)
            # early stop if nothing moved enough this pass
            if max_delta < tol:
                print(f"Early stop after pass {p+1}: max Δ = {max_delta:.4f} < tol {tol}")
                break
        
        # Save summaries (pickle dicts)
        np.save(f'xc{xc}_yc{yc}_passes_summary.npy', results_by_pass)
        np.save(f'xc{xc}_yc{yc}_final_coeffs.npy', aberration_coeff)
        
        # Apply final correction and compare
        dmd.Halt()
        final_pattern = gen.generate_correction_pattern(aberration_coeff, xcenter=xc, ycenter=yc)
        dmd.SeqPut(final_pattern[np.newaxis, :])
        dmd.Run()
        time.sleep(2.0 / fDMD)
        img_corr = camera.snap()
        [corrx, corry] = plot_max_crop(img_corr, half_crop, return_slice=True)
        met = fit_gaussian2d_and_fwhm(img_corr, half=half_crop) 
        print(f"({xc}, {yc}) - After optimisation: peak={met['peak']:.1f}, FWHM_x={met['fwhm_x']:.2f}px, FWHM_y={met['fwhm_y']:.2f}px") 
        np.save(f"xc{xc}_yc{yc}_corrected_cam.npy", img_corr)
        pk_array_wo_opt[i, j] = met['peak']
        fwhmy_array_wo_opt[i, j] = met['fwhm_y']
        fwhmx_array_wo_opt[i, j] = met['fwhm_x']
        
        dmd.Halt()
        ref_pattern = gen.generate_correction_pattern(xcenter=xc, ycenter=yc)
        dmd.SeqPut(ref_pattern[np.newaxis, :])
        dmd.Run()
        time.sleep(2.0 / fDMD)
        img_ref = camera.snap()
        met = fit_gaussian2d_and_fwhm(img_ref, half=half_crop) 
        print(f"({xc}, {yc}) - Before optimisation: peak={met['peak']:.1f}, FWHM_x={met['fwhm_x']:.2f}px, FWHM_y={met['fwhm_y']:.2f}px") 
        [refx, refy] = plot_max_crop(img_ref, half_crop, return_slice=True)
        np.save(f"xc{xc}_yc{yc}_reference_cam.npy", img_ref)
        
        plt.figure(); plt.plot(corrx, label='corrected x'); plt.plot(refx, label='reference x'); plt.legend(); plt.show()
        plt.figure(); plt.plot(corry, label='corrected y'); plt.plot(refy, label='reference y'); plt.legend(); plt.show()
        
        pk_array_aft_opt[i, j] = met['peak']
        fwhmx_array_aft_opt[i, j] = met['fwhm_x']
        fwhmy_array_aft_opt[i, j] = met['fwhm_y']

        dmd.Halt()
        dmd.FreeSeq()
        dmd.Free()
        camera.close()
        

np.save(f'Y{min(yc_list)}-{ystep}-{max(yc_list)}_X{min(xc_list)}-{xstep}-{max(xc_list)}_2Dpks_after_opt.npy', pk_array_aft_opt)        
np.save(f'Y{min(yc_list)}-{ystep}-{max(yc_list)}_X{min(xc_list)}-{xstep}-{max(xc_list)}_2Dpks_before_opt.npy', pk_array_wo_opt)        
np.save(f'Y{min(yc_list)}-{ystep}-{max(yc_list)}_X{min(xc_list)}-{xstep}-{max(xc_list)}_2Dfwhmx_after_opt.npy', fwhmx_array_aft_opt)        
np.save(f'Y{min(yc_list)}-{ystep}-{max(yc_list)}_X{min(xc_list)}-{xstep}-{max(xc_list)}_2Dfwhmx_before_opt.npy', fwhmx_array_wo_opt)        
np.save(f'Y{min(yc_list)}-{ystep}-{max(yc_list)}_X{min(xc_list)}-{xstep}-{max(xc_list)}_2Dfwhmy_after_opt.npy', fwhmy_array_aft_opt)        
np.save(f'Y{min(yc_list)}-{ystep}-{max(yc_list)}_X{min(xc_list)}-{xstep}-{max(xc_list)}_2Dfwhmy_before_opt.npy', fwhmy_array_wo_opt)   