from config import *
from AbstractMutator import Mutator

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern
import numpy as np
import logging
logger = logging.getLogger(__name__)

class BayesianMutator(Mutator):
    def __init__(self):
        self._learning = True
        super().__init__()
        kernel = Matern(nu=MATERN_NU, length_scale_bounds=(KERNEL_LOWER, KERNEL_UPPER))
        self.model = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=N_RESTARTS_OPTIMIZER, alpha=ALPHA)
        self.is_trained = False

    def log_hyperparameters(self):
        logger.info("n_restarts_optimizer: {}".format(N_RESTARTS_OPTIMIZER))
        logger.info("matern_nu: {}".format(MATERN_NU))
        logger.info("acquisition_function: {}".format(ACQUISITION_FUNCTION))
        logger.info("acquisition_parameter: {}".format(ACQUISITION_PARAMETER))
        logger.info("alpha: {}".format(ALPHA))
        logger.info("kappa: {}".format(KAPPA))
    
    def get_name(self):
        return "Bayesian Mutator"
    
    # def train_model(self, batch):
    #     y_train = batch["actual_fitness_change"].astype(float)
    #     batch = batch[TRAINING_COLUMNS]
    #     X_train = self.transform_dataframe(batch)
    #     self.model.fit(X_train, y_train)
    #     self.is_trained = True

    def decide_mutation(self, candidate_actions):
        if not self.is_trained:
            return self.random_mutation(candidate_actions)
        
        batch = self.build_mutation_batch_from_actions(candidate_actions)
        batch_encoded = self.transform_dataframe(batch)
        predicted_fitness_change, sigma = self.model.predict(batch_encoded, return_std=True)
        
        # Decision metric incorporates both the mean prediction and uncertainty
        decision_metric = predicted_fitness_change + KAPPA * sigma
        if MINIMISATION:
            best_mutation_index = np.argmin(decision_metric)
        else:
            best_mutation_index = np.argmax(decision_metric)
        best_mutation_row = batch.iloc[best_mutation_index]
        
        best_mutation = [best_mutation_row["mutation_type"], best_mutation_row["action"], best_mutation_row["index"], best_mutation_row["second_index"], predicted_fitness_change[best_mutation_index]]
        return best_mutation
