import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import math

#Equation of dipolar and exchange coupling
def full_coupling(r, theta, d_0, j_0, dec):
    field_angle = 0 # (45)*np.pi/180 # correction for in-plane field angle of external field
    D = d_0 * (1 - 3 * (np.cos(theta-field_angle)**2))/(2 * math.pi * r**3)
    J = j_0 * np.exp(-1.0 * (r - 1.0525)/dec) #1.0525 nm = r_0
    return 2*D + J

#Wrapper of coupling function because curve_fit() can only fit 1d functions
def fit_func(D, *args):
    t, r = D
    res = np.zeros(r.shape)
    res += full_coupling(r, t, *args)
    return res

data = pd.read_excel('dimer_fit_data.xlsx', sheet_name = 'Sheet1',index_col = 0, nrows=32)

#store atomic spacing, theta and coupling for each dimer
dist = np.zeros((2,22))
coupling = np.zeros(22)

for i in range(22):
    coupling[i] = data['J'][i] #coupling (MHz)
    dist[0,i] = data['theta'][i]
    dist[1,i] = data['r (nm)'][i]
    print(coupling[i], dist[0,i], dist[1,i]) 

x = dist[1,:]*np.cos(dist[0,:])
y = dist[1,:]*np.sin(dist[0,:])

p0 = [165, 710, 0.1249] #initial guess of parameters d_0, j_0, decay
param_opt, param_cov = curve_fit(fit_func, dist, coupling, p0 = p0)
p_err = np.sqrt(np.diag(param_cov))

print("Fit function parameters result")
print("D_0 = %.1f +- %.1f (MHz)" % (param_opt[0], p_err[0]))
print("J_0 = %.1f +- %.1f (MHz)" % (param_opt[1], p_err[1]))
print("decay = %.3f +- %.3f (nm)" % (param_opt[2], p_err[2]))


########################################################################################
# Generate plot of coupling function with fit parameters and experimental data
#########################################################################################

rmin, rmax, dr = 1.05, 1.7, 200
tmin, tmax, dt = 0, 2*np.pi, 200

r, t = np.linspace(rmin, rmax, dr), np.linspace(tmin, tmax, dt)
R, T = np.meshgrid(r, t)
Z = np.zeros(R.shape)

for r_ind in range(len(r)):
    for t_ind in range(len(t)):
        Z[r_ind,t_ind] = full_coupling(R[r_ind,t_ind],T[r_ind,t_ind], *param_opt)

fig, axs = plt.subplots(1,3,figsize=(6.4*2,4.8), gridspec_kw={'width_ratios': [1, 1, 1]}, subplot_kw={"projection": "3d"})
for i in axs:
    i.tick_params(direction='in')
    
    
axs[0].plot_surface(R*np.cos(T), R*np.sin(T), Z, cmap='coolwarm', alpha=0.6)
axs[0].scatter3D(x, y, coupling, color = "black", alpha=1)
axs[0].set_xlabel('x (nm)')
axs[0].set_ylabel('y (nm)')
axs[0].set_zlabel('ESR splitting (MHz)')
axs[0].view_init(30, 45)

axs[1].plot_surface(R*np.cos(T), R*np.sin(T), Z, cmap='coolwarm', alpha=0.6)
axs[1].scatter3D(x, y, coupling, color = "black")
axs[1].set_xlabel('x (nm)')
axs[1].set_ylabel('y (nm)')
axs[1].set_zlabel('ESR splitting (MHz)')
axs[1].view_init(0, 45-90)

axs[2].scatter3D(x, y, coupling, color = "black", alpha=1,zorder=1)
im = axs[2].plot_surface(R*np.cos(T), R*np.sin(T), Z, cmap='coolwarm', alpha=0.6)
axs[2].set_xlabel('x (nm)')
axs[2].set_ylabel('y (nm)')
axs[2].set_zticks([])
axs[2].view_init(90, 0)
cbar = fig.colorbar(im,ax=axs[2],fraction=0.03)#,ticks=[-0.1,0.1])
cbar.set_label('ESR splitting (MHz)')
plt.savefig('../plots/dimer_analysis.pdf')
plt.show()

