import sys
import numpy as np
sys.path.append('./code')
import loadData
import mjVoting as mv
import commonFuncs as cf
import DSEM
import expert
import argparse
from keras.utils import np_utils
from resnet import cifar10_resnet 
from keras.preprocessing.image import ImageDataGenerator
from scipy.special import rel_entr
from math import log2
import scipy.stats as stats

def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

def avg_per_batch(y, batch_size):
	number_batch = int(len(y) / batch_size)
	start = 0
	for i in range(number_batch):
		end = start + min(len(y) - start, batch_size)
		avg_mat = y[start:end]
		avg_mat = avg_mat.mean(0)
		y[start:end] = avg_mat

		start = end
		#if end > len(y) - 1:
		#	start = 0
	return y

def run_offline_cifar100(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate
	
	num_class = 100
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset_cifar100(nWorker, 100, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	#model = expert.cnn(img_shape, num_class)
	model = cifar10_resnet(depth=7, num_classes=num_class)
	print("***********************")
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))

	erroList = [0 for i in range(TIMES)]
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	for i in range(TIMES):
		print("Iteration: ", i)
		p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=150, dataset='cifar-100')
		p_prior = avg_per_batch(p_prior, 256)
		result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
		y_train_agg = np.array(result)
		y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
		y_train_agg = y_train_agg.reshape((len(X_train),num_class))

		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))


def run_opt_cifar100(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate
	
	num_class = 100
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset_cifar100(nWorker, 100, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	#model = expert.cnn(img_shape, num_class)
	model = cifar10_resnet(depth=7, num_classes=num_class)
	print("***********************")
	prior_gt = [1.0/num_class for c in range(num_class)]
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))

	erroList = [0 for i in range(TIMES)]
	kl_div_st = []
	kl_div_ts = []
	kl_st = []
	kl_ts = []
	k_th = 0.0
	n_th = 0.1
	n2_th = 0.04
	mu_w = 4.5
	var_w = 2.5

	unit = True
	pre_unit =  unit
	norm_arr = []
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	for i in range(TIMES):
		print("Iteration: ", i)
		if unit:
		#if i % 3 == 0:
			p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=150, dataset='cifar-100')
			p_prior = avg_per_batch(p_prior, 256)
			result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
			p_prior_total = np.average(p_prior, axis=0)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_st[0] + kl_ts[0])
			jsp = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			print("js: ", jsp)
		else:
			result, y_train_agg = DSEM.DSEM2(L, labelDomain, 500, unit)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_st[0] + kl_ts[0])
		pre_unit = unit

		if pre_unit:
			if js >= k_th and stats.norm.pdf(nWorker, mu_w, var_w) >= n_th:
				unit = True
			else:
				unit = False
		else:
			norm_arr.append(stats.norm.pdf(nWorker, mu_w, var_w*10))
			if sum(norm_arr) >= n2_th:
				unit = True
				norm_arr = []
			else:
				unit = False

		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))
	print('KL-divergence softmax || ground truth: ', kl_st)
	print('KL-divergence ground truth || softmax: ', kl_ts)




