"""
@author: j.h.koo@tudelf.nl
"""

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import os

result_path = "D:\\Programs\\PDMPC_ML_C3\\Test_\\"
input_path = "D:\\Programs\\PDMPC_ML_C3\\DC\\"


def build_model(input_dim, output_dim, lr, nodes, layer, activation, loss):
    model = keras.Sequential()
    model.add(keras.Input(shape=(input_dim,)))
    model.add(layers.Dense(nodes, activation=activation))
    for i in range(layer-1):
        model.add(layers.Dense(nodes, activation=activation))
    model.add(layers.Dense(output_dim))

    optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
    model.compile(loss=loss, optimizer=optimizer, metrics=['mae', 'mse'])
    return model


def opt_BO(train_X, train_y, params_ls):

    kfold = KFold(n_splits=5, shuffle=True)
    stop_epochs = []
    mse_ls = []
    train_X = train_X.to_numpy()
    train_y = train_y.to_numpy()

    earlystop_callback = keras.callbacks.EarlyStopping(monitor="val_loss", min_delta = 10, patience=50, verbose=0)

    for train, test in kfold.split(train_X, train_y):
        model = build_model(train_X[train].shape[1], train_y[train].shape[1], lr=params_ls[0], nodes=params_ls[1], layer=params_ls[2], activation=params_ls[3], loss=params_ls[4])
        history = model.fit(train_X[train], train_y[train], validation_data=(train_X[test], train_y[test]), batch_size=params_ls[5], epochs=params_ls[6], verbose=0, callbacks = [earlystop_callback])

        eval_results = model.evaluate(train_X[test], train_y[test], verbose=0)
        mse_ev = eval_results[2]
        mse_ls.append(mse_ev)
        stop_epochs.append(len(history.history['loss']))

    penalty = sum(mse_ls)/len(mse_ls)

    return penalty, stop_epochs


if __name__ == '__main__':

    test_E = [1, 2, 25]
    val_E = [22, 23, 24]
    train_E = [i for i in range(1, 29) if i not in val_E + test_E]
    hyperopt_E = [3, 9, 10, 19, 27]

    OW_train = pd.read_excel(input_path + "Collected_OW_train.xlsx", index_col=0)
    OW_test = pd.read_excel(input_path + "Collected_OW_test.xlsx", index_col=0)
    OW_hyper = pd.read_excel(input_path + "Collected_OW_hyper.xlsx", index_col=0)

    O_columns = ['I_RWL_t', 'PI_0', 'PI_1', 'PI_2', 'PI_3', 'PI_4', 'PI_5', 'p_OUT_0', 'p_OUT_1', 'p_OUT_2', 'p_OUT_3',
                 'p_OUT_4', 'p_OUT_5', 'OUT_0', 'OUT_1', 'OUT_2', 'OUT_3', 'OUT_4', 'OUT_5']

    train_O = OW_train[O_columns]
    test_O = OW_test[O_columns]
    hyper_O = OW_hyper[O_columns]

    train_s = (train_O - train_O.min(axis=0)) / (train_O.max(axis=0)-train_O.min(axis=0))
    test_s = (test_O - train_O.min(axis=0)) / (train_O.max(axis=0) - train_O.min(axis=0))
    hyper_s = (hyper_O - train_O.min(axis=0)) / (train_O.max(axis=0) - train_O.min(axis=0))

    X_columns = ['I_RWL_t', 'PI_0', 'PI_1', 'PI_2', 'PI_3', 'PI_4', 'PI_5', 'p_OUT_0', 'p_OUT_1', 'p_OUT_2', 'p_OUT_3',
                 'p_OUT_4', 'p_OUT_5']
    y_columns = ['OUT_1', 'OUT_2', 'OUT_3', 'OUT_4', 'OUT_5']

    train_X = hyper_s[X_columns]
    train_y = hyper_O[y_columns].round()
    test_X = test_s[X_columns]
    test_y = test_O[y_columns].round()


    PS_lr = [0.0001, 0.0005, 0.001, 0.005, 0.01]
    PS_nodes = [16, 32, 64, 128]
    PS_layer = [6]
    PS_activation = ['relu']
    PS_loss = ['mse']
    PS_batch_size = [10, 50, 100, 200]
    PS_epochs = [1000]
    PS_ = []
    for i in PS_lr:
        for ii in PS_nodes:
            for iii in PS_layer:
                for iiii in PS_activation:
                    for iiiii in PS_loss:
                        for iiiiii in PS_batch_size:
                            for iiiiiii in PS_epochs:
                                PS_.append([i, ii, iii, iiii, iiiii, iiiiii, iiiiiii])


    j=0
    params_ls = PS_[j]
    mean_penaltys, stop_epochs = opt_BO(train_X, train_y, params_ls)
    PP = pd.DataFrame([params_ls + [mean_penaltys] + stop_epochs], columns=['lr', 'nodes', 'layers', 'activation', 'loss', 'batch_size', 'epochs', 'mean_penaltys', 'stop_epochs_1', 'stop_epochs_2', 'stop_epochs_3', 'stop_epochs_4', 'stop_epochs_5'])
    PP.to_excel(result_path + "GS_results_J{}.xlsx".format(j))