###############################################################################################################################
###############################################################################################################################
###############################################################################################################################

        # Implementation of the Yamada model (c) by W. Yao, L. Puts, K. Williams and D. Lenstra
        # as part of the work published in "Phase-space analysis of a two-section InP laser as an all-optical spiking 
        # neuron: dependency on control and design parameters" published by IOP Publishing Ltd under
        # doi 10.1088/2634-4386/ad575e, url: https://iopscience.iop.org/article/10.1088/2634-4386/ad575e
        # 
        # This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
        # 
        # You should have received a copy of the license along with this work. If not, see 
        # <https://creativecommons.org/licenses/by-nc-sa/4.0/>.
        
        # Eindhoven Univesity of Technology, Eindhoven, the Netherlands, June 2024
        
        # Contact: l.puts@tue.nl

###############################################################################################################################
###############################################################################################################################
###############################################################################################################################


import numpy as np
import matplotlib.pyplot as plt
from math import exp, log, sqrt
from scipy.signal import savgol_filter
import time
from scipy.integrate import solve_ivp
import pandas as pd
from matplotlib import colors
import scipy.signal as signal



def yamada_switchedLn_phot(t, y, p):
    #Yamada model to calculate S, Nq, and Ng of a two-section laser.
    #Different parameters are stored and passed via p
    
    tau_p = p['tau_p']  # photon lifetime
    tau_g = p['tau_g']  # gain carrier lifetime
    tau_q = p['tau_q']  # loss carrier lifetime
    a_NG  = p['a_NG']  # gain cross section
    a_NQ  = p['a_NQ']  # loss cross section
    Ntr  = p['Ntr']  # transparency carrier density
    Nq0  = p['Nq0']  # voltage dependent fixed absorber loss carrier density
    gamma  = p['gamma']  # confinement factor 4xQW
    vg  = p['vg']  # group velocity
    V_gain = p['V_gain'] # gain volume [m^3]
    V_loss = p['V_loss'] # absorber volume [m^3]
    S_sat = p['S_sat'] # saturation photon number in cavity
    I = p['I'] # injection current in A
    
    # physical parameters
    e = 1.6E-19

    # Equations, 3 degree, time derivative on left, intensity, gain, loss
    S, H, Nq = y
    Ng = exp(H)
    
    # parametrized gain model for the InP generic platform
    gain_modal = a_NG*Ntr*(H-log(Ntr))
    c1 = -1.057054
    c2 = 14.7170818
    c3 = -13.445131
    N_ratio = Ng/Ntr
    alpha_gain = (c1*N_ratio**2+c2*N_ratio+c3)/0.01       #/0.01 for normalization to m^-1
    gain_net = gamma*gain_modal-alpha_gain #Eq. 4 in paper

    # saturable loss calculation
    alpha_net = gamma * a_NQ * (Nq-Nq0)
    alpha_sat = alpha_net/(1+(S)/(S_sat)) #Eq. 5 in paper

    # physical rate equations
    beta = 1e-4
    B = 1e-16           
    S_sp = V_gain*beta*B*Ng**2 #Eq. 7 in paper

    dSdt = vg * gain_net * S + vg * alpha_sat * (S) - S/tau_p   + S_sp + get_iStim(t)/(tau_r**1)
    dHdt =1/Ng*((I)/(e*V_gain) - vg*gain_net*(S)/V_gain - Ng/tau_g)
    dNqdt = -vg*alpha_sat*(S)/V_loss - (Nq)/tau_q

    return [dSdt, dHdt, dNqdt]

def get_Stimulus(t, input, i_unit, times, t_unit, T_pulse=200e-12):
    #function to generate the stimulus with a gaussian pulse shape, based on the parameter passed
    ts = t
    input_array = np.zeros(np.shape(ts))
    input_index = np.searchsorted(ts, np.asarray(times)*t_unit)
    input_array[input_index] = np.multiply(input, i_unit)
    # gaussian pulse
    T = t_unit
    t = np.linspace(-T, T, 2*int(len(ts)/(max(t)/t_unit)), endpoint=False)
    i, q, pulse = signal.gausspulse(t, fc=1/(0.5*T_pulse), retquad=True, retenv=True)
    
    # convolution
    output_array = signal.convolve(input_array, pulse, mode='same')
    return output_array