def run_alg_offline_cifar(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate

	num_class = 10
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset(nWorker, 10, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	model = expert.cnn(img_shape, num_class)
	
	prior_gt = [1.0/num_class for c in range(num_class)]
	#for t in range(num_class):
	#	prior_gt[t] = trueLabels.count(t) / float(len(trueLabels))
	print("***********************")
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	kl_div_st = []
	kl_div_ts = []
	kl_st = []
	kl_ts = []
	erroList = [0 for i in range(TIMES)]
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	k_th = 0.20
	n_th = 0.2
	n2_th = 0.14
	mu_w = 6
	var_w = 1

	unit = True
	pre_unit =  unit
	norm_arr = []
	for i in range(TIMES):
		print("Iteration: ", i)
		
		if unit:
			p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=100)
			p_prior = avg_per_batch(p_prior, 256)
			result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
			p_prior_total = np.average(p_prior, axis=0) #avg_per_batch(p_prior, len(p_prior))
			#p_prior_total = p_prior_total.tolist()
			#print(p_prior_total, len(p_prior_total))
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			#for k in range(len(p_prior)):
				
				#kl_div_st.append(kl_divergence(p_prior[k], prior_gt))
				#kl_div_ts.append(kl_divergence(prior_gt, p_prior[k]))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_st[0] + kl_ts[0])
		else:
			result, y_train_agg = DSEM.DSEM2(L, labelDomain, 500, unit)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_st[0] + kl_ts[0])
		pre_unit = unit

		if pre_unit:
			if js >= k_th and stats.norm.pdf(nWorker, mu_w, var_w) >= n_th:
				unit = True
			else:
				unit = False
		else:
			norm_arr.append(stats.norm.pdf(nWorker, mu_w, var_w*10))
			if sum(norm_arr) >= n2_th:
				unit = True
				norm_arr = []
			else:
				unit = False



			#kl_st.append(sum(kl_div_st) / float(len(kl_div_st)))
			#kl_ts.append(sum(kl_div_ts) / float(len(kl_div_ts)))
		
		#erroList[i] = cf.erroRate(result, trueLabels)

		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))
	print('KL-divergence softmax || ground truth: ', kl_st)
	print('KL-divergence ground truth || softmax: ', kl_ts)

def cross_entropy(pred, target):
	N = pred.shape[0]
	ce = -np.sum(target * np.log(pred))/N
	return ce

def run_dist_alg_cifar10(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate

	num_class = 10
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset(nWorker, 10, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	model = expert.cnn(img_shape, num_class)
	
	prior_gt = [1.0/num_class for c in range(num_class)]
	#for t in range(num_class):
	#	prior_gt[t] = trueLabels.count(t) / float(len(trueLabels))
	print("***********************")
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	kl_div_st = []
	kl_div_ts = []
	kl_st = []
	kl_ts = []
	erroList = [0 for i in range(TIMES)]
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	k_th = 0.20
	n_th = 0.2
	n2_th = 0.14
	mu_w = 6
	var_w = 1

	unit = True
	pre_unit =  unit
	norm_arr = []
	cross_entropy_arr = [0]
	for i in range(TIMES):
		print("Iteration: ", i)
		
		if unit:
			p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=100)
			p_prior = avg_per_batch(p_prior, 256)
			result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
			p_prior_total = np.average(p_prior, axis=0) #avg_per_batch(p_prior, len(p_prior))
			#p_prior_total = p_prior_total.tolist()
			#print(p_prior_total, len(p_prior_total))
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			#for k in range(len(p_prior)):
				
				#kl_div_st.append(kl_divergence(p_prior[k], prior_gt))
				#kl_div_ts.append(kl_divergence(prior_gt, p_prior[k]))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_st[0] + kl_ts[0])
		else:
			result, y_train_agg = DSEM.DSEM2(L, labelDomain, 500, unit)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_st[0] + kl_ts[0])
		

		fx = model.predict(X_train)
		cross_entropy_arr.append(cross_entropy(fx, y_train_agg))

		print("ce: ", cross_entropy_arr)
		if js >= k_th or cross_entropy_arr[-1] - cross_entropy_arr[-2] > 0:
		#if cross_entropy_arr[-1] - cross_entropy_arr[-2] > 0:
			unit = True
			print("unit: ", unit)
		else:
			unit = False
			print("unit: ", unit)
		

			#kl_st.append(sum(kl_div_st) / float(len(kl_div_st)))
			#kl_ts.append(sum(kl_div_ts) / float(len(kl_div_ts)))
		
		#erroList[i] = cf.erroRate(result, trueLabels)

		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('min error rate: ', min(erroList))
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))
	print('KL-divergence softmax || ground truth: ', kl_st)
	print('KL-divergence ground truth || softmax: ', kl_ts)

