import numpy as np
from scipy.signal import find_peaks, medfilt, butter, filtfilt
from scipy.stats import mode
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

plt.rcParams.update({
                     'font.family': 'Times New Roman',
                     'font.size': 14
                     })

def lowpass_filter(data, cutoff_freq, fs, order=1):
    """Apply a low-pass Butterworth filter to 1D or 2D data."""
    nyquist = 0.5 * fs
    norm_cutoff = cutoff_freq / nyquist
    b, a = butter(order, norm_cutoff, btype='low', analog=False)

    if data.ndim == 1:
        return filtfilt(b, a, data)
    else:
        return np.vstack([filtfilt(b, a, chan) for chan in data])

def find_edges(data, distance,
               hscale=0.2,
               mode_kernel=5,
               debug=True):
    if mode_kernel < 1:
        raise ValueError("mode_kernel must be ≥ 1.")
    if mode_kernel % 2 == 0:      # make it odd
        mode_kernel += 1
    if len(data) < 1:
        raise ValueError('Length of data must be greater than 1.')
            
    data = medfilt(data, mode_kernel)
            
    diff = np.diff(data)
    pksL, _ = find_peaks(diff,  height=diff.max() *  hscale, distance=distance)
    pksR, _ = find_peaks(-diff, height=(-diff).max() * hscale, distance=distance)
    
    if len(pksL) < 1:
        raise ValueError("No peaks found")
    if len(pksR) < 1:
        raise ValueError("No peaks found")

    # shift a couple of samples so we’re surely inside the plateau
    pksL -= 2
    pksR += 2
    
    # guarantee alternating start/stop
    if pksL[0] < pksR[0]:
        start, stop = pksL, pksR
    else:
        start, stop = pksR, pksL
        
        
    start = np.arange(start[0], len(data), mode(np.diff(start))[0])
    stop = np.arange(stop[0], len(data), mode(np.diff(stop))[0])
    
    print(f"{len(pksL)} left and {len(pksR)} right edges found starting with {start[0]} and {stop[0]}, number of datapoints per position is {mode(np.diff(stop))[0]}")

    if (stop[0] - start[0]) < (mode(np.diff(start))[0])/4:
        temp = start[1:]
        start = stop
        stop = temp
        
    # balance array lengths
    if len(start) - len(stop) == 1:
        stop = np.append(stop, len(data)-1)
    elif len(stop) - len(start) == 1:
        start = np.append(start, len(data)-1)
    elif abs(len(stop)-len(start)) > 2:
        plt.figure()
        plt.plot(data[:stop[2]])
        plt.plot(stop[:2], data[stop[:2]], 'b*')
        plt.plot(start[:2], data[start[:2]], 'r*')
        raise ValueError("Edge count mismatch – tweak 'distance' or 'hscale'.")
        
    if debug:
        plt.figure()
        plt.plot(data[:start[1]])
        plt.plot(start[0], data[start[0]], 'r*')
        plt.plot(stop[0], data[stop[0]], 'b*')

    
    edges = np.empty(start.size + stop.size, dtype=int)
    edges[0::2] = start
    edges[1::2] = stop
    return edges



def signal_averaging(data, edges):
    # vectorised mean of each [start, stop) slice
    lengths = edges[1::2] - edges[0::2]
    sums    = np.add.reduceat(data, edges)[::2]
    return sums / lengths

def create2D(filename, dscale=0.9, hscale=0.2, mode_kernel=5, lowpass=0, plot=True, save_plot=False):
    file = np.load(filename, allow_pickle=True)
    data = file['data']
    fs = file['sampling_rate']
    fDMD = file['fDMD']
    xwidth = file['xwidth']
    ywidth = file['ywidth']
    dx = file['xstep']
    dy = file['ystep']
    Nx = int(xwidth/dx) + 1
    Ny = int(ywidth/dy) + 1
    output = np.zeros((3, Ny, Nx))
                
    if data.ndim > 1:
        QUAD = np.zeros_like(data)
        QUAD[2,:] = data[0,:] + data[1,:] + data[2,:] + data[3,:]
        data = data - np.mean(data, axis=1, keepdims=True)
        QUAD[0,:] = data[0,:] - data[1,:] - data[2,:] + data[3,:]
        QUAD[1,:] = data[0,:] + data[1,:] - data[2,:] - data[3,:]
        QUAD[3,:] = np.sqrt(QUAD[0,:]**2 + QUAD[1,:]**2)
        
        idx = np.argmax(QUAD.max(axis=1) - QUAD.min(axis=1))
        if idx == 0:
            print("SD used for data processing")
        elif idx == 1:
            print("UD used for data processing")
        elif idx == 2:
            print("SUM used for data processing")
        else:
            print("QUAD used for data processing")
            
        if lowpass > 0:
            QUAD_filtered = lowpass_filter(QUAD, fDMD*lowpass, fs)
            edges = find_edges(QUAD_filtered[idx,:], int(round(fs / fDMD * dscale)), hscale, mode_kernel)
        else:
            edges = find_edges(QUAD[idx,:], int(round(fs / fDMD * dscale)), hscale, mode_kernel)
        
        for i in range(QUAD.shape[0]-1):
            temp = signal_averaging(QUAD[i,:], edges)
            if len(temp) < Nx*Ny:
                temp = np.concatenate([temp, [0]*(Nx*Ny - len(temp))])
            
            signal2D = temp[:Nx*Ny].reshape(Ny, Nx)
            output[i, :, :] = signal2D
            if plot:
                extent = (0, xwidth, 0, ywidth) 
                plt.figure()
                plt.imshow(signal2D, cmap='gray', extent=extent)
                plt.xlabel("x pixels")
                plt.ylabel("y pixels")
                plt.grid()
                plt.colorbar()
                if save_plot:
                    plt.savefig(f"{filename[:-4]}_2D_{i}.png")
       
    else:        
        edges = find_edges(data, int(round(fs / fDMD * dscale)), hscale, mode_kernel)
        signal1D = signal_averaging(data, edges)
        
        if len(signal1D) < Nx*Ny:
            if np.std(np.diff(edges[::2])) < 1:        
                signal1D = np.concatenate([signal1D, [0]*(Nx*Ny - len(signal1D))])
                
        signal2D = signal1D[:Nx*Ny].reshape(Ny, Nx)
        output[0,:,:] = signal2D
        if plot==True:
            extent = (0, xwidth, 0, ywidth) 
            plt.figure()
            plt.imshow(signal2D, cmap='gray', extent=extent)
            plt.xlabel("x pixels")
            plt.ylabel("y pixels")
            plt.colorbar()
            plt.show()
        
    return output