def get_iStim(t):
    #Function that returns the generated stimulus for a given t
    def find_nearest(array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return array[idx]

    value0 = find_nearest(ts, t)
    index, = np.where(ts == value0)
    return stim[index[0]]

def get_InitialCond():
    # function which returns a standard set of initial conditiond for the ODE solver
    y0 = S, Ng, Nq = 1, log(1e-5), 0
    return y0

def get_Parameters():
    # function to get the currect values of all parameters. Used later to store those in a file when simulating.
    
    p = {'tau_p': tau_p,         # photon lifetime, calculated from losses in cavity
          'tau_g': tau_g,        # gain carrier lifetime 300 ps
          'tau_q': tau_q,        # absorber carrier lifetime 10 ps, similar to photon lifetime
          'a_NG': a_NG,          # gain cross section [m^2], from Erwin PhIwiki
          'a_NQ': a_NQ,          # differential absorption [m^2], from Heck, thesis
          'Ntr': Ntr,            # transparency carrier density [m^-3], from Erwin PhIwiki
          'Nq0': Nq0,            # absorber transparency carrier density [m^-3], Heck thesis
          'gamma': gamma,        # confinement factor MQW COBRA, Erwin data
          'vg': vg,              # group velocity [m/s]
          'V_gain': V_gain,      # volume of gain section [m^3], 500 um x 2 um x 500 nm block
          'V_loss': V_loss,      # volume of absorber section [m^3], 100 um x 2 um x 500 nm block
          'S_sat': S_sat,        # saturation photon number, saturation power calculated back to photon number, 5pJ Esat
          'I': I,                # injection current in [A]
          't_active': t_active,  # active medium thickness [m]
          'Vbias': Vbias,        # bias voltage on absorber [V]
          'E_sat': E_sat,        # Saturation energy in [J]
          'alpha_G': alpha_G,    # Saturation energy in [J]
          'alpha_Q': alpha_Q,    # Saturation energy in [J]
          'T_pulse': T_pulse}    # pulse width in [s]
    return p

def solve_ODE(modelfunction, initial, time, params):
    #ODE solver to solve the Yamada model
    tStart = time[0]
    tEnd = time[-1]
    t_span = [tStart, tEnd]
    sol = solve_ivp(fun=lambda t, x: modelfunction(t, x, params), t_eval=time, t_span=t_span,
              y0=initial, method="LSODA", dense_output=True,
              rtol=1e-8, atol=1e-12)
    return sol

def get_PhotonNumber(P_in = 1e-3,tau_p = 300e-12,  wl = 1550e-9,eta_c = 0.6):
    #function to calculate the photon number (Eq. 10 in paper)
    h = 6.6300e-34
    c0 = 299792458
    eta_c = eta_c
    tau_p = tau_p
    S_out = P_in*tau_p*wl/(h*c0*eta_c)
    return S_out

def get_PowerOutCavity(S_in = 1e6,tau_p = 7e-12,  wl = 1550e-9, eta_c=0.6):
    #function to calculate the optical output power (Eq. 9 in paper)
    h = 6.6300e-34
    c0 = 299792458
    eta_c = eta_c        
    tau_p = tau_p
    P_out = eta_c/tau_p*h*c0/wl*S_in
    return P_out

def set_global_save_path(suffix_t = 'None', suffix_value_t = 0):
    # function to set the save path of generated data files
    global global_save_path
    global sweep_tmp
    global variable_name
    global sweep_tmp2
    global variable_name2
    global suffix
    global suffix_value
    global stim, input, ts

    if suffix_t != 'None':
        suffix = suffix_t
        suffix_value = suffix_value_t
    global_save_path = 'nodes_physical_output/'+'maxStim_'+str(max(stim))+'_input_' +str(input)+'_time_'+str(min(ts))+'_'+str(max(ts))+'sweep_'+ str(variable_name)+ str(sweep_tmp)+str(variable_name2)+ str(sweep_tmp2) + suffix + str(suffix_value)

def get_riseIndex(data):
    # generic function to find x coordinate of a rising edge of pulse
    # is used later to determe rising edge of solution
    try:
        window = 101
        der2 = savgol_filter(data, window_length=window, polyorder=2, deriv=2)
        max_der2 = np.max(np.abs(der2))
        large = np.where(np.abs(der2) > max_der2 / 2)[0]
        gaps = np.diff(large) > window
        begins = np.insert(large[1:][gaps], 0, large[0])
        ends = np.append(large[:-1][gaps], large[-1])
        index = np.around(((begins + ends) / 2).astype(int))
        return index.flat[0]
    except:
        return -1

def plot_results(sol, ts, stim, savefig=False, savedata=False):
    #function to plot the results of a single simulation solution sol
    h = 6.6300e-34
    c0 = 299792458
    tstart = min(ts)
    tend = max(ts)
    solution = sol.sol(ts)  # returns S, Ng, Nq
    if savefig:
        if solution.shape[0] == 3:
            if savedata==True:
                #save data in generic format for later use
                file=(str(global_save_path)+'_ts,S,Ng,Nq,stimulus.npy')
                np.save(file,np.column_stack((ts,get_PowerOutCavity(solution[0, :], tau_p,1550e-9), np.exp(solution[1, :]),solution[2, :],1/0.6/T_pulse*h*c0/(1550e-9)*stim)))
                print("Data saved.")
            
            fig, axs = plt.subplots(nrows=4, figsize=(7, 10))
            fig.patch.set_alpha(0)
            axs[3].plot(ts, 1/0.6/T_pulse*h*c0/(1550e-9)*stim, linewidth=2, label='stimulus')
            axs[0].plot(ts, get_PowerOutCavity(solution[0, :], tau_p,1550e-9), linewidth=2,label='S')
            axs[1].plot(ts, np.exp(solution[1, :]), linewidth=2, label='Ng')
            axs[2].plot(ts, solution[2, :], linewidth=2, label='Nq')
    
            axs[1].plot([tStart, tEnd], [get_Parameters()['Ntr'], get_Parameters()['Ntr']], '--', linewidth=2, label='Ntr')
            axs[2].plot([tStart, tEnd], [get_Parameters()['Nq0'], get_Parameters()['Nq0']], '--', linewidth=2, label='Nq0')
            axs[0].set_xlabel("time (s)")
            axs[0].set_ylabel("coupled output (W)")
            axs[1].set_xlabel("time (s)")
            axs[1].set_ylabel("gain carrier \n concentration (m^-3)")
            axs[2].set_xlabel("time (s)")
            axs[2].set_ylabel("absorber carrier \n concentration (m^-3)")
            axs[3].set_xlabel("time (s)")
            axs[3].set_ylabel("stimulus (W)")
            axs[0].legend()
            axs[1].legend()
            axs[2].legend()
            axs[3].legend()
            axs[0].grid()
            axs[1].grid()
            axs[2].grid()
            axs[3].grid()
            axs[0].set_ylim(0, 1.2*max(get_PowerOutCavity(solution[0, :], tau_p,1550e-9)))
            axs[1].set_ylim(0, 1.4 *max(np.exp(solution[1, :])))
            axs[2].set_ylim(0, 1.5 * get_Parameters()['Nq0'])
    
            axs[3].set_ylim(min(stim), 1.2*max(1/0.6/T_pulse*h*c0/(1550e-9)*stim))
            axs[0].set_xlim(tstart, tend)
            axs[1].set_xlim(tstart, tend)
            axs[2].set_xlim(tstart, tend)
            axs[3].set_xlim(tstart, tend)   
        plt.tight_layout()
        plt.savefig(str(global_save_path)+ '.png',dpi=600)
        
        #save information file with simulation parameters and values
        info = pd.DataFrame(get_Parameters().items())
        info.columns = ['Parameter', 'Value']
        init = pd.DataFrame(get_InitialCond())
        init.insert(0,'State', ['S', 'log(Ng)', 'Nq'])
        init.columns = ['Parameter', 'Value']
        info = info.append(init)
        info.to_csv(str(global_save_path) + '.csv', index=False)
    
def analyse_results(sol, stim, ts, plotfig = False):
    # procedure to analyze sol results and return
    # maximum of response peak, maximum of stimulation function, time different between stim and response

    maxResponse = 0
    maxStim = 0
    offset_StimResponse = 0

    #get solution
    sol = sol.sol(ts)  # S, Ng, Nq

    # find maximum values, index of excitable response
    maxResponse = np.amax(np.absolute(sol[0, :]))
    index_maxResp = np.where(sol[0, :] == maxResponse)
    
    #calculate the average output over the simulation time
    avgResponse = np.average(np.absolute(sol[0, :]))
    #calculate the max output 
    maxStim = max(np.absolute(stim))

    # detect rising edge index
    index_sol = get_riseIndex(sol[0, :])
    index_stim = get_riseIndex(stim)

    #plotting of resuls
    if(plotfig==True):
        plt.figure()
        plt.plot(ts, stim)
        plt.plot(ts, sol[0, :])
        if(index_sol.any() != -1 and index_stim.any() != -1):
            plt.plot(np.array(ts)[index_stim.astype(int)], np.array(stim)[index_stim.astype(int)], 'ro')
            plt.plot(np.array(ts)[index_sol.astype(int)], np.array(sol[0, :])[index_sol.astype(int)], 'ro')

        plt.plot(ts[[x[0] for x in index_maxResp]], sol[0, index_maxResp[0].astype(int)], 'bo')

    if(index_sol != -1 and index_stim != -1):
        offset_StimResponse = np.abs(np.array(ts)[index_stim.astype(int)]-np.array(ts)[index_sol.astype(int)])
        offset_StimResponse = offset_StimResponse.item()
    else:
        offset_StimResponse = 0

    return [maxResponse, maxStim, offset_StimResponse, avgResponse]

def sweep_simulation2(variable_name_t, sweep_start, sweep_end, sweep_points, variable_name2_t, sweep_start2, sweep_end2, sweep_points2):
    # calculating and analyzing results for a 2D sweep
    global sweep_tmp
    global sweep_tmp2
    global variable_name
    global variable_name2
    
    sweep_var = np.linspace(sweep_start, sweep_end, sweep_points)
    sweep_var2 = np.linspace(sweep_start2, sweep_end2, sweep_points2)
    
    #arrays to save generated data: maximum response, offset and average output 
    maxResponse_res = np.zeros((sweep_points,sweep_points2))
    offset_StimResponse_res = np.zeros((sweep_points,sweep_points2))
    avgResponse_res = np.zeros((sweep_points,sweep_points2))
    
    for i in range(len(sweep_var)):
        for j in range(len(sweep_var2)):
            exec('global '+variable_name_t +'; ' + variable_name_t+ " = sweep_var[i]")
            exec('global '+variable_name2_t +'; ' + variable_name2_t+ " = sweep_var2[j]")
            sweep_tmp = sweep_var[i]
            sweep_tmp2 = sweep_var2[j]
            variable_name = variable_name_t
            variable_name2 = variable_name2_t
            set_global_save_path()
            update_params()
            update_stimulus()
            update_fit(I, Vbias)
            sol_lsoda = solve_ODE(yamada_switchedLn_phot, get_InitialCond(), ts, get_Parameters())
            plot_results(sol_lsoda, ts, stim, savefig=False, savedata=True)
            print('Processing... '+str(len(sweep_var2)*i+j+1)+'/'+ str(len(sweep_var)*len(sweep_var2)))
            [maxResponse, maxStim, offset_StimResponse, avgResponse] = analyse_results(sol_lsoda, stim, ts, plotfig=False)
            maxResponse_res[i,j] = maxResponse
            offset_StimResponse_res[i,j] = offset_StimResponse
            avgResponse_res[i,j] = avgResponse

    return maxResponse_res, offset_StimResponse_res, avgResponse_res

def sweepwrapper2D(v1, v1_start, v1_end, v1_N, v2, v2_start, v2_end, v2_N, v3 = None, v3_value=None):
    # two dimensional parameter sweep 
    
    #generic sweep parameter one
    variable_name = v1
    sweep_start = v1_start
    sweep_end = v1_end
    sweep_points = v1_N

    #generic sweep parameter two
    variable_name2 = v2
    sweep_start2 = v2_start
    sweep_end2 = v2_end
    sweep_points2 = v2_N
    
    #set third parameter if passed
    v3_value=round(v3_value,5)
    exec('v3_value =' + v3)
    
    set_global_save_path(suffix_t=v3, suffix_value_t=v3_value)

    start_lsoda = time.process_time()
    [maxRes, offset, avgRes] = sweep_simulation2(variable_name_t = variable_name, sweep_start=sweep_start, sweep_end=sweep_end, sweep_points=sweep_points,variable_name2_t = variable_name2, sweep_start2=sweep_start2, sweep_end2=sweep_end2, sweep_points2=sweep_points2)
    print('Sweep simulation time is: ', time.process_time() - start_lsoda)

    # save the 2D map of results
    x = np.linspace(sweep_start2, sweep_end2, sweep_points2)
    y = np.linspace(sweep_start, sweep_end, sweep_points)
    X, Y = np.meshgrid(x, y)
    np.save(global_save_path +'_CONTOUR_'+v3+str(v3_value)+'N_'+str(v1_N*v2_N)+ '_X.npy',X)
    np.save(global_save_path +'_CONTOUR_'+v3+str(v3_value)+'N_'+str(v1_N*v2_N)+ '_Y.npy',Y)
    np.save(global_save_path +'_CONTOUR_'+v3+str(v3_value)+'N_'+str(v1_N*v2_N)+ '_avgRes.npy',avgRes)

def update_stimulus():
    global stim
    global i_unit
    i_unit = get_PhotonNumber(P_in=Peak_pulse, tau_p=T_pulse)
    stim = get_Stimulus(ts, input, i_unit, times, t_unit, T_pulse)

def update_fit(I,Vbias):
    # parametrized life time vs voltage of the absorber
    a = 1.086e-10
    b = -1.626
    tau_q = a * exp(b * Vbias)
    return tau_q

def update_params():
    # updating parameters between iterations in sweeps
    global Nq0, S_sat, alpha_mirror, tau_p, V_gain, V_loss, a_NQ2, Nq02, tau_q, tau_g
    Nq0 = a*Vbias + b   
    S_sat = get_PhotonNumber(E_sat/T_pulse, T_pulse) 
    alpha_mirror = 1/(L_absorber+L_gain+L_passive)*log(1/r1/r2)
    tau_p = 1/(vg*alpha_mirror)   
    V_gain = L_gain*t_active*w_active   
    V_loss = L_absorber*t_active*w_active   
    tau_q = update_fit(I, Vbias)


def phasespace_trajectories(modelfunction, initial, ts, params):
    # return the solutions over time for phase space analysis
    sol_lsoda2 = solve_ODE(modelfunction, initial, ts, params)
    S_traj=sol_lsoda2.y[0]
    Ng_traj=sol_lsoda2.y[1]
    Nq_traj=sol_lsoda2.y[2]
    
    return [S_traj,Ng_traj,Nq_traj]

# =====================   Simulation initialization  ===============================
# Described below are all the necessary parameters to calculate the response of the Yamada model
# to an input perturbation. Some of the parameters below, i.e. I and Vbias, r1, and input are 
# changed in the other simulations below, thus they are redefined later.

# ======================= time parameters ======================
tStart, tEnd = 0, 3.0e-9
N_points = 20000
N_POINTS = N_points
ts = np.linspace(tStart, tEnd, N_POINTS)

# ================ stimulus parameters ======================
input = [0, 1,0] #boolean whether there should be a pulse or not (up to three possible)
times = [0, 1,0] #time at which a pulse should be injected
T_pulse = 50e-12 #time duration of the pulse
Peak_pulse = 0.005    # unit of Watts
i_unit = get_PhotonNumber(P_in=Peak_pulse, tau_p=T_pulse)  # unit in photon numbers, scales input
t_unit = 1e-9  # unit in s, scales times

# create stimulus vectors
stim = get_Stimulus(ts, input, i_unit, times, t_unit, T_pulse)

# ================ Device physical parameters ======================
tau_g = 300e-12         # gain carrier lifetime 300 ps, Moskalenko thesis
a_NG = 17.5e-20         # gain cross section [m^2], or differential gain
a_NQ = 40e-20           # differential absorption [m^2]  from Heck thesis
Ntr = 0.5e24            # transparency carrier density [m^-3] 
gamma = 0.053           # confinement factor MQW 
vg = 299792458 / 3.65   # group velocity [m/s]
t_active = 100e-9       # active section thickness in [m], ~100nm where QW/QB stack is located
w_active = 2e-6         # width of active waveguide
r1 = 0.866              # 300 um long DBR grating, field reflection, design value
r2 = sqrt(0.4)          # MIR power reflectivity = 0.4, Emil Kleijn thesis
alpha_G = 3             # linewidth enhancement gain
alpha_Q = 2             # linewidth enhancement absorber
I = 40.5e-3             # injection current in [A], self-pulsation
L_absorber = 50e-6      # length of absorber section
L_gain = 500e-6         # length of gain section
L_passive = 1000e-6     # length of passive cavity section
E_sat = 1e-12           # saturation photon number depends on pulse width and energy: P = Esat/pulse width [in J]
Vbias = 0.6             # reverse bias voltage, positive value here = negative bias
a = 3.625e+22           # fitting model for transparency carrier density, depends on voltage fit model
b = 4.75e+22            # fitting model for transparency carrier density, depends on voltage fit model
Nq0 = a * Vbias + b     # absorber transparency carrier density [m^-3]

# ==================== Device dependent params ========================
S_sat = get_PhotonNumber(E_sat / T_pulse, T_pulse)                       # saturation photon number, saturation power calculated back to photon number , power = E_sat/Tpulse
alpha_mirror = 1 / (L_absorber + L_gain + L_passive) * log(1 / r1 / r2)  # Eq. 7, mirror losses in cavity
tau_p = 1 / (vg * alpha_mirror)                                          # Eq. 6, photon lifetime describes only mirror losses as rest has been accounted for
V_gain = L_gain * t_active * w_active                                    # volume of gain section [m^3]  500 um x 2 um x 100 nm block
V_loss = L_absorber * t_active * w_active                                # volume of absorber section [m^3]  100 um x 2 um x 100 nm block
tau_r = 2 * (L_absorber + L_gain + L_passive) / vg                       # cavity round-trip time
tau_q = update_fit(I, Vbias)                                             # update lifetime depending on Vbias

# ===============  initialize variables for sweep or non sweep ================
sweep_tmp = None
sweep_tmp2 = None
global_save_path = None
variable_name = None
variable_name2 = None
suffix = 'None'
suffix_value = None
    
#%% =====================================    Fig. 2: reverse bias voltage vs absorber carrier lifetime ===================================== ##
# Fig. 2 in NCE paper   
# Create the compact model figure using the data provided in M. Trajkovic, “High speed electro-absorption modulators in indium phosphide generic integration technologies,” Technische Universiteit Eindhoven, 2019.

#plotting options
plt.rcParams['font.size'] = 12
plt.rcParams['font.family' ] ='Times New Roman'
plt.rcParams['mathtext.default'] = 'regular'

v_bias=np.linspace(0.5,2.0,20)
tau_model=np.zeros(len(v_bias))
for i in range(len(v_bias)):
    tau_model[i]=update_fit(I,v_bias[i])

plt.figure(figsize=(5.0,3.0))
plt.plot(v_bias,tau_model*1e12,'k')
plt.xlabel('Reverse bias voltage [V]')
plt.ylabel(r'$\tau_Q$[ps]' )
plt.grid()
plt.tight_layout()
plt.savefig('Fig.2.png',format='png', bbox_inches='tight', dpi=1200)
plt.close('all')

#%% =====================================    Fig. 3: reverse bias voltage vs NQ0 ===================================== ##
# Fig. 3 in NCE paper
# Create the compact model figure from the analysis as described in the paper on page 4 and 5. Nq0_data_1550 is derived
# from the measurement data published in M. J. R. Heck, “Ultrafast integrated semiconductor laser technology at 1.55 µm,” Technische Universiteit Eindhoven, 2008.

#plotting options
plt.rcParams['font.size'] = 12
plt.rcParams['font.family' ] ='Times New Roman'
plt.rcParams['mathtext.default'] = 'regular'

Nq0_data_1550 = (np.array([0.05,0.12, 0.185, 0.27]))*1e24+1.57e23

fig=plt.figure(figsize=(6.4*0.8,4.8*0.65))

#data axis
V = np.linspace(-0,-6,4)
Nq0_data = Nq0_data_1550

#create fit to the four points
a_fit,b_fit= np.polyfit(Nq0_data, V, 1)

def func_lin(a_fit,b_fit,Nq0_data):
    return(a_fit*Nq0_data+b_fit)

#plot points and fit
plt.plot(Nq0_data, -func_lin(a_fit,b_fit,Nq0_data), 'k-')
plt.plot(Nq0_data, -V, 'ks')
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0),useMathText=True)
plt.xlabel(r'$N_{Q0}\, [m^{-3}]$')
plt.ylabel(r'Reverse bias voltage$ \,[V]$')
plt.grid()
fig.patch.set_alpha(0)
plt.tight_layout()
plt.savefig('Fig.3.png',format='png', bbox_inches='tight', dpi=1200)
plt.close('all')

