from __future__ import print_function, division
from keras.datasets import cifar10
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.regularizers import l2
from keras.utils import np_utils
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
from keras import backend as K
import tensorflow as tf
import random
from sklearn import metrics
from loss import total_loss, custom_sparse_categorical_accuracy, custom_loss
from scipy.stats import entropy

K.set_image_dim_ordering('th')

import os
import time
import numpy as np
import pickle
import copy

from scipy.stats import truncnorm


os.environ['CUDA_VISIBLE_DEVICES'] = "0"
#os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import matplotlib.pyplot as plt

def uniform_mix_C(mixing_ratio, num_classes):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    #return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
    #    (1 - mixing_ratio) * np.eye(num_classes)

    C = (1 - mixing_ratio) * np.eye(num_classes)
    for i in range(num_classes):
        for j in range(num_classes):
            if(i != j):
                C[i][j] = mixing_ratio * float(1.0 / (num_classes - 1))
    
    return C
    
def flip_labels_C(corruption_prob, num_classes, seed=123):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    np.random.seed(seed)
    C = np.eye(num_classes) * (1 - corruption_prob)
    row_indices = np.arange(num_classes)
    for i in range(num_classes):
        C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
    return C

def flip_labels_C_t(dataset, corruption_prob, num_classes, seed=1):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    if dataset == "cifar10":
        source_class = [9, 2, 3, 5, 4]
        target_class = [1, 0, 5, 3, 7]
        np.random.seed(seed)
        C = np.eye(num_classes) #* (1.0 - corruption_prob)
        for i in range(len(source_class)):
            C[source_class[i]][target_class[i]] = corruption_prob
            C[source_class[i]][source_class[i]] = 1.0 - corruption_prob
    elif dataset == "cifar100":
        P = np.eye(num_classes)
        n = corruption_prob
        nb_superclasses = 20
        nb_subclasses = 5
        np.random.seed(seed)
        
        if n > 0.0:
            for i in np.arange(nb_superclasses):
                init, end = i * nb_subclasses, (i+1) * nb_subclasses
                P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)
                 
        C = P
    
    return C

def truncnorm_mix_C(mixing_ratio, num_classes, mu=3.0, sigma=1.0):
    '''
    returns a linear interpolation of a uniform matrix and an truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1 - mixing_ratio) * np.eye(num_classes)
    
    lower = 0
    upper = num_classes - 1
    mu = mu * num_classes / 10 
    sigma = sigma * num_classes / 10
    tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    r = tn.pdf(range(num_classes))
    r = r / np.sum(r) # ensure sum = 1
    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * r[j]
    
    return C

def truncnorm_alt_mix_C(mixing_ratio, num_classes, sigma=1.0):
    '''
    returns a matrix with a truncnorm centered on the diagonal for each row
    and extending to the adjacent classes with roll-around at the borders
    '''
    from scipy.stats import truncnorm
    
    C = np.zeros(num_classes, num_classes)
    mu = int(num_classes/2)
    sigma = sigma * num_classes / 10
    lower = 0
    upper = num_classes - 1
    tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    r = tn.pdf(range(num_classes))
    r = r / np.sum(r) # ensure sum = 1
    for i in range(num_classes):
        C[i] = np.roll(r, i - mu)
        
    return C

def bimodal_mix_C(mixing_ratio, num_classes, mu1=3.0, sigma1=1.0, mu2=7.0, sigma2=0.5):
    '''
    returns a linear interpolation of a uniform matrix and two truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1 - mixing_ratio) * np.eye(num_classes)
    
    lower = 0
    upper = num_classes - 1
    mu1 = mu1 * num_classes / 10
    sigma1 = sigma1 * num_classes / 10
    mu2 = mu2 * num_classes / 10
    sigma2 = sigma2 * num_classes / 10
    tn1 = truncnorm((lower - mu1) / sigma1, (upper - mu1) / sigma1, loc=mu1, scale=sigma1)
    tn2 = truncnorm((lower - mu2) / sigma2, (upper - mu2) / sigma2, loc=mu2, scale=sigma2)
    r1 = tn1.pdf(range(num_classes))
    r1 = r1 / np.sum(r1) # ensure sum = 1
    r2 = tn2.pdf(range(num_classes))
    r2 = r2 / np.sum(r2) # ensure sum = 1

    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * (r1[j] + r2[j]) / 2.0
    return C

