import os
import numpy as np
from scipy.special import factorial
from joblib import Parallel, delayed, parallel_backend

class LensPatternGenerator:
    """
    Generates binary Fresnel lens patterns with optional single or multiple Zernike aberrations.
    Supports pure sweeps, spatial scans, combined sweep-and-scan, and custom pattern centers.

    All patterns use amplitude‑thresholding ((cos(phase)+1)/2 > 0.5) and zero phase outside aperture.
    Filenames encode Zernike mode, coefficient, and Xc/Yc offsets where applicable.
    """
    def __init__(self, width=1920, height=1080,
                 pixel_size=10.8e-6, wavelength=633e-9,
                 focal_length=0.15, cutoff=True,
                 xcenter=0, ycenter=0):
        """
        xcenter, ycenter: center of the lens pattern in pixels (DMD coords)
        """
        self.width = width
        self.height = height
        self.pixel_size = pixel_size
        self.wavelength = wavelength
        self.focal_length = focal_length
        self.xcenter = xcenter
        self.ycenter = ycenter
        # Prepare coordinate grid around specified center
        x = ((np.arange(width) - width/2) - xcenter) * pixel_size
        y = ((np.arange(height) - height/2) - ycenter) * pixel_size
        self.X, self.Y = np.meshgrid(x, y)
        # Compute aperture mask
        rho2 = self.X**2 + self.Y**2
        if cutoff:
            r_cut = wavelength * focal_length / (2 * pixel_size)
            self.aperture = rho2 <= r_cut**2
        else:
            self.aperture = np.ones_like(rho2, bool)
        # Base lens phase zeroed outside aperture
        k = 2 * np.pi / wavelength
        phase = (k / (2 * focal_length)) * rho2
        phase[~self.aperture] = 0
        self.phase_lens = phase
        # Precompute Zernike grid (rho, theta)
        rho = np.zeros_like(self.X)
        theta = np.zeros_like(self.X)
        mask = self.aperture
        rho[mask] = np.sqrt(rho2[mask]) / np.sqrt(np.max(rho2[mask]))
        theta = np.arctan2(self.Y, self.X)
        theta[~mask] = np.nan
        self.rho = rho
        self.theta = theta

    def _radial_poly(self, n, m, rho):
        R = np.zeros_like(rho)
        m = abs(m)
        for k in range((n - m)//2 + 1):
            num = (-1)**k * factorial(n - k)
            den = factorial(k) * factorial((n + m)//2 - k) * factorial((n - m)//2 - k)
            R += (num/den) * rho**(n - 2*k)
        return R

    def _zernike(self, n, m):
        R = self._radial_poly(n, m, self.rho)
        if m > 0:
            return R * np.cos(m * self.theta)
        elif m < 0:
            return R * np.sin(-m * self.theta)
        else:
            return R

    def generate_pattern(self, mode=None, coeff=None):
        """
        Generate a single binary lens pattern with optional one Zernike mode.
        mode: (n,m), coeff: scalar aberration coefficient
        """
        phase = self.phase_lens.copy()
        if mode and coeff is not None:
            Z = self._zernike(*mode)
            phase += float(coeff) * np.pi * Z * self.aperture
        amp = (np.cos(phase) + 1) / 2
        binary = (amp > 0.5) & self.aperture
        return (binary.astype(np.uint8) * 255)

    def generate_correction_pattern(self, zernike_coeffs=None, xcenter=None, ycenter=None):
        """
        Generate a pattern combining multiple Zernike terms and roll it so that the pattern
        center is at (xcenter, ycenter) in DMD pixel coordinates (overrides instance center if given).
        zernike_coeffs: dict {(n,m): coeff}
        xcenter, ycenter: desired pattern center offsets (pixels)
        Returns uint8 2D array.
        """
        # Use override center if provided
        xc = self.xcenter if xcenter is None else xcenter
        yc = self.ycenter if ycenter is None else ycenter
        # Compute phase with aberrations
        phase = self.phase_lens.copy()
        if zernike_coeffs:
            for mode, c in zernike_coeffs.items():
                Z = self._zernike(*mode)
                phase += float(c) * np.pi * Z * self.aperture
        # Binarize
        amp = (np.cos(phase) + 1) / 2
        binary = (amp > 0.5) & self.aperture
        # Roll pattern to desired center
        # Calculate pixel shift between desired and instance centers
        dx = int(xc - self.xcenter)
        dy = int(yc - self.ycenter)
        shifted = np.roll(np.roll(binary, -dy, axis=0), -dx, axis=1)
        return (shifted*255).astype(np.uint8)

    def sweep_coefficients(self, mode, coeffs, out_dir=None, prefix=''):
        """Generate base patterns for each aberration coefficient."""
        if out_dir is None:
            out_dir = f"Patterns_Z{mode[0]}{mode[1]}"
        os.makedirs(out_dir, exist_ok=True)
        paths = []
        for c in coeffs:
            pat = self.generate_pattern(mode, c)
            c_str = f"{c:.2f}"
            fname = f"{prefix}Z_{mode[0]}_{mode[1]}_{c_str}.npy"
            p = os.path.join(out_dir, fname)
            np.save(p, pat)
            paths.append(p)
        return paths

    def scan_spatial(self, xwidth, ywidth, xstep, ystep,
                     out_dir=None, prefix=''):
        """Generate spatial scan of the base pattern over offsets."""
        if out_dir is None:
            fm = int(self.focal_length*1e3)
            out_dir = f"Scan_X{xwidth}_Y{ywidth}_dx{xstep}_dy{ystep}_f{fm}mm"
        os.makedirs(out_dir, exist_ok=True)
        xs = np.arange(-xwidth/2, xwidth/2+1, xstep) + self.xcenter
        ys = np.arange(-ywidth/2, ywidth/2+1, ystep) + self.ycenter
        paths = []
        base = self.generate_pattern()
        for px in xs:
            for py in ys:
                dx, dy = int(px), int(py)
                pat = np.roll(np.roll(base, -dy, axis=0), -dx, axis=1)
                fname = f"{prefix}Xc{int(px):d}_Yc{int(py):d}.npy"
                p = os.path.join(out_dir, fname)
                np.save(p, pat)
                paths.append(p)
        return paths

    def sweep_and_scan(self, mode, coeffs, xwidth, ywidth, xstep, ystep,
                       xcenter=0, ycenter=0,
                       out_dir=None, prefix=''):
        """Combine aberration sweep and spatial scan into a single batch.
           Allows overriding the pattern center per call via xcenter/ycenter.
        """
        # Determine center offsets
        xc = self.xcenter if xcenter is None else xcenter
        yc = self.ycenter if ycenter is None else ycenter
        # Setup output directory
        if out_dir is None:
            fm = int(self.focal_length*1e3)
            out_dir = f"SweepScan_Z{mode[0]}{mode[1]}_X{xwidth}_Y{ywidth}_dx{xstep}_dy{ystep}_xc{xcenter}_yc{ycenter}_f{fm}mm"
        os.makedirs(out_dir, exist_ok=True)
        # Build spatial offsets list, handle zero step/width
        if xstep == 0 or xwidth == 0:
            xs = [xc]
        else:
            xs = list(np.arange(-xwidth/2, xwidth/2+1, xstep) + xc)
        if ystep == 0 or ywidth == 0:
            ys = [yc]
        else:
            ys = list(np.arange(-ywidth/2, ywidth/2+1, ystep) + yc)
        paths = []
        # Loop over coefficients and spatial positions
        for c in coeffs:
            base = self.generate_pattern(mode, c)
            c_str = f"{c:.2f}"
            for px in xs:
                for py in ys:
                    dx, dy = int(px), int(py)
                    pat = np.roll(np.roll(base, -dy, axis=0), -dx, axis=1)
                    fname = f"{prefix}Xc{int(px):d}_Yc{int(py):d}_Z_{mode[0]}_{mode[1]}_{c_str}.npy"
                    p = os.path.join(out_dir, fname)
                    np.save(p, pat)
                    paths.append(p)
                    
                    print()
        return paths

