# -*- coding: utf-8 -*-
"""
Created on Mon Aug 15 17:50:12 2022

@author: jinglaaynes
"""

import scipy.signal as sp
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import scipy.constants as c
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : '15'}
plt.rc('font', **font)
# Auxiliary functions

def my_round(num,dec=1):
    '''
    Rounds numbers in scientific notation and converts them into strings

    Parameters
    ----------
    num : float, number to be rounded
    dec : integer, number of decimals

    Returns
    -------
    out : string representing the rounded number
    '''
    if type(num) != float:
        out = [format(n,'.{}e'.format(dec)) for n in num]
    else:
        out = format(num,'.{}e'.format(dec))
    return out

def Fermi_vel_BLG(n):
    '''
    Fermi velocity of bilayer graphene (in m/s) assuming no bandgap opening.
    It only depends on n and accounts for the parabolic structure at low n.

    Parameters
    ----------
    n : float, carrier density

    '''
    vF0 = 1E6
    gamma1 = 0.39 * c.e
    alpha = 1 / (c.pi*c.hbar**2*vF0**2)
    return np.sqrt(2)*vF0*np.sqrt(np.sqrt(1+4*np.abs(n)/(alpha*gamma1**2))-1)

def fit_func(x, a):
    return a*x**2

def fit_func2(x, a, b, c):
    return a*x**2 + b*x + c

def read_multiple_files(filename_list, skrows=14):
    for i, f in enumerate(filename_list):
        if i == 0:
            data = np.loadtxt(f, skiprows=skrows)
        else:
            data1 = np.loadtxt(f, skiprows=skrows)
            data = np.vstack((data, data1))
    return data

def plot_contact_MR(filename, Vbgset=3):
    '''
    Generates the plots for the contact magnetoresistance shown in Fig. S4.

    Parameters
    ----------
    filename : string containing the measurement filename
    Vbgset : numeric, Vbg value to be plotted. The default is 3.
    '''
    data1 = np.loadtxt(filename, skiprows=14)
    B, Vbg, Vtg, Vx, Vy, time, Tarr = data1.T
    T_set = np.unique(np.round(Tarr,0))
    f_T_set = [T  in [10,20,30,40,50,60,70,80,90, 100] for T in T_set]
    plt.plot(T_set[f_T_set], 'o')
    plt.grid()
    plt.show()
    colormap = plt.cm.get_cmap('jet')
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=[colormap(k) for k in np.linspace(0, 1, 11)])
    RB0 = []
    for T in T_set[f_T_set]:
        filt_T_set = np.round(Tarr, 0) == T
        BT = B[filt_T_set]
        VxT = Vx[filt_T_set]
        VbgT = Vbg[filt_T_set]
        
        # Vbg filtering
        if Vbgset>0:
            filt_Vg = VbgT>0
        else:
            
            filt_Vg = VbgT<0
        BT1 = BT[filt_Vg]
        VxT1 = VxT[filt_Vg]
        plt.plot(BT1[2:-2]-B0, VxT1[2:-2]*1E7*1E-3, label='T={}K'.format(T))
        plt.tick_params(direction='in', top=True, right=True)
        plt.xlabel('B (T)', fontsize=20)
        plt.ylabel('$R_\mathrm{det}$ (k$\Omega$)', fontsize=20)
        ## Bmin filtering
        filt_Bmin = BT1 == np.min(np.abs(BT1))
        R0 = VxT1[filt_Bmin]
        RB0.append(R0[0]*1E7*1E-3)
    plt.legend()
    plt.xlim([-1.5, 1.5])
    plt.savefig('TdepRcVbg{}V.pdf'.format((VbgT[filt_Vg])[0]))
    default_color_cycle()
    return T_set[f_T_set], RB0 

