# -*- coding: utf-8 -*-
"""
Created on Fri Feb  3 16:35:48 2023

@author: lmveldman
"""

from qutip import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
import math as math
import numpy as np
import os
import glob
import cv2
import scipy.integrate as integrate

### CONSTANTS ###
# all given in GHz
kb = 8.617e-5*241799 # Boltzmann constant in GHz
gn = -0.31539 # Nuclear g-factor
mu_B = 5.788e-5/4.136e-6 # Bohr magneton in GHz
mu_N = 3.152e-8/4.136e-6 # Nuclear magneton in GHz
e = 1 # Electron charge

####################
### HAMILTONIANS ###
####################

def get_dim(s):
#calculates spin dimension for each atom
    dim =[int(2*i)+1 for i in s]
    return dim

def get_r(x,y):
#calculates matrix of atom-atom distances between all atoms
    N = len(x)
    r = np.empty( (N,N) )
    for i in range(N):
        for j in range(N):
            r[i,j] = np.sqrt((x[j]-x[i])**2+(y[j]-y[i])**2)*0.42/np.sqrt(2)
    return r

def get_angle(x,y):
#calculates matrix of angles between all atoms
    N = len(x)
    theta = np.empty( (N,N) )
    for i in range(N):
        for j in range(N):
            theta[i,j] = np.arctan2( (y[j]-y[i]), (x[j]-x[i]) )
    return theta

def get_J(r):
#calculates coupling constants from atom-atom distances
    r_0 = 0.72 #[nm] taken from Bae 2018 suppl.
    J_0 = 27.7 #[GHz] taken from Bae 2018 suppl.
    d = 0.094 #[nm] taken from Bae 2018 suppl.
    J = J_0 * np.exp(-(r-r_0)/d)  #[GHz]
    return J

def get_J_D_from_fit(r,theta,j_0,d_0,dec):
#calculates coupling constants from atom-atom distances
# r_0 = 1.0525
    # j_0 = 0.772
    # dec = 0.122
    # d_0 = 0.157
    field_angle = 0# -(90-14)*np.pi/180 # correction for in-plane field angle of external field
    J = j_0 * np.exp(-1.0 * (r - 1.0525)/dec)  #[GHz]
    D = d_0 * (1 - 3 * (np.cos(theta+field_angle)**2))/(2 * math.pi * r**3)
    return J,D

#Equation of dipolar and exchange coupling
def full_coupling(r, theta, d_0, j_0, dec):
    D = d_0 * (1 - 3 * (np.cos(theta)**2))/(2 * math.pi * r**3)
    J = j_0 * np.exp(-1.0 * (r - 1.0525)/dec) #1.0525 nm = r_0
    return 2*D + J

def build_H_ident(dim): 
#generates vector with an identity matrix for each atom. 
#The dimension of each identity matrix matches the dimension of the spin matrix of each atom.
    H_ident = []
    for i in range(0,len(dim)):
        H_ident.append(qeye(dim[i]))
    return H_ident

def get_spin_mat(s):
#returns vector of length for number of atoms
#each entry gives x,y and z spin matrices for each atom.
    spin_mat = []
    for i in s:
        Spin = [spin_Jx(i),spin_Jy(i),spin_Jz(i)]
        spin_mat.append(Spin)
    return spin_mat

def get_spin_readout_mat(s,dim,N,atom_nr):
#returns vector containing spin readout matrices in x,y and z directions for a certain atom
    Spin = [spin_Jx(s[atom_nr]),spin_Jy(s[atom_nr]),spin_Jz(s[atom_nr])]
    vec = list(np.empty(len(Spin)))
    for i in range(len(Spin)):
        Sz=[]
        atoms = []
        for j in range (0,N):
            if j == atom_nr:
                atoms.append(Spin[i]) 
            else:
                atoms.append(qeye(dim[j]))
        Sz = tensor(atoms) 
        vec[i] = Sz
    return vec

def build_H_heis(dim,s,J):
#generates Heisenberg exchange coupling Hamiltonian
    H_h = 0
    N = len(dim)
    spin_mat = get_spin_mat(s)
    for i in range(N):
        for j in range(i+1,N): #to avoid counting coupling twice we start from i+1
                for s in range(0,3): #loop over Sx,Sy,Sz
                    H_heis = build_H_ident(dim) #initialize
                    H_heis[i]=spin_mat[i][s] #insert i-th atom spin matrix
                    H_heis[j]=spin_mat[j][s] #insert j-th atom spin matrix
                    H_h = H_h + J[i,j]*tensor(H_heis) #add term-for-term to Hamiltonian
    return H_h


def build_H_zeeman(dim,s,g,B,B_tip,tipspin): 
#generates Zeeman Hamiltonian for both external field and tip field
    mu_B = 5.788e-5/4.136e-6  #[GHz] Bohr magneton constant
    H_z = 0
    spin_mat = get_spin_mat(s)
    N = len(dim)
    for s in range(0,3): #loop over Sx,Sy,Sz
        for i in range(0,N): #Loop over number of spins
            H_zeeman = build_H_ident(dim)
            if i == tipspin: #Add B_tip only to i'th atom
                H_zeeman[i]= g[i][s]*mu_B*spin_mat[i][s]*(B[s]+B_tip[s])#/h_bar
