import numpy as np

# Physical constants
from scipy.constants import hbar, e, nano, physical_constants
from kwant.digest import uniform

phi_0 = 2 * physical_constants["mag. flux quantum"][0]  # [Wb]

t = 3.16
gamma3 = - 0.38 / t
gamma4 = 0.14 / t

def onsite_1(site, mu, dmu, m, m2, L_device, L_lead, chi):
    """
    Define onsite potential.

    Parameters:
    -----------

    site: kwant site
    mu: float
        Chemical potential.
    dmu: float
        Disorder amplitude.
    """
    if dmu:
        disorder = dmu * (uniform(repr(site)) - 0.5)
        onsite = mu + disorder
    else:
        onsite = mu
    if m2:
        x, y = site.pos
        gap = m + m2 * (
            0.5
            * (1 - np.tanh((x + L_device - chi) / chi))
            * (1 - np.heaviside(y - L_device / 2 + L_lead / 2, 0.5))
            * np.heaviside(y + L_device / 2 - L_lead / 2, 0.5)
        )
    else:
        gap = m
    return - onsite + gap


def onsite_2(site, mu, dmu, m, m2, L_device, L_lead, chi):
    return onsite_1(site, mu, dmu, -m, -m2, L_device, L_lead, chi)


def onsite_leads_1(site, mu, m):
    """
    Define onsite potential.

    Parameters:
    -----------

    site: kwant site
    mu: float
        Chemical potential.
    """
    return -mu + m


def onsite_leads_2(site, mu, m):
    return onsite_leads_1(site, mu, -m)


def hopping(site1, site2, B):
    """
    Define hoppings.

    Parameters:
    -----------

    site1: kwant site
    site2: kwant site
    phi: float
        Magnetic flux.
    """
    x1, y1 = site1.pos * nano * 0.142
    x2, y2 = site2.pos * nano * 0.142
    ymed = (y1 + y2) / 2
    # Calculate the gauge field
    gauge = B * ymed
    phase = 2 * np.pi / phi_0 * (x2 - x1) * gauge
    return np.exp(1j * phase)

def hopping_t(site1, site2, B):
    return - 1 * hopping(site1, site2, B)

def hopping_gamma3(site1, site2, B, alpha):
    return gamma3 * alpha * hopping(site1, site2, B)

def hopping_gamma4(site1, site2, B, alpha):
    return gamma4 * alpha * hopping(site1, site2, B)

def hopping_upper(site1, site2, B, L_device):
    """
    Define hoppings.

    Parameters:
    -----------

    site1: kwant site
    site2: kwant site
    phi: float
        Magnetic flux.
    """
    x1, y1 = site1.pos * nano * 0.142
    x2, y2 = site2.pos * nano * 0.142
    # Calculate the gauge field
    gauge = B * 1.5 * L_device * nano * 0.142
    phase = 2 * np.pi / phi_0 * (x2 - x1) * gauge
    return np.exp(1j * phase)

def hopping_upper_t(site1, site2, B, L_device):
    return - 1 * hopping_upper(site1, site2, B, L_device)

def hopping_upper_gamma3(site1, site2, B, L_device, alpha):
    return gamma3 * alpha * hopping_upper(site1, site2, B, L_device)

def hopping_upper_gamma4(site1, site2, B, L_device, alpha):
    return gamma4 * alpha * hopping_upper(site1, site2, B, L_device)

def hopping_lower(site1, site2, B, L_device):
    """
    Define hoppings.

    Parameters:
    -----------

    site1: kwant site
    site2: kwant site
    phi: float
        Magnetic flux.
    """
    x1, y1 = site1.pos * nano * 0.142
    x2, y2 = site2.pos * nano * 0.142
    # Calculate the gauge field
    gauge = -B * 1.5 * L_device * nano * 0.142
    phase = 2 * np.pi / phi_0 * (x2 - x1) * gauge
    return np.exp(1j * phase)

def hopping_lower_t(site1, site2, B, L_device):
    return - 1 * hopping_lower(site1, site2, B, L_device)

def hopping_lower_gamma3(site1, site2, B, L_device, alpha):
    return gamma3 * alpha * hopping_lower(site1, site2, B, L_device)

def hopping_lower_gamma4(site1, site2, B, L_device, alpha):
    return gamma4 * alpha * hopping_lower(site1, site2, B, L_device)