# %%
import time as t
import sympy as sp
import numpy as np
import scipy.constants as c
import matplotlib.pyplot as plt
from functools import partial
import xarray as xr
from scipy.optimize import curve_fit
import itertools as it
from tqdm import tqdm

from matplotlib import rc

rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]})
rc("text", usetex=True)
plt.rcParams["lines.linewidth"] = 0.654
plt.rcParams["font.size"] = 15
plt.rcParams["legend.fontsize"] = 15

# %%
# Import k-resolved spectral density at the Fermi energy
fermi_surf = xr.load_dataset('./data/fermi_surf.nc')

# %%
# Convert kx, ky coordinates to polar and find maximum for each polar angle
# The maximum of the spectral density corresponds to the Fermi surface

max_k = []

# Generate a mesh grid with cartesian coordinates in k-space
kx = fermi_surf.kx + fermi_surf.attrs["bz_point"][0]
ky = fermi_surf.ky + fermi_surf.attrs["bz_point"][1]
kx, ky = np.meshgrid(kx, ky)
# Generate polar angles
thetas = np.linspace(-np.pi, np.pi, 500)
dt = np.diff(thetas)[0]
for t in thetas[1:]:
    # Define polar angle interval [t_m, t_p]
    t_m = t - dt / 2
    t_p = t + dt / 2
    # Cut spectral density data within this interval
    cut = fermi_surf.where(
        np.multiply(
            np.arctan2(ky, kx) > t_m,
            np.arctan2(ky, kx) < t_p,
        )
    ).sel(alpha=1)
    dummy_stacked = cut.stack(z=["kx", "ky"])
    # Find maximum of spectral density in the given interval
    idx = dummy_stacked.idxmax("z")
    kx_max = idx.kx.data + fermi_surf.attrs["bz_point"][0]
    ky_max = idx.ky.data + fermi_surf.attrs["bz_point"][1]
    max_k.append(np.sqrt(kx_max**2 + ky_max**2))
max_k = np.array(max_k)
thetas = thetas[1:]

# %%
a = 1.42e-10
vF = 1e6


def trigonal_warping(phi, valley, phic=0):
    """
    Trigonal warping term.

    Parameters:
    -----------

    phi: float
        Polar angle.
    valley: int (-1, 1)
        Valley quantum number.
    """
    return valley * np.cos(3 * phi + phic)


def k_F(phi, kF0, delta_k, valley=1, phic=0):
    """
    Fermi surface with trigonal warping.

    Parameters:
    -----------

    phi: float
        Polar angle.
    kF0: float
        Fermi momentum without trigonal warping.
    delta_k: float
        Amplitude of trigonal warping modulation.
    phic: float
        Crystal orientation angle.
    valley: int (-1, 1)
        Valley quantum number.
    """
    return kF0 + delta_k * trigonal_warping(phi=phi, valley=valley, phic = phic)


def kx_ky(phi, kF0, delta_k, phic=0, valley=1):
    """
    Fermi momentum with trigonal warping.

    Parameters:
    -----------

    phi: float
        Polar angle.
    kF0: float
        Fermi momentum without trigonal warping.
    delta_k: float
        Amplitude of trigonal warping modulation.
    phic: float
        Crystal orientation angle.
    valley: int (-1, 1)
        Valley quantum number.
    """
    k = k_F(phi=phi, kF0=kF0, delta_k=delta_k, valley=valley, phic=phic)
    return np.array([k * np.cos(phi), k * np.sin(phi)])


# %%
# Fit Fermi surface from tight-binding calculation using the function `k_F`
k_F_valley = partial(k_F, valley=1)
values, _ = curve_fit(
    k_F_valley,
    thetas, max_k
)

#Unpack fitting parameters
kF0_fit, delta_k_fit = values

# %%
# Compute fitted Fermi surface
fitted_fermi = k_F(phi=thetas, kF0=kF0_fit, delta_k=delta_k_fit, valley=1)