#             if i == (tipspin+1): #Add B_tip only to i'th atom
#                 H_zeeman[i]= g[i]*mu_B*spin_mat[i][s]*(B[s]+B_tip[s]/5)#/h_bar
            elif i != tipspin:
                H_zeeman[i]= g[i][s]*mu_B*spin_mat[i][s]*B[s]#/h_bar
            H_z = H_z + tensor(H_zeeman)
    return H_z

def build_H_dipole(dim,s,D):
    H_d = 0
    N = len(s)
    spin_mat = get_spin_mat(s)
    for i in range(N):
        for j in range(i+1,N): #to avoid counting coupling twice we start from i+1
            for s in range(0,3): #loop over Sx,Sy,Sz
                H_dipole = build_H_ident(dim) #initialize
                H_dipole[i]=spin_mat[i][s] #insert i-th atom spin matrix
                H_dipole[j]=spin_mat[j][s] #insert j-th atom spin matrix
                if s==2: #for Sz
                    H_d = H_d + 2*D[i,j]*(tensor(H_dipole)) #add term-for-term to Hamiltonian
                else: #for Sx and Sy
                    H_d = H_d - D[i,j]*(tensor(H_dipole)) #add term-for-term to Hamiltonian
    return H_d

def Build_B(mag,theta):
    theta_rad = theta*np.pi/180 # convert degrees to rad
    z = mag*np.cos(theta_rad)
    y = 0
    x = mag*np.sin(theta_rad)
    b = [x,y,z]
    return b

def Boltzmann(E,T):
#input E is qutip list of eigenstate energies, T is temperature scalar
    Boltz=np.exp(-(E)/(kb*T))
    Z=np.sum(Boltz)
    P=Boltz/Z
    return P

###########################################
### SOLVING TIME DEPENDENT HAMILTONIAN ###
###########################################

def Build_initial_state(P,edim,ndim):
    # Builds initial state from all eigenstates and populations P
    init_dm = Qobj(np.diag(P),dims = [[edim,ndim],[edim,ndim]])
    return init_dm

def solve_master_decay(H,psi,times,T_1_e,T_2_e,dim,N,atom_nr,s):
    # Solves Lindblad equations with relaxation and decoherence
    # psi is the starting state
    # All times are multiplied by 2*pi because energies are in GHz
    M = []
    Sz = get_spin_readout_mat(s,dim,N,atom_nr)[2]
    
    relaxation = []
    decoherence = []
    for j in range(N): # sum over atoms
        relaxation.append( np.sqrt(1/(T_1_e*2*np.pi))* get_spin_readout_mat(s,dim,N,j)[0])
        decoherence.append( np.sqrt(1/(T_2_e*2*np.pi))* get_spin_readout_mat(s,dim,N,j)[2])
    # relaxation = [np.sqrt(1/(T_1_e*2*np.pi))*tensor(spin_Jm(espin),nI),np.sqrt(1/(T_1_n*2*np.pi))*tensor(eI,spin_Jm(nspin))]
    # decoherence = [np.sqrt(1/(T_2_e*2*np.pi))*(tensor(spin_Jz(espin),nI)),np.sqrt(1/(T_2_n*2*np.pi))*(tensor(eI,spin_Jz(nspin)))]
    for i in [Sz]:
        Sz = i 
        # result = mesolve(H, psi, times*(2*np.pi), relaxation+decoherence, Sz) 
        result = mesolve(H, psi, times*(2*np.pi),[], Sz) 
        M.append(result.expect)
    return M

###########################
### SIMULATING ESR DATA ###
###########################

def Fano(f_0,A,FWHM,f,q):
    delta = (f-f_0)/(FWHM/2)
    return A*(1+delta*2*q)/(1+delta**2)

#################
### PLOTTING ####
#################

def plot_structure(x,y,colors = 'tab:blue'): # creating an image of the atomic structure being calculated
    xmax = np.max(x)+1 #set max limit for grid
    xmin = np.min(x)-1
    ymax = np.max(y)+1 #set max limit for grid
    ymin = np.min(y)-1
    xticks = np.arange(round(xmin)-1, round(xmax)+1, 1)
    yticks = np.arange(round(ymin)-1, round(ymax)+1, 1)

    fig = plt.figure()
    ax = fig.gca()
    
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xlim(xmin,xmax)
    ax.set_ylim(ymin,ymax)
    ax.tick_params(labelbottom=False,labelleft=False)
    ax.set_aspect('equal', adjustable='box')
    ax.grid(which='both',zorder=1)
    for i in np.arange(round(xmin)-1,round(xmax)+1):
        for j in np.arange(round(ymin)-1,round(ymax)+1):
            plt.scatter(i,j,color='darkgrey',alpha=0.8,s=5000/(3*xmax),zorder=2)
            plt.scatter(i+0.5,j+0.5,color='black',alpha=0.6,s=5000/(3*xmax),zorder=3)
    plt.scatter(x, y,s=15000/(3*xmax),zorder=3,color=colors)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.show()