# -*- coding: utf-8 -*-
"""
Created on Tue Feb 26 11:53:11 2019

@author: PezijM1
"""

# this module contains general functions for use in Pastas modelling 
# approaches.

# import statements
from glob import glob
import os
import pandas as pd
from pandas.plotting import register_matplotlib_converters
import numpy as np
import datetime
import matplotlib.pyplot as plt
plt.close('all')

# register Pandas formatters and converters with Matplotlib
register_matplotlib_converters()

###############################################################################
###############################################################################
###############################################################################
def read_prec_data():
    '''
    function to read converted precipitation data (from KNMI rasters)
    '''
    # set filename
    fName = os.path.join('data', 'precipitation', 'df_prec_2016_2018.pkl')
    
    # read data in pandas dataframe
    df_prec = pd.read_pickle(fName)
    
    # set nodata values
    df_prec[df_prec==-9999] = 0#np.nan
    
    # return dataframe
    return df_prec

###############################################################################
###############################################################################
###############################################################################
def read_prec_data_csv():
    '''
    function to read in converted precipitation data (from KNMI rasters)
    '''
    # set filename
    fName = os.path.join('data', 'precipitation', 'prec_2016_2019_raam.csv')
    
    # read data in pandas dataframe
    df_prec = pd.read_csv(fName,
                          parse_dates=True,
                          index_col=[0])
    
    # set nodata values
#    df_prec[df_prec==-9999] = 0#np.nan
    
    # return dataframe
    return df_prec

###############################################################################
###############################################################################
###############################################################################
def read_et_ref_data_csv():
    '''
    function to read in converted precipitation data (from KNMI rasters)
    '''
    # set filename
    fName = os.path.join('data', 'reference_ET', 'knmi_Makkink_ET_2016_2018.csv')
    
    # read data in pandas dataframe
    df_et = pd.read_csv(fName,
                          parse_dates=True,
                          index_col=[0],
                          sep=';')
    
    # set nodata values
#    df_prec[df_prec==-9999] = 0#np.nan
    
    # return dataframe
    return df_et

###############################################################################
###############################################################################
###############################################################################
def read_knmi_data():
    '''
    function to read in knmi data (Twenthe station)
    '''
    
    # set filename
    fName = os.path.join('..', 'data', 'knmi', 'knmi_2018.pkl')
    
    # read data in pandas dataframe
    df_knmi = pd.read_pickle(fName)
    
    # retrn dataframe
    return df_knmi

###############################################################################
###############################################################################
###############################################################################
def read_ETact_data():
    '''
    function to read in ETact data
    '''
    
    # set filename
    fName_ETact = os.path.join('..', 'data','actual_ET','df_et_Act.pkl')
    
    # read pickle
    df_ETact = pd.read_pickle(fName_ETact)/100
    
    return df_ETact

###############################################################################
###############################################################################
###############################################################################
def read_ETdef_data():
    '''
    function to read in ETact data
    '''
    
    # set filename
    fName_ETdef = os.path.join('..', 'data','ET_def','df_et_def.pkl')
    
    # read pickle
    df_ETdef = pd.read_pickle(fName_ETdef)/100
    
    return df_ETdef

###############################################################################
###############################################################################
###############################################################################
def read_sm_metadata():
    '''
    function to read Raam data
    '''
    
    fName_metadata = os.path.join('data', 'location', 
                                  r'Twente_stations_2016_RD_char.csv')
    
    df_station_data = pd.read_csv(fName_metadata,
                                  sep=';')
    
    return df_station_data
  
###############################################################################
###############################################################################
###############################################################################    
def read_sm_smap():
    fName = os.path.join('data', 'smap', 'SMAP_Twente_2015-2019.csv')
    
    df = pd.read_csv(fName,
                     sep=';',
                     parse_dates=True,
                     na_values=-9999,
                     index_col=0)
    
    return df

###############################################################################
###############################################################################
###############################################################################
def read_sm_smap_raam():
    fName = os.path.join('data', 'smap', 'SMAP_Raam_2015-2019.csv')
    
    df = pd.read_csv(fName,
                     sep=';',
                     parse_dates=True,
                     na_values=-9999,
                     index_col=0)
    
    return df

