"""
@author: j.h.koo@tudelf.nl
"""

import pandas as pd
import numpy as np
from pyomo.environ import value
from PDMPC_formulation_4O_3W import MPC_formula
from PDMPC_BO_P_4O_3W_6O_simple import Sol_to_Weight
from PDMPC_solver import g_solver
from W_ML import W_ML_

result_path = "D:\\"
input_path = "D:\\"

try:
    LV_curve = pd.read_csv(r"/home/koojah/RL/LV_curve.csv")
except:
    LV_curve = pd.read_csv(r"LV_curve.csv")

def LtoS(lvl):
    sto = np.interp(lvl, LV_curve['0'], LV_curve['1'])
    return sto

def StoL(sto):
    lvl = np.interp(sto, LV_curve['1'], LV_curve['0'])
    return lvl

MT = 264  # the maximun discharge via turbines
MSP = int(11680)  # the maximum discharge via spillway gates
LWL = 60.0
FWL = 80.0
LWS = int(LtoS(LWL))
FWS = int(LtoS(FWL))

IWL = 76.6  # initial water level
IWS = int(LtoS(IWL))
TWL = 76.5  # Target water level at the end of the episode
TWL_N = 76.5
TWL_F = 78.5
TWL_L = 76.0
TWS = int(LtoS(TWL))
TWS_F = int(LtoS(TWL_F))
TWS_L = int(LtoS(TWL_L))
TWS_N = int(LtoS(TWL_N))
I_O = 150
O_demand = 52

W_ML = W_ML_()
scaler_max_ = W_ML.train_r.max(axis=0)
scaler_min_ = W_ML.train_r.min(axis=0)
rfC = W_ML.rfc()
rfR = W_ML.rfr(max_depth=30, n_estimators=150)
lfR = W_ML.lfr()
W_ML.clust_K_op__(n_cluster=100)
cluster, d_info_train, _ = W_ML.clust_K_nc(n_cluster=5)
scaler_max = scaler_max_[:-4].to_numpy()
scaler_min = scaler_min_[:-4].to_numpy()


def MPC_main(F, E, PP, random_it, z_ = 'wrfr', init_RWL = 76.5):

    inf_ = pd.read_csv(input_path + "dc_F{}_E{}_{}_{}.csv".format(F, E, PP, random_it))
    inf_real = pd.read_csv(input_path + "dc_F{}_E{}_{}.csv".format(F, E, 'PT'))
    inf_.drop(inf_.columns[0], axis=1, inplace=True)
    inf_real.drop(inf_real.columns[0], axis=1, inplace=True)

    max_step = len(inf_) - F - 1
    I_SO_k = [I_O for i in range(F)]

    model = MPC_formula()
    solver = g_solver(model)

    for k in range(0, max_step):

        QIN_k_ = inf_.iloc[k].to_list()
        QIN_real_k_ = inf_real.iloc[k].to_list()
        Md_k_ = [O_demand for i in range(F)]

        if k == 0:
            SO_p_k = I_SO_k
            I_RWS_k = LtoS(init_RWL)

        state_X = np.array([[StoL(I_RWS_k)] + QIN_k_ + SO_p_k])
        state_Xs = (state_X[0] - scaler_min) / (scaler_max - scaler_min)
        state_Xs = [state_Xs.tolist()]

        class_predict = rfC.predict(state_X)

        if z_ == 'wrfr':
            if class_predict[0] == 2:
                Z_predict = rfR.predict(state_Xs)
                Z_predict_re = np.array([[j * 3 / sum(i) for j in i] for i in Z_predict])
                W = Z_predict_re[0].tolist()
            else:
                W = [1, 1, 1]
        elif z_ == 'cluster_K_op':
            if class_predict[0] == 2:
                Z_predict = W_ML.clust_K_op(state_Xs)
                Z_predict = Z_predict.clip(min=0)
                Z_predict_re = np.array([[j * 3 / sum(i) for j in i] for i in Z_predict])
                W = Z_predict_re[0].tolist()
            else:
                W = [1, 1, 1]
        elif z_ == 'wlfr':
            if class_predict[0] == 2:
                Z_predict = lfR.predict(state_Xs)
                Z_predict = Z_predict.clip(min=0)
                Z_predict_re = np.array([[j * 3 / sum(i) for j in i] for i in Z_predict])
                W = Z_predict_re[0].tolist()
            else:
                W = [1, 1, 1]

        def MPC_process(W):
            weights = Sol_to_Weight(W, F)
            State_k = {None: {'I_RWS': {None: I_RWS_k},
                              'Z1': {None: weights[0]},
                              'Z2': {None: weights[1]},
                              'Z3': {None: weights[2]},
                              'TWS_N': {None: TWS_N},
                              'TWS_L': {None: TWS_L},
                              'TWS_F': {None: TWS_F},
                              't': {None: [i for i in range(F)]},
                              'QIN': {i: QIN_k_[i] for i in range(F)},
                              'SO_P': {i: SO_p_k[i] for i in range(F)},
                              'Md': {i: Md_k_[i] for i in range(F)}}}
            inst, results = solver.call_solver(State_k)
            SO_n_k = [value(inst.SO[i]) for i in inst.t]
            next_RWS = I_RWS_k + (QIN_real_k_[0] - SO_n_k[0]) * 3600
            return weights, next_RWS, SO_n_k

        weights, next_RWS, SO_n_k = MPC_process(W)

        I_RWS_k = next_RWS
        SO_p_k = SO_n_k