def find_peak_area(xdata, ydata, ind_i, ind_f, color=[], plots=True):
    '''
    This function is used to find the area under a peak using the trapz function
    of the numpy package and the baseline is taken from linear interpolation
    between the minima identified by ind_i and ind_f
    '''
    if np.sum(np.diff(xdata))<0:
        xdata = xdata[::-1]
        ydata = ydata[::-1]
    # Here we create the baseline by linear interpolation between the minima
    # detected before and given by ind_i and ind_f
    xbase = [xdata[ind_i], xdata[ind_f]]
    ybase = [ydata[ind_i], ydata[ind_f]]
    baseline_values = np.interp(xdata, xbase, ybase)
    x_peak = xdata[ind_i:ind_f]
    y_peak = ydata[ind_i:ind_f]
    b_peak = baseline_values[ind_i:ind_f]
    A1 = np.trapz(y_peak, x_peak) 
    A2 = np.trapz(b_peak, x_peak)
    #print('A1={}, A2={}'.format(A1, A2))
    area_under_peak = A1 - A2
    if plots:
        if color == []:
            plt.plot(x_peak, y_peak, '-o')
            plt.plot(x_peak, b_peak, '-')
            plt.fill_between(x_peak, y_peak, b_peak, alpha=0.5)
        else:
            plt.plot(x_peak, y_peak, '-o', c=color)
            plt.plot(x_peak, b_peak, '-', c=color)
            plt.fill_between(x_peak, y_peak, b_peak, color=color, alpha=0.5)
        plt.tick_params(direction='in',top=True,right=True)
        plt.xlabel('$B$ (T)', fontsize=20)
        plt.ylabel('$R_\mathrm{nl}/R_\mathrm{c}$ ($\\times10^3$)', fontsize=20)# Rc is in kOhms
    xmean = (xdata[ind_i]+xdata[ind_f])/2
    return xmean, area_under_peak

def scattering_rate(area, area0, L, vF):
    '''
    Calculates the scattering rate using the formula from
    Science 353, 1526 (2016).

    Parameters
    ----------
    area : area under the peak at the temperature of study
    area0 : area under the peak at base temperature
    L : contact separation in m
    vF : Fermi velocity in m/s

    Returns
    -------
    TYPE
        DESCRIPTION.

    '''
    return -2*vF/(np.pi*L)*np.log(area/area0)

def processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL, BminL, Vbg, L, vF,
                               polyfit = False, col2=False):
    '''
    Calculates the scattering rate using the formula from
    Science 353, 1526 (2016). The parameters extracted from fitting 
    are printed but not passed as output

    Parameters
    ----------
    Tout : list of temperatures with the same length as Bpeak 
    Bpeak : list of B fields for each peak at temperature T in Tout and area
    in Areapeak
    Areapeak : list of peak areas
    BmaxL : list with 2 elements. Maximum value of Bpeak to be considered
    BminL : list with 2 elements. Minimum value of Bpeak to be considered
    Vbg : Backgate voltage to analyse
    L : Separation between injector and detector
    vF : Fermi velocity
    polyfit : Boolean. If True the rate is fit to aT^2+bT+c, if False to dT^2 
    col2 : if the plot is the second to be plotted it makes the line of color C2
        instead of C0
    '''
    Bpeak = np.array(Bpeak)
    Tout = np.array(Tout)
    Areapeak = np.array(Areapeak)
    if type(BminL)==float:
            BminL=[BminL]
            BmaxL=[BmaxL]
    for i, (Bmin, Bmax) in enumerate(zip(BminL, BmaxL)):
        if col2:
                i = 1
        filt_Bpeak = np.multiply(Bpeak>Bmin, Bpeak<Bmax)
        AreapeakF = Areapeak[filt_Bpeak]
        TauPB = scattering_rate(AreapeakF, AreapeakF[0], L, vF)
        plt.plot(Tout[filt_Bpeak],TauPB*1E-12, 'o', 
                 label=' $p={}$'.format(i+1), c='C{}'.format(i*2))
        plt.title('$V_\mathrm{}={}$ V'.format('{bg}',Vbg),fontsize=20)
        if polyfit:
            
            popt, pcov = curve_fit(lambda x, a, b, c: fit_func2(x,a, b, c), Tout[filt_Bpeak], TauPB * 1E-12)
            plt.plot(Tout[filt_Bpeak], fit_func2(Tout[filt_Bpeak], *popt),
                     c='C{}'.format(i*2))
            print('Fit coefficient y=aT^2+bT+c:{}\n Error: {}'.format(popt,np.sqrt(np.diag(pcov))))
            plt.text(20,0,'$\\tau_p^{}=aT^2+bT+c$ \n{}'.format('{-1}',my_round(popt)),
                     fontsize=15)
        else:
            popt, pcov = curve_fit(lambda x, a: fit_func(x,a), Tout[filt_Bpeak],TauPB*1E-12)
            plt.plot(Tout[filt_Bpeak], fit_func(Tout[filt_Bpeak], popt), c='C{}'.format(i*2))
            print('Fit coefficient y=aT^2:{}\n Error: {}'.format(popt,np.sqrt(np.diag(pcov))))
            plt.text(20,0,'$y=aT^2$:{}\n Error: {}'.format(my_round(popt),np.round(np.sqrt(np.diag(pcov)),2)),fontsize=15)
    plt.tick_params(direction='in', top=True, right=True)
    plt.ylabel('$\\tau_p^{-1}$ (ps$^{-1}$)', fontsize=20)
    plt.xlabel('$T$ (K)', fontsize=20)
    return



