from __future__ import absolute_import
from __future__ import print_function

import os
import sys
import time
import numpy as np
import keras.backend as K
import argparse
from keras.utils import np_utils
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras.optimizers import SGD
from keras.callbacks import ModelCheckpoint
import tensorflow as tf
from sklearn import metrics

from util import get_lr_scheduler
from datasets import get_data
from models import get_model
from loss import symmetric_cross_entropy, cross_entropy, lsr, joint_optimization_loss, generalized_cross_entropy, boot_soft, boot_hard, forward, backward
from callback_util import LoggerCallback, SGDLearningRateTracker
from scipy.stats import entropy

D2L = {'mnist': {'init_epoch': 5, 'epoch_win': 5}, 'svhn': {'init_epoch': 20, 'epoch_win': 5},
       'cifar-10': {'init_epoch': 40, 'epoch_win': 5}, 'cifar-100': {'init_epoch': 60, 'epoch_win': 5}}
# prepare folders
folders = ['data', 'model', 'log', 'results']
for folder in folders:
    path = os.path.join('./', folder)
    if not os.path.exists(path):
        os.makedirs(path)

MODEL_PARAMS = { \
    'cifar-10': {'epoch': 120, 'batch_size' : 128,  'sl' : {'alpha' : 0.1, 'beta' : 1.0}},
    'cifar-100': {'epoch': 200, 'batch_size' : 128, 'sl' : {'alpha' : 6.0, 'beta' : 1.0}}}

def train(dataset='mnist', model_name='sl', batch_size=None, epochs=None, noise_ratio=0, noise_pattern='asym', alpha = None, beta = None, random_shuffle=False):
    """
    Train one model with data augmentation: random padding+cropping and horizontal flip
    :param dataset: 
    :param model_name:
    :param batch_size: 
    :param epochs: 
    :param noise_ratio:
    :param noise_pattern:
    :return: 
    """
    assert dataset in MODEL_PARAMS, \
        "dataset parameter must be either 'mnist', 'cifar-10' or 'cifar-100'"
    
    model_params = MODEL_PARAMS[dataset]

    if epochs is None:
        epochs = model_params['epoch']
    
    if batch_size is None:
        batch_size = model_params['batch_size']
    
    if alpha is None and model_name == 'sl':
        alpha = model_params['sl']['alpha']

    if beta is None and model_name == 'sl':
        beta = model_params['sl']['beta']

    print('Dataset: %s, model: %s, batch: %s, epochs: %s, noise ratio: %s%%, noise_pattern: %s, alpha: %s, beta: %s' %
          (dataset, model_name, batch_size, epochs, noise_ratio, noise_pattern, alpha, beta))

    # load data
    X_train, y_train, y_train_clean, X_test, y_test, file_suffix = get_data(dataset, noise_ratio, noise_pattern, random_shuffle)
    n_images = X_train.shape[0]
    image_shape = X_train.shape[1:]
    num_classes = y_train.shape[1]
    val_idx = np.array([True for i in range(n_images)])
    val_idx_int = np.array([i for i in range(n_images) if val_idx[i]]) # integer index
    np.random.shuffle(val_idx_int)
    n_val_half = int(np.sum(val_idx)/2)
    val1_idx = val_idx_int[:n_val_half] # integer index
    val2_idx = val_idx_int[n_val_half:] # integer index
    print("n_images", n_images, "num_classes", num_classes, "image_shape:", image_shape)
    
    # define P for forward and backward loss
    P = np.eye(num_classes)
    init_shape = (3, 32, 32) if K.image_dim_ordering() == 'th' else (32, 32, 3)

    
    # load model
    model = get_model(dataset, input_tensor=None, input_shape=image_shape, num_classes=num_classes)

    if dataset == 'cifar-100':
        optimizer = SGD(lr=0.1, decay=5e-3, momentum=0.9)
    else:
        optimizer = SGD(lr=0.1, decay=1e-4, momentum=0.9)

    # create loss
    if model_name == 'ce':
        loss = cross_entropy
    elif model_name =='sl':
        loss = symmetric_cross_entropy(alpha,beta)
    elif model_name == 'lsr':
        loss = lsr
    elif model_name =='joint':
        loss = joint_optimization_loss
    elif model_name =='gce':
        loss = generalized_cross_entropy
    elif model_name == 'boot_hard':
        loss = boot_hard
    elif model_name == 'boot_soft':
        loss = boot_soft
    elif model_name == 'forward':
        loss = forward(P)
    elif model_name == 'backward':
        loss = backward(P)
    else:
        print("Model %s is unimplemented!" % model_name)
        exit(0)

    # model
    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=['accuracy']
    )
    callbacks = []

    model_save_file = "model/%s_%s_%s_%s.{epoch:02d}.hdf5" % (noise_pattern, model_name, dataset, noise_ratio)

    ## do real-time updates using callbakcs

    if model_name == 'sl':
        cp_callback = ModelCheckpoint(model_save_file,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)
    else:
        cp_callback = ModelCheckpoint(model_save_file,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)

    # learning rate scheduler if use sgd
    lr_scheduler = get_lr_scheduler(dataset)
    callbacks.append(lr_scheduler)

    callbacks.append(SGDLearningRateTracker(model))

    # acc, loss, lid
    log_callback = LoggerCallback(model, X_train, y_train, y_train_clean, X_test, y_test, dataset, model_name, noise_ratio, noise_pattern, epochs, alpha, beta)
    callbacks.append(log_callback)

    # data augmentation
    if dataset in ['mnist', 'svhn']:
        datagen = ImageDataGenerator()
    elif dataset in ['cifar-10']:
        datagen = ImageDataGenerator(
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True)
    else:
        datagen = ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True)
    datagen.fit(X_train)

    # train model
    #model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
    #                    steps_per_epoch=len(X_train) / batch_size, epochs=epochs,
    #                    validation_data=(X_test, y_test),
    #                    verbose=1,
    #                    callbacks=callbacks
    #                    )
    
    model.fit_generator(datagen.flow(X_train[val1_idx], y_train[val1_idx], batch_size=batch_size),
                        steps_per_epoch=len(X_train[val1_idx]) / batch_size, epochs=epochs,
                        validation_data=(X_test, y_test),
                        verbose=1,
                        callbacks=callbacks
                        )
        
    y_pred = model.predict(X_test)
    y_test = np.argmax(y_test, axis=1)
    estimated_label = np.argmax(y_pred, axis=1)
    print(metrics.classification_report(y_test, estimated_label, digits=3))
 
    cm = metrics.confusion_matrix(y_test, estimated_label)
    file_name = "results/cm_result_%s_%s_%s_noiseratio%s_epochs%d_batch_size%d%s.txt" % (dataset, model_name, noise_pattern, noise_ratio, epochs, batch_size, file_suffix)
    f_results = np.save(file_name, cm)
    
    cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
        
    print(cm.diagonal())
    test_accuracy = np.sum(y_test == estimated_label)/y_test.shape[0]
    print(f'Test accuracy: {test_accuracy}')

    if model_name == 'sl':
        file_suffix = ('_alpha%.2f_beta%.2f' % (alpha, beta)) + file_suffix

    file_name = "results/test_accuracy_%s_%s_%s_noiseratio%s_epochs%d_batch_size%d%s.txt" % (dataset, model_name, noise_pattern, noise_ratio, epochs, batch_size, file_suffix)
    f_results = open(file_name,"a+")
    f_results.write("%.3f\n" % test_accuracy)
    f_results.close()


