import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def grid_data(df, value):
    data = df.pivot_table(index="y", columns="x", values=value, aggfunc="mean")
    data = data.sort_index().sort_index(axis=1)
    return data

def crop_figure(img, half):
    iy0, ix0 = np.unravel_index(np.argmax(img), img.shape)
    y0min, y0max = max(0, iy0-half), min(img.shape[0], iy0+half+1)
    x0min, x0max = max(0, ix0-half), min(img.shape[1], ix0+half+1)
    return img[y0min:y0max, x0min:x0max]

def fill_missing_by_nn(a, missing_is_zero=False, max_iters=200):
    """
    Fill missing values (NaNs by default) by averaging nearest valid neighbors.
    Works iteratively so corner blocks get filled from the inside out.
    """
    a = np.array(a, dtype=float).copy()
    if missing_is_zero:
        a[a == 0] = np.nan

    nan_mask = np.isnan(a)
    if not np.any(nan_mask):
        return a  # nothing to do

    for _ in range(max_iters):
        prev_mask = nan_mask.copy()
        sums   = np.zeros_like(a, dtype=float)
        counts = np.zeros_like(a, dtype=float)

        # 8-neighborhood
        for dy, dx in [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]:
            s = np.roll(a, shift=(dy, dx), axis=(0, 1))
            # prevent wrap-around contributions
            if dy == -1: s[-1, :] = np.nan
            if dy ==  1: s[ 0, :] = np.nan
            if dx == -1: s[:, -1] = np.nan
            if dx ==  1: s[:,  0] = np.nan

            valid = ~np.isnan(s)
            sums[valid]   += s[valid]
            counts[valid] += 1

        fill_here = nan_mask & (counts > 0)
        a[fill_here] = sums[fill_here] / counts[fill_here]
        nan_mask = np.isnan(a)

        # stop if no progress
        if np.array_equal(nan_mask, prev_mask) or not np.any(nan_mask):
            break

    return a
    
# --- load your data ---
fn = "fwhm_peak_values.xlsx"   
df = pd.read_excel(fn, sheet_name="Values")

# --- compute improvements ---
df["peak_impr"] = df["peak_after"] / df["peak_before"]
df["peak_impr_pct"] = 100 * df["peak_impr"]

xc_list = np.arange(-200, 201, 25)
yc_list = np.arange(-100, 101, 25)

folder = "aberration_correction"
total_intensity = np.zeros((len(yc_list), len(xc_list)))
for i, yc in enumerate(yc_list):
    for j, xc in enumerate(xc_list):
        
        filename = f"xc{xc}_yc{yc}_reference_cam.npy"
        try:
            data = np.load(os.path.join(folder, filename))
            total_intensity[i,j] = np.mean(crop_figure(data, 15))
        except:
            total_intensity[i,j] = np.nan

total_intensity = fill_missing_by_nn(total_intensity)
total_intensity = total_intensity/np.max(total_intensity)
np.save('total_intensity_for_aberration_correction.npy', total_intensity)


df["eps_after"] = abs(df["FWHM_x_after"] - df["FWHM_y_after"]) / (df["FWHM_x_after"] + df["FWHM_y_after"])
df["eps_before"] = abs(df["FWHM_x_before"] - df["FWHM_y_before"]) / (df["FWHM_x_before"] + df["FWHM_y_before"])

Z_before = grid_data(df, 'eps_before')
Z_after = grid_data(df, 'eps_after')

Z_peak_before = grid_data(df, 'peak_before') / np.max(Z_peak_after)
Z_peak_after = grid_data(df, 'peak_after') / np.max(Z_peak_after)

Z_peak_after_norm = Z_peak_after / total_intensity

np.save(r'eps_before_2D.npy', Z_before)
np.save(r'eps_after_2D.npy', Z_after)
np.save(r'peak_before_2D.npy', Z_peak_before / Z_peak_after_norm.max())
np.save(r'peak_after_2D.npy', Z_peak_after / Z_peak_after_norm.max())
np.save(r'peak_after_norm_2D.npy', Z_peak_after_norm / Z_peak_after_norm.max())