def plot_focusing_temp(filename_list, Vbgset=3):
    '''
    Generates the plots in Fig. 3a and b of the main manuscript

    Parameters
    ----------
    filename_list : list of filenames (strings)
    Vbgset : Vbg value to plot. Can only be +3 and -3.
    '''
    data = read_multiple_files(filename_list, skrows=14)
    B, Vbg, Vtg, Vx, Vy, time, Tarr = data.T
    T_set = np.sort(np.array(np.unique(np.round(Tarr,0))))
    f_T_set = [T  in [2,10,20,30,40,50,60,70,80,90, 100] for T in T_set]
    colormap = plt.cm.get_cmap('jet')
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=[colormap(k) for k in np.linspace(0, 1, 11)])
    for T in T_set[f_T_set]:
        
        filt_T_set = np.round(Tarr, 0) == T
        BT = B[filt_T_set]
        VxT = Vx[filt_T_set]
        VbgT = Vbg[filt_T_set]
        if Vbgset > 0:
            filt_Vg = VbgT > 0
        else:
            filt_Vg = VbgT < 0
        VbgT = VbgT[filt_Vg]
        BT1 = BT[filt_Vg]
        VxT1 = VxT[filt_Vg]
        plt.plot(BT1[2:-2], VxT1[2:-2]*1E7, label='T={}K'.format(T))
        plt.tick_params(direction='in', top=True, right=True)
        plt.xlabel('B (T)', fontsize=20)
        plt.ylabel('$R_\mathrm{nl}$ ($\Omega$)', fontsize=20)
    plt.legend()
    #plt.tight_layout()
    plt.savefig('TdepRnlVbg{}V.pdf'.format(np.round(VbgT[0])))
    default_color_cycle()
    return



def default_color_cycle():
    # Back to default color cycle
    new_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
                  '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
                  '#bcbd22', '#17becf']
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=new_colors)
    return

def area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=3,Tsave=[20,30,40,50], plots=True):
    '''
    Plots each TEF spectrum normalized by the contact resistance.
    Vgset specifies the Vbg value to be plotted, Tsave, the temperatures to
    save the plot for Fig. S5.
    
    Returns a list with the T, Bfields and areas for each peak.
    Parameters
    ----------
    flist : list of filenames (strings)
    Rm3B0 : list of contact resistances at B=0 and Vbg=-3V
    Rp3B0 : list of contact resistances at B=0 and Vbg=+3V
    Vbgset : Vbg value to process. can only be +3 and -3.
    Tsave : TEF plots with areas underneath to save to pdf. 
        The default is [20,30,40,50].
    plots : Boolean to decide if results are plotted. The default is True.
    '''
    data = read_multiple_files(flist, skrows=14)
    B, Vbg, Vtg, Vx, Vy, time, Tarr = data.T
    
    T_set = np.unique(np.round(Tarr,0))
    f_T_set = np.multiply(T_set<50, T_set>41)
    f_T_set = np.logical_not(f_T_set)
    f_lowT_set = np.multiply(T_set<10, T_set>2)
    f_lowT_set = np.logical_not(f_lowT_set)
    f_T_set = np.multiply(f_T_set, f_lowT_set)
    # Here I repeat the contact resistance at 10K for the 2K case.
    Rm3B0=np.insert(Rm3B0, 0, Rm3B0[0])
    Rp3B0=np.insert(Rp3B0, 0, Rp3B0[0])
    colormap = plt.cm.get_cmap('jet')
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=[colormap(k) for k in np.linspace(0, 1, 10)])
    Tout = []
    Bpeak = []
    Areapeak = []
    for T, Rm3,Rp3 in zip(T_set[f_T_set], Rm3B0, Rp3B0):
        
        filt_T_set = np.round(Tarr, 0) == T
        BT = B[filt_T_set]
        VxT = Vx[filt_T_set]
        VbgT = Vbg[filt_T_set]
        if Vbgset>0:
            filt_Vg = VbgT>0
            Rc = Rp3
        else:
            filt_Vg = VbgT<0
            Rc = Rm3
        VbgT = VbgT[filt_Vg]
        BT1 = BT[filt_Vg]
        IxT1 = VxT[filt_Vg]*1E7/Rc
        if Vbgset>0:
            filt_Bsign = BT1 < 0
        else:
            filt_Bsign = BT1 > 0
            
        BT1 = BT1[filt_Bsign]
        IxT1 = IxT1[filt_Bsign]
        linBG = np.polyfit(BT1, IxT1, 1)#2
        pol=np.poly1d(linBG)
        if plots:
            plt.plot(BT1, IxT1-pol(BT1),'k')
        min_ind2 = sp.find_peaks(-sp.savgol_filter(IxT1-pol(BT1),5,1), distance=1, prominence=0.1)#prom=0.2 sp.savgol_filter, threshold=0.2
        ind_i_list = min_ind2[0][:-1]#[40, 100, 150, 210, 255]
        ind_f_list = min_ind2[0][1:]#[90, 140, 210, 250, 300]
        i=0
        for ind_i, ind_f in zip(ind_i_list, ind_f_list):
            xmean, area_under_peak = find_peak_area(BT1, IxT1-pol(BT1), ind_i, 
                                                    ind_f,
                                                    color='C{}'.format(i), plots=plots)
            Tout.append(T)
            Bpeak.append(xmean)
            Areapeak.append(area_under_peak)
            if i<9:
                i += 1
            elif i == 9:
                i = 0
        #plt.plot(BT1, IxT1-pol(BT1), label='T={}K'.format(T))
        if plots:
            plt.title('T={} $V_\mathrm{}={}$ V'.format(T,'{bg}',np.round(VbgT[0])))
            if T in Tsave:
                plt.savefig('FocusingAreaUnderPeakVbg{}VT{}K.pdf'.format(np.round(VbgT[0]),T))
            plt.show()   
    default_color_cycle()
    return np.array(Tout), np.array(Bpeak), np.array(Areapeak) 

