from __future__ import print_function, division
from keras.datasets import cifar10, cifar100
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.regularizers import l2
from keras.utils import np_utils
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
from keras import backend as K
import tensorflow as tf
import random
from sklearn import metrics
from loss import total_loss, custom_sparse_categorical_accuracy, custom_loss
from scipy.stats import entropy
from numpy.testing import assert_array_almost_equal
from callback_util import SGDLearningRateTracker
from util import get_lr_scheduler

#K.common.set_image_dim_ordering('th')

import os
import time
import numpy as np
import pickle
import copy

from scipy.stats import truncnorm


os.environ['CUDA_VISIBLE_DEVICES'] = "0"

def build_for_cifar100(size, noise):
    """ random flip between two random classes.
    """
    assert(noise >= 0.) and (noise <= 1.)

    P = np.eye(size)
    cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
    P[cls1, cls2] = noise
    P[cls2, cls1] = noise
    P[cls1, cls1] = 1.0 - noise
    P[cls2, cls2] = 1.0 - noise

    assert_array_almost_equal(P.sum(axis=1), 1, 1)
    return P

def uniform_mix_C(mixing_ratio, num_classes):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    #return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
    #    (1 - mixing_ratio) * np.eye(num_classes)

    C = (1 - mixing_ratio) * np.eye(num_classes)
    for i in range(num_classes):
        for j in range(num_classes):
            if(i != j):
                C[i][j] = mixing_ratio * float(1.0 / (num_classes - 1))
    
    return C
    
def flip_labels_C(corruption_prob, num_classes, seed=123):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    np.random.seed(seed)
    C = np.eye(num_classes) * (1 - corruption_prob)
    row_indices = np.arange(num_classes)
    for i in range(num_classes):
        C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
    return C

def flip_labels_C_t(dataset, corruption_prob, num_classes, seed=1):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    if dataset == "cifar10":
        source_class = [9, 2, 3, 5, 4]
        target_class = [1, 0, 5, 3, 7]
        np.random.seed(seed)
        C = np.eye(num_classes) #* (1.0 - corruption_prob)
        for i in range(len(source_class)):
            C[source_class[i]][target_class[i]] = corruption_prob
            C[source_class[i]][source_class[i]] = 1.0 - corruption_prob
    elif dataset == "cifar100":
        P = np.eye(num_classes)
        n = corruption_prob
        nb_superclasses = 20
        nb_subclasses = 5
        np.random.seed(seed)
        
        if n > 0.0:
            for i in np.arange(nb_superclasses):
                init, end = i * nb_subclasses, (i+1) * nb_subclasses
                P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)
                 
        C = P
    
    return C

def truncnorm_mix_C(mixing_ratio, num_classes, mu=3.0, sigma=1.0):
    '''
    returns a linear interpolation of a uniform matrix and an truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1 - mixing_ratio) * np.eye(num_classes)
    
    lower = 0
    upper = num_classes - 1
    mu = mu * num_classes / 10 
    sigma = sigma * num_classes / 10
    tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    r = tn.pdf(range(num_classes))
    r = r / np.sum(r) # ensure sum = 1
    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * r[j]
    
    return C

def truncnorm_alt_mix_C(mixing_ratio, num_classes, sigma=1.0):
    '''
    returns a matrix with a truncnorm centered on the diagonal for each row
    and extending to the adjacent classes with roll-around at the borders
    '''
    from scipy.stats import truncnorm
    
    C = np.zeros(num_classes, num_classes)
    mu = int(num_classes/2)
    sigma = sigma * num_classes / 10
    lower = 0
    upper = num_classes - 1
    tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    r = tn.pdf(range(num_classes))
    r = r / np.sum(r) # ensure sum = 1
    for i in range(num_classes):
        C[i] = np.roll(r, i - mu)
        
    return C

def bimodal_mix_C(mixing_ratio, num_classes, mu1=3.0, sigma1=1.0, mu2=7.0, sigma2=0.5):
    '''
    returns a linear interpolation of a uniform matrix and two truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1 - mixing_ratio) * np.eye(num_classes)
    mu1 = np.random.randint(num_classes)
    mu2= np.random.randint(num_classes)
    if (mu1 > mu2):
        temp = mu2
        mu2 = mu1
        mu1 = temp
    lower = 0
    upper = num_classes - 1
    mu1 = mu1 * num_classes / 10
    sigma1 = sigma1 * num_classes / 10
    mu2 = mu2 * num_classes / 10
    sigma2 = sigma2 * num_classes / 10
    tn1 = truncnorm((lower - mu1) / sigma1, (upper - mu1) / sigma1, loc=mu1, scale=sigma1)
    tn2 = truncnorm((lower - mu2) / sigma2, (upper - mu2) / sigma2, loc=mu2, scale=sigma2)
    r1 = tn1.pdf(range(num_classes))
    r1 = r1 / np.sum(r1) # ensure sum = 1
    r2 = tn2.pdf(range(num_classes))
    r2 = r2 / np.sum(r2) # ensure sum = 1

    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * (r1[j] + r2[j]) / 2.0
    return C

