import numpy as np
from Saturation import Child, Generation
import os
from glob import glob
from interface import cp2kInterface


def read_xyz(file_name):
    atoms = np.empty(0)
    coordinates = np.array([]).reshape((0, 3))
    with open(file_name, "r") as f:
        lines = f.readlines()

    for line in lines[2:]:
        split_line = line.strip().split()
        assert len(split_line) == 4, "Not a .xyz file, try again"
        coords = np.array(
            [float(split_line[1]), float(split_line[2]), float(split_line[3])]
        )

        atoms = np.append(atoms, str(split_line[0]))
        coordinates = np.vstack([coordinates, coords])
    return atoms, coordinates


def read_LAMMPS_data_file_full(file_name: str, use_default_dict=True):
    if use_default_dict:
        atom_numberings = {"1": "Si", "2": "O", "3": "H"}
    else:
        atom_numberings: dict = {}

    reading_file = False
    atoms = []
    xyz = []
    with open(file_name, "r") as f:
        for line in f:
            if "Atoms #" in line:
                reading_file = True
                continue

            if "Velocities" in line:
                reading_file = False
                break

            if reading_file:
                line = line.strip().split()
                if not line:
                    continue

                key = line[2]
                if key not in atom_numberings:
                    user_input = input(f"What is the atom corresponding to {key}?\n")
                    atom_numberings[key] = user_input

                atoms.append(str(atom_numberings[key]))
                xyz.append([float(line[4]), float(line[5]), float(line[6])])

    return np.array(atoms), np.array(xyz)


def silanol_density_check(next_gen, chosen_side, wanted_silanol_density):
    continue_break_bonds = True
    for parent in next_gen:
        match chosen_side.lower():
            case "top":
                current_silanol_density = parent.check_silanol_density("top")
                print(
                    "current silanol density at top of model is {:2f}".format(
                        current_silanol_density
                    ),
                    flush=True,
                )
                if current_silanol_density > wanted_silanol_density:
                    continue_break_bonds = False
            case "bottom":
                current_silanol_density = parent.check_silanol_density("bottom")
                print(
                    "current silanol density at bottom of model is {:2f}".format(
                        current_silanol_density
                    ),
                    flush=True,
                )
                if current_silanol_density > wanted_silanol_density:
                    continue_break_bonds = False
            case "both":
                current_silanol_density = parent.check_silanol_density("top")
                print(
                    "current silanol density at top of model is {:2f}".format(
                        current_silanol_density
                    ),
                    flush=True,
                )
                if current_silanol_density > wanted_silanol_density:
                    continue_break_bonds = False
                current_silanol_density = parent.check_silanol_density("bottom")
                print(
                    "current silanol density at bottom of model is {:2f}".format(
                        current_silanol_density
                    ),
                    flush=True,
                )
                if current_silanol_density > wanted_silanol_density:
                    continue_break_bonds = False
            case "total":
                current_silanol_density = parent.check_silanol_density("total")
                parent.probe_surface_rays(max_iter=5, from_top=True)
                parent.probe_surface_rays(max_iter=5, from_top=False)
                print(
                    "current silanol density at bottom of model is {:2f}".format(
                        current_silanol_density
                    ),
                    flush=True,
                )
                if current_silanol_density > wanted_silanol_density:
                    continue_break_bonds = False
            case _:
                raise ValueError("invalid side. Either: top, bottom, both")
    return continue_break_bonds


def pre_opt(strucs: Child, max_steps=50):
    cp2k = cp2kInterface("", "../matsci-0-3")
    for idx, struc in enumerate(strucs):
        print(f"optimizing child {idx+1}", flush=True)
        cp2k.set_xyz_strucutre(struc)
        cp2k.geo_opt(max_steps)
        struc.energy = cp2k.energy


def main():
    print("started", flush=True)
    # xyz_file_name = glob("*.data")[0]
    # atoms, xyz = read_LAMMPS_data_file_full(xyz_file_name)
    xyz_file_name = glob("332*.xyz")[0]
    # xyz_file_name = "test_struc_break_bonds.xyz"
    atoms, xyz = read_xyz(xyz_file_name)

    cutoffs = {
        "Si": {"Si": 2.2, "O": 2.0, "H": 1.0},
        "O": {"Si": 2.0, "O": 1.8, "H": 1.3},
        "H": {"Si": 1.0, "O": 1.3, "H": 1.0},
    }

    test = Child(atoms, xyz)
    test.set_cl([21.5, 21.5, 80])
    test.set_cut_offs((cutoffs))

    max_num_children = 7
    number_new_parents = 1
    next_gen = [test]
    # pre_opt(next_gen, max_steps=30)

    continue_sat = False
    for parent in next_gen:
        _, under_cn = parent.check_cn()
        for list_under_cn in under_cn.values():
            if len(list_under_cn) > 0:
                continue_sat = True

    while continue_sat:
        new_gen = Generation(max_num_children)
        for parent in next_gen:
            new_gen.add_parent(parent)

        print("saturating", flush=True)
        new_gen.make_gen_sat(max_steps=25)
        new_gen.dump_children()
        next_gen = new_gen.new_parents(number_new_parents)

        continue_sat = False  # set the conintue flag to False for the moment
        # go through all children and if there is one that is fully saturated stop
        for parent in next_gen:
            _, under_cn = parent.check_cn()
            for list_under_cn in under_cn.values():
                if len(list_under_cn) > 0:
                    continue_sat = True  # Set it to True if there are any remaining under-saturated atoms

    chosen_side = "both"
    wanted_silanol_density = 15
    continue_break_bonds = silanol_density_check(
        next_gen, "both", wanted_silanol_density
    )

    while continue_break_bonds:
        print("breaking bonds", flush=True)
        new_gen = Generation(max_num_children)
        for parent in next_gen:
            new_gen.add_parent(parent)

        new_gen.make_gen_break_bonds(max_steps=25, side=chosen_side)
        new_gen.dump_children()
        next_gen = new_gen.new_parents(number_new_parents)

        continue_break_bonds = silanol_density_check(
            next_gen, chosen_side, wanted_silanol_density
        )

    for parent in next_gen:
        parent.write_xyz_file("final_struc")

    print("done", flush=True)


if __name__ == "__main__":
    main()
