from config import *
from ParentSelection import *
from ReproductionStrategy import *
from FitnessEvaluator import *
from Candidate import Candidate
from Visualisation import Visualisation # just for testing, change later to original file

import pandas as pd
from typing import List
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)

class Population:
    def __init__(self, mutator, SELECTION_STRATEGY, REPRODUCTION_STRATEGY, visualisation: bool = False):
        logger.info("Initialising Population with...")
        self.mutator = mutator
        logger.info("Mutator: {}".format(mutator))
        self.selection_strategy = SELECTION_STRATEGY
        logger.info("Selection Strategy: {}".format(SELECTION_STRATEGY.__name__))
        self.reproduction_strategy = REPRODUCTION_STRATEGY
        logger.info("Reproduction Strategy: {}".format(REPRODUCTION_STRATEGY.__name__))

        self.visualisation = visualisation
        if visualisation:
            self.visualiser = Visualisation()
            # self.visualiser.initialise_plot()
            logger.info("Visualisation: Enabled")
        else:
            logger.info("Visualisation: Disabled")

        self.size = POPULATION_SIZE
        logger.info("Population Size: {}".format(POPULATION_SIZE))
        self.generations = GENERATIONS
        logger.info("Generations: {}".format(GENERATIONS))

        if self.mutator.learning and not RECORDING:
            logger.info("Learning mutator, setting up mutation recording...")
        elif not self.mutator.learning and RECORDING:
            logger.info("Mutator does not support learning, but recording is enabled.")
        elif self.mutator.learning and RECORDING:
            logger.info("Learning mutator and recording enabled, setting up mutation recording...")
        else:
            logger.info("Mutator does not support learning, and recording is disabled.")
        self.history_all = pd.DataFrame() if self.mutator.learning or RECORDING else None

        logger.info("Initialising population...")
        self.candidates = [Candidate(ACTIONS) for _ in range(self.size)]

        fitness_values = make_fitness_from_data(run_model_per_candidate(self.candidates))
        self.population_update_fitness(fitness_values)

        logger.info("Population Initialised!")

    def population_update_fitness(self, fitness_values):
        candidate_dict = {candidate.unique_id: candidate for candidate in self.candidates}
        for candidate_id, fitness_value in fitness_values.items():
            candidate = candidate_dict.get(candidate_id)
            if candidate:
                candidate.update_fitness(fitness_value)
            else:
                logger.warning(f"Candidate with ID {candidate_id} not found.")

    def select_parents(self) -> List[Candidate]:
        '''
        Selects two parents using the provided selection strategy.
        '''
        return self.selection_strategy(self.candidates)
    
    def reproduce(self, parent_selection):
        '''
        Takes the actions of two parents and returns actions for two offspring
        '''
        children = self.reproduction_strategy(parent_selection)
        return children
    
    def record_evolution_step(self, history_batch: List[dict], generation: int):

        # add mandatory actions to the actions list for all candidates listed in history_batch, namely columns "action_before" and "action_after"
        history_batch["actions_before"] = history_batch["actions_before"].apply( lambda x: ACTIONS_MANDATORY_BEGIN + x + ACTIONS_MANDATORY_END )
        history_batch["actions_after"] = history_batch["actions_after"].apply( lambda x: ACTIONS_MANDATORY_BEGIN + x + ACTIONS_MANDATORY_END )

        self.history_all = pd.concat([self.history_all, pd.DataFrame(history_batch)])
        # save current history_batch to .csv or .xls file
        self.save_file(history_batch, generation)

        # if SAVE_FORMAT == "XLSX":
        #     self.save_excel(history_batch, generation)
        # elif SAVE_FORMAT == "CSV":
        #     self.save_csv(history_batch, generation)
        # else:
        #     logger.warning("Invalid SAVE_FORMAT. Please choose either 'XLSX' or 'CSV'.")

    def evolve_population_once(self, generation: int):
        '''
        Performs one evolutionary step on the population, which consists of
        0. Assemble the history_batch DataFrame to record this evolution process
        1. Candidates are passed to the Mutator to mutate, a mutation is applied and the candidate's actions are updated.
        2. Run the model with the new actions and collect the new data
        3. Learn from the mutation outcomes and save data
        4. Selection and Reproduction
        '''
        # Step 1: Pass candidates to the Mutator to mutate and apply the mutation
        mutation_decisions = [self.mutator.decide_mutation(candidate.actions) for candidate in self.candidates]
        
        # Apply mutation to each candidate
        for candidate, (mutation_type, action, index, second_index, _) in zip(self.candidates, mutation_decisions):
            candidate.apply_mutation_to_candidate(mutation_type, action, index, second_index)
        
        
        # # Step 2: Run the model with the new actions and collect the new data sequentially 
        # fitness_values = make_fitness_from_data(run_model_per_candidate(self.candidates)) # returns dictionary candidate_id : fitness_list
        # self.population_update_fitness(fitness_values)
        # # Initialize an empty DataFrame only if necessary
        # if self.mutator.learning or RECORDING:
        #     # Assuming mutation_decisions is a list of tuples/lists with consistent length
        #     columns_mutation = ["mutation_type", "action", "index", "second_index", "predicted_fitness_change"]
        #     # Construct the DataFrame using a single dictionary of lists
        #     history_batch = pd.DataFrame({
        #         "candidate_id": [candidate.unique_id for candidate in self.candidates],
        #         "actions_before": [candidate.previous_actions.copy() for candidate in self.candidates],
        #         "actions_after": [candidate.actions for candidate in self.candidates],
        #         "fitness_before": [candidate.previous_fitness for candidate in self.candidates],
        #         "fitness_after": [candidate.fitness for candidate in self.candidates],
        #         **{col: [mutation[i] for mutation in mutation_decisions] for i, col in enumerate(columns_mutation)}
        #     }, dtype=object)
        #     history_batch["actual_fitness_change"] = history_batch.apply(
        #         lambda row: np.subtract(row["fitness_after"], row["fitness_before"]), axis=1
        #     )
        #     # Replace NaNs with None
        #     history_batch = history_batch.replace({np.nan: None})

        history_batch = pd.DataFrame()
        if self.mutator.learning or RECORDING:
            history_batch["candidate_id"] = [candidate.unique_id for candidate in self.candidates]
            history_batch["actions_before"] = [candidate.previous_actions.copy() for candidate in self.candidates]
        # Step 2: Run the model with the new actions and collect the new data sequentially 
        model_run_data = run_model_per_candidate(self.candidates)
        self.population_update_fitness(make_fitness_from_data(model_run_data))
        if self.mutator.learning or RECORDING:
            history_batch["actions_after"] = [candidate.actions.copy() for candidate in self.candidates]
            history_batch["mutation_type"] = [mutation[0] for mutation in mutation_decisions]
            history_batch["action"] = [mutation[1] for mutation in mutation_decisions]
            history_batch["index"] = [mutation[2] for mutation in mutation_decisions]
            history_batch["second_index"] = [mutation[3] for mutation in mutation_decisions]
            history_batch["predicted_fitness_change"] = [mutation[4] for mutation in mutation_decisions]
            history_batch["fitness_before"] = [candidate.previous_fitness for candidate in self.candidates]
            history_batch["fitness_after"] = [candidate.fitness for candidate in self.candidates]
            history_batch["actual_fitness_change"] = history_batch.apply(
                lambda row: np.subtract(row["fitness_after"], row["fitness_before"]), axis=1
            )

            history_batch = history_batch.replace({np.nan: None})

        # Step 3: Learning from mutation outcomes and save data
        if self.mutator.learning:
            self.mutator.train_model(history_batch)

        if RECORDING:
            # model_run_data drops column "actions_after"
            model_run_data = model_run_data.drop(columns=["actions_after"])
            # merge model_run_data with history_batch on candidate_id
            history_batch = pd.merge(history_batch, model_run_data, on="candidate_id")

            # print("After merge, these are the columns:\n", history_batch.columns)
            # print(history_batch[TRAINING_COLUMNS])
            self.record_evolution_step(history_batch, generation)

        # Step 6: Selection and Reproduction
        new_generation = []
            # Select parents
        parents_selection = self.select_parents()
            # Reproduce offspring
        new_generation.extend(self.reproduce(parents_selection))
            # Update the population with the new generation
        self.candidates = new_generation[:self.size]

    def evolve_population(self):
        for generation in range(self.generations):
        # for generation in range(self.generations):
            self.evolve_population_once(generation)
            self.update_visualisation(generation, [candidate.fitness for candidate in self.candidates])
        self.finalise_visualisation()
        if RECORDING:
            self.save_file(self.history_all)

    # Include a method to update visualization after each generation
    def update_visualisation(self, generation, fitness_stats):
        if self.visualisation:
            self.visualiser.update_plot(generation, fitness_stats)

    # Call this method at the end of your evolutionary loop
    def finalise_visualisation(self):
        if self.visualisation:
            self.visualiser.finalise_plot()

    def save_file(self, history_batch, generation=None, file_format="CSV"):
        base_dir = os.path.dirname(__file__)
        overall_folder = os.path.dirname(base_dir)
        history_dir = os.path.join(overall_folder, 'history')
        title = RUN_ID  if generation is None else RUN_ID  + f"_gen{generation}"
        
        if file_format.upper() == "CSV":
            history_path = os.path.join(history_dir, f'{title}.csv')
            history_batch.to_csv(history_path, index=False, encoding='utf-8', sep=SEPARATOR)
        elif file_format.upper() == "XLSX":
            history_path = os.path.join(history_dir, f'{title}.xlsx')
            history_batch.to_excel(history_path, index=False, encoding='utf-8', sep=SEPARATOR)
        else:
            logger.warning("Invalid file format. Please choose either 'XLSX' or 'CSV'.")

        # if CLEAN_UP and generation is None:
        #     # remove files with RUN_ID and generation number
        #     files = [f for f in os.listdir(history_dir) if f.startswith(RUN_ID + "_gen*") ]
        #     logger.info("Clean-up enabled. Deleting intermediate generation files...")
        #     for file in files:
        #         os.remove(os.path.join(history_dir, file)) 