import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
from skimage.feature import canny
from skimage.transform import hough_line, hough_line_peaks
from skimage.measure import profile_line
from scipy.interpolate import interp1d
from scipy.signal import find_peaks
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

mpl.rcParams.update({
    "text.usetex": False,                 # disable external LaTeX
    "font.family": "serif",
    "font.serif": ["Times New Roman"],    # use TNR for text
    "font.size": 36,
    "mathtext.fontset": "stix",           # STIX = Times-like math
    "mathtext.rm":    "Times New Roman",  # roman math = TNR
    "mathtext.it":    "Times New Roman:italic"  # italic math = TNR italic
})


rc('font', family='times new roman', size=22)

# --------------------------------
# 1) Load or define your 6161 data
# --------------------------------
# Replace the next two lines with e.g. `data = np.loadtxt(...)`
grating = np.load(r'X60_Y60_dx1_dy1_Xc-250_Yc0_Z-4.0_f150mm_DMD100Hz_output_2D_array.npy')
data = grating[1, :, :]

# --------------------------------
# 2) Detect strongest Hough line
# --------------------------------
edges = canny(data, sigma=1.0)
h, angles, dists = hough_line(edges)
accums, angle_peaks, dist_peaks = hough_line_peaks(
    h, angles, dists, threshold=0.3 * np.max(h)
)
theta = angle_peaks[0]
dist  = dist_peaks[0]

# a point on that line:
x0, y0 = dist * np.cos(theta), dist * np.sin(theta)

# direction vector normal to it (i.e. the perpendicular)
dx_perp, dy_perp = np.cos(theta), np.sin(theta)

# --------------------------------
# FIGURE1: overlay on 2D data
# --------------------------------
fig1, ax1 = plt.subplots(figsize=(2.8,2.8))
im = ax1.imshow(data, cmap='gray', origin='lower')
divider = make_axes_locatable(ax1)
cax = divider.append_axes(
    "top",     # position: top of ax1
    size="7%", # thickness of colorbar
    pad=0.15   # distance between plot and bar
)

cbar = fig1.colorbar(
    im,
    cax=cax,
    orientation='horizontal'
)

# move ticks and label above the bar
cax.xaxis.set_label_position('top')
cax.xaxis.set_ticks_position('top')


# original Hough line (red dotted)
slope_h = -np.cos(theta) / np.sin(theta)
ax1.axline((x0, y0), slope=slope_h,
           color='red', linestyle=':', linewidth=2,
           label='hough line')

# perpendicular (blue)
slope_p = np.sin(theta) / np.cos(theta)
if np.isfinite(slope_p):
    ax1.axline((x0, y0), slope=slope_p,
               color='blue', linewidth=2,
               label='perpendicular')
else:
    ax1.axvline(x0, color='blue', linewidth=2, label='Perpendicular')

# Pixel ticks every 5 pixels
ax1.set_xticks(np.arange(0, data.shape[1], 60))
ax1.set_yticks(np.arange(0, data.shape[0], 60))
ax1.set_xlabel(r'$\mathit{x}$, pixels')
ax1.set_ylabel(r'$\mathit{y}$, pixels')

# ax1.legend(loc='lower right', fontsize=14)
          
plt.savefig("grating_with_lines.svg")



hgt, wdt = data.shape
t_vals = []
if dx_perp != 0:
    t_vals += [(    0 - x0)/dx_perp, ((wdt-1) - x0)/dx_perp]
if dy_perp != 0:
    t_vals += [(    0 - y0)/dy_perp, ((hgt-1) - y0)/dy_perp]

pts = []
for t in t_vals:
    x_i = x0 + t*dx_perp
    y_i = y0 + t*dy_perp
    if 0 <= x_i <= wdt-1 and 0 <= y_i <= hgt-1:
        pts.append((x_i, y_i, t))
pts = sorted(pts, key=lambda p: p[2])
x_start, y_start, _ = pts[0]
x_end,   y_end,   _ = pts[-1]

raw_prof = profile_line(
    data,
    (y_start, x_start),
    (y_end,   x_end),
    mode='reflect', order=1
)
seg_len = np.hypot(x_end - x_start, y_end - y_start)
raw_dist = np.linspace(0, seg_len, len(raw_prof))

interp = interp1d(raw_dist, raw_prof, kind='cubic', fill_value='extrapolate')
N_fine = 500
d_fine = np.linspace(0, seg_len, N_fine)
p_fine = interp(d_fine)

# minima = peaks of the inverted profile
min_idx, _ = find_peaks(-p_fine, prominence=0.1, distance=10)
d_min = d_fine[min_idx]
p_min = p_fine[min_idx]

rc('font', family='times new roman', size=24)
# --------------------------------
# FIGURE2: intensity vs. distance
# --------------------------------
fig2, ax2 = plt.subplots(figsize=(7,4))
ax2.plot(d_fine, p_fine, 'k--', label='Interpolated Profile')
ax2.plot(raw_dist, raw_prof, 'bo', label='Measured data')
ax2.plot(d_min,  p_min, 'rv', label='Minima', markersize=6)
ax2.set_xlabel('Pixels')
ax2.set_ylabel('Difference signal, V')
ax2.legend(loc='upper left', fontsize=18, framealpha=0.3)
ax2.grid(True)

plt.tight_layout()
plt.savefig("cross-section.svg")
plt.show()

arrays = [
    d_fine, p_fine,
    raw_dist, raw_prof,
    d_min, p_min
]
names = ['d_fine','p_fine','d_raw','p_raw','d_min','p_min']

# find the maximum length
max_len = max(arr.shape[0] for arr in arrays)

# make an output array full of nan
out = np.full((max_len, len(arrays)), np.nan)

# copy each into its column
for i, arr in enumerate(arrays):
    out[:arr.shape[0], i] = arr

# save with header
np.savetxt('grating_profile.csv',
           out,
           delimiter=',',
           header=','.join(names),
           comments='')