# coding: utf-8
# In[1]:
import numpy as np
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
from keras.models import clone_model
from datasets import get_data, get_training_data
from models import get_model, resnet_v1, resnet_v2
from util import select_clean_uncertain, combine_result, inject_noise, active_selection
from keras.datasets import mnist, cifar10, cifar100, fashion_mnist
import time
import argparse
from tensorflow.python.lib.io import file_io
from keras.utils import np_utils, multi_gpu_model
from keras import backend as K
from io import BytesIO
from loss_acc_plot import loss_acc_plot
import os
import pickle
import pdb
import copy
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
al_method = "ent"
def add_new_oracle_batch(original_oracle_x, new_oracle_x, original_oracle_y, new_oracle_y):
    # 5 is defined by myself, which means I use the last 5 oracle batches
    if len(original_oracle_x) == 5:
        del original_oracle_x[0]
        del original_oracle_y[0]
        original_oracle_x.append(new_oracle_x)
        original_oracle_y.append(new_oracle_y)
    else:
        original_oracle_x.append(new_oracle_x)
        original_oracle_y.append(new_oracle_y)
    return original_oracle_x, original_oracle_y
# combine all the data in oracle_buffer into one array
def oracle_combine(oracle_buffer):
    if len(oracle_buffer) == 1:
        return oracle_buffer[0]
    elif len(oracle_buffer) == 2:
        return  np.append(oracle_buffer[0], oracle_buffer[1], axis=0)
    else:
        final_array = []
        final_array = np.append(oracle_buffer[0], oracle_buffer[1], axis=0)
        for i in np.arange(2, len(oracle_buffer)):
            final_array = np.append(final_array, oracle_buffer[i], axis=0)
        return final_array
NUM_CLASSES = {'mnist': 10, 'svhn': 10, 'cifar-10': 10, 'cifar-100': 100, 'celeb': 20}
dataset = "cifar-10"
init_noise_ratio = 0
data_ratio = 20.0
X_train, y_train, X_test, y_test, un_selected_index = get_data(dataset, init_noise_ratio, data_ratio, random_shuffle=False)
image_shape = X_train.shape[1:]
#model_quality = resnet_v2(X_train.shape[1:], 56, NUM_CLASSES[dataset])
model_quality_NoSel = get_model(dataset, input_tensor=None, input_shape=image_shape, num_classes=NUM_CLASSES[dataset])
#model_quality = multi_gpu_model(resnet_v2(X_train.shape[1:], 56, NUM_CLASSES[dataset]), gpus=4)
#model_classifier = multi_gpu_model(get_model(dataset, input_tensor=None, input_shape=image_shape, num_classes=NUM_CLASSES[dataset]), gpus=4)
optimizer = SGD(lr=0.01, decay=1e-4, momentum=0.9)
model_quality_NoSel.compile(loss='categorical_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])
# model_classifier.compile(loss='categorical_crossentropy',
#               optimizer=optimizer,
#               metrics=['accuracy'])
datagen = ImageDataGenerator(
    featurewise_center = False,  # set input mean to 0 over the dataset
    samplewise_center = False,  # set each sample mean to 0
    featurewise_std_normalization = False,  # divide inputs by std of the dataset
    samplewise_std_normalization = False,  # divide each input by its std
    zca_whitening = False,  # apply ZCA whitening
    rotation_range = 0,  # randomly rotate images in the range (degrees, 0 to 180)
    width_shift_range = 0.1,  # randomly shift images horizontally (fraction of total width)
    height_shift_range = 0.1,  # randomly shift images vertically (fraction of total height)
    horizontal_flip = True,  # randomly flip images
    vertical_flip = False,  # randomly flip images
    )
datagen.fit(X_train)
epochs_init = 60
batch_size = 64
# following code is used only under multi_gpu_model
# rand_repeat = 4  - X_train.shape[0] % 4
# if rand_repeat != 4:
#     chosen_repeat = np.random.choice(np.arange(0,X_train.shape[0]), rand_repeat, replace=False)
#     X_train = np.append(X_train[chosen_repeat], X_train,axis=0)
#     y_train = np.append(y_train[chosen_repeat], y_train,axis=0)
h_quality_NoSel  =  model_quality_NoSel.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
                        steps_per_epoch=X_train.shape[0]//batch_size, epochs=epochs_init,
                        validation_data=(X_test, y_test)
                        )
model_quality_OptSel = copy.deepcopy(model_quality_NoSel)
h_quality_OptSel = copy.deepcopy(h_quality_NoSel)
#
# h_classifier  =  model_classifier.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
#                         steps_per_epoch=X_train.shape[0]//batch_size, epochs=epochs_init,
#                         validation_data=(X_test, y_test)
                        #)
