
from yade import pack, ymport
from yade import utils, wrapper

###functions for exporting data
os.mkdir(os.getcwd()+'/VTK/')
os.mkdir(os.getcwd()+'/fstat/')
os.mkdir(os.getcwd()+'/data/')
def savePropData(O):
        from yade import export
        import numpy as np
        
        path = os.getcwd()+'/VTK/'
        vtkExporter = export.VTKExporter(path)
        vtkExporter.exportSpheres(numLabel = O.iter, what = dict( \
                                      dist = 'b.state.pos.norm()', \
                               linVelocity = 'b.state.vel', \
                               angVelocity = 'b.state.angVel', \
                               mass = 'b.state.mass', \
                               mat_rand = 'b.material.id', \
                               lqc= 'lqc.liqBody(b.id)' , \
                               id='b.id' , \
                               numOfContacts = 'len(b.intrs())'))
                               

######Material properties                              
                               
nCG=3
fr=0.9

rho = 3000

r = 0.0005 * nCG      


Gamma = 0.146 * nCG    
Theta = 0.0
vB = 0.02* (4/3)*3.14* (r**3)
Eta=0.001 * nCG
CapType="Lambert"

kkN= 2*(nCG)*(100)
kkS= 2*(nCG)*(30)
ccN= 2*(nCG**2)*(0.008)
ccS= 2*(nCG**2)*(0.008)
###en=0.3
#####time parameters
O.dt = 5*1e-6* nCG
it=math.floor(0.05/O.dt )
simT=10
#####drum 
drumr=0.1
druml=0.06

mat=O.materials.append(
        ViscElCapMat(frictionAngle=fr, density=rho, Vb=vB, gamma=Gamma, eta=Eta, theta=Theta, Capillar=True, CapillarType=CapType, kn=kkN, ks=kkS, cn=ccN, cs=ccS )
)
mat2=O.materials.append(
        ViscElCapMat(frictionAngle=fr, density=rho, Vb=vB, gamma=Gamma, eta=0, theta=90, Capillar=True, CapillarType=CapType, kn=kkN, ks=kkS, cn=ccN, cs=ccS )
)


#defining the spheres

sp=pack.SpherePack()
sp.makeCloud((drumr-0.7*drumr,drumr-0.7*drumr,0.0015),(drumr+0.7*drumr,drumr+0.7*drumr,0.057),rMean=r,num=13334)
sp.toSimulation(material=mat)

Nprtcl=len(O.bodies)
print(Nprtcl)



#liquidMigration
VV=0.0
Vmin=0.0

for s in O.bodies:
        if not type(s.shape)==wrapper.Sphere:
                continue
        s.state.Vf=VV * (4/3) * 3.14*(s.shape.radius)**3
        s.state.Vmin=Vmin

#Drum=geom.facetCylinder(material=mat2,center=(0.05,0.05,0.015), segmentsNumber=32, wallMask=4, radius=drumr,height=20*druml,orientation=Quaternion(Vector3(0,0,1),(pi/2.0)))
walls = O.bodies.append(ymport.stl('drum.stl',material=mat2))
O.periodic = True
O.cell.setBox(0.5,0.5,0.06)




##engine
O.engines = [
        ForceResetter(),
        InsertionSortCollider([Bo1_Sphere_Aabb(), Bo1_Facet_Aabb()], allowBiggerThanPeriod = True),
        InteractionLoop(
                [Ig2_Sphere_Sphere_ScGeom(),Ig2_Facet_Sphere_ScGeom()],
                [Ip2_ViscElCapMat_ViscElCapMat_ViscElCapPhys()],
                [Law2_ScGeom_ViscElCapPhys_Basic()],
                
        ),
       
        NewtonIntegrator(gravity=[0, -9.8, 0]),
        RotationEngine(ids=walls,rotationAxis=[0,0,1],rotateAroundZero=True, zeroPoint=[0.1,0.1,0.03], angularVelocity=0.45),
        PyRunner(command='liquidSpray(O)', iterPeriod=it),
        LiqControl(label='lqc'),
        PyRunner(command='savePropData(O)', iterPeriod=it),
        PyRunner(command='writeDataSeparate(O)', iterPeriod=it),
        PyRunner(command='energyOut(O)', iterPeriod=it),
        #PyRunner(command='liqVariation(O)', iterPeriod=it),
        PyRunner(command='writeFstatSeparate(O)', iterPeriod=it),
        PyRunner(command='saving(O)',iterPeriod=10*it)
        #VTKRecorder(iterPeriod=10000,recorders=['facets', 'bstresses'],fileName='test-')
        	]

#Functions 


