from xyzClass import XyzStruc
from interface import cp2kInterface
import numpy as np
import random
import globals
import os
import random
from typing import Dict, List, Tuple

NORM_COORDINATION = globals.NORM_COORDINATION
O_H_BONDLENGTH = globals.O_H_BONDLENGTH
SI_O_BONDLENGTH = globals.SI_O_BONDLENGTH
OVER_POS = globals.OVER_POSITIVE
NO_DOUBLE_SATURATION = globals.NO_DOUBLE_SATURATION
PROBE_WITDH = globals.PROBE_WIDTH
PROBE_INTERVAL = globals.PROBE_INTERVAL
RADII = globals.RADII

class Child(XyzStruc):
    """
    Class for the storing and manipulating the data for the children made during the genetic algorithm.

    Inherits from:
      XyzStruc

    params:
      
    """
    def __init__(self, atoms, coordinates):
        super().__init__(atoms, coordinates)
        self.energy = None
        self.already_saturated = None

    def check_cn(self):
        """
        Checks the over an under coordination of all the atoms in the structure
        
        globals:
          NORM_COORDINATION: dictionary of all of atom types and their expected coordination.

        args: 
          -none-
        
        returns:
          over_cn (Dict[str, List[int]]) -- dict containing all of the atom typse that have overcoordinated versions of themselves
          idx repeats for each extra coordination. ie. 6-coordinate Si will have its index appear twice.
          undr_cn (Dict[str, List[int]]) -- dict containing all of the atom types that have undercoordinated versions of themselves
          idx repeats for each missing coordination. ie. 0-coordinate O will have its index appear twice.
        """
        all_cn = self.give_all_cn(CN_IdxAtom=True)
        over_cn = {}
        undr_cn = {}

        for atom_type in all_cn.keys():
            assert atom_type in NORM_COORDINATION, f"Trying to handle unknow atom type, {atom_type}"

            for cn, all_idx in all_cn[atom_type].items():
                if cn > NORM_COORDINATION[atom_type]:
                    if atom_type not in over_cn.keys():
                        over_cn[atom_type] = []
                    for _ in range(cn - NORM_COORDINATION[atom_type]):
                        for idx in all_idx:
                            over_cn[atom_type].append(idx)

                elif cn < NORM_COORDINATION[atom_type]:
                    if atom_type not in undr_cn.keys():
                        undr_cn[atom_type] = []
                    for _ in range(NORM_COORDINATION[atom_type] - cn):
                        for idx in all_idx:
                            undr_cn[atom_type].append(idx)

        return over_cn, undr_cn

    def possible_atoms(self, array_to_check):
        """
        
        """
        counts = np.bincount(self.chosen_idx)
        for element, count in enumerate(counts):
            if count > 0:
                indices = np.where(array_to_check == element)[0][:count]
                array_to_check = np.delete(array_to_check, indices)
        return array_to_check

    def _make_point(self, idx_ref_atom: int, bond_length: float, tolerance: float) -> np.ndarray:
        """
        Generates a cartesian point at a given distance away from a reference point and checks if any other
        atoms are within a given distance away from it.

        args:
          idx_ref_atom (int): Reference atom around which the point will be generated
          bond_length (float): Desired bond length for new atom
          tolerance (float): Minimum distance the new points must be away from all other atoms
        
        returns:
          np.ndarray: array of the cartesian points
        """
        idx_ref_atom = idx_ref_atom % len(self.atoms)
        chosen_xyz = self.xyz[idx_ref_atom]
        all_idx = np.arange(len(self.atoms))
        mic_xyz = self.mic(idx_ref_atom, all_idx)
        mic_xyz_abs = np.abs(mic_xyz)

        mask_chosen_atom = all_idx != idx_ref_atom
        mask_x = mic_xyz_abs[:, 0] <= 5.0
        mask_y = mic_xyz_abs[:, 1] <= 5.0
        mask_z = mic_xyz_abs[:, 2] <= 5.0

        # Add the values of chosen_xyz to move the mic points back to their "origional"/"intended" positions.
        close_xyz = mic_xyz[np.logical_and.reduce([mask_chosen_atom, mask_x, mask_y, mask_z])] + chosen_xyz

        range_theta = np.linspace(0, 2 * np.pi, 250)
        range_phi = np.linspace(0, 2 * np.pi, 250)
        mesh_theta, mesh_phi = np.meshgrid(range_theta, range_phi, indexing='xy')

        r = bond_length
        x = r * np.sin(mesh_phi) * np.cos(mesh_theta)
        y = r * np.sin(mesh_phi) * np.sin(mesh_theta)
        z = r * np.cos(mesh_phi)
        chosen_xyz = chosen_xyz[:, np.newaxis, np.newaxis]
        new_points = np.array((chosen_xyz + np.array([x, y, z])))

        dists = np.linalg.norm(close_xyz[:, :, np.newaxis, np.newaxis] - new_points, axis=1)
        dist_points = np.min(dists, axis=0)
        mask = np.where(dist_points >= tolerance, True, False)
        possible_points = new_points.transpose(1, 2, 0)[mask]

        # For the edge case that there are two points which are have the same distance from the bulk.
        if len(possible_points.flatten()) == 3:
            # only one possible point here
            new_point = possible_points
        else:
            # have no tolerable points so choose the one farthestaway from everything
            mask = np.where(dist_points == np.max(dist_points), True, False)

            # make sure that there is only one maximum
            if np.sum(mask) > 1:
                new_point = new_points.transpose(1, 2, 0)[mask]
                new_point = new_point[0]
            else:
                new_point = new_points.transpose(1, 2, 0)[mask]
            # print(f"point taken lower than tolerance by {tolerance - np.max(dist_points)} angstrom")

        return new_point.flatten()

    def place_new_atom(self, idx_ref_atom: int, new_atom_type: str, bond_length: float, tolerance:float=1.5) -> None:
        """
        Places a new instance of a given atom on the structure

        args:
          idx_ref_atom (int): Reference atom around which the point will be generated
          new_atom_type (str): The chemical symbol of the new atom
          bond_length (float): Desired bond length for new atom
          tolerance (float): Minimum distance the new points must be away from all other atoms
        """
        assert idx_ref_atom <= len(self.atoms), f"Trying to access an atom that does not exist. {idx_ref_atom=}, {len(self.atoms)=}"
        new_point = self._make_point(idx_ref_atom, bond_length, tolerance)
        self.add_new_atom(new_atom_type, new_point)

    def move_atom(self, idx_move: int, idx_ref: int, new_dist: float) -> None:
        """
        Move an already existing atom in space to a given distance away from another atom

        args:
         idx_move (int): Index of the atom which will be moved
         idx_ref (int): Index of the atom which will be used to reference distance
         new_dist (float): The new distance between the atoms
        """
        atom_move = self.atoms[idx_move]
        xyz_ref = self.xyz[idx_ref]

        direction = self.mic(idx_move, idx_ref)
        unit_direction = direction / np.linalg.norm(direction)
        new_point = xyz_ref - new_dist * unit_direction
        self.replace_atom(idx_move, atom_move, new_point) 

    def break_bond(self, idx_atom_move: int, idx_atom_ref: int) -> None:
        """
        Method to break the bond between two atom through rotating the position of one.

        args:
          idx_atom_move (int): Index of the atom that will be moved.
          idx_atom_ref (int): Atom to keep fixed as the reference point
        """
        new_point = self._make_point(idx_atom_ref, SI_O_BONDLENGTH, 2.01)
        self.replace_atom(idx_atom_move, self.atoms[idx_atom_move], new_point)

    def check_bound_to_under_coord(self, idx_check_atom: int) -> bool:
        """
        Method to check if an atom is bound to an under-coordinated atom

        args: 
          idx_check_atom (int): Index of the atom for which the method will check
        
        return:
          bool: whether the atom has any neighbors which are under-coordinated
        """
        cn_atoms = self.give_all_cn(IdxAtom_CN=True)
        bound_to_under_coord = False
        for idx_neighbor in self.give_nl(idx_check_atom):
            atom_neighbor = self.atoms[idx_neighbor]
            cn_neighbor = cn_atoms[atom_neighbor][idx_neighbor]

            if cn_neighbor < NORM_COORDINATION[atom_neighbor]:
                bound_to_under_coord = True
        return bound_to_under_coord

    def over_pos(self, idx_atom_chosen):
        chosen_atom = self.atoms[idx_atom_chosen]
        chosen_atom_cn = self.give_cn(idx_atom_chosen)
        assert OVER_POS[chosen_atom] is True, f"{chosen_atom} is negativly charged when under-coordinated"
        assert chosen_atom_cn > NORM_COORDINATION[chosen_atom]
        if self.already_saturated is None:
            self.set_already_saturated()

        if chosen_atom_cn <= 3:
            nl = self.give_nl(idx_atom_chosen)

            # filter out the atoms which are over or under-coordinated
            filtered_nl = []
            for idx_neighbor in nl:
                if self.give_cn(idx_neighbor) == NORM_COORDINATION[self.atoms[idx_neighbor]] and idx_neighbor not in self.already_saturated and not self.check_bound_to_under_coord(idx_neighbor):
                    filtered_nl.append(idx_neighbor)
            if len(filtered_nl) == 0:
                filtered_nl = [idx for idx in nl if self.give_cn(idx) == NORM_COORDINATION[self.atoms[idx]] and not self.check_bound_to_under_coord(idx)]
            if len(filtered_nl) == 0:
                filtered_nl = [idx for idx in nl]
            filtered_nl = np.array(filtered_nl, dtype=np.int32)

            if len(filtered_nl) == 1:
                idx_atom_ref = filtered_nl[0]
            else:
                nl_mic = self.mic(idx_atom_chosen, filtered_nl)
                dists = np.linalg.norm(nl_mic, axis=1)
                idx_atom_ref = filtered_nl[dists == np.max(dists)][0]

            self.move_atom(idx_atom_ref, idx_atom_chosen, new_dist = 2.75) # Go back and rename all of the variables to make this read logically
            self.place_new_atom(idx_atom_ref, "O", SI_O_BONDLENGTH, 2.01)
            self.place_new_atom(-1, "H", O_H_BONDLENGTH, 1.5)

        else:
            nl = self.give_nl(idx_atom_chosen)

            # filter out the atoms which are over or under-coordinated
            filtered_nl = []
            for idx_neighbor in nl:
                if self.give_cn(idx_neighbor) == NORM_COORDINATION[self.atoms[idx_neighbor]]:
                    filtered_nl.append(idx_neighbor)
            if len(filtered_nl) == 0:
                filtered_nl = [idx for idx in nl]
            filtered_nl = np.array(filtered_nl, dtype=np.int32)

            if len(filtered_nl) == 1:
                idx_atom_move = filtered_nl[0]
            else:
                xyz_mic = self.mic(idx_atom_chosen, nl)
                dists = np.linalg.norm(xyz_mic, axis=1)
                idx_atom_move = nl[dists == np.max(dists)][0]
            idx_atom_ref = random.choice(self.give_nl(idx_atom_move)[self.give_nl(idx_atom_move) != idx_atom_chosen])

            self.break_bond(idx_atom_move, idx_atom_ref)
            self.place_new_atom(idx_atom_move, "O", SI_O_BONDLENGTH, 2.01)
            self.place_new_atom(-1, "H", O_H_BONDLENGTH, 1.5)

    def over_neg(self, idx_atom_chosen):
        chosen_atom = self.atoms[idx_atom_chosen]
        chosen_atom_cn = self.give_cn(idx_atom_chosen)
        assert OVER_POS[chosen_atom] is False, f"{chosen_atom} is positivly charged when under-coordinated"
        assert chosen_atom_cn > NORM_COORDINATION[chosen_atom]
        if self.already_saturated is None:
            self.set_already_saturated()

        if chosen_atom_cn <= 3:
            nl = self.give_nl(idx_atom_chosen)

            # filter out the atoms which are over or under-coordinated
            filtered_nl = []
            for idx_neighbor in nl:
                if self.give_cn(idx_neighbor) == NORM_COORDINATION[self.atoms[idx_neighbor]] and idx_neighbor not in self.already_saturated and not self.check_bound_to_under_coord(idx_neighbor):
                    filtered_nl.append(idx_neighbor)
            if len(filtered_nl) == 0:
                filtered_nl = [idx for idx in nl if self.give_cn(idx) == NORM_COORDINATION[self.atoms[idx]] and not self.check_bound_to_under_coord(idx)]
            if len(filtered_nl) == 0:
                filtered_nl = [idx for idx in nl]
            filtered_nl = np.array(filtered_nl, dtype=np.int32)

            if len(filtered_nl) == 1:
                idx_atom_ref = filtered_nl[0]
            else:
                nl_mic = self.mic(idx_atom_chosen, filtered_nl)
                dists = np.linalg.norm(nl_mic, axis=1)
                idx_atom_ref = filtered_nl[dists == np.max(dists)][0]

            self.move_atom(idx_atom_ref, idx_atom_chosen, new_dist = 2.75) # Go back and rename all of the variables to make this read logically
            self.place_new_atom(idx_atom_ref, "H", O_H_BONDLENGTH, 1.5)

        else:
            nl = self.give_nl(idx_atom_chosen)

            # filter out the atoms which are over or under-coordinated
            filtered_nl = []
            for idx_neighbor in nl:
                if self.give_cn(idx_neighbor) == NORM_COORDINATION[self.atoms[idx_neighbor]]:
                    filtered_nl.append(idx_neighbor)
            if len(filtered_nl) == 0:
                filtered_nl = [idx for idx in nl]
            filtered_nl = np.array(filtered_nl, dtype=np.int32)

            if len(filtered_nl) == 1:
                idx_atom_move = filtered_nl[0]
            else:
                xyz_mic = self.mic(idx_atom_chosen, nl)
                dists = np.linalg.norm(xyz_mic, axis=1)
                idx_atom_move = nl[dists == np.max(dists)][0]
            idx_atom_ref = random.choice(self.give_nl(idx_atom_move)[self.give_nl(idx_atom_move) != idx_atom_chosen])

            self.break_bond(idx_atom_move, idx_atom_ref)
            self.place_new_atom(idx_atom_move, "H", O_H_BONDLENGTH, 1.5)

    def undr_pos(self, idx_atom_chosen: int) -> None:
        """
        Method to saturate an under-coordinated atom which has a postive formal charge with a -OH fragment

        args:
          idx_atom_chosen (int): atom to bind the -OH fragment to

        raises:
          AssertError: If the atom indexed is not under-coordinated
          AssertError: If the atom indexed is not considered negatively charged when under-coordinated
        """
        chosen_atom = self.atoms[idx_atom_chosen]
        chosen_atom_cn = self.give_cn(idx_atom_chosen)
        assert not OVER_POS[chosen_atom] is True, f"{chosen_atom} is negativly charged when under-coordinated"
        assert chosen_atom_cn < NORM_COORDINATION[chosen_atom]

        self.place_new_atom(idx_atom_chosen, "O", SI_O_BONDLENGTH, tolerance=2.0)
        self.place_new_atom(-1, "H", O_H_BONDLENGTH, tolerance=1.5)

    def undr_neg(self, idx_atom_chosen: int) -> None:
        """
        Method to saturate an under-coordinated atom which has a negative formal charge with a -H fragment

        args:
          idx_atom_chosen (int): atom to bind the -H fragment to

        raises:
          AssertError: If the atom indexed is not under-coordinated
          AssertError: If the atom indexed is not considered positively charged when under-coordinated
        """
        chosen_atom = self.atoms[idx_atom_chosen]
        chosen_atom_cn = self.give_cn(idx_atom_chosen)
        assert not OVER_POS[chosen_atom] is False, f"{chosen_atom} is positivly charged when under-coordinated"
        assert chosen_atom_cn < NORM_COORDINATION[chosen_atom]

        self.place_new_atom(idx_atom_chosen, "H", O_H_BONDLENGTH, tolerance=1.5)

    def set_already_saturated(self) -> None:
        """
        Method to set the "already_saturated" parameter of the child to avoid adding excessive fragment to an atom.
        Done through checking the first and second closest neighbors of each H atom in the structure.
        """
        idx_H_atoms = np.where(self.atoms == "H")[0]
        output = np.empty(0, dtype=int)

        for idx_H in idx_H_atoms:
            idx_neighbors_H = self.give_nl(idx_H)
            for idx_first_closest in idx_neighbors_H:
                output = np.append(output, idx_first_closest)
                for idx_second_closest in self.give_nl(idx_first_closest):
                    output = np.append(output, idx_second_closest)
        self.already_saturated = output[self.atoms[output] != "H"]

    def ring_paths(self, size: int):
        neighbors = {}
        all_paths = []

        for idx_atom, _ in enumerate(self.atoms):
            neighbors[idx_atom] = self.give_nl(idx_atom)

        def find_paths(neighbors, current_node, depth, current_path, all_paths):
            if depth == 0:
                all_paths.append(current_path)
                return

            for neighbor in neighbors.get(current_node, []):
                find_paths(neighbors, neighbor, depth - 1, current_path + [neighbor], all_paths)

        def get_all_paths(neighbors, start_node, max_depth):
            all_paths = []
            for neighbor in neighbors.get(start_node, []):
                find_paths(neighbors, neighbor, max_depth - 1, [neighbor], all_paths)
            return all_paths

        def filter_paths(all_paths_from_node, start_node):
            filter_1 = [path for path in all_paths_from_node if path[-1] == start_node]
            filter_2 = [path for path in filter_1 if len(set(path))==len(path)]
            return filter_2

        for node in neighbors.keys():
            all_paths_from_node = get_all_paths(neighbors, node, size)
            for path in filter_paths(all_paths_from_node, node):
                all_paths.append(path)

        all_paths_sorted = [sorted(path) for path in all_paths]
        all_unique_paths = np.array(list(map(list, set(map(tuple, all_paths_sorted)))), dtype=int)
        return all_unique_paths

    def idx_in_rings(self, size: int = None, paths = None):
        if size is None and paths is None:
            raise AssertionError("Need to choose one option. Either set the size of the ring or provide paths.")

        if not size is None:
            paths = self.ring_paths(size)
        if not paths is None:
            paths = paths

        flattened = paths.flatten()
        all_Si_in_ring = flattened[self.atoms[flattened] == "Si"]
        all_O_in_ring = flattened[self.atoms[flattened] == "O"]
        return all_Si_in_ring, all_O_in_ring

    def check_silanol_density(self, side: str): 
        def check_top():      
            # self.probe_surface(PROBE_WITDH, PROBE_INTERVAL, from_top=True)
            self.probe_surface_rays(max_iter=7, from_top=True)
            surface_Si_top = [idx for idx in self.idx_surface_atoms_top if self.atoms[idx] == "Si"]
            surface_O_top = [idx for idx in self.idx_surface_atoms_top if self.atoms[idx] == "O"]

            for idx_Si in surface_Si_top:
                for idx_neighbor in self.give_nl(idx_Si):
                    surface_O_top.append(idx_neighbor)

            idx_H = set()
            surface_O_top = np.unique(surface_O_top)
            for idx_O in surface_O_top:
                for idx_neighbor in self.give_nl(idx_O):
                    idx_H.add(idx_neighbor)
            idx_H = [idx for idx in idx_H if self.atoms[idx]=="H"]
            return np.array(idx_H, dtype=int)

        def check_bottom():
            # self.probe_surface(PROBE_WITDH, PROBE_INTERVAL, from_top=False)
            self.probe_surface_rays(max_iter=7, from_top=False)
            surface_Si_bot = [idx for idx in self.idx_surface_atoms_bot if self.atoms[idx] == "Si"]
            surface_O_bot = [idx for idx in self.idx_surface_atoms_bot if self.atoms[idx] == "O"]
            for idx_Si in surface_Si_bot:
                for idx_neighbor in self.give_nl(idx_Si):
                    surface_O_bot.append(idx_neighbor)

            idx_H = set()
            surface_O_bot = np.unique(surface_O_bot)
            for idx_O in surface_O_bot:
                for idx_neighbor in self.give_nl(idx_O):
                    idx_H.add(idx_neighbor)
            idx_H = [idx for idx in idx_H if self.atoms[idx]=="H"]
            return np.array(idx_H, dtype=int)

        match side.lower():
            case "top":
                idx_H = check_top()
                silanol_density = len(idx_H) / (
                    (0.1 * self.cell_lengths[0]) * (0.1 * self.cell_lengths[1])
                )

            case "bottom":
                idx_H = check_bottom()
                silanol_density = len(idx_H) / ((0.1*self.cell_lengths[0]) * (0.1*self.cell_lengths[1]))

            case "both":
                idx_H_top = check_bottom()
                idx_H_bottom = check_bottom()

                idx_H = np.unique(np.append(idx_H_top, idx_H_bottom))
                silanol_density = len(idx_H) / ((0.1*self.cell_lengths[0]) * (0.1*self.cell_lengths[1]))

            case "total":
                num_H = [idx for idx, atom in enumerate(self.atoms) if atom == "H"]
                self.get_surface_area([173,173,400], RADII)
                silanol_density = len(num_H) / (0.01*self.total_surface_area)
            case _:
                raise ValueError("invalid side. Either: top, bottom, both")

        return silanol_density