# %%
# Inspect fitted Fermi surface
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)

ax.scatter(thetas, max_k, s=3)
ax.plot(thetas, fitted_fermi, zorder=2, c='r')
ax.set_xlabel('')
ax.set_ylabel('')

plt.show()

# %%
# Extract angle distribution from tight-binding calculations
ang_dist = xr.load_dataset('./data/angle_distribution.nc')
# Take rolling average to average out over many C-C bonds
ds_roll = ang_dist.rolling(theta=300, center=True).mean().dropna('theta')


# %%
# Gaussian function
def gaussian(theta, w, theta_0, beta):
    return beta * np.exp(-(((theta - theta_0) / w) ** 2))


# Two gaussian functions with symmetric offset from 0
def two_gaussians(theta, w, theta_0, beta):
    return (
        (
            gaussian(theta=theta, w=w, theta_0=theta_0, beta=beta)
            + gaussian(theta=theta, w=w, theta_0=-theta_0, beta=beta)
        )
        / 2
    )


# %%
# Fit angular distribution with two gaussians
ds_roll_alpha = ds_roll.sel(alpha=1)
popt, _ = curve_fit(
    two_gaussians, ds_roll_alpha.theta.data, ds_roll_alpha.current.data,
    p0 = (0.2, 0.5, 5e-4)
)
# Extract fitting paramters
w_fit, theta_0_fit, beta_fit = popt

# %%
# Compute fitted angular distribution
fitted_dist = two_gaussians(
    ds_roll_alpha.theta.data,
    w = w_fit,
    theta_0 = theta_0_fit,
    beta = beta_fit
)

# %%
# Inspect angular distribution fit
fig = plt.figure()
from matplotlib.gridspec import GridSpec
gs = GridSpec(1, 2, figure=fig)

ax1 = fig.add_subplot(gs[1], polar=True)
plt.scatter(ds_roll_alpha.theta.data, ds_roll_alpha.current, s=10)
plt.plot(ds_roll_alpha.theta.data, fitted_dist, c='r', lw=1.5)

ax1.set_thetamin(-90)
ax1.set_thetamax(90)
ax1.set_theta_zero_location('N')
ax1.set_title('')
ax1.set_yticks([0, np.max(fitted_dist)], [r'$0$', r'$I_{\mathrm{max}}$'], fontsize=20)
ax1.set_xticks(np.pi / 6 * np.arange(-3, 4))
ax1.set_xlabel(r'$\theta~[\mathrm{^\circ}]$', labelpad=-15, fontsize=20)
ax1.set_ylabel('')
# ax1.text(x=np.pi / 6, y=1.5 * np.max(fitted_dist), s='(d)', fontsize=20)

ax2 = fig.add_subplot(gs[0], polar=True)

ax2.scatter(thetas, max_k, s=3)
ax2.plot(thetas, fitted_fermi, zorder=2, c='r', lw=1.5)
ax2.set_xlabel('')
ax2.set_ylabel('')
ax2.set_yticks([0, 0.1])
ax2.set_ylim([0, 0.15])
ax2.set_xlabel(r'$\phi~[\mathrm{^\circ}]$', fontsize=20)
ax2.set_ylabel(r'$k~[1 / a]$', labelpad=33, fontsize=20)
# ax2.text(x=3 * np.pi / 4, y=0.23, s='(b)', fontsize=20)

plt.tight_layout()
plt.savefig('./images/fitting_tb.pdf', bbox_inches='tight')
plt.show()

# %%
# Prepare paper figure
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
ds_roll_alpha.current.plot()
plt.plot(ds_roll_alpha.theta.data, fitted_dist)

ax.set_thetamin(-90)
ax.set_thetamax(90)
ax.set_theta_zero_location('N')
plt.show()


