import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.model_selection import train_test_split
import shap
from skmultilearn.problem_transform import LabelPowerset, BinaryRelevance, ClassifierChain
from xgboost import XGBClassifier
from sklearn.metrics import classification_report, roc_curve, auc
from sklearn.linear_model import LogisticRegression
import scikitplot as skplt
from sklearn.metrics import precision_recall_curve
from imblearn.over_sampling import SMOTE, ADASYN
import joblib

#===================================================================================================

class extremeEvents:
    def __init__(self):
        self.heatwave_df = None
        self.drought_df = None
        self.impact_df = None
        self.datadir = 'Data/'
        self.model = None
        self.br_classifier_best = None
        self.cc_classifier_best = None
        self.lp_classifier_best = None
    
    def load_data(self, heatwave=None, drought=None, impact=None):
        if heatwave is not None:
            self.heatwave_df = pd.read_csv(self.datadir + heatwave, sep=' ', quotechar='"', engine='python')
            
        if drought is not None:
            self.drought_df = pd.read_csv(self.datadir + drought, sep=' ', quotechar='"', engine='python')
        
        if impact is not None:
            self.impact_df = pd.read_csv(self.datadir + impact)
        
    def train_preprocessor(self, categorize, impact_num, impactType=None):
        #========================================================================================================================================
        #Preprocess heatwave data
        #function to expand the heatwave data to cover all months of the year. Missing months are filled with 0 indicating 
        #absence of extreme event
        ### Create a MultiIndex from 'year' and 'month' columns
        self.heatwave_df.set_index(['year', 'month'], inplace=True)
        # Create a DataFrame with a MultiIndex covering all years and months
        year_range = range(self.heatwave_df.index.get_level_values('year').min(), self.heatwave_df.index.get_level_values('year').max() + 1)
        month_range = range(1, 13)
        index = pd.MultiIndex.from_product([year_range, month_range], names=['year', 'month'])

        expanded_df = pd.DataFrame(index=index)

        # Merge the expanded DataFrame with the original data
        self.heatwave_df = expanded_df.merge(self.heatwave_df, left_index=True, right_index=True, how='left')

        # Fill missing values with 0
        self.heatwave_df.fillna(0, inplace=True)
        
        #==================================================================================================================
        #Preprocess drought/soil moisture data
        # Extract year and month from the 'date' column
        self.drought_df['year'] = pd.to_datetime(self.drought_df['date']).dt.year
        self.drought_df['month'] = pd.to_datetime(self.drought_df['date']).dt.month
        
        self.drought_df.set_index(['year', 'month'], inplace=True)
        
        
        #======================================================================================================================
        #Preprocess impact data
        
        self.impact_df = self.impact_df.fillna(0).astype(int)

        ### Create a MultiIndex from 'year' and 'month' columns
        self.impact_df.set_index(['year', 'month'], inplace=True)
        
        if impactType == 'heatwave':
            impact_df = self.impact_df
        
            all_data = self.heatwave_df
            #all_data = pd.merge(self.heatwave_df, self.drought_df, on=['year', 'month'])
            all_data = pd.merge(all_data, impact_df, on=['year', 'month'])
            ### dropping unnecessary variables
            all_data = all_data.reset_index()
            all_data = all_data.drop(columns=['year'], axis=1)
        
            # Separate features (X) and target labels (y)
            X = all_data.drop(columns=['d1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9', 
                                       'd10', 'd11', 'd12', 'd13', 'd14', 'd15', 'h1', 'h4', 'h8', 'h9', 'h12', 
                                       'h13', 'h14'], axis=1)
            y = all_data[['h14']]
            
            smote = SMOTE(sampling_strategy='auto', random_state=42)

            X_smote_list = []
            y_smote_combined_list = []

            label_indices = {label_col: idx for idx, label_col in enumerate(y.columns)}

            for label_col in y.columns:
                label_y = y[label_col]
                X_label = X.loc[label_y.index]  # Select data with the current label

                if len(X_label) > 0:  # Apply SMOTE only if there are positive samples
                    label_idx = label_indices[label_col]
                    X_smote, y_smote = smote.fit_resample(X_label, label_y)

                    X_smote_list.append(X_smote)
                    y_smote_combined_list.append(y_smote)

            # Concatenate the resampled data for each label
           

            X_months = pd.get_dummies(X_smote_list[0]['month'], prefix='month', drop_first=False)
            X_merged = pd.concat([X_months, X_smote_list[0]], axis=1)
            X_merged = X_merged.drop(columns=['month'], axis=1)

            X_train, X_test, y_train, y_test = train_test_split(X_merged, y_smote_combined_list[0], test_size=0.2, random_state=1000)
            y_smote_combined_list = y_smote_combined_list[0]
            
        elif impactType == 'drought':
        
            if categorize == 'yes':

                self.impact_df['cd1'] = self.impact_df[['d1', 'd2', 'd3', 'd4', 'd5', 'd6']].any(axis=1).astype(int) #agriculture and forestry
                self.impact_df['cd2'] = self.impact_df[['d7', 'd8', 'd13', 'd14', 'd15']].any(axis=1).astype(int)
                self.impact_df['cd3'] = self.impact_df[['d9', 'd10', 'd11', 'd12']].any(axis=1).astype(int)
                #self.impact_df['cd4'] = self.impact_df[['d12', 'd13', 'd14', 'd15']].any(axis=1).astype(int)

                impact_df = self.impact_df

                all_data = self.drought_df
                #all_data = pd.merge(self.heatwave_df, self.drought_df, on=['year', 'month'])
                all_data = pd.merge(all_data, impact_df, on=['year', 'month'])
                ### dropping unnecessary variables
                all_data = all_data.reset_index()
                all_data = all_data.drop(columns=['id', 'date', 'SM_drought', 'year'], axis=1)

                # Separate features (X) and target labels (y)
                # Separate features (X) and target labels (y)
                X = all_data.drop(columns=['d1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9', 
                                       'd10', 'd11', 'd12', 'd13', 'd14', 'd15', 'h1', 'h4', 'h8', 'h9', 'h12', 
                                       'h13', 'h14','cd1', 'cd2', 'cd3'], axis=1)
                y = all_data[['cd1', 'cd2', 'cd3']]



            else:
                #all_data = pd.merge(self.heatwave_df, self.drought_df, on=['year', 'month'])

                impact_df = self.impact_df

                all_data = self.drought_df
                #all_data = pd.merge(self.heatwave_df, self.drought_df, on=['year', 'month'])
                all_data = pd.merge(all_data, impact_df, on=['year', 'month'])
                ### dropping unnecessary variables
                all_data = all_data.reset_index()
                all_data = all_data.drop(columns=['id', 'date', 'SM_drought', 'year'], axis=1)

                # Separate features (X) and target labels (y)
                X = all_data.drop(columns=['d1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9', 'd10', 
                                           'd11', 'd12', 'd13', 'd14', 'd15', 'h1', 'h4', 'h8', 'h9', 'h12', 
                                           'h13', 'h14'], axis=1)
                y = all_data[['d1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9', 'd10', 'd11', 
                              'd12', 'd13', 'd14', 'd15']]

    

        #=========================================================================================================
        # Select specific impact data and combine predictors and response variables
        
         # Split data into train and test/validation sets
       
        
        
            smote = SMOTE(sampling_strategy='auto', random_state=42)

            X_smote_list = []
            y_smote_combined_list = []

            label_indices = {label_col: idx for idx, label_col in enumerate(y.columns)}

            for label_col in y.columns:
                label_y = y[label_col]
                X_label = X.loc[label_y.index]  # Select data with the current label

                if len(X_label) > 0:  # Apply SMOTE only if there are positive samples
                    label_idx = label_indices[label_col]
                    X_smote, y_smote = smote.fit_resample(X_label, label_y)

                    X_smote_list.append(X_smote)
                    y_smote_combined_list.append(y_smote)

            # Concatenate the resampled data for each label
            impact_num = impact_num - 1

            X_months = pd.get_dummies(X_smote_list[impact_num]['month'], prefix='month', drop_first=False)
            X_merged = pd.concat([X_months, X_smote_list[impact_num]], axis=1)
            X_merged = X_merged.drop(columns=['month'], axis=1)

            X_train, X_test, y_train, y_test = train_test_split(X_merged, y_smote_combined_list[impact_num], test_size=0.2, 
                                                                random_state=1000)
            y_smote_combined_list = y_smote_combined_list[impact_num]
        
        return X_merged, y_smote_combined_list, X_train, X_test, y_train.values, y_test

        # X_months = pd.get_dummies(X['month'], prefix='month', drop_first=False)
        # X_merged = pd.concat([X_months, X], axis=1)
        # X_merged = X_merged.drop(columns=['month'], axis=1)
        # X_train, X_test, y_train, y_test = train_test_split(X_merged, y, test_size=0.2, random_state=1000)
        # return X_merged, y, X_train, X_test, y_train, y_test
    
    def xgb_model(self, predictors,response,X_train, X_test, y_train, y_test, nter):
        #=============================================================================================================
        #HYPERPARAMETER TUNING
        #=============================================================================================================
        # Define parameter distributions
        n_estimators = [int(x) for x in np.linspace(start=100, stop=2000, num=10)]
        grow_policy = ['depthwise', 'lossguide']
        booster = ['gbtree']
        tree_method = ['exact']
        learning_rate = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        gamma = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        subsample = [0.7, 0.8, 0.9, 1.0]  # Values between 0 and 1
        sampling_method = ['uniform']
        colsample_bytree = [0.7, 0.8, 0.9, 1.0]  # Values between 0 and 1
        max_leaves = [None]  # Adjust based on problem
        min_child_weight = [1, 2, 3, 4, 5]
        max_depth = [2, 4, 6, 8, 10, 12, 14]
        alpha = [0.1, 0.2, 0.3, 0.4, 0.5]  # Positive values for L1 regularization
        lambda1 = [0.1, 0.2, 0.3, 0.4, 0.5]  # Positive values for L2 regularization

        # Define the parameter grid for RandomizedSearchCV
        param_dist = {
            'n_estimators': n_estimators,
            'tree_method': tree_method,
            'subsample': subsample,
            'booster': booster,
            'learning_rate': learning_rate,
            'gamma': gamma,
            'sampling_method': sampling_method,
            'colsample_bytree': colsample_bytree,
            'max_leaves': max_leaves,
            'min_child_weight': min_child_weight,
            'max_depth': max_depth,
            'alpha': alpha,
            'lambda': lambda1,
            'grow_policy': grow_policy,
        }
        
       
        lp_classifier = classifier=XGBClassifier()
        lp_random = RandomizedSearchCV(lp_classifier, param_distributions=param_dist, n_iter=nter, cv=nter, random_state=42, n_jobs=6)
        lp_random.fit(X_train, y_train)
        best_params_lp = lp_random.best_params_
        self.lp_classifier_best = XGBClassifier(**best_params_lp)
        self.model = self.lp_classifier_best.fit(X_train, y_train)
        return self.model
    
    
    def rf_model(self, predictors, response, X_train, X_test, y_train, y_test, nter):
        #=============================================================================================================
        #HYPERPARAMETER TUNING
        #=============================================================================================================
        param_dist = {
        'classifier__n_estimators': np.arange(50, 1000, 10),
        'classifier__max_depth': [None] + list(np.arange(5, 30, 5)),
        'classifier__min_samples_leaf': np.arange(1, 11),
        'classifier__min_samples_split': np.arange(2, 11),
        'classifier__criterion': ['gini', 'entropy'],
        'classifier__bootstrap': [True],
        'classifier__random_state': [100],
        'classifier__n_jobs': [-1],
        'classifier__max_samples': [None] + list(np.arange(0.1, 1.1, 0.1)),
        'classifier__max_leaf_nodes': [None] + list(range(10, 101, 10)),
        'classifier__class_weight': [None, 'balanced', 'balanced_subsample'],
        }
        
        rf_classifier = RandomForestClassifier()  

        # Create a Label Powerset classifier with XGBoost as the base classifier
        gb = LabelPowerset(classifier=rf_classifier)
        gb_random = RandomizedSearchCV(gb, param_distributions = param_dist, n_iter = nter, 
                                       cv = 5, verbose=2, random_state=42, n_jobs = 6)
        gb_random.fit(predictors, response)
        
        #Get the best set of parameter combinations
        best_params = gb_random.best_params_
        
        #=====================================================================================================
        #TRAIN RANDOM FOREST MODEL BASED ON THE OPTIMAL PARAMETER SET
        #=====================================================================================================

        
        # Train the final model on the entire training set with optimal hyperparameters
        rf = LabelPowerset(classifier=rf_classifier)
        rf.set_params(**best_params)
        self.model = rf.fit(X_train, y_train)
    
    def evaluate_model_cr(self, X_test, y_test, model):
        #=====================================================================================================
        #EVALUATE PERFORMANCE OF TRAINED MODEL ON TEST DATA
        #=====================================================================================================
        #meta_features = self.create_metaFeatures(X_test)
        self.model = joblib.load(model)
        y_pred = self.model.predict(X_test)
        print(classification_report(y_test, y_pred))
        
    def plot_label_cooccurrence(self, X_test, y_test, model):
        #Plot cooccurence
        #meta_features = self.create_metaFeatures(X_test)
        self.model = joblib.load(model)
        y_pred = self.model.predict(X_test)
        skplt.metrics.plot_confusion_matrix(y_test, y_pred)
        plt.show()
        
    def plot_precision_recall(self, X_test, y_test, model):
        # Calculate precision and recall for each class
        #meta_features = self.create_metaFeatures(X_test)
        self.model = joblib.load(model)
        y_pred_prob = self.model.predict_proba(X_test)
        skplt.metrics.plot_precision_recall_curve(y_test, y_pred_prob, curves=('each_class'))
        plt.show()
        
        
    def plot_roc(self, X_test, y_test, model):
        self.model = joblib.load(model)
        #meta_features = self.create_metaFeatures(X_test)
        y_pred_prob = self.model.predict_proba(X_test)
        skplt.metrics.plot_roc_curve(y_test, y_pred_prob, curves=('each_class'))
        plt.show()
    
    #=========================================================================================================================
    #FUTURE PREDICTION
    #============================================================================================================================
        
    def load_scenario_data(self, heatwave=None, drought=None):
        date_format = '%d.%m.%Y'  # Adjust this format according to your data
        if heatwave is not None:
            self.heatwave_df = pd.read_csv(self.datadir + heatwave, sep=' ', quotechar='"', engine='python', 
                                           parse_dates=['date'], date_parser=lambda x: pd.to_datetime(x, format=date_format))
            
        self.drought_df = pd.read_csv(self.datadir + drought, sep=' ', quotechar='"', engine='python', 
                                          parse_dates=['date'], date_parser=lambda x: pd.to_datetime(x, format=date_format))
        
        
    def predict_preprocessor(self, start_year, end_year, risk_type):
       
        #Preprocess drought/soil moisture data
        # Extract year and month from the 'date' column
        if risk_type == 'drought':
            self.drought_df['year'] = pd.to_datetime(self.drought_df['date']).dt.year
            self.drought_df['month'] = pd.to_datetime(self.drought_df['date']).dt.month

            self.drought_df = self.drought_df[(self.drought_df['year'] >= start_year) & (self.drought_df['year'] <= end_year)]

            self.drought_df.set_index(['year', 'month'], inplace=True)

            all_data = self.drought_df
            all_data = all_data.reset_index()
            X = all_data.drop(columns=['id', 'date', 'SM_drought', 'year'], axis=1)

            X_months = pd.get_dummies(X['month'], prefix='month', drop_first=False)
            X_merged = pd.concat([X_months, X], axis=1)
            X_merged = X_merged.drop(columns=['month'], axis=1)
            
        elif risk_type=='heatwave':
            
            self.heatwave_df = self.heatwave_df.reset_index()
            self.heatwave_df['year'] = self.heatwave_df['year'].astype(int)

            self.heatwave_df = self.heatwave_df[(self.heatwave_df['year'] >= start_year) & (self.heatwave_df['year'] <= end_year)]

            self.heatwave_df.set_index(['year', 'month'], inplace=True)
            
            
            year_range = range(self.heatwave_df.index.get_level_values('year').min(), 
                               self.heatwave_df.index.get_level_values('year').max() + 1)
            month_range = range(1, 13)
            index = pd.MultiIndex.from_product([year_range, month_range], names=['year', 'month'])

            expanded_df = pd.DataFrame(index=index)

            # Merge the expanded DataFrame with the original data
            self.heatwave_df = expanded_df.merge(self.heatwave_df, left_index=True, right_index=True, how='left')
            self.heatwave_df = self.heatwave_df.drop(columns=['index'], axis=1)

            # Fill missing values with 0
            self.heatwave_df.fillna(0, inplace=True)

            all_data = self.heatwave_df
            all_data = all_data.reset_index()
            X = all_data.drop(columns=['year'], axis=1)

            X_months = pd.get_dummies(X['month'], prefix='month', drop_first=False)
            X_merged = pd.concat([X_months, X], axis=1)
            X_merged = X_merged.drop(columns=['month'], axis=1)
            
        return X_merged
        
    
    
    def model_predict(self, model, X):
        predicted = self.model.predict(X)
        predicted = pd.DataFrame(predicted.toarray(), columns=y_test.columns)
        return predicted