# used to memorize
last_oracle_x = []
last_oracle_y = []
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.reshape(-1, 32, 32, 3)
X_test = X_test.reshape(-1, 32, 32, 3)
X_train = X_train / 255.0
X_test = X_test / 255.0
means = X_train.mean(axis=0)
# std = np.std(X_train)
X_train = (X_train - means)  # / std
X_test = (X_test - means)  # / std
# they are 2D originally in cifar
y_train = y_train.ravel()
y_test = y_test.ravel()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
y_train = np_utils.to_categorical(y_train, NUM_CLASSES[dataset])
y_test = np_utils.to_categorical(y_test, NUM_CLASSES[dataset])
epochs_training = 20
training_noise_level = 30
oracle_length = []
# training_steps indicate how many data used in each batch
training_steps = 1000
steps = int(np.ceil(len(un_selected_index) / float(training_steps)))
for i in np.arange(steps):
    model_quality_NoSel.compile(loss='categorical_crossentropy',
                    optimizer=optimizer,
                    metrics=['accuracy'])
    model_quality_OptSel.compile(loss='categorical_crossentropy',
                    optimizer=optimizer,
                    metrics=['accuracy'])
    # model_classifier.compile(loss='categorical_crossentropy',
    #               optimizer=optimizer,
    #               metrics=['accuracy'])
    if i == 0:
        sub_un_selected_list = un_selected_index[0:training_steps]
    elif i != steps - 1:
        sub_un_selected_list = un_selected_index[i*training_steps:(i+1)*training_steps]
    else:
        sub_un_selected_list = un_selected_index[i*training_steps:]
    X_clean_iteration = X_train[sub_un_selected_list]
    y_clean_iteration = y_train[sub_un_selected_list]
    y_noisy_iteration, noisy_idx = inject_noise(dataset, y_train[sub_un_selected_list], training_noise_level)
    clean_idx = (np.delete(np.array(sub_un_selected_list, dtype='int64'), noisy_idx)).tolist()
    X_OptSel_iteration = X_train[clean_idx]
    #clean_list, uncertain_list = select_clean_uncertain(X_clean_iteration, y_noisy_iteration, model_quality, model_quality)
    #clean_list, uncertain_list, predict_prob = select_clean_uncertain(X_clean_iteration, y_noisy_iteration, model_quality, model_quality)
    # threshold = 0.1 means we could only ask 10% of the whole batch to
    #threshold = 0.1
    #if len(uncertain_list) > int(threshold*training_steps):
        #oracle_list = np.random.choice(uncertain_list, int(threshold*training_steps), replace=False)
    #else:
        #oracle_list = uncertain_list
    #oracle_list = []
    #num_al = 200
    #oracle_list, inf_ind = active_selection (predict_prob,uncertain_list, al_method, num_al)
    #oracle_length.append((len(oracle_list)))
    #y_noisy_iteration[oracle_list] = y_clean_iteration[oracle_list]
    #training_list = np.append(clean_list, oracle_list)
    #training_list = map(int, training_list)
    # in this case, this is the first batch with noisy label data.
    # why we have this last_oracle_x, is because we want to each oracle_given_data use two times
    #if last_oracle_x == []:
     #   X_training_iteration = X_clean_iteration[training_list]
      #  y_training_iteration = y_noisy_iteration[training_list]
    #else:
     #   X_training_iteration = np.append(X_clean_iteration[training_list], oracle_combine(last_oracle_x),axis=0)
      #  y_training_iteration = np.append(y_noisy_iteration[training_list], oracle_combine(last_oracle_y),axis=0)
    # rand_repeat = 4  - (X_training_iteration.shape[0] % 4)
    # if rand_repeat != 4:
    #     chosen_repeat = np.random.choice(np.arange(0,X_training_iteration.shape[0]), rand_repeat, replace=False)
    #     X_training_iteration = np.append(X_training_iteration[chosen_repeat], X_training_iteration,axis=0)
    #     y_training_iteration = np.append(y_training_iteration[chosen_repeat], y_training_iteration,axis=0)
    h_training_epoch_quality_NoSel =  model_quality_NoSel.fit_generator(datagen.flow(X_clean_iteration, y_noisy_iteration, batch_size=batch_size),
                    steps_per_epoch=y_noisy_iteration.shape[0]//batch_size+1, epochs=epochs_training,
                    validation_data=(X_test, y_test)
                    )
    h_training_epoch_quality_OptSel =  model_quality_OptSel.fit_generator(datagen.flow(X_OptSel_iteration, y_train[clean_idx], batch_size=batch_size),
                    steps_per_epoch=y_train[clean_idx].shape[0]//batch_size+1, epochs=epochs_training,
                    validation_data=(X_test, y_test)
                    )
    # h_training_epoch_classifier =  model_classifier.fit_generator(datagen.flow(X_training_iteration, y_training_iteration, batch_size=batch_size),
    #                 steps_per_epoch=y_training_iteration.shape[0]//batch_size+1, epochs=epochs_training,
    #                 validation_data=(X_test, y_test)
    #                 )
    #last_oracle_x, last_oracle_y = add_new_oracle_batch(last_oracle_x, X_clean_iteration[oracle_list], last_oracle_y, y_noisy_iteration[oracle_list])
    #pdb.set_trace()
    h_quality_NoSel  = combine_result(h_quality_NoSel, h_training_epoch_quality_NoSel)
    h_quality_OptSel  = combine_result(h_quality_OptSel, h_training_epoch_quality_OptSel)
    #pdb.set_trace()
    # h_classifier  = combine_result(h_classifier, h_training_epoch_classifier)
    #print(oracle_length)
print(steps)
#print(h_quality.history['val_acc'])
#print(h_quality.history['val_loss'])
with open('resultsNoSel'+str(epochs_init)+'+'+str(epochs_training)+'_'+str(steps)+'chunks.pickle', 'wb') as file_pi:
        pickle.dump(h_quality_NoSel, file_pi)
with open('resultsOptSel'+str(epochs_init)+'+'+str(epochs_training)+'_'+str(steps)+'chunks.pickle', 'wb') as file_pi:
        pickle.dump(h_quality_OptSel, file_pi)
# print(h_classifier.history['val_acc'])
# print(h_classifier.history['val_loss'])
# dest = 'gs://ml_engine_data_bucket/result/CELEB_FORMAT128_20CLASS_20PERCENT.npy'
# np.save(file_io.FileIO(dest, 'w'), h_quality.history)
#np.save(file_io.FileIO(job_dir + 'result/cifar10_quality_resnet2_56_0percent_1.npy', 'w'), h_quality.history)
# np.save(file_io.FileIO(job_dir + 'result/cifar10_classifier_vgg_0percent_1.npy', 'w'), h_classifier.history)