import subprocess
import os
from xyzClass import XyzStruc
import numpy as np

class cp2kInterface:
    #conversion factors
    def __init__(self, LIB_PATH, sk_file_path, xyz_structure: XyzStruc | None = None):
        self.bin_path = LIB_PATH
        self.sk_path = sk_file_path

        self.xyz_struc: XyzStruc | None = xyz_structure
        self.energy = None

    def set_xyz_strucutre(self, xyz_strucutre: XyzStruc):
        assert isinstance(xyz_strucutre, XyzStruc), f"Not giving the XyzStruc class, {type(xyz_strucutre)}"
        if self.xyz_struc != None:
            print("overwriting XyzStrcutre readied in optimizer", flush=True)
        self.xyz_struc = xyz_strucutre

    def geo_opt(self, max_steps):
        assert self.xyz_struc != None, "No xyz strucutre loaded"
        return_path = os.getcwd()
        exec_dir = os.path.dirname(os.path.abspath(__file__))
        os.chdir(exec_dir)
        try:
            os.mkdir("Geo_Opt")
        except FileExistsError:
            for root, dirs, files in os.walk("Geo_Opt", topdown=False):
                for file in files:
                    os.remove(os.path.join(root, file))
                for dir in dirs:
                    os.rmdir(os.path.join(root, dir))
        os.chdir("Geo_Opt")
        self.make_in_file_CP2K(max_steps)
        self.run_cp2k()

        try:
            new_atoms, new_xyz, E = self.read_child_opt("OPT-pos-1.xyz")
            self.xyz_struc.set_xyz_info(new_atoms, new_xyz)
            self.energy = E
        except:
            self.energy = 0
        os.chdir(return_path)

    def make_in_file_CP2K(self, max_steps):
        cl = self.xyz_struc.cell_lengths
        sk_file_path = self.sk_path
        sk_params = self.make_block_sk_param('           SK_FILE atom1 atom2 atom1-atom2.skf', sk_file_path)
        xyz = self.make_xyz_string()
        instructions = f""" &GLOBAL
   PRINT_LEVEL MEDIUM
   PROJECT_NAME OPT
   RUN_TYPE GEO_OPT
 &END GLOBAL
 
 &MOTION
   &GEO_OPT
     OPTIMIZER BFGS
     TYPE MINIMIZATION
     MAX_ITER = {max_steps}
   &END GEO_OPT
 &END MOTION
 
 &FORCE_EVAL
  METHOD QS
   &DFT
   &PRINT
     &MULLIKEN off
     &END MULLIKEN
     &HIRSHFELD off
     &END HIRSHFELD
   &END PRINT
     &MGRID
       CUTOFF 450
       REL_CUTOFF 60
       NGRIDS 4
     &END MGRID

     &QS
       METHOD DFTB
       &DFTB
         DISPERSION T
         DO_EWALD  T
         &PARAMETER
{sk_params}         &END PARAMETER
       &END DFTB
     &END QS
     
     &SCF
       SCF_GUESS ATOMIC
       EPS_SCF 0.1E-04
       &OT ON
         PRECONDITIONER FULL_SINGLE_INVERSE
         MINIMIZER DIIS
       &END

     &END SCF
     &POISSON
       POISSON_SOLVER PERIODIC
       PERIODIC XYZ
     &END POISSON
   &END DFT
   &SUBSYS
     &CELL
       A    {cl[0]}    0.0000000000000000    0.0000000000000000
       B    0.0000000000000000   {cl[1]}    0.000000000000000
       C    0.0000000000000000    0.0000000000000000   {cl[2]}
       PERIODIC  XYZ
       !MULTIPLE_UNIT_CELL  1 1 1
     &END CELL
     &COORD
{xyz}     &END COORD

     &PRINT
       &TOPOLOGY_INFO
         XYZ_INFO
       &END TOPOLOGY_INFO
     &END PRINT

   &END SUBSYS
 &END FORCE_EVAL"""
        with open("geo_opt.in", "w") as f:
            f.write(instructions)

    def make_block_sk_param(self, name_format, param_file_dir):
        output = f"	   PARAM_FILE_PATH {param_file_dir}\n"
        for atom1 in self.xyz_struc.atom_types:
            for atom2 in self.xyz_struc.atom_types:
                template = (name_format+" ")[:-1]
                output += template.replace('atom1', atom1).replace('atom2', atom2)
                output += "\n"
        return output

    def make_xyz_string(self):
        output = str()
        for atom, xyz in zip(self.xyz_struc.atoms, self.xyz_struc.xyz):
            output += f"{atom} {xyz[0]} {xyz[1]} {xyz[2]}\n"
        return output

    @staticmethod
    def run_cp2k():
        command = ['mpirun', '-np', '16', 'cp2k.popt', 'geo_opt.in']
        with open("geo_opt.out", "w") as outfile:
            process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
            for stdout_line in iter(process.stdout.readline, ''):
                outfile.write(stdout_line)
                outfile.flush()  # Ensure output is written immediately

            for stderr_line in iter(process.stderr.readline, ''):
                outfile.write(stderr_line)
                outfile.flush()  # Ensure error output is written immediately

            # Wait for the process to complete
            process.stdout.close()
            process.stderr.close()
            process.wait()
    
    @staticmethod
    def read_child_opt(file_name):
        """reads the optimized coordinates, returns coordinates in angstoms"""
        with open(file_name, "r") as f:
            lines = f.readlines()
        
        min_E = 0
        start_idx = None
        for idx, line in enumerate(lines):
            if " i =    " in line:
                E = line.strip().split()[-1]
                if float(E) <= min_E:
                    min_E = float(E)
                    start_idx = idx + 1

        reading_file = True
        
        def yield_line(lines, start_idx):
            for xyz in lines[start_idx:]:
                yield xyz.strip().split()
        
        all_atoms = []
        all_xyz = []
        line_gen = yield_line(lines, start_idx)
        while reading_file:
            try:
                xyz = next(line_gen)
                if len(xyz) == 1:
                    reading_file = False
                else:
                    all_atoms.append(xyz[0])
                    all_xyz.append([val for val in xyz[1:]])
            except StopIteration:
                reading_file = False
                    
        return np.array(all_atoms, dtype=np.str_), np.array(all_xyz, dtype=np.float64), min_E