# %%
def xdet_warping(valley, phii, B, r_fit, misalignment=0, x0=0, color=None, plot=False):
    """
    Compute position where cyclotron orbits return to y=0.

    Parameters:
    -----------

    valley: int (-1, 1)
        Valley quantum number.
    phii: float
        Initial angle.
    B: float
        Magnetic field in tesla.
    x0: float
        Initial position.
    plot: bool
        Plot trajectory.
    """
    # Check if B is nonzero
    assert B != 0, f"B cannot be zero."
    # Generate array of angles
    phi = np.linspace(phii, phii + 2 * np.pi, 5000, endpoint=False)
    # Divide by magnetic field to get units of m
    r = r_fit(phi=phi, valley=valley) * c.hbar / a / c.e
    r[1, :] *= -1
    y0 = r[1, 0]
    # cutting operation
    filt_y = r[1] - y0 > 0
    if np.sum(filt_y) < 2:
        return 0, phii
    rf = r[:, filt_y]
    if plot:
        x = (rf[0] - rf[0, 0]) / B + x0
        y = (rf[1] - y0) / B
        if color is not None:
            color = color * np.ones(len(x))
            plt.scatter(
                x * 1e6,
                y * 1e6,
                c=valley * color,
                cmap="RdBu_r",
                vmin=-1,
                vmax=1,
                s=color,
                alpha=color / 2,
                zorder=np.max(color),
            )
        else:
            if valley == 1:
                plt.plot(x * 1e6, y * 1e6, c="r", lw=3, ls="--")
            else:
                plt.plot(x * 1e6, y * 1e6, c="b", lw=3, ls="--")

    # detection angle
    drf = np.diff(rf[:, -2:], axis=1)[:, 0]
    # Reflection means a negative sign is needed!
    drf[1] *= -1
    dr = np.diff(r, axis=1)
    dr /= np.linalg.norm(dr, axis=0)

    phi_next = phi[filt_y][-1] + misalignment

    sx = rf[0, -1] - rf[0, 0]  # phi_next
    return sx / B, phi_next


def x_focus(
    phii,
    valley,
    p,
    B,
    color=None,
    plot=False,
    injection="armchair",
    misalignment=0,
    alpha=1,
):
    """
    Compute position where cyclotron orbits return to y=0.

    Parameters:
    -----------

    valley: int (-1, 1)
        Valley quantum number.
    phii: float
        Initial angle.
    B: float
        Magnetic field in tesla.
    p: int
        Number cyclotron orbit bounces.
    plot: bool
        Plot trajectory.
    """
    # We use the fitting parameters and rotate by pi / 2
    # The rotation is because r = dk/dphi
    if injection == "zigzag":
        phic = - np.pi
    elif injection == "armchair":
        phic = - np.pi / 2
    r_fit = partial(
        kx_ky,
        kF0=kF0_fit,
        delta_k=delta_k_fit * alpha,
        phic=phic + misalignment,
    )
    xdet = 0
    # List of positions.
    xdet_list = []
    for i in range(p):
        if injection == "zigzag":
            valley *= -1
        # Compute next position.
        x, phi_next = xdet_warping(
            valley=valley, phii=phii, B=B, x0=xdet, color=color, plot=plot, r_fit=r_fit, misalignment=misalignment
        )
        xdet += x
        xdet_list.append(xdet)
        phii = phi_next
    return np.array(xdet_list)


def angle_converter(phiarr, valley, dk_fit):
    """
    Converts the incident angles into polar angles in the Fermi surface

    """
    dkx, dky = dk_fit(phi=phiarr, valley=valley) * c.hbar / a / c.e
    return np.arctan2(dky, dkx)