def gen_conf_matrix(num_classes, num_workers, worker_noise_ptn, worker_noise_rto):
	# if num_classes not in config.matrix_para:
	# 	print("error: num_classes = %d is not configurated." % num_classes)
	# 	print("note: you can fix this problem by changing the configuration in /common/generate_noisy_labels.")
	# 	return None

	# if num_workers > config.max_workers:
	# 	print("error: max number of workers is %d." % config.max_workers)
	# 	print("note: you can fix this problem by changing the configuration in /common/generate_noisy_labels.")
	# 	return None
	
	# corresponding to k, t, c dimensions in the proposal
    matrix = np.zeros([num_workers, num_classes, num_classes]) 	
    if (len(worker_noise_ptn) == num_workers and len(worker_noise_rto) == num_workers):
        for worker in range(len(worker_noise_rto)):
            if worker_noise_ptn[worker] == 'unif':
                C = uniform_mix_C(worker_noise_rto[worker], num_classes)
            elif worker_noise_ptn[worker] == 'flip':
                C = flip_labels_C(worker_noise_rto[worker], num_classes, seed=123)
            elif worker_noise_ptn[worker] == 'truncnorm':
                C = truncnorm_mix_C(worker_noise_rto[worker], num_classes, mu=3.0, sigma=1.0)
            elif worker_noise_ptn[worker] == 'bimodal_mix_C':
                C = bimodal_mix_C(worker_noise_rto[worker], num_classes, mu1=3.0, sigma1=1.0, mu2=7.0, sigma2=0.5)	
            matrix[worker] = C
    return matrix
def gen_noisy_labels(original_data, num_workers, worker_noise_ptn, worker_noise_rto, empty_prop):
	# original_data has the true labels of the training samples
    classes = []
    for c in original_data:
        if c not in classes:
            classes.append(c)
    classes.sort()
    print("num classes:", len(classes))
    
    num_classes = len(classes)
    matrix = gen_conf_matrix(num_classes, num_workers, worker_noise_ptn, worker_noise_rto)
    if matrix is None:
        return None
    
    noisy_labels = np.zeros((original_data.shape[0],num_workers))
    for i in range(original_data.shape[0]):
        for j in range(num_workers):
            cm = matrix[j]
            noisy_labels[i][j] = np.random.choice(num_classes, p=cm[int(original_data[i])], replace=False)
            
        if num_workers > 1:
            empty_idx = np.random.choice(num_workers, int(empty_prop*num_workers))
            noisy_labels[i][empty_idx] = -1
    return noisy_labels


def asy_noise(dataset, noise_ratio, y_train):
    
    y_train_clean = np.copy(y_train)
    if dataset == 'mnist':
    # 1 < - 7, 2 -> 7, 3 -> 8, 5 <-> 6
        source_class = [7, 2, 3, 5, 6]
        target_class = [1, 7, 8, 6, 5]
    elif dataset == 'cifar-10':
    # automobile < - truck, bird -> airplane, cat <-> dog, deer -> horse
        source_class = [9, 2, 3, 5, 4]
        target_class = [1, 0, 5, 3, 7]
    
    for s, t in zip(source_class, target_class):
        
        cls_idx = np.where(y_train_clean == s)[0]
        n_noisy = int(noise_ratio * cls_idx.shape[0] / 100)
        noisy_sample_index = np.random.choice(cls_idx, n_noisy, replace=False)
        y_train[noisy_sample_index] = t
    
    return y_train #, y_train_clean



def other_class(n_classes, current_class):
    """
    Returns a list of class indices excluding the class indexed by class_ind
    :param nb_classes: number of classes in the task
    :param class_ind: the class index to be omitted
    :return: one random class that != class_ind
    """
    if current_class < 0 or current_class >= n_classes:
        error_str = "class_ind must be within the range (0, nb_classes - 1)"
        raise ValueError(error_str)

    other_class_list = list(range(n_classes))
    other_class_list.remove(current_class)
    other_class = np.random.choice(other_class_list)
    return other_class