###############################################################################
###############################################################################
###############################################################################
def read_sm_data(df_station_info, resample='yes'):
    '''
    function to read ITCSM data
    '''

    fNames_sm = glob(os.path.join('data', 'soil_moisture', '*.csv'))
    
    dict_sm = dict()
    
    for index, row in df_station_info.iterrows():
        
        # for every station
        for fName in fNames_sm:
            
            # if station_name
            if row.station_name in fName:
                
                # read data into pandas dataframe
                df_sm = pd.read_csv(fName,
                                    sep=';')
                
        # get proper datetime
        datet = pd.to_datetime(df_sm['Measurement Time'])
        df_sm.index = datet
        
        del df_sm['Measurement Time']
        
        if resample=='yes':
            df_sm = df_sm.resample('D').mean()
        
        df_sm[df_sm==-9999] = np.nan
        df_sm[df_sm<0] = np.nan
        
        dict_sm[row.station_name] = df_sm
    
    return dict_sm

###############################################################################
###############################################################################
###############################################################################                
def exploratory_plotting(dict_data):
    '''
    function to plot data stored in dict
    '''
    # get number of subplots
    subplots = len(dict_data)
    
    if subplots == 0:
        print('No data available for plotting...')
    
    elif subplots == 1:
        n_rows = 1
        n_cols = 1
        
    else:
        n_cols=2
        n_rows = int(np.ceil(subplots/n_cols))
        
    
        
    # open plotting object with n_rows and n_col    
    fig, axarr = plt.subplots(nrows=n_rows, ncols=n_cols)
    
    # plot all dict data
    for key, ax in zip(dict_data, axarr.reshape(-1)):
        
        if key == 'prec':
            ax.bar(dict_data[key].index, dict_data[key].values)
            ax.set_ylim(0, 30)
            
        elif key == 'et_act':
            ax.bar(dict_data[key].index, dict_data[key].values)
            ax.set_ylim(0, 6)
    
        elif key == 'et_def':
            ax.bar(dict_data[key].index, dict_data[key].values)
            ax.set_ylim(0, 2)
        
        elif key == 'sm5cm' or key=='sm5_prev' or key == 'sm40cm':
            ax.plot(dict_data[key])
            ax.set_ylim(0, 0.7)
        
        else:
            ax.plot(dict_data[key])
        
        ax.set_title(key)
        
    plt.tight_layout()
            
###############################################################################
###############################################################################
###############################################################################  
def slice_time_series(dictionary):
    '''
    function to slice time series based on shortest time series
    '''
    
    first_date = []
    last_date = []

    # get first and last date
    for key, values in dictionary.items():
        
        if not first_date:
            first_date = values.index[0]
           
        else:
            if values.index[0] > first_date:
                first_date = []
                first_date = values.index[0]
                
        if not last_date:
            last_date = values.index[-1]
            
        else:
            if values.index[-1] < last_date:
                last_date = []
                last_date = values.index[-1]
            
    # create empty sliced dictionary
    data_stat_sliced = dict()

    for key, values in dictionary.items():
        
        sliced = values[first_date:last_date]
        sliced = sliced[~sliced.index.duplicated(keep='first')]
        data_stat_sliced[key] = sliced
        
    return data_stat_sliced

###############################################################################
###############################################################################
###############################################################################
def setlabel(ax, label, loc=2, borderpad=0.2, **kwargs):
    """
    function for annotating subplots (a, b, c....)
    Derived from:
        https://stackoverflow.com/questions/22508590/enumerate-plots-in-matplotlib-figure
    """

    legend = ax.get_legend()
    if legend:
        ax.add_artist(legend)

    line, = ax.plot(np.NaN, np.NaN, color="none", label=label)

    label_legend = ax.legend(
        handles=[line],
        loc=loc,
        handlelength=0,
        handleheight=0,
        handletextpad=0,
        borderaxespad=0,
        borderpad=borderpad,
        frameon=False,
        **kwargs
    )
    label_legend.remove()

    ax.add_artist(label_legend)

    line.remove()