def TEF_calc(valley, p, B, L, Wc, misalignment=0, alpha=1, injection="armchair"):
    """
    Compute transverse electron focusing (TEF) "spectra" (peaks).

    Parameters:
    -----------
    phidist: 1D array
        Injection angles.
    valley: int (-1, 1)
        Valley quantum number.
    p: int
        Maximum number of peaks.
    B: float
        Magnetic field in tesla.
    L: float
        Distance between injector and collector.
    Wc: float
        Collector width.
    """
    # We use the fitting parameters and rotate by pi / 2
    # The rotation is because r = dk/dphi
    if injection == "zigzag":
        phic = -np.pi / 6
        theta_0 = np.abs(theta_0_fit)
    elif injection == "armchair":
        phic = 0
        theta_0 = np.abs(theta_0_fit)
    dk_fit = partial(
        kx_ky,
        kF0=kF0_fit,
        delta_k=delta_k_fit * alpha,
        phic=phic+misalignment,
    )
    # Collected current
    Ic = 0

    nphi = 5000
    phiarr = np.linspace(-2 * np.pi, 2 * np.pi, nphi, endpoint=True)
    phidist = angle_converter(phiarr=phiarr, valley=valley, dk_fit=dk_fit)
    filt_phidist = np.multiply(phidist > 0 + 1e-2, phidist < np.pi - 1e-2)
    phidist = phidist[filt_phidist]
    phiarr = phiarr[filt_phidist]
    if injection == "zigzag":
        theta_0 = np.abs(theta_0_fit) + np.pi / 6
    elif injection == "armchair":
        theta_0 = np.abs(theta_0_fit)
    dIdphi = gaussian(
        theta=phidist + phic,
        w=w_fit / 8,
        theta_0=valley * theta_0 + misalignment + np.pi / 2,
        beta=beta_fit,
    )
    dIdphi /= np.max(dIdphi)

    # Compute focusing positions
    x_focus_partial = partial(
        x_focus, valley=valley, p=p, B=B, misalignment=misalignment, injection=injection
    )
    vec_x_focus = np.vectorize(x_focus_partial, signature="()->(n)")
    xc = vec_x_focus(phii=phiarr + np.pi / 2)
    # Count how many points coincide with collector region
    suuu = np.sum(np.multiply(xc > L, xc < L + Wc), axis=1) > 0
    Ic = np.sum(dIdphi[suuu])

    return Ic


# %%
# Distance between injector and collector
L = 2e-6
# Collector widht
Wc = 1e-7
# Maximum number of peaks
pmax = int(L / Wc + 1)

# %%
Ic = []
invars = []

partial_TEF = partial(
    TEF_calc,
    p = pmax,
    L = L,
    Wc = Wc,
    alpha = 1
)

# %%
from dask.distributed import Client

client = Client("tcp://localhost:10345")

# %%
client

# %%
# List of parameters to perform calculation
params = {
    "valley": np.array([-1, 1], dtype='float'),
    "B": np.linspace(3, 0, 500, endpoint=False),
    "misalignment": np.array([0., 0.05])
}

# Wrapper for current calculation
def wrapped_current(args):
    return partial_TEF(valley=float(args[0]), B=float(args[1]), misalignment = float(args[2]))

# Reshape params
values = list(params.values())
args = np.array(list(it.product(*values)))
shapes = [len(values[i]) for i in range(len(values))]

# Run calculation
result_ungathered = [client.submit(wrapped_current, i) for i in args]
result = client.gather(result_ungathered)

# %%
# Prepare dataset
Is = np.reshape(result, shapes)
ds = xr.Dataset(
    data_vars=dict(
        Is=(
            ["valley", "B", "misalignment"],
            Is
        )
    ),
    coords=params,
)

# %%
# Plot valley-resolved current
ds.Is.plot(hue='valley', col='misalignment')
ds.Is.sum('valley').plot(col='misalignment')
plt.show()