def gen_conf_matrix(num_classes, num_workers, worker_noise_ptn, worker_noise_rto):
	# if num_classes not in config.matrix_para:
	# 	print("error: num_classes = %d is not configurated." % num_classes)
	# 	print("note: you can fix this problem by changing the configuration in /common/generate_noisy_labels.")
	# 	return None

	# if num_workers > config.max_workers:
	# 	print("error: max number of workers is %d." % config.max_workers)
	# 	print("note: you can fix this problem by changing the configuration in /common/generate_noisy_labels.")
	# 	return None
	
	# corresponding to k, t, c dimensions in the proposal
    matrix = np.zeros([num_workers, num_classes, num_classes]) 	
    if (len(worker_noise_ptn) == num_workers and len(worker_noise_rto) == num_workers):
        for worker in range(len(worker_noise_rto)):
            if worker_noise_ptn[worker] == 'unif':
                C = uniform_mix_C(worker_noise_rto[worker], num_classes)
            elif worker_noise_ptn[worker] == 'flip':
                C = flip_labels_C(worker_noise_rto[worker], num_classes, seed=123)
            elif worker_noise_ptn[worker] == 'truncnorm':
                C = truncnorm_mix_C(worker_noise_rto[worker], num_classes, mu=3.0, sigma=1.0)
            elif worker_noise_ptn[worker] == 'bimodal_mix_C':
                C = bimodal_mix_C(worker_noise_rto[worker], num_classes, mu1=3.0, sigma1=1.0, mu2=7.0, sigma2=0.5)	
            matrix[worker] = C
    return matrix

def gen_noisy_labels(original_data, num_workers, worker_noise_ptn, worker_noise_rto, empty_prop):
	# original_data has the true labels of the training samples
    classes = []
    for c in original_data:
        if c not in classes:
            classes.append(c)
    classes.sort()
    print("num classes:", len(classes))
    
    num_classes = len(classes)
    matrix = gen_conf_matrix(num_classes, num_workers, worker_noise_ptn, worker_noise_rto)
    if matrix is None:
        return None
    
    noisy_labels = np.zeros((original_data.shape[0],num_workers))
    for i in range(original_data.shape[0]):
        for j in range(num_workers):
            cm = matrix[j]
            noisy_labels[i][j] = np.random.choice(num_classes, p=cm[int(original_data[i])], replace=False)
            
        if num_workers > 1:
            empty_idx = np.random.choice(num_workers, int(empty_prop*num_workers))
            noisy_labels[i][empty_idx] = -1
    return noisy_labels


def asy_noise(dataset, noise_ratio, y_train):
    
    y_train_clean = np.copy(y_train)
    if dataset == 'mnist':
    # 1 < - 7, 2 -> 7, 3 -> 8, 5 <-> 6
        source_class = [7, 2, 3, 5, 6]
        target_class = [1, 7, 8, 6, 5]
    elif dataset == 'cifar-10':
    # automobile < - truck, bird -> airplane, cat <-> dog, deer -> horse
        source_class = [9, 2, 3, 5, 4]
        target_class = [1, 0, 5, 3, 7]
    
    for s, t in zip(source_class, target_class):
        
        cls_idx = np.where(y_train_clean == s)[0]
        n_noisy = int(noise_ratio * cls_idx.shape[0] / 100)
        noisy_sample_index = np.random.choice(cls_idx, n_noisy, replace=False)
        y_train[noisy_sample_index] = t
    
    return y_train #, y_train_clean



def other_class(n_classes, current_class):
    """
    Returns a list of class indices excluding the class indexed by class_ind
    :param nb_classes: number of classes in the task
    :param class_ind: the class index to be omitted
    :return: one random class that != class_ind
    """
    if current_class < 0 or current_class >= n_classes:
        error_str = "class_ind must be within the range (0, nb_classes - 1)"
        raise ValueError(error_str)

    other_class_list = list(range(n_classes))
    other_class_list.remove(current_class)
    other_class = np.random.choice(other_class_list)
    return other_class