#%% =====================================    Fig. 4 3D phase space diagram and corresponding time trace  ===================================== ##
# Fig. 4 in NCE paper
#Create the 3D plot with 2D projections of trajectory.

#plotting options
plt.rcParams['font.size'] = 12
plt.rcParams['font.family' ] ='Times New Roman'
plt.rcParams['mathtext.default'] = 'regular'

#set initial cond
initial=[100000, log(7.7e23), 2.50e22]#S, Ng, Nq

# set biasing operation point and pulse injection properties
I=0.0405
Vbias_sweep=0.6
T_pulse = 50e-12 #time duration of the pulse
Peak_pulse = 0.005
input=[0,1,0] #boolean whether there should be a pulse or not (up to three possible)
times = [0, 1.0,0] #time at which a pulse should be injected

#update all parameters and create stimulus
update_stimulus()
update_params()

#perform simulation
[S_traj,Ng_traj,Nq_traj]=phasespace_trajectories(yamada_switchedLn_phot, initial,ts,get_Parameters())

#plot the results in 3D phase space (top fig)
plt.close('all')
ax = plt.figure().add_subplot(projection='3d')
ax.plot3D(Nq_traj,np.exp(Ng_traj),S_traj,'k',linewidth=1.5,alpha=1)

# #plot shadows
ax.plot3D(np.zeros(len(Nq_traj)),np.exp(Ng_traj),S_traj,'k',linewidth=1.0,alpha=0.6)
ax.plot3D(Nq_traj,ax.get_ylim()[1]*np.ones(len(np.exp(Ng_traj))),S_traj,'k',linewidth=1.0,alpha=0.6)
ax.plot3D(Nq_traj,np.exp(Ng_traj),np.zeros(len(S_traj)),'k',linewidth=1.0,alpha=0.6)