def writeDataSeparate(O):
        N=Nprtcl
        time=O.time
        xmin=0.0
        xmax=2*drumr
        ymin=0.0
        ymax=2*drumr
        zmin=0.0
        zmax=druml
        iter=int(O.iter/it)
        path1 = os.getcwd()+'/data/'
        xballsdata= open(path1+"data." +str(10000+iter)[1:],"a")
        

        nline=f" {N} {time:.8f} {xmin} {ymin} {zmin} {xmax} {ymax} {zmax} \n"
        xballsdata.write(nline)
        for i,b in enumerate(O.bodies):
                if not type(b.shape)==wrapper.Sphere:
                        continue

                ##post-processing the positions out of the periodic domain

                px = b.state.pos[0] 
                py = b.state.pos[1] 
                pz = b.state.pos[2] 

                if (pz>zmax):
                        remainder=pz%(zmax-zmin)
                        pz=zmin+remainder

                if (pz<zmin):
                        remainder=(-pz)%(zmax-zmin)
                        pz=zmax-remainder




                vx = b.state.vel[0]
                vy = b.state.vel[1] 
                vz = b.state.vel[2] 

                radius=b.shape.radius

                omegx = b.state.angVel[0]
                omegy = b.state.angVel[1] 
                omegz = b.state.angVel[2] 

                alpha=b.state.ori[0] 
                beta=b.state.ori[1] 
                gamma=b.state.ori[2] 

                vf= lqc.liqBody(i)
                
                pline = f"{px:.16f} {py:.16f} {pz:.16f} {vx:.8f} {vy:.8f} {vz:.8f} {radius} {alpha} {beta} {gamma} {omegx:.8f} {omegy:.8f} {omegz:.8f} {vf:.12f} \n"
                xballsdata.write(pline)

        xballsdata.close()
        

def writeFstatSeparate(O):
        fstatVersion=1
        time=O.time
        xmin=0.0
        xmax=2*drumr
        ymin=0.0
        ymax=2*drumr
        zmin=0.0
        zmax=druml

        iter=int(O.iter/it)
        path2 = os.getcwd()+'/fstat/'
        fstatdata= open(path2+"fstat." +str(10000+iter)[1:],"a")
        
        n1line=f"# {time:.8f} {fstatVersion} \n"
        n2line=f"# {xmin} {ymin} {zmin} {xmax} {ymax} {zmax} \n"
        n3line=f"# {0} {0} {0} {0} {0} {0} {0}\n"
        fstatdata.write(n1line)
        fstatdata.write(n2line)
        fstatdata.write(n3line)

        for c in O.interactions:
                i=c.id1
                j=c.id2
                cx=c.geom.contactPoint[0]
                cy=c.geom.contactPoint[1]
                cz=c.geom.contactPoint[2]

                #postprocessing the z pisition of the contact point for periodic domain purpose
                if (cz>zmax):
                        remainder=cz%(zmax-zmin)
                        cz=zmin+remainder

                if (cz<zmin):
                        remainder=(-cz)%(zmax-zmin)
                        cz=zmax-remainder

                delta=c.geom.penetrationDepth
                deltat=0.0
                fn=c.phys.normalForce
                fnnorm=c.phys.normalForce.norm()
                #fnMag=((fn[0])^2+(fn[1])^2+(fn[2])^2)^0.5
                ft=c.phys.shearForce
                ftnorm=c.phys.shearForce.norm()
                #ftMag=((ft[0])^2+(ft[1])^2+(ft[2])^2)^0.5
                if (fnnorm==0):
                        nx=0
                        ny=0
                        nz=0
                else:
                        nx=fn[0]/fnnorm
                        ny=fn[1]/fnnorm
                        nz=fn[2]/fnnorm
                
                if (ftnorm==0):
                        tx=0
                        ty=0
                        tz=0
                else:
                        tx=ft[0]/ftnorm
                        ty=ft[1]/ftnorm
                        tz=ft[2]/ftnorm
                normal=c.geom.normal        
                fndotn=fn[0]*normal[0] + fn[1]*normal[1] +fn[2]*normal[2]
                Fnc=c.phys.Fn
                Fnv=c.phys.Fv
                Vbr=c.phys.Vb

                #bridgeCreat=c.phys.liqBridgeCreated
                #bridgeActive=c.phys.liqBridgeActive

                pline = f"{time:.8f} {i} {j} {cx:.8f} {cy:.8f} {cz:.8f} {delta:.8f} {deltat} {fnnorm:.8f} {ftnorm:.8f} {normal[0]:.8f} {normal[1]:.8f} {normal[2]:.8f} {tx} {ty} {tz} {fndotn:.8f} {Fnc:.8f} {Fnv:.8f} {Vbr:.12f}\n"
                fstatdata.write(pline)
        fstatdata.close()


def energyOut(O):

        time=O.time
        ene_gra=0.0
        ene_kin=utils.kineticEnergy()
        ene_rot=0.0
        ene_ela=0.0
        x_com=0.0
        y_com=0.0
        z_com=0.0
        
        energydata= open("test.ene","a")
        
        nline=f" {time} {ene_gra} {ene_kin} {ene_rot} {ene_ela} {x_com} {y_com} {z_com}\n"
       
        energydata.write(nline)

        energydata.close()