def truncnorm_noise(X, Y_train, noise_ratio, n_classes, mu, sigma):
    noise_ratio = noise_ratio * 100
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio >= 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        lower = 0
        upper = n_classes
        tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
        for idx in noisy_idx:
            i = Y_train[idx]
            a = int(tn.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn.rvs())
            new_Ytrain[idx] = a
    return X, new_Ytrain

def bimodal_noise(X, Y_train, noise_ratio, n_classes, mu_1, sigma_1, mu_2, sigma_2):
    noise_ratio = noise_ratio * 100
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio >= 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        
        n_idx_1 = noisy_idx[:int(len(noisy_idx)/2)]
        n_idx_2 = noisy_idx[int(len(noisy_idx)/2):]

        lower = 0
        upper = n_classes

        tn_1 = truncnorm((lower - mu_1) / sigma_1, (upper - mu_1) / sigma_1, loc=mu_1, scale=sigma_1)
        tn_2 = truncnorm((lower - mu_2) / sigma_2, (upper - mu_2) / sigma_2, loc=mu_2, scale=sigma_2)
        
        for idx in n_idx_1: # first peak
            i = Y_train[idx]
            a = int(tn_1.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn_1.rvs())
            new_Ytrain[idx] = a
        
        for idx in n_idx_2: # second peak
            i = Y_train[idx]
            a = int(tn_2.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn_2.rvs())
            new_Ytrain[idx] = a

    return X, new_Ytrain


def inject_noise_uniform(X_train, Y_train, noise_ratio=30, n_classes=10):
    if noise_ratio > 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        for i in noisy_idx:
            Y_train[i] = other_class(n_classes=n_classes, current_class=Y_train[i])

    return X_train, Y_train

