import numpy as np
import os.path
import sys

print("starting")

def gauss(e1,e2,sigma):
    norm = 1.0/np.sqrt(2*np.pi*sigma**2)
    exp = np.exp(-((e1-e2)**2)/(2*sigma**2))
    return norm*exp

def contribution(fname):

    with open(fname, 'r') as f:
        text = f.readlines()
    # number of lines    
    nlin=len(text)                                                                                                                                                                                           
    for i in range(nlin):  
        text[i] = text[i].split()
    
    data = np.zeros((nx,ny,nz),dtype=float)
           
    # go through the data 
    x=0             
    y=0             
    z=0             
    for lin in range(394,nlin):                             
        ar = np.array(text[lin],dtype=float)
        for num in range(len(ar)):
            data[x,y,z] = ar[num]    #/vol       
            x+=1
            if x==nx:
                x=0                                                                           
                y+=1 
                if y==ny:
                    y=0
                    z+=1

    return data

# Read DOSCAR
print("opening DOSCAR")
with open("DOSCAR", 'r') as f:
    text = f.readlines()      
for i in range(len(text)):    
    text[i] = text[i].split() 
                                        
# Fermi energy                          
ef = float(text[5][3])                  
print("Fermi energy:", ef)

# Extract E(k,n)
fname = "EIGENVAL"
if os.path.isfile(fname):
    # READ EIGENVAL
    with open("EIGENVAL", 'r') as f:
        text = f.readlines()
    # number of lines    
    nlin=len(text)
    
    nkp = int(text[5].split()[1])
    print("Number of k-points:", nkp)
    nbands = int(text[5].split()[2])
    print("Number of bands:", nbands)
    # energies
    en = np.zeros((nkp,nbands))
    
    # weights
    kw = np.zeros(nkp)
    
    kp = 0
    for i in range(6,nlin):
        text[i] = text[i].split()
        if len(text[i])>3:
	    print (i, kp)
	    print ( text[i] )
            kw[kp] = float(text[i][3])
            print("kpt:",kp, "weight:",kw[kp])
            for j in range(1,nbands+1):
                line = text[i+j].split()
                bnd = int(line[0])-1
                en[kp,bnd] = float(line[1])   #-ef
                print("bnd:",bnd, "energy:",en[kp,bnd])
            kp+=1
else:
    # Read PROCAR
    with open("PROCAR", 'r') as f:
        text = f.readlines()
    # number of lines    
    nlin=len(text)
    
    nkp = int(text[1].split()[3])
    print("Number of k-points:", nkp)
    nbands = int(text[1].split()[7])
    print("Number of bands:", nbands)
    nions = int(text[1].split()[11])
    print("Number of ions:", nions)
    # energies
    en = np.zeros((nkp,nbands))
    
    # weights
    kw = np.zeros(nkp)
    
    for i in range(nlin):
        bnd = 0
        text[i] = text[i].split()
        if len(text[i])>0 and text[i][0] == "k-point":
            kp = int(text[i][1])-1
            kw[kp] = float(text[i][8])
            print("kpt:",kp, "weight:",kw[kp])
        if len(text[i])>0 and text[i][0] == "band":
            bnd = int(text[i][1])-1
            en[kp,bnd] = float(text[i][4])-ef
            print("bnd:",bnd, "energy:",en[kp,bnd])


        

# Read CHG
with open("CHG", 'r') as f:
    text = f.readlines()
# number of lines    
nlin=len(text)                                                                                                                                                                                           
for i in range(nlin):  
    text[i] = text[i].split()           

# Extract data
nx = int(text[393][0])
ny = int(text[393][1])
nz = int(text[393][2]) 
npt = nx*ny*nz     

emin = 5.0 
emax = 6.0
nen = 10
de = (emax-emin)/nen
sigma = 0.05

rho = np.zeros((nen,nx,ny,nz),dtype=float)

# loop over energy
for ie in range(nen):
    e = (emax-emin)*ie/nen + emin
    print(ie,e)    
    # loop over kpt
    for ik in range(nkp):
        print(ik)
        # loop over bnd
        for ib in range(nbands):
            fname = "PARCHG."+str(ib+1).zfill(4)+"."+str(ik+1).zfill(4)
            if os.path.isfile(fname):
		print(ib,en[ik,ib],ef)
                smearing = gauss(e,en[ik,ib],sigma)*de
                print(e,en[ik,ib],smearing)
                if smearing>0.0000000001:
                    rho[ie] += smearing*contribution(fname)*kw[ik]
            else:
                continue

    # averaged over x-y
#    rho_z = np.zeros((nen,nz),dtype = float)
    
#for ie in range(nen):
#    for iz in range(nz):
#        rho_z[ie,iz] = +np.sum(rho[ie,:,:,iz])/(nx*ny)
        
#data = np.zeros((nen*nz,3))
    
#line = 0
#for ie in range(nen):
#    for iz in range(nz):
#        data[line,0] = iz
#        data[line,1] = (emax-emin)*ie/nen + emin
#        data[line,2] = rho_z[ie,iz]
#        line+=1
            
#np.savetxt("ldos_scat.dat",data)


    # averaged over x-z
    rho_y = np.zeros((nen,ny),dtype = float)

for ie in range(nen):
    for iy in range(ny):
        rho_y[ie,iy] = +np.sum(rho[ie,:,iy,:])/(nx*nz)

data = np.zeros((nen*ny,3))

line = 0
for ie in range(nen):
    for iy in range(ny):
        data[line,0] = iy
        data[line,1] = (emax-emin)*ie/nen + emin
        data[line,2] = rho_y[ie,iy]
        line+=1

np.savetxt("ldos_scat.dat",data)