def run_dist_alg_cifar100(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate
	
	num_class = 100
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset_cifar100(nWorker, 100, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	#model = expert.cnn(img_shape, num_class)
	model = cifar10_resnet(depth=7, num_classes=num_class)
	print("***********************")
	prior_gt = [1.0/num_class for c in range(num_class)]
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))

	erroList = [0 for i in range(TIMES)]
	kl_div_st = []
	kl_div_ts = []
	kl_st = []
	kl_ts = []
	k_th = 0.21
	n_th = 0.1
	n2_th = 0.04
	mu_w = 4.5
	var_w = 2.5

	unit = True
	pre_unit =  unit
	norm_arr = []
	cross_entropy_arr = [0]
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	for i in range(TIMES):
		print("Iteration: ", i)
		if unit:
		#if i % 3 == 0:
			p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=150, dataset='cifar-100')
			p_prior = avg_per_batch(p_prior, 256)
			result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
			p_prior_total = np.average(p_prior, axis=0)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_st[0] + kl_ts[0])
			jsp = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			print("js: ", jsp)
		else:
			result, y_train_agg = DSEM.DSEM2(L, labelDomain, 500, unit)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			js = 0.5 * (kl_divergence(p_prior_total, prior_gt) + kl_divergence(prior_gt, p_prior_total))
			#js = 0.5 * (kl_st[0] + kl_ts[0])
		
		
		fx = model.predict(X_train)
		cross_entropy_arr.append(cross_entropy(fx, y_train_agg))

		print("ce: ", cross_entropy_arr)
		if js >= k_th or cross_entropy_arr[-1] - cross_entropy_arr[-2] > 0:
			unit = True
			print("unit: ", unit)
		else:
			unit = False
			print("unit: ", unit)
		
		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('min error rate: ', min(erroList))
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))
	print('KL-divergence softmax || ground truth: ', kl_st)
	print('KL-divergence ground truth || softmax: ', kl_ts)


def run_opt_offline_cifar(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate

	num_class = 10
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset(nWorker, 10, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	model = expert.cnn(img_shape, num_class)
	
	prior_gt = [0 for c in range(num_class)]
	for t in range(num_class):
		prior_gt[t] = trueLabels.count(t) / float(len(trueLabels))
	print("***********************")
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	kl_div_st = []
	kl_div_ts = []
	kl_st = []
	kl_ts = []
	erroList = [0 for i in range(TIMES)]
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	unit = 1
	for i in range(TIMES):
		if i % unit == 0:
			p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=100)
			p_prior = avg_per_batch(p_prior, 256)
			result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
			p_prior_total = np.average(p_prior, axis=0) #avg_per_batch(p_prior, len(p_prior))
			#p_prior_total = p_prior_total.tolist()
			#print(p_prior_total, len(p_prior_total))
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
			#for k in range(len(p_prior)):
				
				#kl_div_st.append(kl_divergence(p_prior[k], prior_gt))
				#kl_div_ts.append(kl_divergence(prior_gt, p_prior[k]))
			kl_st.append(kl_divergence(p_prior_total, prior_gt))
			kl_ts.append(kl_divergence(prior_gt, p_prior_total))
			#kl_st.append(sum(kl_div_st) / float(len(kl_div_st)))
			#kl_ts.append(sum(kl_div_ts) / float(len(kl_div_ts)))
		else:
			result, y_train_agg = DSEM.DSEM2(L, labelDomain, 500, unit)
			y_train_agg = np.array(result)
			y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
			y_train_agg = y_train_agg.reshape((len(X_train),num_class))
		
		#erroList[i] = cf.erroRate(result, trueLabels)

		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))
	print('KL-divergence softmax || ground truth: ', kl_st)
	print('KL-divergence ground truth || softmax: ', kl_ts)