import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl

fontsize = 24
mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': fontsize,
    'axes.titlesize': fontsize,
    'axes.labelsize': fontsize,
    'xtick.labelsize': fontsize-4,
    'ytick.labelsize': fontsize-4,

    # make mathtext use Times New Roman
    'mathtext.fontset': 'custom',
    'mathtext.rm': 'Times New Roman',
    'mathtext.it': 'Times New Roman:italic',
    'mathtext.bf': 'Times New Roman:bold',
})

eps_before = np.load('eps_before_2D.npy')
eps_after = np.load('eps_after_2D.npy')

peak_before = np.load('peak_before_2D.npy')
peak_after = np.load('peak_after_2D.npy')
peak_after_norm = np.load('peak_after_norm_2D.npy')

fontsize = 32
plt.rc('font', family='Times New Roman', size=fontsize)
plt.rc('axes', titlesize=fontsize, labelsize=fontsize)
plt.rc('xtick', labelsize=fontsize-4)
plt.rc('ytick', labelsize=fontsize-4)

xmin = -200
xmax = 200
ymin = -100
ymax = 100
epsmin = 0
epsmax = eps_before.max()

plt.figure(figsize=(6, 5))
im = plt.imshow(eps_before, extent=[xmin, xmax, ymin, ymax], vmin=0, vmax=epsmax)
cbar = plt.colorbar(im, pad=0.07, aspect=10, fraction=0.045)
cbar.set_ticks([epsmin, epsmax])
cbar.set_ticklabels([f'{epsmin:.2f}', f'{epsmax:.2f}']) 
plt.xlabel(r'$x_c$, pixels')
plt.ylabel(r'$y_c$, pixels')
plt.savefig('eps_before.svg', bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(6, 5))
im = plt.imshow(eps_after, extent=[xmin, xmax, ymin, ymax], vmin=0, vmax=epsmax)
cbar = plt.colorbar(im, pad=0.07, aspect=10, fraction=0.045)
cbar.set_ticks([epsmin, epsmax])
cbar.set_ticklabels([f'{epsmin:.2f}', f'{epsmax:.2f}']) 
plt.xlabel(r'$x_c$, pixels')
plt.ylabel(r'$y_c$, pixels')
plt.savefig('eps_after.svg', bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(6, 5))
im = plt.imshow(peak_before, extent=[xmin, xmax, ymin, ymax], vmin=0, vmax=1)
cbar = plt.colorbar(im, pad=0.07, aspect=10, fraction=0.045)
cbar.set_ticks([0, 1])
cbar.set_ticklabels([0, 1]) 
plt.xlabel(r'$x_c$, pixels')
plt.ylabel(r'$y_c$, pixels')
plt.savefig('peak_before.svg', bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(6, 5))
im = plt.imshow(peak_after, extent=[xmin, xmax, ymin, ymax], vmin=0, vmax=1)
cbar = plt.colorbar(im, pad=0.07, aspect=10, fraction=0.045)
cbar.set_ticks([0, 1])
cbar.set_ticklabels([0, 1]) 
plt.xlabel(r'$x_c$, pixels')
plt.ylabel(r'$y_c$, pixels')
plt.savefig('peak_after.svg', bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(6, 5))
im = plt.imshow(peak_after_norm, extent=[xmin, xmax, ymin, ymax], vmin=0, vmax=1)
cbar = plt.colorbar(im, pad=0.07, aspect=10, fraction=0.045)
cbar.set_ticks([0, 1])
cbar.set_ticklabels([0, 1]) 
plt.xlabel(r'$x_c$, pixels')
plt.ylabel(r'$y_c$, pixels')
plt.savefig('peak_after_norm.svg', bbox_inches='tight', pad_inches=0)

plt.show()