# %%
# Plot trajectories
def plot_main(p = 3, misalignment = 0, file=None):
    angles = np.linspace(-np.pi, np.pi, 100)
    injection = "armchair"

    phic = 0
    theta_0 = np.abs(theta_0_fit)
    dk_fit = partial(
        kx_ky,
        kF0=kF0_fit,
        delta_k=delta_k_fit,
        phic=phic + misalignment,
    )

    convert_angle = partial(angle_converter, phiarr=angles, dk_fit=dk_fit)
    valleys = [-1, 1]

    for valley in valleys:
        if injection == "zigzag" and valley == 1:
            x_focus(
                phii=valley * theta_0 + misalignment - phic,
                valley=valley,
                p=p,
                B=1,
                plot=True,
                injection=injection,
                alpha=alpha,
                misalignment = misalignment
            )
            x_focus(
                phii=valley * theta_0 + misalignment - phic - 2 * np.pi / 3,
                valley=valley,
                p=p,
                B=1,
                plot=True,
                injection=injection,
                alpha=alpha,
                misalignment = misalignment
            )
        else:
            x_focus(
                phii=valley * theta_0 + misalignment - phic,
                valley=valley,
                p=p,
                B=1,
                plot=True,
                injection=injection,
                alpha=alpha,
                misalignment = misalignment
            )
    plt.ylim(0, 1)
    plt.axis("equal")
    if file:
        plt.savefig('./images/' + file)
    plt.show()


# %%
plot_main(p = 3, misalignment = 0, file='trajectory_armchair_injection.pdf')

# %%
plot_main(p = 3, misalignment = 0.05, file='trajectory_misaligned_injection.pdf')

# %%
# Plot trajectories
p = 2
angles = np.linspace(-np.pi, np.pi, 100)
injection = "armchair"
misalignment = 0.
alpha = 1

theta_0_fit = np.pi / 6

if injection == "zigzag":
    phic = -np.pi / 6
    theta_0 = np.abs(theta_0_fit)
elif injection == "armchair":
    phic = 0
    theta_0 = np.abs(theta_0_fit)
dk_fit = partial(
    kx_ky,
    kF0=kF0_fit,
    delta_k=delta_k_fit * alpha,
    phic=phic + misalignment,
)

convert_angle = partial(angle_converter, phiarr=angles, dk_fit=dk_fit)
valleys = [-1, 1]
for i, phii in enumerate(angles[:-1]):
    for valley in valleys:
        if injection == "zigzag" and valley == 1:
            dIdphi = 0.5 * (
                gaussian(
                    theta=convert_angle(valley=valley)[i] + phic,
                    w=w_fit,
                    theta_0=valley * theta_0 + misalignment,
                    beta=beta_fit,
                ) / beta_fit + gaussian(
                    theta=convert_angle(valley=valley)[i] + phic,
                    w=w_fit,
                    theta_0=valley * theta_0 + misalignment - 2 * np.pi / 3,
                    beta=beta_fit,
                ) / beta_fit
            )
        else: 
            dIdphi = gaussian(
                theta=convert_angle(valley=valley)[i] + phic,
                w=w_fit,
                theta_0=valley * theta_0 + misalignment,
                beta=beta_fit,
            ) / beta_fit
        x_focus(
            phii=float(convert_angle(valley=valley)[i]),
            valley=valley,
            color=dIdphi,
            p=p,
            B=1,
            plot=True,
            injection=injection,
            alpha=alpha,
            misalignment = misalignment
        )

for valley in valleys:
    if injection == "zigzag" and valley == 1:
        x_focus(
            phii=valley * theta_0 + misalignment - phic,
            valley=valley,
            p=p,
            B=1,
            plot=True,
            injection=injection,
            alpha=alpha,
            misalignment = misalignment
        )
        x_focus(
            phii=valley * theta_0 + misalignment - phic - 2 * np.pi / 3,
            valley=valley,
            p=p,
            B=1,
            plot=True,
            injection=injection,
            alpha=alpha,
            misalignment = misalignment
        )
    else:
        x_focus(
            phii=valley * theta_0 + misalignment - phic,
            valley=valley,
            p=p,
            B=1,
            plot=True,
            injection=injection,
            alpha=alpha,
            misalignment = misalignment
        )
plt.colorbar()
plt.ylim(0, 1)
plt.axis("equal")
plt.savefig('./images/all_trajectories.pdf')
plt.show()

# %%
# Store data
ds.to_netcdf('./data/semiclassical_focusing.nc')