def ratio_peak_areas(Tout, Bpeak, Areapeak, BmaxL, BminL): 
    '''
    Calculates the ratio between the areas under the two peaks determined
    by the BmaxL and BminL lists

    Parameters
    ----------
    Tout : list of temperatures with the same length as Bpeak 
    Bpeak : list of B fields for each peak at temperature T in Tout and area
    in Areapeak
    Areapeak : list of peak areas
    BmaxL : list with 2 elements. Maximum value of Bpeak to be considered
    BminL : list with 2 elements. Minimum value of Bpeak to be considered
    '''
    ToutSet = np.unique(Tout)
    A2overA1 = []
    for T in ToutSet:
        filt_Tout = Tout == T
        BpeakT = Bpeak [filt_Tout]
        filt_Bpeak1 = np.multiply(BpeakT > BminL[0], BpeakT < BmaxL[0])
        filt_Bpeak2 = np.multiply(BpeakT > BminL[1], BpeakT < BmaxL[1]) 
        AreapeakT = Areapeak [filt_Tout]
        A2overA1.append(AreapeakT[filt_Bpeak2][0]/AreapeakT[filt_Bpeak1][0])
    plt.plot(ToutSet, A2overA1, 'o-')
    plt.tick_params(direction='in', top=True, right=True)
    plt.xlabel('$T$ (K)', fontsize=20)
    plt.ylabel('$A_2/A_1$', fontsize=20)
    return np.array(A2overA1)

# Relevant parameters
B0 = 0.074
L = 2E-6
n = 0.00144 * 3/c.e
vF = Fermi_vel_BLG(n)

#%% Here I plot the focusing data at different temperatures Fig. 4a and b manuscript
flist = ['M15_I_32to18_100nA_V02to30_100x_FocusingForTdep.dat',
         'M16_I_32to18_100nA_V02to30_100x_FocusingT10.dat',
         'M17_I_32to18_100nA_V02to30_100x_FocusingT20to50K.dat',
         'M18_I_32to18_100nA_V02to30_100x_FocusingT40to50K.dat',
         'M19_I_32to18_100nA_V02to30_100x_FocusingT60to100K.dat']
plot_focusing_temp(flist,Vbgset=3)
plt.show()
plot_focusing_temp(flist,Vbgset=-3)
#%% Here I plot the contact resistance Fig. S4
filename = 'M20_I_32to30_100nA_V32to30_10x_DetRes1p8to100K.dat'
Tp3, Rm3B0 = plot_contact_MR(filename, Vbgset=3)
plt.show()
Tm3, Rp3B0 = plot_contact_MR(filename, Vbgset=-3)
plt.show()
plt.figure(figsize=[3,2])
plt.plot(Tm3, Rm3B0,c='C0',label='Vbg=$-$3 V')
plt.plot(Tp3, Rp3B0,c='C7', label='Vbg=$+$3 V')
plt.tick_params(direction='in',top=True, right=True)
plt.ylabel('$R_c$ (k$\\Omega$)', fontsize=20)
plt.xlabel('T (K)', fontsize=20)

