from config import *

from abc import ABC, abstractmethod
import random
import numpy as np
from sklearn.preprocessing import StandardScaler, MaxAbsScaler
from sklearn.feature_extraction import FeatureHasher
import pandas as pd
import logging
logger = logging.getLogger(__name__)

class Mutator(ABC):
    def __init__(self):
        super().__init__()
        if self.learning:
            self.hasher = FeatureHasher(n_features=N_FEATURES_HASHING, input_type='pair')
            logger.info(f"FeatureHasher initialized with n_features={N_FEATURES_HASHING}")
            self.scaler = StandardScaler()
            logger.info("StandardScaler initialized")
        else:
            logger.info("Mutator does not require learning.")

    @property
    def learning(self):
        return self._learning

    @learning.setter
    def learning(self, value):
        self._learning = value

    def classify_columns(self, df):
        numerical_cols = []
        categorical_cols = []
        for col in df.columns:
            # Check if column is numerical
            if pd.api.types.is_numeric_dtype(df[col]):
                numerical_cols.append(col)
            # Check if column is categorical or object type (common for categorical data)
            elif pd.api.types.is_categorical_dtype(df[col]) or pd.api.types.is_object_dtype(df[col]):
                categorical_cols.append(col)
            # Warning for columns that are neither numerical nor categorical
            else:
                logger.warn(f'Column {col} is neither strictly numerical nor categorical.')
        
        return numerical_cols, categorical_cols
    
    def transform_dataframe(self, df):
        # create list of possible actions of each position
        action_columns = [f'action_{i+1}' for i in range(MAXIMUM_LENGTH)]
        # Expand the actions_before into separate columns
        actions_expanded = df['actions_before'].apply(self.expand_candidate_actions)
        temp = pd.DataFrame(actions_expanded.tolist(), columns=action_columns, index=df.index)

        # Drop the original 'actions_before' and merge with the rest of the data.
        df = df.drop(columns=['actions_before'])
        df = pd.concat([temp, df], axis=1)

        numerical_cols, categorical_cols = self.classify_columns(df)

        # Hash categorical data
        hashed_features = []
        for index, row in df.iterrows():
            hashed_row = []
            for col in categorical_cols:
                value = row[col]
                if pd.isnull(value):
                    value = 'None'
                hashed_row.append((col, value))
            hashed_features.append(self.hasher.transform([hashed_row]).toarray()[0])
        
        hashed_df = pd.DataFrame(hashed_features, index=df.index)

        # Scale numerical data
        if numerical_cols:
            scaled_numerical = self.scaler.fit_transform(df[numerical_cols])
            # Concatenate hashed categorical and scaled numerical data
            final_df = np.concatenate([hashed_df, scaled_numerical], axis=1)
        else:
            final_df = hashed_df
        
        return pd.DataFrame(final_df, index=df.index)

    # # Function to encode a single row using FeatureHasher
    # def hash_row(self, row):
    #     # Replace None with a placeholder
    #     row = ['None' if v is None else v for v in row]
    #     # Convert to a format FeatureHasher expects: list of (feature, value) tuples
    #     transformed_row = [(str(i), str(row[i])) for i in range(len(row))]
    #     # Hash the row and return the transformed array
    #     return self.hasher.transform([transformed_row]).toarray()
    
    # def transform_dataframe(self, df):
    #     # Apply hash_row to each row of the DataFrame and collect results
    #     # Ensure each result is a 1D array, if necessary, before adding to the list
    #     hashed_features = [self.hash_row(row.values).flatten() for index, row in df.iterrows()]

    #     # Convert the list of arrays into a 2D numpy array, then into a DataFrame
    #     hashed_df = pd.DataFrame(np.vstack(hashed_features), index=df.index)
    #     return hashed_df
    
    @abstractmethod
    def log_hyperparameters(self):
        '''
        Subclasses are required to implement this method
        to log the hyperparameters of the mutator.
        '''
        pass

    @abstractmethod
    def get_name(self):
        '''
        Subclasses are required to implement this method
        to return the unique name of the mutator.
        '''
        pass

    @abstractmethod
    def decide_mutation(self, candidate_actions):
        '''
        Receives a candidate and returns a mutation for the candidate
        This is handled by the individual Mutator instance

        Parameters:
        - The actions of a candidate as a list

        Returns:
        A random elements of the available mutation space of the form
        [mutation_type, action, index, second_index, predicted_fitness_change = 0]
        '''
        pass

    def __str__(self):
        # Calls the implemented get_name method to get the name for string representation.
        return self.get_name()
    
    # def initialize_encoder(self, df):
    #     self.initial_columns = sorted(df.select_dtypes(include=['category', 'object']).columns)
    #     self.encoder.fit(df[self.initial_columns])

    # def update_encoder(self, df):
    #     # Ensure there are no new columns
    #     incoming_columns = sorted(df.select_dtypes(include=['category', 'object']).columns)
    #     if set(incoming_columns) != set(self.initial_columns):
    #         logger.warning("New columns detected which were not present during initial encoding.")
    #         return
        
    #     # Initialize a flag to track updates
    #     updated = False

    #     # Update encoder categories for existing columns
    #     for idx, col in enumerate(self.initial_columns):
    #         # Get current and new categories
    #         existing_categories = self.encoder.categories_[idx]
    #         new_categories_in_col = np.unique(df[col].dropna().values)

    #         # Update categories if new ones are found
    #         updated_categories = np.unique(np.concatenate((existing_categories, new_categories_in_col)))
    #         if len(updated_categories) != len(existing_categories):
    #             self.encoder.categories_[idx] = updated_categories
    #             updated = True

    #     # Log if no updates were made
    #     if not updated:
    #         logger.info("No new categories or columns were detected; the encoder remains unchanged.")

    def apply_dict_except_indices(self, input_list, transformation_dict, skip_indices):
        """
        Helper Function
        Applies a dictionary transformation to a list except for elements at specified indices.
        
        Parameters:
        - input_list: The list to transform.
        - transformation_dict: The dictionary used for transformation.
        - skip_indices: A set of indices to skip during transformation.
        
        Returns:
        - A list with the transformation applied, except for skipped indices.
        """
        # Calculate the actual negative indices based on the length of the list
        actual_skip_indices = {index if index >= 0 else len(input_list) + index for index in skip_indices}
        
        # Apply the transformation except for the specified indices
        transformed_list = [transformation_dict.get(item, item) if i not in actual_skip_indices else item
                            for i, item in enumerate(input_list)]
        
        return transformed_list

    def create_mutatation_space_human(self, actions):
        '''
        Creates a human-readable list of possible mutations for a candidate
        
        Parameters:
        - The actions of a candidate as a list
        
        Returns:
        - A list of lists, the latter being a mutation possibility in the form of:
        [mutation_type, action, index, second_index]
        action can be none, same as second_index.
        '''
        mutation_space_human = []

        # Adding an action
        # only viable if the candidate has less than the maximum number of actions.
        if len(actions) < MAXIMUM_LENGTH:
            mutation_space_human += [[1, action, index, None] for action in ACTIONS for index in range(len(actions) + 1) if action not in actions]
        # Removing an action
        if len(actions) > MINIMUM_LENGTH:
            # Removing an action at index but that is not in ACTIONS_MANDATORY
            mutation_space_human += [[2, None, index, None] for index in range(len(actions)) if actions[index] not in ACTIONS_MANDATORY]
        # Permutating two actions
        if len(actions) > 2:
            mutation_space_human += [[3, None, index, second_index] for index in range(len(actions)) for second_index in range(index)]

        assert len(mutation_space_human) > 0, "No mutations possible for the given candidate."

        return mutation_space_human

    # def apply_mutation(self, candidate, mutation_type, action, index, second_index):
    #     '''
    #     Applies the mutation to the candidate based on the Candidate class. It's implemented as such:        

    #     def apply_mutation_to_candidate(self, mutation_type, action, index, second_index=None):
    #         if mutation_type == 1:
    #             # Mutation type 1 requires an action and uses index
    #             if action is not None:
    #                 self.add_action(index, action)
    #             else:
    #                 logging.error("Mutation type 1 requires an action to add.")
    #         elif mutation_type == 2:
    #             # Mutation type 2 only uses index, ignores action and second_index
    #             self.remove_action(index)
    #         elif mutation_type == 3:
    #             # Mutation type 3 uses index and second_index, ignores action
    #             if second_index is not None:
    #                 self.permute_actions(index, second_index)
    #             else:
    #                 logging.error("Mutation type 3 requires two indices to permute actions.")
    #         else:
    #             logging.error(f"Invalid mutation type {mutation_type}.")

    #     '''
    #     pass

    def random_mutation(self, candidate_actions):
        '''
        Randomly selects a mutation from the mutation space for the given candidate

        Parameters:
        - The actions of a candidate as a list

        Returns:
        A random elements of the available mutation space
        '''
        random_mutation = random.choice(self.create_mutatation_space_human(candidate_actions))  + [0] # Add 0 for no predicted fitness change
        return random_mutation
    
    def expand_candidate_actions(self, candidate_actions):
        '''
        Expand the candidate actions 

        Parameters:
        - The actions of a candidate as a list

        Returns:
        The expanded candidate actions with Nones filled up to the maximum length.
        '''
        # Padding candidate_actions with None until it reaches MAXIMUM_LENGTH
        padded_candidate_actions = candidate_actions + [None] * (MAXIMUM_LENGTH - len(candidate_actions))
        return padded_candidate_actions
    
    def build_mutation_batch_from_actions(self, candidate_actions):
            '''
            Iterate through the mutation_space_human.
            Create a single record for each and feed that into a data frame.
            single_record= 
            {
            "actions_before": [candidate_actions],
            "mutation_type": mutation[0],
            "action": mutation[1],
            "index": mutation[2],
            "second_index": mutation[3]
            }
            '''
            batch = pd.DataFrame()
            # Create human-readable muchation space
            mutation_space_human = self.create_mutatation_space_human(candidate_actions)

            for mutation in mutation_space_human:
                single_record= {
                    "actions_before": [candidate_actions],
                    "mutation_type": mutation[0],
                    "action": mutation[1],
                    "index": mutation[2],
                    "second_index": mutation[3]
                }
                batch = pd.concat([batch, pd.DataFrame(single_record, index=[1])])
            return batch
    
    def train_model(self, batch, use_weights=False):
        """
        Trains the model on a batch of history records. The model is trained on the input features and the output labels.
        
        Parameters:
        - batch: DataFrame containing the history records.
        - use_weights: Boolean indicating whether to use sample weights during training.

        Returns:
        - None
        """
        # Extract the target variable
        y_train = batch["actual_fitness_change"].astype(float)
        
        # Calculate weights if required
        weights = None
        if use_weights:
            weights = np.abs(batch["predicted_fitness_change"].astype(float) - batch["actual_fitness_change"].astype(float))
            weights = weights / weights.sum()

        # Select training features
        X_train = batch[TRAINING_COLUMNS]

        # Transform the data
        X_train = self.transform_dataframe(X_train)

        # Train the model with or without sample weights
        if use_weights and hasattr(self.model, 'fit'):
            self.model.fit(X_train, y_train, sample_weight=weights)
        else:
            self.model.fit(X_train, y_train)
        
        # Mark the model as trained
        self.is_trained = True
       
    # def deflate_candidate_actions(self, candidate_actions):
    #     '''
    #     Postprocesses the candidate actions before decoding

    #     Parameters:
    #     - The actions of a candidate as a list with Nones filled up to the maximum length.

    #     Returns:
    #     The postprocessed candidate actions with Nones removed
    #     '''
    #     filtered_candidate_actions = [action for action in candidate_actions if action is not None]
    #     return filtered_candidate_actions
    
    # def pre_encode_batch(self, batch):
    #     '''
    #     Transforms a batch into a format that can be used for training the model

    #     Parameters:
    #     - batch: A pandas DataFrame containing the history records to train the model on. Each row is a history record, and each column is a feature of the history record.
    #     history_record = {
    #         "initial_actions": [candidate_actions],
    #         "mutation_type": mutation[0],
    #         "action": mutation[1],
    #         "index": mutation[2],
    #         "second_index": mutation[3],
    #         "predicted_fitness_change": mutation[4],
    #         "actual_fitness_change": candidate.fitness - candidate.previous_fitness
    #     }

    #     Returns:
    #     - X_train: The input features of the history records. Training features are:
    #     history_record = {
    #         "initial_actions": [candidate_actions], # Split into separate columns
    #         "mutation_type": mutation[0],
    #         "action": mutation[1],
    #         "index": mutation[2],
    #         "second_index": mutation[3]
    #     }
    #     '''

    #     # Expand the initial_actions into separate columns
    #     action_columns = [f'action_{i+1}' for i in range(maximum_length)]
    #     # print(batch['initial_actions'], type(batch['initial_actions']))
    #     actions_expanded = batch['initial_actions'].apply(self.expand_candidate_actions)
    #     X_train_actions = pd.DataFrame(actions_expanded.tolist(), columns=action_columns, index=batch.index)

    #     # Drop the original 'initial_actions' and merge with the rest of the data.
    #     X_train = batch.drop(columns=['initial_actions'])
    #     X_train = pd.concat([X_train_actions, X_train], axis=1)

    #     # Apply dictionary to all X_train entries
    #     # X_train.replace(self.human_to_machine_dict, inplace=True)

    #     return X_train

    # def decode_batch(self, batch):
    #     pass

    # def encode_mutation(self, candidate_actions, mutation_list):
    #     '''
    #     Encodes the candidate and possible mutation to a one-hot representation

    #     Parameters:
    #     - candidate_actions: The actions of a candidate as a list
    #     - mutation_list: A list of possible mutations of the form [mutation_type, action, index, second_index, predicted_fitness_change]
    #     '''
    #     # Encode candidate's action and list of possible mutations into numerical representation
    #     # Note that we only want to encode the second element of the mutations list at index 1 because this is the action,
    #     # and everything else is already numeric.
    #     features = self.preprocess_candidate_actions_before_encoding(candidate_actions) + mutation_list
    #     features_machine = self.apply_dict_except_indices(features, self.human_to_machine_dict, skip_indices = {-5, -3, -2, -1})

    #     # Construct features array and encode it
    #     features_array = np.array(features_machine).reshape(-1, 1)
    #     encoded_features = self.encoder.fit_transform(features_array)
    #     return encoded_features

    # def decode_mutation(self, encoded_features):
    #     '''
    #     Decodes the one-hot representation back to candidate actions and mutation list
    #     '''
    #     # Inverse transform the encoded features to get back the numerical representations
    #     # Since sparse_output is not used, no need to convert from sparse to dense format
    #     decoded_features_machine = self.encoder.inverse_transform(encoded_features).flatten()

    #     # convert into human-readable format
    #     decoded_features = self.apply_dict_except_indices(decoded_features_machine, self.machine_to_human_dict, skip_indices = {-5, -3, -2, -1})

    #     # Assuming the first part of decoded_features corresponds to candidate actions
    #     # and the last five elements correspond to mutation_list (including predicted_fitness_change)
    #     candidate_actions, mutation_list = self.postprocess_candidate_actions_before_decoding(decoded_features[:-5]), decoded_features[-5:]

    #     # Return the decoded candidate actions and mutation list
    #     return candidate_actions, mutation_list
    
    # def encode_history_record_for_learning(self, history_record: dict): # This might have to go into the ForestMutator.py
    #     '''
    #     Encodes the history record for learning

    #     Parameters:
    #     - history_record: A dictionary containing the history record of a candidate of the form
    #     history_record = {
    #         "initial_actions": [candidate_actions],
    #         "mutation_type": mutation[0],
    #         "action": mutation[1],
    #         "index": mutation[2],
    #         "second_index": mutation[3],
    #         "predicted_fitness_change": mutation[4],
    #         "actual_fitness_change": candidate.fitness - candidate.previous_fitness
    #     }

    #     Returns:
    #     One-hot encoded representation of the history record ready to pass to the respective model split up into X_train
    #     '''

    #     candidate_actions = history_record["initial_actions"]

    #     mutation_list_X = [history_record["mutation_type"], history_record["action"], history_record["index"], history_record["second_index"], history_record["predicted_fitness_change"] ]
    #     # mutation_list_y =  [history_record["mutation_type"], history_record["action"], history_record["index"], history_record["second_index"], history_record["actual_fitness_change"] ]

    #     encoded_X = self.encode_mutation(candidate_actions, mutation_list_X) 
    #     # encoded_y = self.encode_mutation(candidate_actions, mutation_list_y)

    #     return encoded_X #, encoded_y

    def record_mutation(self):
        pass