def main(args):
    random_shuffle =  False
    if args.repetitions > 1:
        random_shuffle = True

    for i in range(args.repetitions):
        train(args.dataset, args.model_name, args.batch_size, args.epochs, args.noise_ratio, args.noise_pattern, args.alpha, args.beta, random_shuffle=random_shuffle)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-d', '--dataset',
        help="Dataset to use; either 'mnist', 'cifar-10', 'cifar-100'",
        required=True, type=str
    )
    parser.add_argument(
        '-m', '--model_name',
        help="Model name: 'ce', 'sl' ",
        required=True, type=str
    )
    parser.add_argument(
        '-e', '--epochs',
        help="The number of epochs to train for.",
        required=False, type=int
    )
    parser.add_argument(
        '-b', '--batch_size',
        help="The batch size to use for training.",
        required=False, type=int
    )
    parser.add_argument(
        '-r', '--noise_ratio',
        help="The percentage of noisy labels [0, 100].",
        required=False, type=int
    )
    parser.add_argument(
        '-n', '--noise_pattern',
        help="noise pattern.",
        required=True, type=str
    )
    parser.add_argument(
        '-alpha', '--alpha',
        help="alpha parameter.",
        required=False, type=float
    )
    parser.add_argument(
        '-beta', '--beta',
        help="beta parameter.",
        required=False, type=float
    )
    parser.add_argument(
        '-R', '--repetitions',
        help="repetitions.",
        required=False, type=int
    )
    parser.set_defaults(epochs=None)
    parser.set_defaults(batch_size=None)
    parser.set_defaults(alpha=None)
    parser.set_defaults(beta=None)
    parser.set_defaults(noise_ratio=0)
    parser.set_defaults(noise_pattern='sym')
    parser.set_defaults(repetitions=1)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    args = parser.parse_args()

    main(args)

    # MNIST

    # args = parser.parse_args(['-d', 'mnist', '-m', 'sl',
    #                           '-e', '50', '-b', '128',
    #                           '-r', '40', '-alpha', '0.01', '-beta', '1.0'])
    # main(args)


    # CIFAR-10

    # args = parser.parse_args(['-d', 'cifar-10', '-m', 'boot_hard',
    #                           '-e', '120', '-b', '128',
    #                           '-r', '40', '-alpha', '6.0', '-beta', '0.1'])
    # main(args)

    # CIFAR-100

    # args = parser.parse_args(['-d', 'cifar-100', '-m', 'sl',
    #                           '-e', '150', '-b', '128',
    #                           '-r', '40', '-alpha', '6.0', '-beta', '0.1'])
    # main(args)

    K.clear_session()