plt.legend()
#plt.tight_layout()
plt.savefig('RcvsT.pdf')
plt.show()

Rm3B0 = np.array(Rm3B0) - 503*1E-3/2
Rp3B0 = np.array(Rp3B0) - 359*1E-3/2

#%% Here I analyze the results Fig. S5


Vbgset=3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset)
#Tout = Tout[:-1]
processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL=[-0.15,-0.35], 
                           BminL=[-0.25,-0.45],Vbg=Vbgset, L=L, vF=vF, polyfit=True)
plt.legend(loc=4)
plt.savefig('ScatteringRateVbgp3V.pdf')
plt.show()
Vbgset=-3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset)
#Tout = Tout[:-1]
processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL=[0.35,0.6], 
                           BminL=[0.2,0.4],Vbg=Vbgset, L=L, vF=vF, polyfit=True)
plt.legend(loc=4)
plt.savefig('ScatteringRateVbgm3V.pdf')
plt.show()

Vbgset=3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset)
#Tout = Tout[:-1]
processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL=[-0.15,-0.35], 
                           BminL=[-0.25,-0.45],Vbg=Vbgset, L=L, vF=vF, polyfit=False)
plt.legend(loc=4)
plt.savefig('ScatteringRate_aT2_Vbgp3V.pdf')
plt.show()
Vbgset=-3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset)
#Tout = Tout[:-1]
processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL=[0.35,0.6], 
                           BminL=[0.2,0.4],Vbg=Vbgset, L=L, vF=vF, polyfit=False)
plt.legend(loc=4)
plt.savefig('ScatteringRate_aT2_Vbgm3V.pdf')
plt.show()
#%% Area p=2/area p=1
Vbgset=3
plt.figure(figsize=[2,1.5])
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset, plots=False) 
ratio_peak_areas(Tout, Bpeak, Areapeak,BmaxL=[-0.15,-0.35], BminL=[-0.25,-0.45])
Vbgset=-3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset, plots=False) 
ratio_peak_areas(Tout, Bpeak, Areapeak,BmaxL=[0.35,0.6], BminL=[0.2,0.4])
plt.legend(labels=['$V_\mathrm{bg}=+3$ V', '$V_\mathrm{bg}=-3$ V'])
plt.ylim([0.5,3.5])
plt.savefig('RatioPeak2vsPeak1Areas.pdf')
plt.show()
#%% Area p=3/area p=2
Vbgset=3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset, plots=False) 
ratio_peak_areas(Tout, Bpeak, Areapeak,BmaxL=[-0.35, -0.45], BminL=[-0.45,-0.7])
Vbgset=-3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset, plots=False) 
ratio_peak_areas(Tout, Bpeak, Areapeak,BmaxL=[0.6,0.81], BminL=[0.4,0.6])
plt.legend(labels=['$V_\mathrm{bg}=+3$ V', '$V_\mathrm{bg}=-3$ V'])
plt.savefig('RatioPeak3vsPeak2Areas.pdf')
plt.show()
#%% Here I generate the main paper plot Fig. 4c
from matplotlib.lines import Line2D
Vbgset=3
Tout3, Bpeak3, Areapeak3 = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset)  
processing_scattering_rate(Bpeak3, Tout3, Areapeak3, BmaxL=[-0.35], 
                           BminL=[-0.45],Vbg=Vbgset, L=L, vF=vF,polyfit = True)

Vbgset=-3
Tout, Bpeak, Areapeak = area_under_peaks(flist, Rm3B0, Rp3B0, Vbgset=Vbgset, plots=False)  
processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL=[0.6], 
                           BminL=[0.4],Vbg=Vbgset, L=L, vF=vF, polyfit=True, col2 = True)
handles = [Line2D([0], [0], marker='o', color='w', label='$V_\mathrm{bg}=+3$ V',
                          markerfacecolor='C0', markersize=8),
           Line2D([0], [0], marker='o', color='w', label='$V_\mathrm{bg}=-3$ V',
                          markerfacecolor='C2', markersize=8)]
plt.legend(handles=handles, loc=4)#['C0','C2'], labels=['$V_\mathrm{bg}=+3$ V','$V_\mathrm{bg}=+3$ V']
plt.savefig('ScatteringTimeVsTMainManuscript.pdf')
plt.show()
    