def truncnorm_noise(X, Y_train, noise_ratio, n_classes, mu, sigma):
    noise_ratio = noise_ratio * 100
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio >= 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        lower = 0
        upper = n_classes
        tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
        for idx in noisy_idx:
            i = Y_train[idx]
            a = int(tn.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn.rvs())
            new_Ytrain[idx] = a
    return X, new_Ytrain

def bimodal_noise(X, Y_train, noise_ratio, n_classes, mu_1, sigma_1, mu_2, sigma_2):
    noise_ratio = noise_ratio * 100
    new_Ytrain = copy.copy(Y_train)
    if noise_ratio >= 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        
        n_idx_1 = noisy_idx[:int(len(noisy_idx)/2)]
        n_idx_2 = noisy_idx[int(len(noisy_idx)/2):]

        lower = 0
        upper = n_classes

        tn_1 = truncnorm((lower - mu_1) / sigma_1, (upper - mu_1) / sigma_1, loc=mu_1, scale=sigma_1)
        tn_2 = truncnorm((lower - mu_2) / sigma_2, (upper - mu_2) / sigma_2, loc=mu_2, scale=sigma_2)
        
        for idx in n_idx_1: # first peak
            i = Y_train[idx]
            a = int(tn_1.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn_1.rvs())
            new_Ytrain[idx] = a
        
        for idx in n_idx_2: # second peak
            i = Y_train[idx]
            a = int(tn_2.rvs()) # this generates numbers between [lower, upper-1]
            while a == i:
                a = int(tn_2.rvs())
            new_Ytrain[idx] = a

    return X, new_Ytrain


def inject_noise_uniform(X_train, Y_train, noise_ratio=30, n_classes=10):
    if noise_ratio > 0:
        n_samples = Y_train.shape[0]
        n_noisy = int(noise_ratio * n_samples / 100)
        noisy_idx = np.random.choice(n_samples, n_noisy, replace=False)
        for i in noisy_idx:
            Y_train[i] = other_class(n_classes=n_classes, current_class=Y_train[i])

    return X_train, Y_train


