"""
@author: j.h.koo@tudelf.nl
"""

import pandas as pd
import numpy as np
from pyomo.environ import value
from pyomo.opt import SolverStatus, TerminationCondition
import PDMPC_Evaluator_6O_simple as EV_ETS
import optuna
from optuna import Trial
from optuna.samplers import TPESampler

try:
    LV_curve = pd.read_csv(r"/home/koojah/RL/LV_curve.csv")
except:
    LV_curve = pd.read_csv(r"LV_curve.csv")

def LtoS(lvl):
    sto = np.interp(lvl, LV_curve['0'], LV_curve['1'])
    return sto
def StoL(sto):
    lvl = np.interp(sto, LV_curve['1'], LV_curve['0'])
    return lvl

MT = 264  # the maximun discharge via turbines
MSP = int(11680)  # the maximum discharge via spillway gates
LWL = 60.0
FWL = 80.0
NHWL = 76.5
LWS = int(LtoS(LWL))
FWS = int(LtoS(FWL))
NHWS = int(LtoS(NHWL))

IWL = 76.5  # initial water level
IWS = int(LtoS(IWL))

# ================================================================================

def Sol_to_Weight(solution, F=6):
    w_ = [solution[0] * 1000 / MSP, solution[1] * 1000 / (FWS-LWS), solution[2] * 1000 / MSP]
    return w_


def PDMPC_main(solver, State_k, max_INF, params_ls, TWL_H=78.5, EV_weights= [1,1,1,1,1,1]):
    weights = Sol_to_Weight(params_ls)
    State_k[None]['W1'][None] = weights[0]
    State_k[None]['W2'][None] = weights[1]
    State_k[None]['W3'][None] = weights[2]

    TWS_F = State_k[None]['TWS_F'][None]
    TWS_L = State_k[None]['TWS_L'][None]
    TWL_F = round(StoL(TWS_F), 3)
    TWL_L = round(StoL(TWS_L), 3)
    penalty = 0

    inst, results = solver.call_solver(State_k)

    if (results.solver.status != SolverStatus.ok):
        penalty += -10000000
    elif results.solver.termination_condition != TerminationCondition.optimal:
        penalty += -10000000

    SO_p_k = list(State_k[None]['SO_P'].values())
    SO_n_k = [value(inst.SO[i]) for i in inst.t]
    SP_n_k = [value(inst.SP[i]) for i in inst.t]
    ST_n_k = [value(inst.ST[i]) for i in inst.t]
    E_RWL_k = [StoL(value(inst.S[i])) for i in inst.t]

    penaltys_EV, _, _, _ = EV_ETS.eval_ETS(SO_p_k, SP_n_k, ST_n_k, SO_n_k, E_RWL_k, TWL_F, TWL_L, TWL_H, max_INF)
    penalty += penaltys_EV * -1
    return penalty, E_RWL_k


def call_OT(solver, State_k, max_INF, BOS=0.0001):
    suggested_params = [[0,0,0]]

    def opt_BO(trial: Trial):
        w_1 = trial.suggest_float('w_1', low=0, high=1, step=0.01)
        w_2 = trial.suggest_float('w_2', low=0, high=1, step=0.01)
        w_3 = trial.suggest_float('w_3', low=0, high=1, step=0.01)

        params_ls = [w_1, w_2, w_3]
        if params_ls in suggested_params:
            raise optuna.TrialPruned()
        else:
            suggested_params.append(params_ls)

        penaltys_EV, _ = PDMPC_main(solver, State_k, max_INF, params_ls)
        mean_p = sum(params_ls) / len(params_ls)
        penalty = penaltys_EV + sum([abs(i - mean_p) for i in params_ls]) * BOS * -1

        return penalty

    def early_stopping_check_1(study, trial, early_stopping_rounds=1000):
        if trial.number >= 100:
            current_trial_number = trial.number
            best_trial_number = study.best_trial.number
            should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
            if should_stop:
                study.stop()

    def early_stopping_check_2(study, trial, early_stopping_threshold=0):
        if trial.number >= 100:
            should_stop = study.best_trial.value >= early_stopping_threshold
            if should_stop:
                study.stop()

    from functools import partial
    study = optuna.create_study(direction='maximize', sampler=TPESampler(n_startup_trials=100))
    optuna.logging.set_verbosity(optuna.logging.WARNING)
    study.optimize(opt_BO, n_trials=100000, callbacks=[partial(early_stopping_check_1, early_stopping_rounds=500), partial(early_stopping_check_2, early_stopping_threshold=0)])
    best_params = list(study.best_trial.params.values())
    best_values = study.best_trial.value

    penaltys_TPE, E_RWL_k = PDMPC_main(solver, State_k, max_INF, best_params)

    default_params1 = [1,1,1]
    penaltys_DEF1, _ = PDMPC_main(solver, State_k, max_INF, default_params1)

    fp_index = np.argmax([penaltys_DEF1, penaltys_TPE])
    final_params = [default_params1, best_params][fp_index]

    return best_params, best_values, final_params