def run_offline_cifar(worker, noise_pattern, noise_ratio, miss_rate):
	nWorker = worker
	nPattern  = noise_pattern
	nRate = noise_ratio
	mRate = miss_rate

	num_class = 10
	X_train, y_train, y_train_t, X_test_t, y_test_t = expert.load_dataset(nWorker, 10, nPattern, nRate, mRate)
	#L = np.argmax(y_train, axis=1)
	L = y_train.tolist()
	
	#print(L[1])
	trueLabels = np.argmax(y_train_t, axis=1)
	trueLabels = trueLabels.tolist()

	labelDomain = [i for i in range(num_class)]
	img_rows = 32
	img_cols = 32
	img_channels = 3
	#img_shape = (img_rows, img_cols, img_channels)
	img_shape = X_train.shape[1:]
	model = expert.cnn(img_shape, num_class)
	
	print("***********************")
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)
	y_train_agg = np.array(result)
	y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	y_train_agg = y_train_agg.reshape((len(X_train),num_class))

	erroList = [0 for i in range(TIMES)]
	#y_train_agg = model.predict(X_train)
	#y_train_agg = np.argmax(y_train_agg, axis=1)
	#y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
	#y_train_agg = y_train_agg.reshape((len(X_train),num_class))
	#p_prior = expert.train_CNN(X_train, y_train_t, X_test_t, y_test_t, model, epoch=30)
	for i in range(TIMES):
		p_prior = expert.train_CNN(X_train, y_train_agg, X_test_t, y_test_t, model, epoch=100)
		p_prior = avg_per_batch(p_prior, 256)
		result, y_train_agg = DSEM.DSEM(L, labelDomain, 500, p_prior)
		y_train_agg = np.array(result)
		y_train_agg = np_utils.to_categorical(y_train_agg, num_classes=num_class)
		y_train_agg = y_train_agg.reshape((len(X_train),num_class))

		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))

def run_offline(L, trueLabels, labelDomain):
	print("***********************")
	TIMES = 10
	result = mv.majorityVoting(L)
	erroRate = cf.erroRate(result, trueLabels)
	print('mv erro rate is:', erroRate)

	erroList = [0 for i in range(TIMES)]
	for i in range(TIMES):
		result = DSEM.DSEM(L, labelDomain, 500)
		erroList[i] = cf.erroRate(result, trueLabels)
	print('erro rate list', erroList)
	print('offline DSEM avg erro rate is:', sum(erroList)/float(TIMES))

def run_online(L, labelDomain, PI=None, P=None):
	return DSEM.Sliding_Win_DSEM(L, labelDomain, 500, PI, P)

def motivation_comparison(L, trueLabels, labelDomain, first_size, set_size, set_num):
	i = first_size
	print("total samples:", first_size)
	run_offline(L[0:first_size], trueLabels[0:first_size], labelDomain)
	online_result, PI, P = run_online(L[0:first_size], labelDomain)
	print("sliding window DSEM error rate is:", cf.erroRate(online_result, trueLabels[0:first_size]))
	for c in range(set_num):
		print()
		print("total samples:", i + set_size)
		run_offline(L[0:i+set_size], trueLabels[0:i+set_size], labelDomain)
		result, PI, P = run_online(L[i:i+set_size], labelDomain, None, None)
		online_result.extend(result)
		print("sliding window DSEM error rate is:", cf.erroRate(online_result, trueLabels[0:i+set_size]))
		i = i + set_size

def sliding_em(L, trueLabels, labelDomain, first_size, set_size, set_num, worker_num = None):
	if not worker_num is None:
		L = [c[:worker_num] for c in L]
	i = first_size
	print("total samples:", first_size)
	online_result, PI, P = run_online(L[0:first_size], labelDomain)
	print(PI)
	print(P)
	e_result = []
	e = cf.erroRate(online_result, trueLabels[0:first_size])
	print("sliding window DSEM error rate is:", e)
	e_result.append(e)
	for c in range(set_num):
		print()
		print("total samples:", i + set_size)
		result, PI, P = run_online(L[i:i+set_size], labelDomain, PI, P)
		online_result.extend(result)
		e = cf.erroRate(online_result, trueLabels[0:i+set_size])
		print("sliding window DSEM error rate is:", e)
		e_result.append(e)
		i = i + set_size
	if i < len(L):
		print("total samples:", len(L))
		result, PI, P = run_online(L[i:], labelDomain, PI, P)
		online_result.extend(result)
		e = cf.erroRate(online_result, trueLabels)
		print("sliding window DSEM error rate is:", e)
		e_result.append(e)

	print("Sliding EM results", e_result)
	return e_result[-1]

#print('RTE:')
#L, trueLabels, labelDomain = loadData.loadRte()
#motivation_comparison(L, trueLabels, labelDomain, 25, 25, 31)
#motivation_comparison(L, trueLabels, labelDomain, 50, 50, 15)
#motivation_comparison(L, trueLabels, labelDomain, 100, 100, 7)
#motivation_comparison(L, trueLabels, labelDomain, 200, 200, 3)
#motivation_comparison(L, trueLabels, labelDomain, 400, 400, 1)
#motivation_comparison(L, trueLabels, labelDomain, 800, 800, 0)

