from config import *

import pandas as pd
import pynetlogo
import fastdtw
import logging
logger = logging.getLogger(__name__)

# calculate default value for all when data is all zeros
DISTANCE_TARGET_ZEROS = 100*fastdtw.fastdtw(TARGET, np.zeros(TICK_LIMIT))[0]

def error_calculation(data, TARGET):
    '''
    This function calculates the error between the data and the target using the Dynamic Time Warping algorithm.
    
    Parameters:
    - data is a list of floats representing the data of the specific reporter to be compared to the target
    - TARGET is a list of floats representing the target data to be compared to
    
    Returns:
    - A float representing the error between the data and the target
    '''
    if np.sum(data) == 0:
        return DISTANCE_TARGET_ZEROS if MINIMISATION else -DISTANCE_TARGET_ZEROS
    else:
        distance, path = fastdtw.fastdtw(data, TARGET)
        return distance
    
def make_fitness_from_data(model_run_data):
    fitness_values = {}
    # Iterate over all unique candidate IDs in the model run data
    for candidate_id in model_run_data["candidate_id"].unique():
        # select the data for the current candidate
        candidate_data = model_run_data[model_run_data["candidate_id"] == candidate_id]
        # for each reporter, calculate the mean
        candidate_fitness = [ candidate_data[f"{reporter}_error"].mean() for reporter in REPORTER_LIST ]
        fitness_values.update({candidate_id : candidate_fitness } )
    return fitness_values

###################################################################

def initialise_netlogo():
    try:
        netlogo = pynetlogo.NetLogoLink(gui=False, jvmargs=JVMARGS, netlogo_home=NETLOGO_HOME)
        netlogo.load_model(os.path.join('model', f'{MODEL_NAME}.nlogo'))
        return netlogo
    except Exception as e:
        logger.error(f"Failed to start NetLogo: {e}")
        exit(1)

def run_single_experiment(candidate_id, candidate_actions, parameters_dict, run, netlogo=None):
    data_run = {
        "candidate_id": candidate_id,
        "actions_after": [candidate_actions],
        "run": run
    }
    data_run.update({reporter: [] for reporter in REPORTER_LIST})

    # Set parameters and setup the model
    for key, value in parameters_dict.items():
        netlogo.command(f"set {key} {value}")
        data_run[key] = value
    netlogo.command("setup")

    # make list of strings from candidate actions
    go_function = ACTIONS_STANDARD_BEGIN + ACTIONS_MANDATORY_BEGIN + candidate_actions + ACTIONS_MANDATORY_END + ACTIONS_STANDARD_END


    # # either this version
    counts = netlogo.repeat_report(REPORTER_LIST, TICK_LIMIT, go=" ".join(go_function))
    for reporter in REPORTER_LIST:
        this_reporter = list(counts[reporter])
        data_run[reporter] = [this_reporter]
        data_run[f"{reporter}_error"] = error_calculation(this_reporter, TARGET)
    
    # # or this
    # for _ in range(TICK_LIMIT):
    #     for function in go_function:
    #         netlogo.command(function)
    #         # try:
    #         #     netlogo.command(function)
    #         # except Exception as e:
    #         #     logger.info(f"Error in function: {function} during simulation.")
    #         #     break
    #     for reporter in REPORTER_LIST:
    #         reported_value = netlogo.report(reporter)
    #         data_run[reporter].append(reported_value)
    # for reporter in REPORTER_LIST:
    #     data_run[f"{reporter}_error"] = error_calculation(data_run[reporter], TARGET)
    #     data_run[reporter] = [data_run[reporter]]

        # print("Reported values are:", data_run[reporter])


    data_run = pd.DataFrame(data_run)

    if data_run.isnull().values.any():
        logger.error(f"NaN values in data_run for candidate {candidate_id}, run {run}.")

    return data_run

def run_model_per_candidate(candidates):
    netlogo = initialise_netlogo()

    if netlogo is None:
        logger.error("Failed to start NetLogo.")
        exit()
    # netlogo.load_model(os.path.join('model', f'{MODEL_NAME}.nlogo'))

    all_data_run = pd.DataFrame()

    for candidate in candidates:
        for parameters_dict in RUN_PARAMETERS_DICT_LIST:
            for run in range(NUMBER_OF_RUNS):
                data_run_df = run_single_experiment(candidate.unique_id, candidate.actions, parameters_dict, run, netlogo)
                all_data_run = pd.concat([all_data_run, data_run_df], ignore_index=True)

    netlogo.kill_workspace() # sure I want that uncommented??

    if all_data_run.isnull().values.any():
        logger.error(f"NaN values in all_data_run after aggregation in run_model_per_candidate.")

    return all_data_run