def cnn(input_shape, num_classes):
        model = Sequential()
        # block 1
        model.add(Conv2D(64, (3, 3), padding='same', kernel_initializer='random_uniform', name='block1_conv1', input_shape=input_shape))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(64, (3, 3), padding='same', kernel_initializer='random_uniform', name='block1_conv2'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool'))
        model.add(Dropout(0.3))
        # block 2
        model.add(Conv2D(128, (3, 3), padding='same', kernel_initializer='random_uniform', name='block2_conv1'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(128, (3, 3), padding='same', kernel_initializer='random_uniform', name='block2_conv2'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool'))
        model.add(Dropout(0.3))
        # block 3
        model.add(Conv2D(196, (3, 3), padding='same', kernel_initializer='random_uniform', name='block3_conv1'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(196, (3, 3), padding='same', kernel_initializer='random_uniform', name='block3_conv2'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool'))
        model.add(Dropout(0.3))

        model.add(Flatten(name='flatten'))

        model.add(
            Dense(256, kernel_initializer='random_uniform', kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01),
                  name='fc1'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))

        model.add(Dense(num_classes, kernel_initializer='random_uniform'))
        model.add(Activation('softmax'))


        return model

def train_CNN(x_train, y_agg, x_test, y_test, model, epoch=2, batch_size=128, dataset='cifar-10'):
    if dataset == 'cifar-10':
        sgd = SGD(lr=0.01, momentum=0.9, decay=1e-4)
        
        expert = model #cnn(img_shape, num_class)
        expert.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

        x_train = x_train
        y_train = y_agg
        x_test = x_test
        y_test = y_test

        print("Train on the training and aggregated data:")
        hist = expert.fit(x_train, y_train, batch_size=batch_size, epochs=epoch, validation_data=(x_test, y_test),)

        print("Evaluate on test data:")
        results = expert.evaluate(x_test, y_test, batch_size=batch_size)
        print("test loss:", results[0])
        print("test acc:", results[1])

        estimated_label = expert.predict(x_train)
        #estimated_label = estimated_label.max(axis=1)
        #estimated_label = np.argmax(estimated_label, axis=1)
    elif dataset == 'cifar-100':
        callbacks = []
        sgd = SGD(lr=0.1, decay=5e-3, momentum=0.9)
        expert = model #cnn(img_shape, num_class)
        expert.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
        lr_scheduler = get_lr_scheduler(dataset)
        callbacks.append(lr_scheduler)
        callbacks.append(SGDLearningRateTracker(model))
        
        datagen = ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True)
        datagen.fit(x_train)

        x_train = x_train
        y_train = y_agg
        x_test = x_test
        y_test = y_test

        print("Train on the training and aggregated data:")
        #hist = expert.fit(x_train, y_train, batch_size=batch_size, epochs=epoch, validation_data=(x_test, y_test),)
        hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                        steps_per_epoch=len(x_train) / batch_size, epochs=epoch,
                        validation_data=(x_test, y_test),
                        verbose=1,
                        callbacks=callbacks)

        print("Evaluate on test data:")
        results = expert.evaluate(x_test, y_test, batch_size=batch_size)
        print("test loss:", results[0])
        print("test acc:", results[1])

        estimated_label = expert.predict(x_train)
        #estimated_label = estimated_label.max(axis=1)
        #estimated_label = np.argmax(estimated_label, axis=1)



    return estimated_label


def load_dataset(num_worker, num_class, worker_noise_ptn, worker_noise_rto, empty_prop):
    
    num_class = num_class
    num_worker = num_worker
    worker_noise_ptn = worker_noise_ptn
    worker_noise_rto = worker_noise_rto
    empty_prop = empty_prop
    
    #num_worker = 6
    #num_class = 10
    
    img_rows = 32
    img_cols = 32
    img_channels = 3
    img_shape = (img_rows, img_cols, img_channels)

    dataset = 'cifar10'
    #worker_noise_ptn = ['unif', 'unif', 'unif', 'unif', 'unif', 'unif']
    #worker_noise_ptn = ['bimodal_mix_C', 'bimodal_mix_C', 'bimodal_mix_C','bimodal_mix_C', 'bimodal_mix_C', 'bimodal_mix_C']
    #worker_noise_rto = [0.4, 0.4, 0.4, 0.4, 0.4, 0.4]
    #empty_prop = 0
    


    if dataset == 'cifar10':
        (X_train_t, y_train_t), (X_test_t, y_test_t) = cifar10.load_data()

    X_train_t = X_train_t.astype('float')
    X_test_t = X_test_t.astype('float')
    

    X_train_t = X_train_t / 255.0
    X_test_t = X_test_t / 255.0
    
    means = X_train_t.mean(axis=0)
    
    X_train_t = X_train_t - means
    X_test_t = X_test_t - means
    
    
    Y_s_copy = copy.copy(y_train_t)

    X_s = X_train_t
    Y_s = gen_noisy_labels(Y_s_copy, num_worker, worker_noise_ptn, worker_noise_rto, empty_prop)


    X_train = X_s
    y_train = Y_s


    y_train_t = np_utils.to_categorical(y_train_t)
    #y_train = np_utils.to_categorical(y_train)
    y_test_t = np_utils.to_categorical(y_test_t)


    return X_train, y_train, y_train_t, X_test_t, y_test_t

def load_dataset_cifar100(num_worker, num_class, worker_noise_ptn, worker_noise_rto, empty_prop):
    
    num_class = num_class
    num_worker = num_worker
    worker_noise_ptn = worker_noise_ptn
    worker_noise_rto = worker_noise_rto
    empty_prop = empty_prop
    
    #num_worker = 6
    #num_class = 10
    
    img_rows = 32
    img_cols = 32
    img_channels = 3
    img_shape = (img_rows, img_cols, img_channels)

    dataset = 'cifar100'
    #worker_noise_ptn = ['unif', 'unif', 'unif', 'unif', 'unif', 'unif']
    #worker_noise_ptn = ['bimodal_mix_C', 'bimodal_mix_C', 'bimodal_mix_C','bimodal_mix_C', 'bimodal_mix_C', 'bimodal_mix_C']
    #worker_noise_rto = [0.4, 0.4, 0.4, 0.4, 0.4, 0.4]
    #empty_prop = 0
    


    (X_train_t, y_train_t), (X_test_t, y_test_t) = cifar100.load_data()

    X_train_t = X_train_t.astype('float')
    X_test_t = X_test_t.astype('float')
    

    X_train_t = X_train_t / 255.0
    X_test_t = X_test_t / 255.0
    
    means = X_train_t.mean(axis=0)
    
    X_train_t = X_train_t - means
    X_test_t = X_test_t - means
    
    y_train_t = y_train_t.ravel()
    y_test_t = y_test_t.ravel()

    Y_s_copy = copy.copy(y_train_t)

    X_s = X_train_t
    Y_s = gen_noisy_labels(Y_s_copy, num_worker, worker_noise_ptn, worker_noise_rto, empty_prop)


    X_train = X_s
    y_train = Y_s


    y_train_t = np_utils.to_categorical(y_train_t,100)
    #y_train = np_utils.to_categorical(y_train)
    y_test_t = np_utils.to_categorical(y_test_t,100)


    return X_train, y_train, y_train_t, X_test_t, y_test_t

    