#plot begin and end points
ax.plot3D(Nq_traj[0],np.exp(Ng_traj[0]),S_traj[0],'ko')
ax.plot3D(Nq_traj[-1],np.exp(Ng_traj[-1]),S_traj[-1],'ks')

#set labels
ax.set_xlabel(r'$N_{Q} \,[m^{-3}]$')
ax.set_ylabel(r'$N_{G} \,[m^{-3}]$')
ax.set_zlabel('S [-]')
ax.ticklabel_format(style='sci', scilimits=(0,0),useMathText=True)

ax.patch.set_alpha(0)
plt.savefig('Fig.4.Top.png',transparent=True,dpi=1200)

#plot timetrace (bottom fig)
plt.figure(figsize=(5,2))
plt.plot(ts*1e9,S_traj,'k',linewidth=2.0,label='Response')
plt.plot(ts*1e9,stim,'k--',linewidth=2.0,label='Injected')
plt.xlabel('Time [ns]')
plt.ylabel('S [-]')
plt.ticklabel_format(axis='y',style='sci', scilimits=(0,0),useMathText=True)
plt.legend()
plt.grid()
plt.tight_layout()
plt.savefig('Fig.4.Bottom.png',transparent=True,dpi=1200)
plt.close('all')
    
#%% =====================================    Fig. 5(a)-(d) examples of different operation states ===================================== ##
# Generate four different figures, each containing 4 figures of the trajectories of S, Ng, Nq in phase space, as well as the timetrace