###############################################################################
###############################################################################
###############################################################################
def plot_3panels(ob_series, sim, sim2, insitu_crop, prec, et, stat_name):
    '''
    function to plot soil moisture estimates, precipitation and et_ref.
    '''
    
    # open plot objects
    fig, ax = plt.subplots(figsize=(6, 4), 
                           nrows=3, 
                           sharex=True, 
                           gridspec_kw={"height_ratios": [5, 2, 2]})
    
    # plot SMAP observations
    ax[0].scatter(
        ob_series.index, ob_series.values, 
        label="SMAP", 
        color="#7570b3", 
        s=1, 
        zorder=1)

    # plot pastas training
    ax[0].plot(sim.index, sim.values, 
               label="TFN training", 
               color="#e6ab02", 
               zorder=2)

    # plot pastas validation
    ax[0].plot(sim2.index, sim2.values, 
               label="TFN prediction", 
               color="#1b9e77", 
               zorder=3)

    # plot in situ measurements
    ax[0].plot(insitu_crop.index, insitu_crop.values, label="In situ", color="#d95f02")

    # plot precipitation data
    ax[1].bar(prec.index, prec.values)
    
    # plot reference evapotranspiration data
    ax[2].bar(et.index, et.values)

    # figure settings
    ax[0].set_xlim(datetime.date(2016, 1, 1), datetime.date(2019, 1, 1))
    ax[0].set_ylim(0, 0.8)
    ax[1].set_ylim(0, 25)
    ax[2].set_ylim(0, 6)
    ax[0].grid()
    ax[1].grid()
    ax[2].grid()
    ax[0].legend(fancybox=False, framealpha=1, ncol=4, fontsize=9, labelspacing=1)

    # labels
    ax[0].set_title("Station " + str(int(stat_name[-2:])))
    ax[2].set_xlabel("Date [year-month]")
    ax[0].set_ylabel("Volumetric moisture \n content \n [$m^3 m^{-3}$]")
    ax[1].set_ylabel("P \n[mm]")
    ax[2].set_ylabel("ET \n[mm]")
    fig.autofmt_xdate()
    fig.tight_layout()

    ax[0].set_ylim(0, 0.8)
    ax[1].set_ylim(0, 30)
    ax[1].yaxis.set_ticks([0, 30])
    ax[2].yaxis.set_ticks([0, 6])

    setlabel(ax[0], "(A)")
    setlabel(ax[1], "(B)")
    setlabel(ax[2], "(C)")

    return fig, ax

###############################################################################
###############################################################################
###############################################################################
def plot_1panels(ob_series, sim, sim2, insitu_crop, prec, et, stat_name):
    '''
    function to plot soil moisture estimates, precipitation and et_ref.
    '''
    
    # open plot objects
    fig, ax = plt.subplots(figsize=(6, 3.5))
    
    # plot SMAP observations
    ax.scatter(
        ob_series.index, ob_series.values, 
        label="SMAP", 
        color="#7570b3", 
        s=1, 
        zorder=1)

    # plot pastas training
    ax.plot(sim.index, sim.values, 
               label="TFN training", 
               color="#e6ab02", 
               zorder=2)

    # plot pastas validation
    ax.plot(sim2.index, sim2.values, 
               label="TFN prediction", 
               color="#1b9e77", 
               zorder=3)

    # plot in situ measurements
    ax.plot(insitu_crop.index, insitu_crop.values, 
           label="In situ", 
           color="#d95f02")


    # figure settings
    ax.set_xlim(datetime.date(2016, 1, 1), datetime.date(2019, 1, 1))
    ax.set_ylim(0, 0.8)
    ax.grid()
    ax.legend(fancybox=False, framealpha=1, ncol=4, fontsize=9, labelspacing=1)

    # labels
    ax.set_title("Station " + str(int(stat_name[-2:])))
    ax.set_xlabel("Date [year-month]")
    ax.set_ylabel("Volumetric moisture \n content \n [$m^3 m^{-3}$]")
    fig.autofmt_xdate()
    fig.tight_layout()



    return fig, ax