def liqVariation(O):
        N=Nprtcl
        time=O.time
        refP=[0.0, 0.0, 0.0] 
        sum1=0.0
        sum2=0.0
        sum3=0.0
        sum4=0.0
        sum5=0.0
        posAve=[0.0,0.0,0.0]

        liqdata= open("test.liqB","a")

        for i,b in enumerate(O.bodies):
                if not type(b.shape)==wrapper.Sphere:
                        continue 

                sum1=sum1 + lqc.liqBody(i)
        

        liqAve= (sum1)/N

        for i,b in enumerate(O.bodies):
                if not type(b.shape)==wrapper.Sphere:
                        continue 
                
                sum2=sum2+(liqAve - lqc.liqBody(i))**2

        liqVar=sum2/(liqAve**2)

        for c in O.interactions:
                
                if ( c.phys.liqBridgeActive == True):
                        i=c.id1
                        j=c.id2
                        posAve[0]=(O.bodies[i].state.pos[0]+O.bodies[j].state.pos[0])/2
                        posAve[1]=(O.bodies[i].state.pos[1]+O.bodies[j].state.pos[1])/2
                        posAve[2]=(O.bodies[i].state.pos[2]+O.bodies[j].state.pos[2])/2
                        dis= sqrt((posAve[0]-refP[0])**2 +(posAve[1]-refP[1])**2 +(posAve[2]-refP[2])**2)
                        sum3=sum3+ c.phys.Vb * dis
                        sum4=sum4+c.phys.Vb
                        sum5=sum5+1.0

        if not (sum4==0.0):               
                r_Vb=sum3/sum4

        if (sum4==0.0):               
                r_Vb=0.0

        if not (sum5==0.0):  
                r_num=sum3/sum5

        if (sum5==0.0):  
                r_num=0.0
        
        nline=f" {time} {liqVar} {r_num} {r_Vb}\n"


        liqdata.write(nline)

        liqdata.close()

def saving(O):
        O.save('test.bz2')

        import pandas as pd

        intrState = pd.DataFrame(columns = ['id1','id2','Fn','Fv','sCrit','normalForce0','normalForce1','normalForce2', 'shearForce0','shearForce1','shearForce2'], dtype=object)

        for ii in O.interactions:
                iiState = pd.DataFrame({'id1':[ii.id1],'id2':[ii.id2],'Fn':[ii.phys.Fn],'Fv':[ii.phys.Fv], 'sCrit':[ii.phys.sCrit], 'normalForce0':[ii.phys.normalForce[0]], 'normalForce1':[ii.phys.normalForce[1]], 'normalForce2':[ii.phys.normalForce[2]], 'shearForce0':[ii.phys.shearForce[0]], 'shearForce1':[ii.phys.shearForce[1]], 'shearForce2':[ii.phys.shearForce[2]]})
                intrState = intrState.append(iiState,ignore_index = True)

        intrState.to_csv('tmpIntrState.csv')

def liquidSpray(O):
        
        VV=0.05
        Vmin=0.0
        Vf=VV * (4/3) * 3.14* r**3
        if (O.iter==2*it): 
                for i,b in enumerate(O.bodies):
                        if not type(b.shape)==wrapper.Sphere:
                                continue

                        px = b.state.pos[0] 
                        py = b.state.pos[1] 
                        pz = b.state.pos[2] 
                
                        if (b.state.Vf<Vf):
                                if ((px-0.1)**2+(py-0.1)**2<(0.05)**2):
                                        b.state.Vf=b.state.Vf+Vf
                                        b.state.Vmin=Vmin


import math

duration=simT/O.dt
O.run( 1 * math.floor(duration),True)


###saving for Restart
O.save('test.bz2')

import pandas as pd

intrState = pd.DataFrame(columns = ['id1','id2','Fn','Fv','sCrit','normalForce0','normalForce1','normalForce2', 'shearForce0','shearForce1','shearForce2'], dtype=object)

for ii in O.interactions:
    iiState = pd.DataFrame({'id1':[ii.id1],'id2':[ii.id2],'Fn':[ii.phys.Fn],'Fv':[ii.phys.Fv], 'sCrit':[ii.phys.sCrit], 'normalForce0':[ii.phys.normalForce[0]], 'normalForce1':[ii.phys.normalForce[1]], 'normalForce2':[ii.phys.normalForce[2]], 'shearForce0':[ii.phys.shearForce[0]], 'shearForce1':[ii.phys.shearForce[1]], 'shearForce2':[ii.phys.shearForce[2]]})
    intrState = intrState.append(iiState,ignore_index = True)

intrState.to_csv('tmpIntrState.csv')


###