#set plot params
plt.rcParams['font.size'] = 12
plt.rcParams['font.family' ] ='Times New Roman'
plt.rcParams['mathtext.default'] = 'regular'

#get trajectories
#initial cond
S_init=1#np.linspace(1e2,1e5,7)
Ng_init=-11.51#np.linspace(54.7,55,7)
Nq_init=0#np.linspace(2.00e22,2.50e22,7)

#plotting options
colors=['k','k']
linestyles=['-','--']
alphas=[1,1]
linewidths=[1.2,1.2]
subfig=[1,2,3,4]
subfig_letter=['a','b','c','d']

#biasing operation points
I_sweep=[0.0871,0.0871,0.0550, 0.0550]
Vbias_sweep=[1.800,1.797,1.000,1.500]

for i in range(len(subfig)):          
    fig, axs = plt.subplots(2,2)
    fig.patch.set_alpha(0)
   
    # set I and V for the simulations
    I=I_sweep[i]
    Vbias=Vbias_sweep[i]
    
    #make sure all parameters are updated correctly
    update_params()
    
    #plotting options
    color = colors[0] 
    linestyle = linestyles[0]
    alpha=alphas[0]
    linewidth=linewidths[0]

    #set initial values
    initial=[S_init, Ng_init, Nq_init]
    
    if i==0 or i==1:
        #set simulation time parameters for cw and spiking case
        tStart, tEnd = 0, 10e-9
        N_points = 8000
        N_POINTS = N_points
        ts = np.linspace(tStart, tEnd, N_POINTS)
        # spiking and cw, disable the injected pulse
        T_pulse = 50e-12
        Peak_pulse = 0.005      # unit of Watts
        input=[0,0,0]
        update_stimulus()
    else:
        #set simulation time parameters for on-set and excitable case
        tStart, tEnd = 0, 3e-9
        N_points = 8000
        N_POINTS = N_points
        ts = np.linspace(tStart, tEnd, N_POINTS)
        # on-set and excitable, enable the injected pulse
        T_pulse = 50e-12
        Peak_pulse = 0.005      # unit of Watts
        input=[0,1,0]
        update_stimulus()
        
    #perform simulation
    [S_traj,Ng_traj,Nq_traj]=phasespace_trajectories(yamada_switchedLn_phot, initial,ts,get_Parameters())
    
    #plot the four subsigures for Ng vs S, Nq vs S, Ng vs Nq and the timetrace
    line1=axs[0,0].plot(np.exp(Ng_traj),S_traj,linewidth=linewidth,alpha=alpha,color=color,linestyle=linestyle,label=str(Vbias))
    line2=axs[1,0].plot(Nq_traj,S_traj,linewidth=linewidth,alpha=alpha,color=color,linestyle=linestyle,label=str(Vbias))
    line3=axs[0,1].plot(np.exp(Ng_traj),Nq_traj,linewidth=linewidth,alpha=alpha,color=color,linestyle=linestyle,label=str(Vbias))
    line4=axs[1,1].plot(ts*1e9,S_traj,linewidth=linewidth,alpha=alpha,color=color,linestyle=linestyle,label='Response')
    if i==2 or i==3: #add injected signal to plot if the on-set or excited state is simulated
        line4=axs[1,1].plot(ts*1e9,stim,linewidth=linewidth,alpha=alpha,color=color,linestyle='--',label='Injection')

    #plots
    axs[0,0].plot(np.exp(Ng_traj)[0],S_traj[0],'o',color=color, markersize=5)
    axs[0,0].plot(np.exp(Ng_traj)[-1],S_traj[-1],'s',color=color, markersize=5)
    axs[0,0].ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
    axs[1,0].plot(Nq_traj[0],S_traj[0],'o',color=color, markersize=5)
    axs[1,0].plot(Nq_traj[-1],S_traj[-1],'s',color=color, markersize=5)
    axs[1,0].ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
    axs[0,1].plot(np.exp(Ng_traj)[0],Nq_traj[0],'o',color=color, markersize=5)
    axs[0,1].plot(np.exp(Ng_traj)[-1],Nq_traj[-1],'s',color=color, markersize=5)
    axs[0,1].ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
    axs[1,1].plot(ts[0]*1e9,S_traj[0],'o',color=color, markersize=5)
    axs[1,1].plot(ts[-1]*1e9,S_traj[-1],'s',color=color, markersize=5)
    axs[1,1].ticklabel_format(style='sci', axis='y', scilimits=(0,0),useMathText=True)
    axs[0,0].set_xlabel(r'$N_{G} \,[m^{-3}]$')
    axs[0,0].set_ylabel('$S \,[-]$')
    axs[1,0].set_xlabel(r'$N_{Q} \,[m^{-3}]$')
    axs[1,0].set_ylabel('$S \,[-]$')
    axs[0,1].set_xlabel(r'$N_{G} \,[m^{-3}]$')
    axs[0,1].set_ylabel(r'$N_{Q} \,[m^{-3}]$')
    axs[1,1].set_ylabel('$S \,[-]$')
    axs[1,1].set_xlabel('Time [ns]')
    plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1)
    if i==2 or i==3: #add legend only when injected signal is added to timetrace
        plt.legend()
    plt.savefig('Fig.5.'+(subfig_letter[i])+'_'+str("{:.4f}".format(I_sweep[i]))+'_mA_'+str("{:.4f}".format(Vbias_sweep[i]))+'_V_S0,Ng0,Nq0'+str("{:.4e}".format(S_init))+','+str("{:.4f}".format(Ng_init))+','+str("{:.4e}".format(Nq_init))+'.png',dpi=1200)
    plt.close('all')

