import os
import sys
import numpy as np
from keras.datasets import mnist, cifar10, cifar100
from keras.utils import np_utils
from util import other_class
from numpy.testing import assert_array_almost_equal
import copy
from scipy.stats import truncnorm

# Set random seed
np.random.seed(123)

NUM_CLASSES = {'mnist': 10, 'svhn': 10, 'cifar-10': 10, 'cifar-100': 100}
NOISE_PARAMS = {'truncnorm': {'mu': 1, 'sigma': 0.5}, 'bimodal': {'mu_1': 2, 'sigma_1': 1, 'mu_2': 7, 'sigma_2': 3} }

def bimodal_noise(Y_train, noise_ratio, n_classes, mu_1, sigma_1, mu_2, sigma_2):
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio > 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        
        n_idx_1 = noisy_idx[:int(len(noisy_idx)/2)]
        n_idx_2 = noisy_idx[int(len(noisy_idx)/2):]

        lower = 0
        upper = n_classes

        tn_1 = truncnorm((lower - mu_1) / sigma_1, (upper - mu_1) / sigma_1, loc=mu_1, scale=sigma_1)
        tn_2 = truncnorm((lower - mu_2) / sigma_2, (upper - mu_2) / sigma_2, loc=mu_2, scale=sigma_2)
        
        for idx in n_idx_1: # first peak
            i = Y_train[idx]
            a = int(tn_1.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn_1.rvs())
            new_Ytrain[idx] = a
        
        for idx in n_idx_2: # second peak
            i = Y_train[idx]
            a = int(tn_2.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn_2.rvs())
            new_Ytrain[idx] = a

    return new_Ytrain

def truncnorm_noise(Y_train, noise_ratio, n_classes, mu, sigma):
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio >= 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        lower = 0
        upper = n_classes
        tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
        for idx in noisy_idx:
            i = Y_train[idx]
            a = int(tn.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn.rvs())
            new_Ytrain[idx] = a
    return new_Ytrain

def sym_noise(Y_train, noise_ratio, n_classes):
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio >= 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        print(f'ratio {n_noisy/n_samples}')
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        
        for idx in noisy_idx:
            i = Y_train[idx]
            a = np.random.randint(0, n_classes)
            while a == i:
                a = np.random.randint(0, n_classes)
            new_Ytrain[idx] = a
    return new_Ytrain


def asym_noise(y_train,y_train_clean, noise_ratio, dataset):
    if dataset == 'mnist': 
        # 1 < - 7, 2 -> 7, 3 -> 8, 5 <-> 6
        source_class = [7, 2, 3, 5, 6]
        target_class = [1, 7, 8, 6, 5]
    elif dataset == 'cifar-10':
        # automobile < - truck, bird -> airplane, cat <-> dog, deer -> horse
        source_class = [9, 2, 3, 5, 4]
        target_class = [1, 0, 5, 3, 7]
    elif dataset == 'cifar-100':
            P = np.eye(NUM_CLASSES[dataset])
            n = noise_ratio/100.0
            nb_superclasses = 20
            nb_subclasses = 5

            if n > 0.0:
                for i in np.arange(nb_superclasses):
                    init, end = i * nb_subclasses, (i+1) * nb_subclasses
                    P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)

                y_train_noisy = multiclass_noisify(y_train, P=P,
                                                    random_state=0)
                actual_noise = (y_train_noisy != y_train).mean()
                assert actual_noise > 0.0
                y_train = y_train_noisy
    else:
        print('Asymmetric noise is not supported now for dataset: %s' % dataset)
        return None, None
    
    if dataset == 'mnist' or dataset == 'cifar-10':
        for s, t in zip(source_class, target_class):
            cls_idx = np.where(y_train_clean == s)[0]
            n_noisy = int(noise_ratio * cls_idx.shape[0] / 100)
            noisy_sample_index = np.random.choice(cls_idx, n_noisy, replace=False)
            y_train[noisy_sample_index] = t
    
    return y_train

def build_for_cifar100(size, noise):
    """ random flip between two random classes.
    """
    assert(noise >= 0.) and (noise <= 1.)

    P = np.eye(size)
    cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
    P[cls1, cls2] = noise
    P[cls2, cls1] = noise
    P[cls1, cls1] = 1.0 - noise
    P[cls2, cls2] = 1.0 - noise

    assert_array_almost_equal(P.sum(axis=1), 1, 1)
    return P

def multiclass_noisify(y, P, random_state=0):
    """ Flip classes according to transition probability matrix T.
    It expects a number between 0 and the number of classes - 1.
    """

    assert P.shape[0] == P.shape[1]
    assert np.max(y) < P.shape[0]

    # row stochastic matrix
    assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
    assert (P >= 0.0).all()

    m = y.shape[0]
    new_y = y.copy()
    flipper = np.random.RandomState(random_state)

    for idx in np.arange(m):
        i = y[idx]
        # draw a vector with only an 1
        flipped = flipper.multinomial(1, P[i, :], 1)[0]
        new_y[idx] = np.where(flipped == 1)[0]

    return new_y


