from AbstractMutator import Mutator
from config import *

from sklearn.ensemble import RandomForestRegressor
import numpy as np
import pandas as pd
import logging
logger = logging.getLogger(__name__)

class ForestMutator(Mutator):
    def __init__(self):
        self._learning = True
        super().__init__()  # Initialize the base class
        self.model = RandomForestRegressor(n_estimators=N_ESTIMATORS_FOREST, min_samples_split=MIN_SAMPLES_SPLIT)
        self.is_trained = False  # Flag to check if the model has been trained
        self.log_hyperparameters()

    def log_hyperparameters(self):
        logger.info("n_estimators: {}".format(N_ESTIMATORS_FOREST))
        logger.info("weighted: {}".format(WEIGHTED_FOREST))
        logger.info("min_samples_split: {}".format(MIN_SAMPLES_SPLIT))

    def get_name(self):
        return "Forest Mutator"
    
    # def train_model(self, batch):
    #     '''
    #     Trains the model on a batch of history records. The model is trained on the input features and the output labels.
    #     - X_train: The input features of the history records
    #     - y_train: The output labels of the history records
    #     - weights: The weights of the history records, normalized to sum to 1

    #     Training features for X_train of a history record are:
    #     history_record = {
    #         "initial_actions": [candidate_actions],
    #         "mutation_type": mutation[0],
    #         "action": mutation[1],
    #         "index": mutation[2],
    #         "second_index": mutation[3]
    #     }

    #     Returns:
    #     - None
    #     '''

    #     y_train = batch["actual_fitness_change"].astype(float)

    #     weights = np.abs(batch["predicted_fitness_change"].astype(float) - batch["actual_fitness_change"].astype(float))
    #     weights = weights / weights.sum()

    #     # Drop features that are not used for training
    #     batch = batch[TRAINING_COLUMNS]

    #     batch = self.transform_dataframe(batch)
    #     if WEIGHTED_FOREST:
    #         self.model.fit(batch, y_train, sample_weight=weights)
    #     else:
    #         self.model.fit(batch, y_train)
    #     self.is_trained = True

    def decide_mutation(self, candidate_actions):
        if self.is_trained:
            batch = self.build_mutation_batch_from_actions(candidate_actions)
            # batch_pre_encoded = self.pre_encode_batch(batch)

            # print(batch_pre_encoded)

            # Update the encoder with the new batch to accommodate any new categories
            # self.update_encoder(batch_pre_encoded)  # Update encoder's knowledge with new data

            # Encode the batch using the updated encoder
            batch_encoded = self.transform_dataframe(batch)

            # Proceed with mutation decision-making using the encoded data
            predicted_fitness_change = self.model.predict(batch_encoded)

            # Add column with predicted fitness change
            batch["predicted_fitness_change"] = predicted_fitness_change

            # Find index of best row, best row being the one with the lowest predicted fitness change
            if MINIMISATION:
                best_mutation_index = batch["predicted_fitness_change"].idxmin()
            else:
                best_mutation_index = batch["predicted_fitness_change"].idxmax()
            best_mutation_row = batch.iloc[best_mutation_index]
            # Extract best mutation as list [mutation_type, action, index, second_index, predicted_fitness_change]
            best_mutation = [best_mutation_row["mutation_type"], best_mutation_row["action"], best_mutation_row["index"], best_mutation_row["second_index"], best_mutation_row["predicted_fitness_change"]]
            # Return best mutation
            return best_mutation
        else:
            # print("Untrained model, defaulting to random mutation")
            return self.random_mutation(candidate_actions)
        
def test():
    candidate_actions = ['action1', 'action2', 'action3', 'action4', 'action5']
    mutator = ForestMutator()
    mutation = mutator.decide_mutation(candidate_actions)
    print(mutation)

    history_batch = pd.DataFrame()
    history_record = {
        "initial_actions": [candidate_actions],
        "mutation_type": mutation[0],
        "action": mutation[1],
        "index": mutation[2],
        "second_index": mutation[3],
        "predicted_fitness_change": mutation[4],
        "actual_fitness_change": 2,
    }
    print(history_record)
    print(pd.DataFrame(history_record))
    history_batch = pd.concat([history_batch, pd.DataFrame(history_record)])
    print(history_batch)
    mutator.train_model(history_batch)
    print("successful batch training")
    mutator.decide_mutation(candidate_actions)
    print("Tests completed successfully!")


# if __name__ == "__main__":
#     test()