#%% =====================================    Fig. 6(a)-(d) examples of different operation states ===================================== ##
# Generate two 3D phase diagrams as well as 2 corresponding timetraces for two different laser designs (different reflectivity values)

#set plot params
plt.rcParams['font.size'] = 12
plt.rcParams['font.family' ] ='Times New Roman'
plt.rcParams['mathtext.default'] = 'regular'
plt.close('all')

#set simulation time parameters
tStart, tEnd = 0, 4e-9
N_points = 8000
N_POINTS = N_points
ts = np.linspace(tStart, tEnd, N_POINTS)

#set inital S, Ng, Nq values
S_init=1e2
Ng_init=54.7
Nq_init=2.00e22

#input pulse params
T_pulse = 50e-12#time duration of the pulse
Peak_pulse = 0.005      # unit of Watts
input=[0,1,0]#boolean whether there should be a pulse or not (up to three possible)
times = [0.0, 1.0,0.0]#time at which a pulse should be injected
update_stimulus() #update stimulus timetrace
update_params() #update simulation parameters

# first reflectivity for Fig. 9a, second for 9b
r_sweep=[0.866,0.400]

# laser operation parameters I and V
I=0.050
V_sweep=[0,1,2,3]

#plotting settings 
subfig_letter_3d=['a','c']
subfig_letter_tt=['b','d']
#line style options
linestyle=['-','--','-.',':','-','-','-','-','-','-','-','-','-','-']
linewidth=1.4
alpha=0.8