def fit_gaussian(data_full, half=10, threshold_factor=0.1):
    """
    Fits a 2D rotated Gaussian to image data using scipy.curve_fit,
    with improved initial offset estimation, tighter bounds on sigma,
    and fitting focused on high-signal pixels above a threshold.
    Returns the fitted amplitude or None on failure.
    """
    # locate peak and crop window
    iy0, ix0 = np.unravel_index(np.argmax(data_full), data_full.shape)
    ymin, ymax = max(0, iy0-half), min(data_full.shape[0], iy0+half)
    xmin, xmax = max(0, ix0-half), min(data_full.shape[1], ix0+half)
    data = data_full[ymin:ymax, xmin:xmax].astype(float)

    # coordinate grid
    y = np.arange(ymin, ymax)
    x = np.arange(xmin, xmax)
    X, Y = np.meshgrid(x, y)

    # estimate background offset from border median
    top    = data[0:2, :].ravel()
    bottom = data[-2:, :].ravel()
    left   = data[:, 0:2].ravel()
    right  = data[:, -2:].ravel()
    offset0 = float(np.median(np.concatenate([top, bottom, left, right])))

    # initial amplitude guess
    A0 = float(data.max() - offset0)
    # initial center guess
    x0_0, y0_0 = float(ix0), float(iy0)
    # initial sigma guess
    sigma0 = float(np.mean(data.shape) / 4)
    p0 = [A0, x0_0, y0_0, sigma0, sigma0, 0.0, offset0]

    # focus fitting on high-signal region
    thr = offset0 + threshold_factor * A0
    mask = data > thr
    X_sub = X[mask]; Y_sub = Y[mask]; data_sub = data[mask]

    # tighter bounds for sigma and offset
    lb = [0, ix0-half, iy0-half, 1.0, 1.0, -np.pi, offset0*0.5]
    ub = [np.inf, ix0+half, iy0+half, half, half, np.pi, data.max()]

    try:
        popt, _ = curve_fit(
            gaussian_2d,
            (X_sub.ravel(), Y_sub.ravel()),
            data_sub.ravel(),
            p0=p0,
            bounds=(lb, ub),
            maxfev=10000
        )
    except RuntimeError:
        return None
    return float(popt[0])

def gaussian_2d(coords, A, x0, y0, sigma_x, sigma_y, theta, offset):
    x, y = coords
    # Rotate coordinates
    x_rot =  (x - x0) * np.cos(theta) + (y - y0) * np.sin(theta)
    y_rot = -(x - x0) * np.sin(theta) + (y - y0) * np.cos(theta)
    # 2D Gaussian formula
    g = offset + A * np.exp(
        - (x_rot**2)/(2*sigma_x**2)
        - (y_rot**2)/(2*sigma_y**2)
    )
    return g.ravel()

def gaussian_1d(x, A, mu, sigma, offset):
    return offset + A * np.exp(-(x - mu)**2 / (2*sigma**2))

def plot_max_crop(img, half, return_slice=False):
    iy0, ix0 = np.unravel_index(np.argmax(img), img.shape)
    # define crop window (clamped to image bounds)
    ymin, ymax = max(0, iy0-half), min(img.shape[0], iy0+half)
    xmin, xmax = max(0, ix0-half), min(img.shape[1], ix0+half)
    data = img[ymin:ymax, xmin:xmax]
    plt.figure()
    plt.imshow(data, cmap='gray')
    plt.colorbar()
    plt.show()
    if return_slice:
        return img[iy0, xmin:xmax], img[ymin:ymax, ix0]
    
def merit_Q(data_full, half=10):
    iy0, ix0 = np.unravel_index(np.argmax(data_full), data_full.shape)
    ymin, ymax = max(0, iy0-half), min(data_full.shape[0], iy0+half+1)
    xmin, xmax = max(0, ix0-half), min(data_full.shape[1], ix0+half+1)
    data = data_full[ymin:ymax, xmin:xmax].astype(float)

    # coordinate grid
    y = np.arange(-half, half+1)
    x = np.arange(-half, half+1)
    X, Y = np.meshgrid(x, y)

    # estimate background offset from border median
    top    = data[0:2, :].ravel()
    bottom = data[-2:, :].ravel()
    left   = data[:, 0:2].ravel()
    right  = data[:, -2:].ravel()
    offset = float(np.median(np.concatenate([top, bottom, left, right])))
    data = data - offset
    
    Q = (data * ((half**2)/2 - X**2 - Y**2)).sum() / data.sum()
    return Q