class Generation:
    num_generation = 0
    def __init__(self, max_children):
        self.parents = {}
        self.number_parents = 0

        self.children = {}
        self.max_children = max_children
        self.number_children = 0

        Generation.num_generation += 1

    def add_parent(self, new_parent: Child):
        """
        Adds a new parent to the generation. Into a dictionary of all parents.
        Will take it an name it according to the order in which is was added.
        First parent is 'parent_1' and second will be 'parent_2'.

        args:
          new_parent (Child): The new parent for the generation

        raises:
          AssertionError: if the new_parent is not of type child.
        """
        assert isinstance(new_parent, Child), f"Not putting in the correct type as new parent, {type(new_parent)}"
        self.number_parents += 1
        parent_name = f"parent_{self.number_parents}"
        self.parents[parent_name] = new_parent

    def get_parent(self, parent_name: str) -> Child:
        """
        Getter method for extracting a given parent from the parent distionary.

        args:
          parent_name (str): The name of the parent structure to return
        """
        return self.parents[parent_name]

    def get_child(self, child_name: str) -> Child:
        """
        Getter method for the getting of a given child from the generation.

        args:
          child_name (str): Name of the child to get
        """
        return self.children[child_name]

    def choose_parent(self) -> Child:
        """
        Method to choose a parent structure to make a new child structure. Will do this based on the Boltzmann-weighted
        electronic energy of the structures. There is no use of a reference state so it will only truly work for structures
        of identical compostition/number of electrons. Good for this case # TODO:Add the option of a reference state.

        return:
          A structure from the current 'parents' dictionary.
        """
        assert self.number_parents != 0, "Choosing parents before there are any"

        if self.number_parents == 1:
            weightings = [1]
        else:
            energies = []
            for _, parent in self.parents.items():
                if parent.energy is None:
                    print("No energy set for given parent")
                    energies.append(0)
                else:
                    energies.append(parent.energy)
            weightings = np.exp(-np.array(energies).astype(float))
            weightings /= np.sum(weightings)

        # parents = self.parents.items()
        chosen_parent = np.random.choice(list(self.parents), p=weightings)
        return self.get_parent(chosen_parent)

    def new_parents(self, number_new_parents: int) -> List[Child]:
        """
        Method for choosing the new parents for the next generation. Does this based on calculated energies of the child structures
        and takes the lowest energy structures according to the value passed for `number_new_parents`.
        If a child does not have a calculated energy then it will be assumed to be 0.

        args:
          number_new_parents (int): The number of structures to return.

        returns:
          A list of the parent structures for the next generation.
        """
        new_parents = []

        names = []
        energies = []
        for child_name, child in self.children.items():
            names.append(child_name)
            if child.energy is None:
                energies.append(0)
            else:
                energies.append(child.energy)

        combined = list(zip(energies, names))
        _, sorted_children = zip(*sorted(combined))
        for child in sorted_children[:number_new_parents]:
            new_parents.append(self.get_child(child))
        return new_parents

    def add_child(self, xyz_structure: Child):
        """
        Structure to add a new child structure to the class. Will name the strucutre based on the current
        number of strucutres already put in and the current number of generations which have been made.

        args:
          xyz_structure (Child): New structure to add.
        """
        self.number_children += 1
        child_name = f"child_gen_{Generation.num_generation}_num_{self.number_children}"
        self.children[child_name] = xyz_structure

    def dump_children(self):
        """
        Method to dump all of the xyz structures of the children within the current generation.
        Will put their energy into the comment line of the xyz file.
        """
        with open(f"all_children_gen{Generation.num_generation}.xyz", "w") as f:
            for child in self.children.values():
                if child.energy is None:
                    child.energy = "not optimized"

                atoms, coordinates = (child.atoms, child.xyz)
                f.write(f"{len(atoms)}\n")
                f.write(f"{child.energy}\n")
                for atom, coords in zip(atoms, coordinates):
                    f.write(f"{atom}")
                    for coord in coords:
                        f.write(f"\t{coord}")
                    f.write("\n")

    def make_gen_sat(self, max_steps: int=50):
        """
        Method of making a new generation of structures. Will make children until it had made as many as defined as the
        maximum number of children per generation.
        Chooses a parent, determines if a under- or over-coordinated atom will be used to fulfill saturation. Under-coordinated
        atoms get priority. Then chooses the indicies of the atoms that will be used for saturation and add H2O according to
        internal logic. Finishes with optimizing the geometry of the child.

        args:
          max_steps (int): maximum number of geometry optmization steps
        """
        for _ in range(self.max_children):
            # choose the parent and initialize new child
            chosen_parent: Child = self.choose_parent()
            atoms, xyz = chosen_parent.atoms, chosen_parent.xyz

            new_child = Child(atoms, xyz)
            new_child.set_cl(chosen_parent.cell_lengths)
            new_child.set_cut_offs(chosen_parent.cut_offs)
            new_child.set_already_saturated()

            # find which atoms are over an under coordinated
            over_cn, undr_cn = new_child.check_cn()
            remaining_undr_pos = 0
            remaining_undr_neg = 0
            # find if there are remaining atoms of negtive and positive charge which are under coordinated
            for atom_type, atom_idx in undr_cn.items():
                if len(atom_idx) != 0:
                    if not OVER_POS[atom_type]:
                        remaining_undr_pos += 1
                    else:
                        remaining_undr_neg += 1

            # assign the label of over- or under-coordinated to the postively and negatively charged atom
            pos_atom = None
            match remaining_undr_pos:
                case 0:
                    pos_atom = "over"
                case _:
                    pos_atom = "undr"

            neg_atom = None
            match remaining_undr_neg:
                case 0:
                    neg_atom = "over"
                case _:
                    neg_atom = "undr"

            # chosen the correct over/under-coordinated atom
            pos_atom_idx, neg_atom_idx = self.choose_atoms(pos_atom, neg_atom, over_cn, undr_cn, new_child.already_saturated)

            self.saturate(new_child, pos_atom_idx, pos_atom, neg_atom_idx, neg_atom)
            new_child.update_atom_types()
            new_child.wrap_coordinates()
            self.add_child(new_child)
            self.dump_children()

        if pos_atom == "undr" and neg_atom == "undr":
            # optimizing it less to save time
            self.opt_children(int(np.ceil(0.5*max_steps)))
        else:
            self.opt_children(max_steps)
        print(f"finished generation {self.num_generation}")

    @staticmethod
    def saturate(new_child: Child, pos_atom_idx: int, pos_atom: str, neg_atom_idx: int, neg_atom: str):
        """
        Static helper method for saturating new children with control flow for whether either is over or under coordinated.
        'undr' for under-coordinated and 'over' for over-coordinated atom.

        args:
          pos_atom_idx (int): Indexing of the positive atom to saturate
          pos_atom (string): Whether the positive atom is over- or under-coordinated
          neg_atom_idx (int): Indexing of the negative atom to saturate
          neg_atom (string): Whether the negative atom is over- or under-coordinated

        raises:
          ValueError: The program is not supposed to two over-coordinated atoms.
        """
        print(f'positive atom: {pos_atom};\tneagtive atom:{neg_atom}')
        match (pos_atom, neg_atom):
            case ("undr", "undr"):
                new_child.undr_pos(pos_atom_idx)
                new_child.undr_neg(neg_atom_idx)

            case("undr", "over"):
                new_child.undr_pos(pos_atom_idx)
                new_child.over_neg(neg_atom_idx)

            case("over", "undr"):
                new_child.over_pos(pos_atom_idx)
                new_child.undr_neg(neg_atom_idx)

            case("over", "over"):
                raise ValueError

    @staticmethod
    def split_pos_neg(over_cn: dict, undr_cn: dict) -> tuple[list[int], list[int]]:
        """
        Static method to splits the over- and under-coordinated atom distionaries into whether they are considered positively or
        negatively charged.

        args:
          over_cn: dictionary of the indicies of the over-coordinated atoms.
          undr_cn: dictionary of the indicies of the under-coordinated atoms.

        returns:
          Two list, one of the indicies of all positive atoms and of the indicies of all negative atoms present.
        """
        pos_atoms = []
        neg_atoms = []

        for atom_type, atom_idx in over_cn.items():
            if OVER_POS[atom_type]:
                for idx in atom_idx:
                    pos_atoms.append(idx)
            else:
                for idx in atom_idx:
                    neg_atoms.append(idx)

        for atom_type, atom_idx in undr_cn.items():
            if not OVER_POS[atom_type]:
                for idx in atom_idx:
                    pos_atoms.append(idx)
            else:
                for idx in atom_idx:
                    neg_atoms.append(idx)

        return pos_atoms, neg_atoms

    @staticmethod
    def choose_atoms(
        pos_atom: str, 
        neg_atom: str, 
        over_cn: Dict[str, List[int]], 
        undr_cn: Dict[str, List[int]], 
        already_saturated: List[int]
        ) -> Tuple[int, int]:
        """
        Static method to choose the specific index of the atom to saturate depending on if we will saturate an over- or under-coordinated atoms
        for the 'formally positive' and 'formally negative' atoms. If the flag for no doubles `NO_DOUBLE_SATURATION` is set to True it will
        filter out the indicies which have already been saturated once from the list.

        args:
          pos_atom (str): flag for it the positive atom will have to be over- or under-coordinated. Should be either 'undr' or 'over'
          neg_atom (str): flag for it the negative atom will have to be over- or under-coordinated. Should be either 'undr' or 'over'
          over_cn (Dict[str, List[int]]): Dictionary of the indices of all over-coordinated atoms within the structure organized as
                                          key-atom type, value-list of indicies of over-coordinated atoms of that type.
          undr_cn (Dict[str, List[int]]): Dictionary of the indices of all under-coordinated atoms within the structure organized as
                                          key-atom type, value-list of indicies of under-coordinated atoms of that type.
          already_saturated (List[int]): List of all the indicies of atoms which have already been saturated.

        returns:
          A tuple of the index of the positive and negative atoms

        raises:
          ValueError: if pos_atom or neg_atom have an incorrect string for the match function.
        """
        pos_atom_idx = None
        neg_atom_idx = None
        match pos_atom:
            case "undr":
                possible_choice = []
                for atom_type, atom_idx in undr_cn.items():
                    possible_choice.append(
                        [idx for idx in atom_idx if not OVER_POS[atom_type]]
                    )
                flattened = [item for sub_list in possible_choice for item in sub_list]
                pos_atom_idx = random.choice(flattened)

            case "over":
                possible_choice = []
                for atom_type, atom_idx in over_cn.items():
                    possible_choice.append([idx for idx in atom_idx if OVER_POS[atom_type]])
                flattened = [item for sub_list in possible_choice for item in sub_list]

                if NO_DOUBLE_SATURATION:
                    choices = [idx for idx in flattened if idx not in already_saturated]
                    if len(choices) == 0:
                        choices == flattened
                else:
                    choices = flattened
                pos_atom_idx = random.choice(choices)
            case _:
                raise ValueError("Something went wrong")

        match neg_atom:
            case "undr":
                possible_choice = []
                for atom_type, atom_idx in undr_cn.items():
                    possible_choice.append(
                        [idx for idx in atom_idx if OVER_POS[atom_type]]
                    )
                flattened = [item for sub_list in possible_choice for item in sub_list]
                neg_atom_idx = random.choice(flattened)

            case "over":
                possible_choice = []
                for atom_type, atom_idx in over_cn.items():
                    possible_choice.append([idx for idx in atom_idx if not OVER_POS[atom_type]])
                flattened = [item for sub_list in possible_choice for item in sub_list]

                if NO_DOUBLE_SATURATION:
                    choices = [idx for idx in flattened if idx not in already_saturated]
                    if len(choices) == 0:
                        choices == flattened
                else:
                    choices = flattened
                neg_atom_idx = random.choice(choices)
            case _:
                raise ValueError("Something went wrong")

        return pos_atom_idx, neg_atom_idx

    def make_gen_break_bonds(self, side: str, max_steps=50, ring_weighting = None):
        """NOT RECOMMENED TO USE"""
        for _ in range(self.max_children):
            # choose the parent and initialize new child
            chosen_parent: Child = self.choose_parent()
            atoms, xyz = chosen_parent.atoms, chosen_parent.xyz

            new_child = Child(atoms, xyz)
            new_child.set_cl(chosen_parent.cell_lengths)
            new_child.set_cut_offs(chosen_parent.cut_offs)
            new_child.set_already_saturated()
            all_cn = new_child.give_all_cn(IdxAtom_CN=True)

            match side.lower():
                case "top":
                    if chosen_parent.idx_surface_atoms_top is None:
                        chosen_parent.probe_surface_rays(max_iter=10, from_top=True)
                    new_child.idx_surface_atoms_top = (
                        chosen_parent.idx_surface_atoms_top.copy()
                    )
                    possible_Si = [
                        idx
                        for idx in new_child.idx_surface_atoms_top
                        if new_child.atoms[idx] == "Si"
                    ]

                case "bottom":
                    if chosen_parent.idx_surface_atoms_bot is None:
                        chosen_parent.probe_surface_rays(max_iter = 10, from_top=False)
                    new_child.idx_surface_atoms_bot = chosen_parent.idx_surface_atoms_bot.copy()
                    possible_Si = [idx for idx in new_child.idx_surface_atoms_bot if new_child.atoms[idx] == "Si"]

                case "both":
                    if chosen_parent.idx_surface_atoms_top is None:
                        chosen_parent.probe_surface_rays(max_iter = 10, from_top = True)
                    new_child.idx_surface_atoms_top = chosen_parent.idx_surface_atoms_top.copy()
                    if chosen_parent.idx_surface_atoms_bot is None:
                        chosen_parent.probe_surface_rays(max_iter = 10, from_top=False)
                    new_child.idx_surface_atoms_bot = chosen_parent.idx_surface_atoms_bot.copy()
                    possible_Si = [idx for idx in set(new_child.idx_surface_atoms_top).union(set(new_child.idx_surface_atoms_bot)) if new_child.atoms[idx] == "Si"]

                case "total":
                    possible_Si = [idx for idx, atom in enumerate(self.atoms) if atom == "Si"]

                case _:
                    raise ValueError("invalid side. Either: top, bottom, both")
            possible_Si = [idx for idx in possible_Si if idx not in new_child.already_saturated]
            possible_Si = [idx for idx in possible_Si if all_cn["Si"][idx] == 4]

            idx_Si_rings, idx_O_rings = new_child.idx_in_rings(size=4)
            possible_Si_ring = [idx for idx in possible_Si if idx in idx_Si_rings]

            ring_weighting = 1 - np.exp(-0.2 * len(possible_Si_ring)) # results in about a 50/50 chance of chosing a ring if there are 2 left
            chosing_ring = np.random.choice([True, False], p=[ring_weighting, 1-ring_weighting])

            if chosing_ring and len(possible_Si_ring) != 0:   
                iter = 0
                ref_Si = None
                while ref_Si is None and iter <= 15:
                    stationary_Si = np.random.choice(possible_Si_ring)
                    chosen_O = [idx for idx in new_child.give_nl(stationary_Si) if idx in idx_O_rings and all_cn["O"][idx]==2]
                    if len(chosen_O) == 0:
                        chosen_O = [idx for idx in new_child.give_nl(stationary_Si) if all_cn["O"][idx]==2]
                    chosen_O = np.random.choice(chosen_O)
                    ref_Si = [idx for idx in new_child.give_nl(chosen_O) if idx != stationary_Si]
                    ref_Si = np.random.choice(ref_Si) #there will always be atleast one element in the list but there can be more. makes next line easier
                    if chosen_O in new_child.already_saturated or ref_Si in new_child.already_saturated and iter <= 14:
                        ref_Si = None
                    iter += 1

            else:
                iter = 0
                ref_Si = None
                while ref_Si is None and iter <= 15:
                    stationary_Si = np.random.choice(possible_Si)
                    chosen_O = [idx for idx in new_child.give_nl(stationary_Si) if all_cn["O"][idx]==2]
                    if len(chosen_O) == 0:
                        chosen_O = [idx for idx in new_child.give_nl(stationary_Si)]
                    chosen_O = np.random.choice(chosen_O)
                    ref_Si = [idx for idx in new_child.give_nl(chosen_O) if idx != stationary_Si]
                    ref_Si = np.random.choice(ref_Si) #there will always be atleast one element in the list but there can be more. makes next line easier
                    if chosen_O in new_child.already_saturated or ref_Si in new_child.already_saturated and iter <= 14:
                        ref_Si = None
                    iter += 1

            new_child.break_bond(chosen_O, ref_Si)
            new_child.place_new_atom(stationary_Si, "O", SI_O_BONDLENGTH, 2.4)
            new_child.place_new_atom(-1, "H", O_H_BONDLENGTH, 1.5)
            new_child.place_new_atom(chosen_O, "H", O_H_BONDLENGTH, 1.5)

            # new_child.wrap_coordinates()
            self.add_child(new_child)
        self.dump_children()
        self.opt_children(max_steps)
        ...

    def opt_children(self, max_steps: int) -> None:
        """
        Method to begin optmizing the children in the current geenration.

        args: 
          max_steps (int): maximum number of optmization steps to take for the geometry optmization.
        """
        cp2k = cp2kInterface("", "../matsci-0-3")
        for idx, child in enumerate(self.children.values()):
            print(f"optimizing child {idx+1}", flush=True)
            cp2k.set_xyz_strucutre(child)
            cp2k.geo_opt(max_steps)
            child.energy = cp2k.energ
