# -*- coding: utf-8 -*-
"""
@author: j.h.koo@tudelf.nl
"""

import pandas as pd
import numpy as np
from pyomo.environ import value
from PDMPC_formulation_4O_3W import MPC_formula
from PDMPC_BO_P_4O_3W_6O_simple import call_OT, Sol_to_Weight
from PDMPC_solver import g_solver

result_path = "D:\\"
input_path = "D:\\"

try:
    LV_curve = pd.read_csv(r"/home/koojah/RL/LV_curve.csv")
except:
    LV_curve = pd.read_csv(r"LV_curve.csv")


def LtoS(lvl):
    sto = np.interp(lvl, LV_curve['0'], LV_curve['1'])
    return sto

def StoL(sto):
    lvl = np.interp(sto, LV_curve['1'], LV_curve['0'])
    return lvl

MT = 264  # the maximun discharge via turbines
MSP = int(11680)  # the maximum discharge via spillway gates
LWL = 60.0
FWL = 80.0
LWS = int(LtoS(LWL))
FWS = int(LtoS(FWL))

IWL = 76.5  # initial water level
IWS = int(LtoS(IWL))
TWL = 76.5  # Target water level at the end of the episode
TWL_N = 76.5
TWL_F = 78.5
TWL_L = 76.0
TWS = int(LtoS(TWL))
TWS_F = int(LtoS(TWL_F))
TWS_L = int(LtoS(TWL_L))
TWS_N = int(LtoS(TWL_N))

O_demand = 52



def MPC_main(F, E, PP, random_it, init_RWL = 76.5, init_OUT = MT, BOS = 0.0001):

    inf_real = pd.read_csv(input_path + "dc_F{}_E{}_{}.csv".format(F, E, 'PT'))
    inf_real.drop(inf_real.columns[0], axis=1, inplace=True)
    inf_real_ = inf_real[inf_real.columns[0]]
    inf_ = pd.read_csv(input_path + "dc_F{}_E{}_{}_{}.csv".format(F, E, PP, random_it))
    inf_.drop(inf_.columns[0], axis=1, inplace=True)

    max_step = len(inf_) - F - 1
    I_SO_k = [init_OUT for i in range(F)]

    ## ====================== start MPC iterations ================================================================

    model = MPC_formula()
    solver = g_solver(model)

    for k in range(0, max_step):

        QIN_k_ = inf_.iloc[k].to_list()
        QIN_real_k_ = inf_real.iloc[k].to_list()
        Md_k_ = [O_demand for i in range(F)]

        if k == 0:
            max_INF = MT
            SO_p_k = I_SO_k
            I_RWS_k = LtoS(init_RWL)
        else:
            max_INF = max(max(inf_real_[:k].to_list()), MT)

        State_k = {None: {'I_RWS': {None: I_RWS_k},
                          'Z1': {None: 1},
                          'Z2': {None: 1},
                          'Z3': {None: 1},
                          'TWS_N': {None: TWS_N},
                          'TWS_L': {None: TWS_L},
                          'TWS_F': {None: TWS_F},
                          't': {None: [i for i in range(F)]},
                          'QIN': {i: QIN_k_[i] for i in range(F)},
                          'SO_P': {i: SO_p_k[i] for i in range(F)},
                          'F': {None: F},
                          'Md': {i: Md_k_[i] for i in range(F)}}}


        best_params, best_values, final_params, final_values = call_OT(solver, State_k, max_INF=max_INF, BOS=BOS)
        solution = final_params.copy()

        weights = Sol_to_Weight(solution, F)
        State_k[None]['Z1'][None] = weights[0]
        State_k[None]['Z2'][None] = weights[1]
        State_k[None]['Z3'][None] = weights[2]

        inst, results = solver.call_solver(State_k)

        SO_n_k = [value(inst.SO[i]) for i in inst.t]
        next_RWS = I_RWS_k + (QIN_real_k_[0] - round(SO_n_k[0])) * 3600

        I_RWS_k = next_RWS
        SO_p_k = [round(i) for i in SO_n_k]