#labels_url = './data/cifar10/ana2d-cifar10-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt'
#groud_truth = './data/cifar10/cifar10-train-labels.txt'

#labels_url = './data/pendigits/ana2d-pendigits-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt'
#groud_truth = './data/pendigits/pendigits-train-labels.txt'

#L, trueLabels, labelDomain = loadData.loadRte()
#L, trueLabels, labelDomain = loadData.loadAge()
#L, trueLabels, labelDomain = loadData.loadAdult()
#L, trueLabels, labelDomain = loadData.loadHeart()
#print(labelDomain)
#print("pendigits")
#sliding_em(L, trueLabels, labelDomain, 200, 200, 36)
#sliding_em(L, trueLabels, labelDomain, 500, 500, 13)

#print("RTE")
#sliding_em(L, trueLabels, labelDomain, 25, 25, 31)
#sliding_em(L, trueLabels, labelDomain, 50, 50, 15)
#sliding_em(L, trueLabels, labelDomain, 800, 800, 0)

#print("Age")
#sliding_em(L, trueLabels, labelDomain, 50, 50, 19)
#sliding_em(L, trueLabels, labelDomain, 25, 25, 39)
#sliding_em(L, trueLabels, labelDomain, 500, 500, 1)
#run_offline(L, trueLabels, labelDomain)
#run_offline_cifar()

def ana2d_pendigits_www(dataset_name = "pendigits"):
	res = []

	ground_truth_file = './data/%s/%s-train-labels.txt' % (dataset_name, dataset_name)

	# change number of workers
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36, worker_num=6))
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36, worker_num=8))
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36, worker_num=10))

	# change empty prop
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.2.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.3.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36))

	# change noise rate
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.4-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.8-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 36))
	print("ana2d_pendigits_www_res:", res)

def ana2d_cifar10_www(dataset_name = "cifar10"):
	res = []

	ground_truth_file = './data/%s/%s-train-labels.txt' % (dataset_name, dataset_name)

	# change number of workers
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249, worker_num=6))
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249, worker_num=8))
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249, worker_num=10))

	# change empty prop
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.2.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.3.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249))

	# change noise rate
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.4-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.6-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249))
	noisy_labels_file = './data/%s/ana2d-%s-train-noisy-labels-bimodal-noise0.8-empty-prop0.1.txt' % (dataset_name, dataset_name)
	L, trueLabels, labelDomain = loadData.load_all(noisy_labels_file, ground_truth_file)
	res.append(sliding_em(L, trueLabels, labelDomain, 200, 200, 249))
	print("ana2d_cifar10_www_res:", res)

#ana2d_pendigits_www()
#ana2d_cifar10_www()
def main(args):
	print("Config:", args.worker, args.patt_arr, args.ratio, args.miss_rate)
	#run_offline_cifar(args.worker, args.patt_arr, args.ratio, args.miss_rate)
	#run_opt_offline_cifar(args.worker, args.patt_arr, args.ratio, args.miss_rate)
	#run_alg_offline_cifar(args.worker, args.patt_arr, args.ratio, args.miss_rate)
	#run_opt_cifar100(args.worker, args.patt_arr, args.ratio, args.miss_rate)
	run_dist_alg_cifar100(args.worker, args.patt_arr, args.ratio, args.miss_rate)
	#run_offline_cifar100(args.worker, args.patt_arr, args.ratio, args.miss_rate)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-w', '--worker',
        help="number of worker",
        required=True, type=int
    )
    parser.add_argument(
        '-p', '--patt_arr',  nargs='+',
        help="noise pattern: 'unif', 'bimodal', 'flip' ",
        required=True
    )
    parser.add_argument(
        '-r', '--ratio', nargs='+',
        help="noise ratio: 0.3, 0.6 ",
        type=float
    )
    parser.add_argument(
        '-m', '--miss_rate',
        help="missing rate: 0.0, 0.1, 0.3 ",
        required=False, type=float
    )
args = parser.parse_args() 
main(args)