def get_data(dataset='mnist', noise_ratio=0, noise_pattern='uniform', random_shuffle=False):
    """
    Get training images with specified ratio of syn/ayn label noise
    """
    if dataset == 'mnist':
        (X_train, y_train), (X_test, y_test) = mnist.load_data()

        X_train = X_train.reshape(-1, 28, 28, 1)
        X_test = X_test.reshape(-1, 28, 28, 1)

        X_train = X_train / 255.0
        X_test = X_test / 255.0

    elif dataset == 'cifar-10':
        (X_train, y_train), (X_test, y_test) = cifar10.load_data()

        X_train = X_train.reshape(-1, 32, 32, 3)
        X_test = X_test.reshape(-1, 32, 32, 3)

        X_train = X_train / 255.0
        X_test = X_test / 255.0

        means = X_train.mean(axis=0)
        # std = np.std(X_train)
        X_train = (X_train - means)  # / std
        X_test = (X_test - means)  # / std

        # they are 2D originally in cifar
        y_train = y_train.ravel()
        y_test = y_test.ravel()

    elif dataset == 'cifar-100':
        # num_classes = 100
        (X_train, y_train), (X_test, y_test) = cifar100.load_data()

        X_train = X_train.reshape(-1, 32, 32, 3)
        X_test = X_test.reshape(-1, 32, 32, 3)

        X_train = X_train / 255.0
        X_test = X_test / 255.0

        means = X_train.mean(axis=0)
        # std = np.std(X_train)
        X_train = (X_train - means)  # / std
        X_test = (X_test - means)  # / std

        # they are 2D originally in cifar
        y_train = y_train.ravel()
        y_test = y_test.ravel()
    else:
        return None, None, None, None


    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')

    y_train_clean = np.copy(y_train)
    
    # generate random noisy labels
    file_suffix = ''
    if noise_pattern in NOISE_PARAMS:
        file_suffix += '_params'
        for k, v in NOISE_PARAMS[noise_pattern].items():
            file_suffix += f'_%s_%.2f' % (k, v)

    if noise_ratio > 0:
        data_file = "data/%s_%s_train_labels_%s%s" % (dataset, noise_pattern, noise_ratio, file_suffix)

        if os.path.isfile(data_file):
            y_train = np.load(data_file)
        else:
            if noise_pattern == 'sym':
                y_train = sym_noise(y_train, noise_ratio, NUM_CLASSES[dataset])
            elif noise_pattern == 'asym':
                y_train = asym_noise(y_train, y_train_clean, noise_ratio, dataset)
            elif noise_pattern == 'truncnorm':
                params = NOISE_PARAMS['truncnorm']
                y_train = truncnorm_noise(y_train, noise_ratio, NUM_CLASSES[dataset], params['mu'], params['sigma'])            
            elif noise_pattern == 'bimodal':
                params = NOISE_PARAMS['bimodal']
                y_train = bimodal_noise(y_train, noise_ratio, NUM_CLASSES[dataset], params['mu_1'], params['sigma_1'], params['mu_2'], params['sigma_2'])
            else:
                print('Unknown noise pattern %s' % noise_pattern)
                sys.exit(1)
            np.save(data_file, y_train)

        # print statistics
        print("Print noisy label generation statistics:")
        for i in range(NUM_CLASSES[dataset]):
            n_noisy = np.sum(y_train == i)
            print("Noisy class %s, has %s samples." % (i, n_noisy))

    if random_shuffle:
        # random shuffle
        idx_perm = np.random.permutation(X_train.shape[0])
        X_train, y_train, y_train_clean = X_train[idx_perm], y_train[idx_perm], y_train_clean[idx_perm]

    noise_r = 0
    for i in range(len(y_train)):
        if y_train_clean[i] != y_train[i]:
            noise_r += 1
    
    print(noise_r/len(y_train))
    
    # one-hot-encode the labels
    y_train_clean = np_utils.to_categorical(y_train_clean, NUM_CLASSES[dataset])
    y_train = np_utils.to_categorical(y_train, NUM_CLASSES[dataset])
    y_test = np_utils.to_categorical(y_test, NUM_CLASSES[dataset])

    print("X_train:", X_train.shape)
    print("y_train:", y_train.shape)
    print("X_test:", X_test.shape)
    print("y_test", y_test.shape)

    return X_train, y_train, y_train_clean, X_test, y_test, file_suffix


if __name__ == "__main__":
    X_train, Y_train, X_test, Y_test = get_data(dataset='mnist', noise_ratio=40)
    Y_train = np.argmax(Y_train, axis=1)
    (_, Y_clean_train), (_, Y_clean_test) = mnist.load_data()
    clean_selected = np.argwhere(Y_train == Y_clean_train).reshape((-1,))
    noisy_selected = np.argwhere(Y_train != Y_clean_train).reshape((-1,))
    print("#correct labels: %s, #incorrect labels: %s" % (len(clean_selected), len(noisy_selected)))
