import numpy as np
import copy
import mjVoting as mv

def initPara(L, labelDomain):
	K = len(L[0])
	J = len(L)
	result = mv.majorityVoting(L)
	T = []
	for c in result:
		T.append(labelDomain.index(c))
	tL = [[None for k in range(K)]for j in range(J)]
	for j in range(J):
		for k in range(K):
			if L[j][k] != None and L[j][k] != -1:
				tL[j][k] = labelDomain.index(L[j][k])
	countTKI = np.zeros([len(labelDomain), K, len(labelDomain)], dtype=int)
	countTKI = countTKI.tolist()
	for j in range(J):
		for k in range(K):
			if tL[j][k] != None and tL[j][k] != -1:
				countTKI[T[j]][k][tL[j][k]] += 1
	return tL, T, countTKI

def DSEM2(L, labelDomain,  maxIter, p_prior):
	tL, T, countTKI = initPara(L, labelDomain)
	I = len(L); K = len(L[0]); LD = len(labelDomain)
	PI = np.zeros([LD, K, LD])
	P = [0 for c in range(LD)]
	#P = p_prior
	for iterT in range(maxIter):
		oldT = copy.deepcopy(T)
		#M-step
		for t in range(LD):
			for k in range(K):
				s = float(sum(countTKI[t][k][:]))
				if s != 0:
					for i in range(LD):
						PI[t][k][i] = countTKI[t][k][i] / s
		for t in range(LD):
		 	P[t] = T.count(t) / float(I)
		#E-step
		for i in range(I):
			for k in range(K):
				if tL[i][k] != None and tL[i][k] != -1:
					countTKI[T[i]][k][tL[i][k]] -= 1
			pList = [0 for c in range(LD)]
			for q in range(LD):
				s = P[q]
				#s = P[i][q]
				for k in range(K):
					if tL[i][k] != None and tL[i][k] != -1:
						s = s * PI[q][k][tL[i][k]]
				pList[q] = s
			T[i] = pList.index(max(pList))
			for k in range(K):
				if tL[i][k] != None and tL[i][k] != -1:
					countTKI[T[i]][k][tL[i][k]] += 1
		dnum = 0
		for i in range(I):
			if T[i] != oldT[i]:
				dnum += 1
		if (float(dnum)/I) < 0.0001:
			break
	
	result = []
	for i in range(I):
		result.append(labelDomain[T[i]])
	return result, T

def DSEM(L, labelDomain,  maxIter, p_prior):
	tL, T, countTKI = initPara(L, labelDomain)
	I = len(L); K = len(L[0]); LD = len(labelDomain)
	PI = np.zeros([LD, K, LD])
	#P = [0 for c in range(LD)]
	P = p_prior
	for iterT in range(maxIter):
		oldT = copy.deepcopy(T)
		#M-step
		for t in range(LD):
			for k in range(K):
				s = float(sum(countTKI[t][k][:]))
				if s != 0:
					for i in range(LD):
						PI[t][k][i] = countTKI[t][k][i] / s
		# for t in range(LD):
		# 	P[t] = T.count(t) / float(I)
		#E-step
		for i in range(I):
			for k in range(K):
				if tL[i][k] != None and tL[i][k] != -1:
					countTKI[T[i]][k][tL[i][k]] -= 1
			pList = [0 for c in range(LD)]
			for q in range(LD):
				#s = P[q]
				s = P[i][q]
				for k in range(K):
					if tL[i][k] != None and tL[i][k] != -1:
						s = s * PI[q][k][tL[i][k]]
				pList[q] = s
			T[i] = pList.index(max(pList))
			for k in range(K):
				if tL[i][k] != None and tL[i][k] != -1:
					countTKI[T[i]][k][tL[i][k]] += 1
		dnum = 0
		for i in range(I):
			if T[i] != oldT[i]:
				dnum += 1
		if (float(dnum)/I) < 0.0001:
			break
	
	result = []
	for i in range(I):
		result.append(labelDomain[T[i]])
	return result, T

def Sliding_Win_DSEM(L, labelDomain,  maxIter, PI_b=None, P_b=None, p_prior=None):
	tL, T, countTKI = initPara(L, labelDomain)
	I = len(L); K = len(L[0]); LD = len(labelDomain)
	PI = np.zeros([LD, K, LD])
	# P = [0 for c in range(LD)]
	P = p_prior
	#M-step
	for t in range(LD):
		for k in range(K):
			s = float(sum(countTKI[t][k][:]))
			if s != 0:
				for i in range(LD):
					PI[t][k][i] = countTKI[t][k][i] / s
	# for t in range(LD):
	# 	P[t] = T.count(t) / float(I)
	
	#if not (PI_b is None):
	#	PI = PI_b

	for iterT in range(maxIter):
		oldT = copy.deepcopy(T)
		#E-step
		for i in range(I):
			for k in range(K):
				if tL[i][k] != None:
					countTKI[T[i]][k][tL[i][k]] -= 1
			pList = [0 for c in range(LD)]
			for q in range(LD):
				#s = P[q]
				s = P[i][q]
				for k in range(K):
					if tL[i][k] != None:
						s = s * PI[q][k][tL[i][k]]
				pList[q] = s
			T[i] = pList.index(max(pList))
			for k in range(K):
				if tL[i][k] != None:
					countTKI[T[i]][k][tL[i][k]] += 1
		#M-step
		for t in range(LD):
			for k in range(K):
				s = float(sum(countTKI[t][k][:]))
				if s != 0:
					for i in range(LD):
						PI[t][k][i] = countTKI[t][k][i] / s
		# for t in range(LD):
		# 	P[t] = T.count(t) / float(I)

		dnum = 0
		for i in range(I):
			if T[i] != oldT[i]:
				dnum += 1
		if (float(dnum)/I) < 0.0001:
			break
	result = []
	for i in range(I):
		result.append(labelDomain[T[i]])
	return result, PI, P, T
