import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import hsv_to_rgb
from scipy.special import factorial
from matplotlib.patches import Circle
import matplotlib as mpl

fontsize = 24
# -- Set Times font, fontsize 16 -----------------------------------------------
plt.rc('font', family='Times New Roman', size=fontsize)
mpl.rcParams['mathtext.fontset'] = 'stix'
plt.rc('axes', titlesize=fontsize, labelsize=fontsize)
plt.rc('xtick', labelsize=fontsize)
plt.rc('ytick', labelsize=fontsize)
plt.rc('legend', fontsize=fontsize)
plt.rc('figure', titlesize=fontsize)
# --------------------------------------------------------------------------------

def complex_plot(A):
    phase = np.angle(A)
    hue   = (phase + np.pi)/(2*np.pi)
    hsv = np.zeros(A.shape + (3,), dtype=float)
    hsv[...,0] = hue
    hsv[...,1] = 1.0
    hsv[...,2] = 1.0
    return hsv_to_rgb(hsv).astype(np.float64)

def radial_poly(n, m, rho):
    R = np.zeros_like(rho)
    m = abs(m)
    for k in range((n - m)//2 + 1):
        num = (-1)**k * factorial(n - k)
        den = factorial(k) * factorial((n + m)//2 - k) * factorial((n - m)//2 - k)
        R += num / den * rho**(n - 2*k)
    return R

def zernike(n, m, rho, theta):
    R = radial_poly(n, m, rho)
    if   m > 0: return R * np.cos(m * theta)
    elif m < 0: return R * np.sin(-m * theta)
    else:       return R


n = 2
m = 0
    
xc = 40
yc = 0

folder = 'data'
zernike_coeffs = np.load(os.path.join(folder, f'xc{xc}_yc{yc}_final_coeffs.npy'), allow_pickle=True).item()

# create grid
x, y = np.meshgrid(np.linspace(-1, 1, 400), np.linspace(-1, 1, 400))
rho   = np.sqrt(x**2 + y**2)
theta = np.arctan2(y, x)

# mask outside unit disk
mask = rho > 1
rho[mask]   = np.nan
theta[mask] = np.nan

# compute phase map
phase = np.zeros_like(x)
for (n, m), c in zernike_coeffs.items():
    Znm = zernike(n, m, rho, theta)
    phase += c * Znm * np.pi
phase[mask] = np.nan

# build RGBA image
rgb   = complex_plot(np.exp(1j * (phase)))
rgb   = rgb - rgb.min()
alpha = np.ones_like(rho)
alpha[mask] = 0.0
rgba  = np.dstack((rgb, alpha))

# plot clipped to circle with colorbar
fig, ax = plt.subplots(figsize=(3,4))
im = ax.imshow(
    rgba,
    origin='lower',
    extent=[-1,1,-1,1],
    interpolation='none'
)

# clip to unit circle
circle = Circle((0, 0), 1, transform=ax.transData)
im.set_clip_path(circle)

# remove axes/frame
ax.set_axis_off()

# add a phase?hue colorbar
norm = plt.Normalize(0, 2*np.pi)
sm = plt.cm.ScalarMappable(cmap='hsv', norm=norm)
sm.set_array([])

cbar = fig.colorbar(sm, ax=ax,
    orientation='vertical',
    fraction=0.06, pad=0.04, aspect=15,
    ticks=[0, 2*np.pi])

labels = [r'$-\pi$', r'$\pi$']
cbar.set_ticklabels(labels)

plt.show()