class GAN():
    def __init__(self):
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.num_worker = 6
        self.num_class = 10
        self.img_shape = (self.channels, self.img_rows, self.img_cols)
        self.input_dis = (self.num_worker+1,self.num_class)
        
        #optimizer = Adam(0.0002, 0.5)
        optimizer = Adam(1e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
        sgd = SGD(lr=0.01, momentum=0.9, decay=1e-4)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator(input_shape=self.input_dis)
        self.discriminator.compile(loss='categorical_crossentropy',
                                   optimizer=sgd,
                                   metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator(input_shape=self.img_shape)
        #self.generator = create_model(input_shape=self.img_shape, classes=10, name='ResNet110', architecture='ResNet110')
        # The generator takes noise as input and generates imgs
        img_input = Input(shape=self.img_shape)
        #var = self.generator(img_input).add(Flatten())
        #a = Input(shape=(1, self.num_class))
        #batch_size = K.shape(a)[0]
        
        #constant = 
        dis1_input =  tf.reshape(self.generator(img_input), (-1, 1, self.num_class))
        dis2_input = Input(shape=(self.num_worker,self.num_class))
        print(self.generator(img_input).shape[0], dis1_input.shape, dis2_input.shape)
        merged = concatenate([dis1_input, dis2_input], axis=1)
        # For the combined model we will only train the generator
        #self.discriminator.trainable = False
        # print("dim", dis_input.shape)
        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(merged)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        #self.combined = Model([img_input, dis2_input], validity)
        #self.combined.compile(loss='categorical_crossentropy', optimizer=sgd)
        self.generator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

    def build_generator(self, input_shape=None):
        model = Sequential()
        # block 1
        model.add(Conv2D(64, (3, 3), padding='same', kernel_initializer='random_uniform', name='block1_conv1', input_shape=input_shape))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(64, (3, 3), padding='same', kernel_initializer='random_uniform', name='block1_conv2'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool'))
        model.add(Dropout(0.3))
        # block 2
        model.add(Conv2D(128, (3, 3), padding='same', kernel_initializer='random_uniform', name='block2_conv1'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(128, (3, 3), padding='same', kernel_initializer='random_uniform', name='block2_conv2'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool'))
        model.add(Dropout(0.3))
        # block 3
        model.add(Conv2D(196, (3, 3), padding='same', kernel_initializer='random_uniform', name='block3_conv1'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(196, (3, 3), padding='same', kernel_initializer='random_uniform', name='block3_conv2'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool'))
        model.add(Dropout(0.3))

        model.add(Flatten(name='flatten'))

        model.add(
            Dense(256, kernel_initializer='random_uniform', kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01),
                  name='fc1'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))

        model.add(Dense(10, kernel_initializer='random_uniform'))
        model.add(Activation('softmax'))


        return model

    def build_discriminator(self, input_shape=None):

        model = Sequential()

        #model.add(Flatten(name='flatten'))
        model.add(Flatten(name='flatten', input_shape=self.input_dis))
        model.add(Dense(256, input_dim=20, kernel_initializer='random_uniform'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512, kernel_initializer='random_uniform'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(10, activation='softmax', kernel_initializer='random_uniform'))
        # model.summary()

        return model

    def train(self, epochs, batch_size=128, sample_interval=50):

        (X_train_t, y_train_t), (X_test_t, y_test_t) = cifar10.load_data()
        
               
        #imp_idx = np.random.choice(50000, 5000)
        #idx = np.arange(len(X_train_t))
        #noImp_idx = np.setdiff1d(idx, imp_idx)
        

        X_train_t = X_train_t.astype('float')
        X_test_t = X_test_t.astype('float')
        

        X_train_t = X_train_t / 255.0
        X_test_t = X_test_t / 255.0
        
        means = X_train_t.mean(axis=0)
        
        X_train_t = X_train_t - means
        X_test_t = X_test_t - means
        mu = 0.5
        var = 0.4
        
        
        
        Y_s_copy = copy.copy(y_train_t)
        num_worker = 5
        worker_noise_ptn = ['unif', 'unif', 'unif','unif', 'unif']
        worker_noise_rto = [0.5, 0.5, 0.5, 0.5, 0.5]
        empty_prop = 0
        num_classes = 10
        #X_s, Y_s = inject_noise_mixture(X_train_t, Y_s_copy, 30, 20, 0.5, 10)
        #X_s, Y_s = inject_noise_mixture_nn(X_train_t, Y_s_copy, 30, 20, 0.5, 10, 0.25, 0.2, 0.75, 0.2) 
        #X_s, Y_s = inject_noise_guassian(X_train_t, Y_s_copy, 30, 10, mu, var)
        #X_s, Y_s = truncnorm_noise(X_train_t, Y_s_copy, 0.4, 10, 1, 0.5)
        X_s = X_train_t
        Y_s = gen_noisy_labels(Y_s_copy, num_worker, worker_noise_ptn, worker_noise_rto, empty_prop)
        #print(Y_s.shape)
        
        #Y_s = asy_noise("cifar-10", 40, Y_s_copy)
        #X_s, Y_s = mixture_noise(X_train_t, Y_s_copy, 70, 10, 10, 0.5)
        #Y_s_copy_test = copy.copy(y_test_t)
        #X_s_test, Y_s_test = inject_noise_mixture(X_test_t, Y_s_copy_test, 30, 20, 0.5, 10)
        #X_s_test, Y_s_test = inject_noise_mixture_nn(X_test_t, Y_s_copy_test, 30, 20, 0.5, 10, 0.25, 0.2, 0.75, 0.2)
        #X_s_test, Y_s_test = inject_noise_guassian(X_test_t, Y_s_copy_test, 30, 10, mu, var)
        #X_s_test, Y_s_test = truncnorm_noise(X_test_t, Y_s_copy_test, 0.4, 10, 1, 0.5)
        X_s_test, Y_s_test = X_test_t, y_test_t
        #X_s_test = X_test_t
        #Y_s_test = asy_noise("cifar-10", 40, Y_s_copy_test)
        #X_s_test, Y_s_test = mixture_noise(X_test_t, Y_s_copy_test, 70, 10, 10, 0.5)
        
        
        
        X_train = X_s
        y_train = np.concatenate((Y_s,y_train_t), axis=1)
        #print(Y_sample_noisy)
        #print(y_train.shape)
        
        #Y_sample_copy = copy.copy(Y_sample)
        mu = 0.5
        var = 0.4
        
        no_hot_y_train = y_train
        print(no_hot_y_train.shape)

        #print(len(X_train_t), len(X_test))
        y_train_t = np_utils.to_categorical(y_train_t)
        #y_train = np_utils.to_categorical(y_train)
        #Y_infc = np_utils.to_categorical(Y_infc)
        #y_train_ct = np_utils.to_categorical(y_train_ct)
        #Y_eval = np_utils.to_categorical(Y_eval)
        y_train = np_utils.to_categorical(y_train)
        y_test_t = np_utils.to_categorical(y_test_t)
        
        #print(y_train.shape)
        #print(y_train)
        img_row = 32
        img_cols = 32
        channels = 3
        img_shape = (channels, img_row, img_cols)

        gold_frac = 0.1
        num_class = 10
        #X_train_gt = X_train[:int(gold_frac*len(X_train))]
        #y_train_gt = y_train[:int(gold_frac*len(X_train))]
        #y_train_tgt = y_train_t[:int(gold_frac*len(X_train))]

        start_index = 0
        
        c_est = np.zeros((num_worker+1, num_class, num_class))
        sgd = SGD(lr=0.01, momentum=0.9, decay=1e-4)
        #Estimating multiple CM
        for i in range(num_worker+1):
            print("Worker: ", i)
            idx_empty = np.where(no_hot_y_train[:,i] == -1)
            idx_no_empty = np.setdiff1d(np.arange(len(no_hot_y_train)),idx_empty)
            print(len(idx_no_empty))
            
            #for row in range(len(y_train)):
                
            x_w_train = X_train[idx_no_empty]
            y_w_train = y_train[idx_no_empty,i,:]
            
            #x_w_train = X_train
            #y_w_train = y_train[:,i,:]
            
            model = self.build_generator(input_shape=img_shape)
            model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
            datagen = ImageDataGenerator(width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True)
            datagen.fit(x_w_train)
            model.fit_generator(datagen.flow(x_w_train, y_w_train, batch_size=batch_size), steps_per_epoch=len(x_w_train) / batch_size, epochs=120, validation_data=(X_test_t, y_test_t),verbose=1)
            indices = np.arange(50000)
            shuffled_y_train = y_train_t
            shuffled_y_train = np.argmax(shuffled_y_train, axis=1)
            #print(shuffled_y_train.shape, shuffled_y_train[:5])
            while(len(set(shuffled_y_train[:int(gold_frac*50000)])) < num_classes):
                np.random.shuffle(indices)
                shuffled_y_train = y_train_t[indices]
            shuffled_index = indices
            shuffled_x_train = X_train_t[shuffled_index]
            #shuffled_x_train = x_w_train[shuffled_index]
            shuffled_index = shuffled_index[:int(gold_frac*50000)]
            probs = model.predict(shuffled_x_train)
            c_hat = np.zeros((num_classes, num_classes))
            for label in range(num_classes):
                indices = np.arange(len(shuffled_y_train))[np.isclose(shuffled_y_train, label)]
                c_hat[label] = np.mean(probs[indices], axis=0, keepdims=True)
            
            c_est[i, :, :] = c_hat
            c_hat_T = c_hat.T.astype(np.float32)
                    
        #Training with Ground truth
        print(c_est)
        
        X_train_gt = X_train[shuffled_index]
        y_train_gt = y_train[shuffled_index]
        y_train_tgt = y_train_t[shuffled_index]
        resIndex = np.setdiff1d(np.arange(len(X_train_t)),shuffled_index)
        
        X_train_r = X_train[resIndex]
        y_train_r = y_train[resIndex]
        y_train_rt = y_train_t[resIndex]
        
        number_batch = int(X_train_gt.shape[0] / batch_size)
        

        epoch_ph_1 = 100
        for epoch in range(epoch_ph_1):
            
            # Learning Rate Scheduler:
            if epoch == 79:
                LR = 0.001
                sgd = SGD(lr=LR, momentum=0.9, decay=1e-4)
                self.discriminator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
                self.generator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
            elif epoch == 39:
                LR = 0.01
                sgd = SGD(lr=LR, momentum=0.9, decay=1e-4)
                self.discriminator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
                self.generator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

            
            
            print("epoch", epoch)
          
            # Select a random batch of images
            for i in range(number_batch):

                end_index = start_index + min(len(X_train_gt) - start_index, batch_size)
                current_Xtrain = X_train_gt[start_index:end_index]
                current_Ytrain = y_train_gt[start_index:end_index]
                current_Ytrain_t = y_train_tgt[start_index:end_index]

                # noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # Generate a batch of new images
                # (current_Xtrain.shape)
                gen_predict = self.generator.predict(current_Xtrain)
                #print(gen_predict)
                gen_predict = gen_predict.reshape((-1, 1, num_class))
                #print(gen_predict.shape, current_Ytrain.shape)
                input_discriminator = np.concatenate((gen_predict, current_Ytrain), axis=1)
                #label_discriminator = np.concatenate((current_Ytrain_t, current_Ytrain_t), axis=1)
                #print(input_discriminator.shape, label_discriminator.shape)
                clean_label_p = self.discriminator.predict(input_discriminator)
                #clean_label_p = clean_label_p.argmax(axis=-1)
                #clean_label_p = np_utils.to_categorical(clean_label_p, num_classes=10)
                d_loss = self.discriminator.train_on_batch(input_discriminator, current_Ytrain_t)
                
                # ---------------------
                #  Train Generator
                # ---------------------

                # Train the generator (to have the discriminator label samples as valid)

                g_loss = self.generator.train_on_batch(current_Xtrain, clean_label_p)
                #g_loss = self.generator.train_on_batch(current_Xtrain,current_Ytrain_t)
                #combined_loss = self.combined.train_on_batch([current_Xtrain, current_Ytrain], current_Ytrain_t)
                start_index = end_index
                if end_index > len(X_train_gt) - 1:
                    start_index = 0
            
            #predicted_label = self.generator.predict(X_test_t)
            #combined_input_D = np.concatenate((predicted_label, y_test_t), axis=1)
            #estimated_label = self.discriminator.predict(combined_input_D)
            #correct_estimition = 0
            #real_noise = 0
            
            #for i in range(len(y_test_t)):
            
            #    estimated = np.argmax(estimated_label[i])
            #    truth = np.argmax(y_test_t[i])
            #    given_y = np.argmax(y_test_t[i])
            
            #    if estimated == truth:
            #        correct_estimition += 1
            
            #    if given_y == truth:
            #        real_noise += 1
                

            #correction_rate = correct_estimition/len(y_test_t)
            #noise_rate = real_noise/len(y_test_t)
            #print("============")
            #print(correction_rate)
            #print(noise_rate)
            
            y_test_tt = np.argmax(y_test_t, axis=1)
            estimated_label = self.generator.predict(X_test_t)
            estimated_label = np.argmax(estimated_label, axis=1)
            print(metrics.classification_report(y_test_tt, estimated_label, digits=3))
        
        
            cm = metrics.confusion_matrix(y_test_tt, estimated_label)
            cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
        
            print(cm.diagonal())
            print("ACC:", np.mean(cm.diagonal()))


        #candidate = np.zeros(num_worker+1)
        #y_gt = np.zeros((len(no_hot_y_train[resIndex]), 1))
        #for i in range(len(no_hot_y_train[resIndex])):
        #    candidate = no_hot_y_train[i, :]
        #    p_candidate = np.zeros(num_worker+1)
        #    for k in range(num_worker+1):
        #        #print(candidate[k])
        #        p_candidate[k] = c_est[k, int(candidate[k]), int(candidate[k])]
        #    y_gt[i] = candidate[np.argmax(p_candidate)]
        
        y = np.argmax(y_train_rt, axis=1)
        classes = np.arange(num_class)
        predict_res = []
        noisy_y = no_hot_y_train[resIndex]
        for i in range(len(noisy_y)):
            pp = [1.0 for c in range(num_class)]
            for t in range(num_class):
                for j in range(num_worker):
                    cur = noisy_y[i][j]
                    if cur != -1:
                        pp[t] = pp[t] * c_est[j, t, int(cur)]
            predict_res.append(classes[pp.index(max(pp))])
        
        print('ACC: ', metrics.accuracy_score(y, predict_res))
        print('Error: ', 1 - metrics.accuracy_score(y, predict_res))
        
        y_gt = np.asarray(predict_res)
        y_gt = np_utils.to_categorical(y_gt)                     
        print(y_gt.shape) 
        number_batch = int(X_train_r.shape[0] / batch_size)
        #Training with noisy labels
        for epoch in range(epochs):
            
            # Learning Rate Scheduler:
            if epoch == 79:
                LR = 0.001
                sgd = SGD(lr=LR, momentum=0.9, decay=1e-4)
                self.discriminator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
                self.generator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
            elif epoch == 39:
                LR = 0.01
                sgd = SGD(lr=LR, momentum=0.9, decay=1e-4)
                self.discriminator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
                self.generator.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

            
            
            print("epoch", epoch)

            # Select a random batch of images
            for i in range(number_batch):

                end_index = start_index + min(len(X_train_r) - start_index, batch_size)
                # idx = np.random.randint(0, X_train.shape[0], batch_size)
                # imgs = X_train[idx]
                current_Xtrain = X_train_r[start_index:end_index]
                current_Ytrain = y_train_r[start_index:end_index]
                current_Ytrain_t = y_gt[start_index:end_index]

                # noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # Generate a batch of new images
                # (current_Xtrain.shape)
                gen_predict = self.generator.predict(current_Xtrain)
                #print(gen_predict)
                gen_predict = gen_predict.reshape((-1, 1, num_class))
                #print(gen_predict.shape, current_Ytrain.shape)
                input_discriminator = np.concatenate((gen_predict, current_Ytrain), axis=1)
                #print(input_discriminator.shape, label_discriminator.shape)
                clean_label_p = self.discriminator.predict(input_discriminator)
                #clean_label_p = clean_label_p.argmax(axis=-1)
                #clean_label_p = np_utils.to_categorical(clean_label_p, num_classes=10)
                d_loss = self.discriminator.train_on_batch(input_discriminator, current_Ytrain_t)
                
                # ---------------------
                #  Train Generator
                # ---------------------

                # Train the generator (to have the discriminator label samples as valid)

                g_loss = self.generator.train_on_batch(current_Xtrain, clean_label_p)
                #g_loss = self.generator.train_on_batch(current_Xtrain,current_Ytrain_t)
                #combined_loss = self.combined.train_on_batch([current_Xtrain, current_Ytrain], current_Ytrain_t)
                start_index = end_index
                if end_index > len(X_train_r) - 1:
                    start_index = 0
            
            
            
            y_test_tt = np.argmax(y_test_t, axis=1)
            estimated_label = self.generator.predict(X_test_t)
            estimated_label = np.argmax(estimated_label, axis=1)
            print(metrics.classification_report(y_test_tt, estimated_label, digits=3))
        
        
            cm = metrics.confusion_matrix(y_test_tt, estimated_label)
            cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
        
            print(cm.diagonal())
            print("ACC:", np.mean(cm.diagonal()))
        
            


       
            
        self.generator.save('./G', include_optimizer=False)
        self.discriminator.save('./D', include_optimizer=False)
        
        

        #print("dim of estima:", y_estimated.shape) 
        #print("dim of estima:", X_infc.shape) 
        #y_testi = y_train[:,10:]
        #print("dim of estima:", y_testi) 
        #######defining new DNN
           
        

if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=120, batch_size=128, sample_interval=200)