#prepare sweep
for j in range(len(r_sweep)):
    #set reflectivity value
    r1=r_sweep[j]
    
    #select correct plotting settings
    if r1==0.866:
        color=['b','r','r','k','k','k','k','k','k','k','k','k','k','k','k','k','k']
    elif r1==0.400:
        color=['r','r','r','k','k','k','k','k','k','k','k','k','k','k','k','k','k']
    else:
        print('Other value for R is specified. Graph lines will all be black.')
        color=['k','k','k','k','k','k','k','k','k','k','k','k','k','k','k','k','k']
    start_color=color
    end_color=color
    
    #generate empty 3D plot
    ax = plt.figure(1).add_subplot(projection='3d')

    #generate empty trajectory plot
    fig_time = plt.figure(2,figsize=(6, 3))
    fig_time.patch.set_alpha(0)
    
    #perfom sweep
    for i in range(len(V_sweep)):
        
        #set correct reverse bias voltage
        Vbias=V_sweep[i]
        
        #update laser parameters
        update_params()
        update_stimulus()
        
        initial=[S_init, Ng_init, Nq_init]#S, Ng, Nq

        [S_traj,Ng_traj,Nq_traj]=phasespace_trajectories(yamada_switchedLn_phot, initial,ts,get_Parameters())
            
        #3D plots
        ax.plot3D(Nq_traj,np.exp(Ng_traj),S_traj,linewidth=linewidth,color=color[i],alpha=alpha,linestyle=linestyle[i])
        ax.plot3D(Nq_traj[0],np.exp(Ng_traj)[0],S_traj[0],'o',color=start_color[i])
        ax.plot3D(Nq_traj[-1],np.exp(Ng_traj)[-1],S_traj[-1],'s',color=end_color[i])
        ax.view_init(22, -45)
        ax.set_xlabel('$N_Q  \,[m^{-3}]$')
        ax.set_ylabel('$N_G  \,[m^{-3}]$')
        ax.set_zlabel('S [-]')
        ax.patch.set_alpha(0)
        ax.xaxis.major.formatter._useMathText = True
        ax.yaxis.major.formatter._useMathText = True
        ax.xaxis.major.formatter._useMathText = True
        ax.zaxis.major.formatter._useMathText = True
            
        # trajectory plots
        fig_time = plt.figure(2)
        plt.plot(ts[0]*1e9,S_traj[0],'o',color=color[i])
        plt.plot(ts[-1]*1e9,S_traj[-1],'s',color=color[i])
        plt.plot(ts*1e9,S_traj,color=color[i],alpha=alpha,linestyle=linestyle[i],label="$V_{rb}$"+' = '+str("{:.1f}".format(V_sweep[i]))+' V')
        
    # trajectory plots add injection signal      
    fig_time = plt.figure(2)
    plt.plot(ts*1e9,stim,linestyle=(0, (5,5)),color='k',label='Injection',alpha=1)
    plt.ylabel('S [-]')
    plt.xlabel('Time [ns]')
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
    plt.grid(True)
    plt.xlim([0.50,1.6])
    plt.legend()
    plt.tight_layout()
    plt.savefig('Fig.6.'+subfig_letter_tt[j]+'_r1_'+str(r1)+'_I_'+str("{:.6f}".format(I))+'_mA_V_bias_'+str((V_sweep))+'_V_S0,Ng0,Nq0'+str("{:.4e}".format(S_init))+','+str("{:.4f}".format(Ng_init))+','+str("{:.4e}".format(Nq_init))+'.png',dpi=1200)
    
    ax = plt.figure(1)
    plt.savefig('Fig.6.'+subfig_letter_3d[j]+'_r1_'+str(r1)+'_I_'+str("{:.6f}".format(I))+'_mA_V_bias_'+str((V_sweep))+'_V_S0,Ng0,Nq0'+str("{:.4e}".format(S_init))+','+str("{:.4f}".format(Ng_init))+','+str("{:.4e}".format(Nq_init))+'.png',dpi=1200)
    plt.close('all')
  
#%% =================================================    Information regarding plots 7 (a)-(d)   ================================================= ##
# For Fig. 7 (e)-(f), a two dimensional sweep must be calculated first, after in the second two seconds the analysis will be done separately. 
# In order to run these section, generate first the sweep data. Alternatively, use the provided datasets in the 'data_Fig_6ef' folder.

#simulation time settings
tStart, tEnd = 0, 2.5e-8
N_points = 20000
N_POINTS = N_points
ts = np.linspace(tStart, tEnd, N_POINTS)

# Single pulse excitation
input = [0, 1,0]      #boolean whether there should be a pulse or not (up to three possible)
times = [0, 1,0]      #time at which a pulse should be injected
T_pulse = 50e-12      #time duration of the pulse
Peak_pulse = 0.0015   # unit of Watts
i_unit = get_PhotonNumber(P_in=Peak_pulse, tau_p=T_pulse)  # unit in photon numbers, scales input
t_unit = 1e-9  # unit in s, scales times
update_stimulus()

#set reflectivity sweep values
r_sweep=[0.3,0.5,0.7,0.9]

#perform the sweep for different r values
for i in range(len(r_sweep)):
    plt.close('all')
    r1=r_sweep[i]
    sweepwrapper2D('I', 30e-3, 90e-3, 31, 'Vbias', -1, 4, 31, v3='r1', v3_value =r1)
     
#%% =====================================    Fig. 7(a)-(d) maps    ===================================== ##
# Analyzes 2D data and generates V vs I plots for different reflectivities

#injection pulse settings, needed to calculate optical output power
T_pulse = 50e-12
h = 6.6300e-34
c0 = 299792458

#four different reflectivities, for Fig. (a)-(d)
r_sweep=['0.3', '0.5', '0.7', '0.9']

for m in range(2):
    #this loops is needed to create two versions of the Figures, one with and one without red marker lines (a) and (d) are marked in the paper
    marked=m
    
    for i in range(len(r_sweep)):
        #load data
        filename='./maxStim_974781.9524078725_input_[0, 1, 0]_time_0.0_2.5e-08sweep_I0.09Vbias4.0r1'+r_sweep[i]+'_CONTOUR_r1'+r_sweep[i]+'N_961'
        
        #for calculated dataset generated in block above
        # X = np.load('./nodes_physical_output/'+filename+'_X.npy')
        # Y = np.load('./nodes_physical_output/'+filename+'_Y.npy')
        # avgRes = np.load('./nodes_physical_output/'+filename+'_avgRes.npy')
        
        #for precalculated dataset stored in 'data_Fig_6abcd'
        X = np.load('./data_Fig_6abcd/'+filename+'_X.npy')
        Y = np.load('./data_Fig_6abcd/'+filename+'_Y.npy')
        avgRes = np.load('./data_Fig_6abcd/'+filename+'_avgRes.npy')
        
        #detection values, between these values an excitable response is observed for this simulation.
        detect_valueA = 0.000008
        detect_valueB = 0.000300

        #create empty matrices, fill with NaN
        excitability_matrix_binary=np.zeros(avgRes.shape)
        excitability_matrix_binary[:]=np.nan
        
        #scan through the data and fill matrix with a 1 if the response was an excited response
        for j in range(excitability_matrix_binary.shape[0]):
            for k in range(excitability_matrix_binary.shape[1]):
                if get_PowerOutCavity(avgRes[j,k], T_pulse,1550e-9) > detect_valueA and get_PowerOutCavity(avgRes[j,k], T_pulse,1550e-9) < detect_valueB:
                    excitability_matrix_binary[j,k]=1
        
        # plot binary surface
        plt.rcParams.update({'font.size': 22})
        fig, ax = plt.subplots(1, 1)
        cs=ax.contourf(Y*1e3, X, excitability_matrix_binary,levels=[0,1],colors='k')
        if marked == 1:
            ax.vlines(50,-1,4,'r')
            ax.hlines(1.0, 30, 90,'r')
            ax.hlines(0.14, 30, 90,linestyle='--',color='r')
        ax.grid()
        ax.set_ylabel('Reverse bias voltage (V)')
        ax.set_xlabel('Gain current (mA)')
        plt.tight_layout()
        plt.savefig('Fig.7.(a)-(d)'+'_r1_'+str(round(float(r_sweep[i]),3))+'_marked_'+str(bool(marked))+'.png', transparent='True', bbox_inches='tight', dpi = 600)
        plt.close('all')
    
#%% =================================================    Information regarding plots 7 (e)-(f)   ================================================= ##
# For Fig. 7 (e)-(f), a two dimensional sweep must be calculated first, after in the second two seconds the analysis will be done separately. 
# In order to run these section, generate first the sweep data. Alternatively, use the provided datasets in the 'data_Fig_6ef' folder.

#set simulation time parameters
tStart, tEnd = 0, 2.5e-8
N_points = 20000
N_POINTS = N_points
ts = np.linspace(tStart, tEnd, N_POINTS)

# Single pulse excitation
input = [0, 1,0]
times = [0, 1,0]
T_pulse = 50e-12
Peak_pulse = 0.0015    # unit of Watts
i_unit = get_PhotonNumber(P_in=Peak_pulse, tau_p=T_pulse)  # unit in photon numbers, scales input
t_unit = 1e-9  # unit in s, scales times
update_stimulus()

#set r and v sweep parameters
r_sweep=np.linspace(0.01,1.0,3)
v_sweep=[0.14,1.00]

#perform 2D sweep
for i in range(len(v_sweep)):
    Vbias=v_sweep[i]
    for j in range(len(r_sweep)):
        plt.close('all')
        r1=r_sweep[j]
        print(r1)
        sweepwrapper2D('I', 20e-3, 90e-3, 51, 'Vbias', Vbias, Vbias, 1, v3='r1', v3_value =r1)
        
#%% =====================================    Fig. 7(e)-(f) maps ===================================== ##
# Analyzes 2D data and generates r vs I plots thats shows the energy in the excitable pulse

#specify r, I, v sweep paramerers based on the generated data
r_sweep=np.linspace(0.01,1.0,51)
i_sweep=np.linspace(20e-3,90e-3,51)
v_sweep=[0.14,1.0]

#specify injection sweep
T_pulse = 50e-12
h = 6.6300e-34
c0 = 299792458

#specify thresholds for excitable response
detect_valueA = 0.000008
detect_valueB = 0.000070

#generate empty energy excitability matrix
excitability_energy=np.zeros((len(r_sweep),len(i_sweep)))

for k in range(len(v_sweep)):
    for i in range(len(r_sweep)):
        #load files
        filename='maxStim_974781.9524078725_input_[0, 1, 0]_time_0.0_2.5e-08sweep_I0.09Vbias'+str(v_sweep[k])+'r1'+str(round(float(r_sweep[i]),4))+'_CONTOUR_r1'+str(round(float(r_sweep[i]),4))+'N_51'
        
        #for calculated dataset generated in block above'
        # avgRes = np.load('./nodes_physical_output/'+filename+'_avgRes.npy') #for newly generated dataset
        
        #for precalculated dataset stored in 'data_Fig_6ef'
        avgRes = np.load('./data_Fig_6ef/'+filename+'_avgRes.npy') 

        #analyze each calculated timetrace and stores the calculated output energy if an excitable response was recorded, or a 0 in case no excitable response
        for j in range(avgRes.size):
            if get_PowerOutCavity(avgRes[j], T_pulse,1550e-9) > detect_valueA and get_PowerOutCavity(avgRes[j], T_pulse,1550e-9) < detect_valueB:
                excitability_energy[i,j]=get_PowerOutCavity(avgRes[j], T_pulse,1550e-9)*1e3
            else:
                excitability_energy[i,j]=0
    
    # plot and save results
    fig, ax = plt.subplots(1, 1)
    cp = ax.pcolor(i_sweep*1e3,r_sweep,excitability_energy)
    ax.grid()
    ax.set_xlabel('Gain current (mA)')
    ax.set_ylabel('Reflectivity (-)')
    cbar = fig.colorbar(cp)
    cbar.set_label('Avg. power (mW)', rotation=90)
    plt.tight_layout()
    plt.savefig('Fig.7.(e)-(f)_r_vs_I_at_V='+str(v_sweep[k])+'V'+'.png', transparent='True', bbox_inches='tight', dpi = 600)
    plt.close('all')