import numpy as np
import pandas as pd
import os
import sys
from util import UserPerformance
from util import load_answers, load_user_info, find_valid_users, read_decisions, get_condition_users
from util import calc_ATI_scale, calc_familiarity, calc_propensity_to_trust, load_debugging_performance
from util import find_valid_users, load_answers, read_decisions, get_assessment, calc_user_reliance_measures
from scipy.stats import wilcoxon, kruskal, mannwhitneyu


p1_tasks = ["p1-{}".format(i) for i in range(10)]
p2_tasks = ["p2-{}".format(i) for i in range(10)]

def compare_performance_debugging(users_with_tutorial, user_batch_order, usertask_dict, user_flip_dict, answer_dict):
	user_performance_list = []
	for user in users_with_tutorial:
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_performance = UserPerformance(username=user)
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, first_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="first_batch")
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, second_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="second_batch")

		user_performance_list.append(tp_performance)

	def get_performance_dict(performance_list):
		keys = ["accuracy", "agreement_fraction", "switching_fraction", "relative_positive_ai_reliance", "relative_positive_self_reliance"]
		performance_dict = {}
		for key_ in keys:
			performance_dict[key_] = []
		for tp_performance in performance_list:
			for key_ in keys:
				performance_dict[key_].append(tp_performance.performance["overall"][key_])
		return performance_dict

	def get_mean_overall_performance(performance_list, group_name="with DKE"):
		keys = ["accuracy", "agreement_fraction", "switching_fraction", "relative_positive_ai_reliance", "relative_positive_self_reliance"]
		performance_dict = get_performance_dict(performance_list)
		print(group_name, len(performance_dict["accuracy"]))
		for key_ in keys:
			print(key_, np.mean(performance_dict[key_]))

	# def post_hoc_comparison(data_list_1, data_list_2, name1, name2):
	# 	print("Use pots-hoc analysis")
	# 	statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='two-sided')
	# 	if pvalue < (0.05 / 3):
	# 		print("Alternative {} <> {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
	# 		statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='greater')
	# 		if pvalue < (0.05 / 3):
	# 			print("Alternative {} > {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
	# 		statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='less')
	# 		if pvalue < (0.05 / 3):
	# 			print("Alternative {} < {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
	# 	else:
	# 		print("No significant difference with post-hoc analysis")

	def post_hoc_comparison(data_list_1, data_list_2, name1, name2):
		print("Use pots-hoc analysis")
		threshold = 0.05 / 3
		flag = False
		statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='greater')
		if pvalue < threshold:
			print("Alternative {} > {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
			flag = True
		statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='less')
		if pvalue < threshold:
			print("Alternative {} < {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
			flag = True
		if not flag:
			print("No significant difference with post-hoc analysis")

	def compare_performance(user_performance_list):
		keys = ["accuracy", "agreement_fraction", "switching_fraction", "relative_positive_ai_reliance", "relative_positive_self_reliance"]
		for var_name in keys:
			performance_list_1 = []
			performance_list_2 = []
			for tp_performance in user_performance_list:
				performance_list_1.append(tp_performance.performance["first_batch"][var_name])
				performance_list_2.append(tp_performance.performance["second_batch"][var_name])
			
			# statistic, pvalue = kruskal(performance_list_1, performance_list_2)
			print(var_name)
			print(len(performance_list_1), len(performance_list_2))

			# Accoding to H2, performance get improved after debugging intervention, thus hypo is less
			res = wilcoxon(x=performance_list_1, y=performance_list_2, alternative='less')
			print("Mean: M(first):{:.2f}, SD(first):{:.2f}".format(np.mean(performance_list_1), np.std(performance_list_1)))
			print("Mean: M(second):{:.2f}, SD(second):{:.2f}".format(np.mean(performance_list_2), np.std(performance_list_2)))
			print("wilcoxon results:", res)
			if res[1] < 0.05 /3 or (1 - res[1]) < 0.05 / 3:
				post_hoc_comparison(performance_list_1, performance_list_2, "before debugging", "after debugging")
			print("-" * 17)

	# print(f"For H2, we have {len(user_performance_list)} participants with tutorial")
	print("Compare participants's performance before debugging and after debugging:")
	compare_performance(user_performance_list)
	# print("-" * 34)

def obtain_all_performance(users, user_condition_dict, user_batch_order, usertask_dict, user_flip_dict, answer_dict):
	user2performance = {}
	for user in users:
		tp_condition = user_condition_dict[user]
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_performance = UserPerformance(username=user)
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, first_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="first_batch")
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, second_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="second_batch")

		# user_performance_list.append(tp_performance)
		user2performance[user] = tp_performance
	return user2performance

def get_agreement_filter(user2performance):
	user2agreement = {}
	user2agreement_1= {}
	outlier_agreement = set()
	for user in user2performance:
		agg_1 = user2performance[user].performance["first_batch"]["agreement_fraction"]
		agg_2 = user2performance[user].performance["second_batch"]["agreement_fraction"]
		if agg_1 == 1.0 or agg_2 == 1.0:
			outlier_agreement.add(user)
	return outlier_agreement

if __name__ == "__main__":
	# load answer dict
	answer_dict_p1 = load_answers("reviews_p1.json")
	answer_dict_p2 = load_answers("reviews_p2.json")
	answer_dict = {}
	for task_id in answer_dict_p1:
		answer_dict["p1-{}".format(task_id)] = answer_dict_p1[task_id]
	for task_id in answer_dict_p2:
		answer_dict["p2-{}".format(task_id)] = answer_dict_p2[task_id]

	valid_users, approved_users = find_valid_users()
	user_condition_dict, user_batch_order = load_user_info(reserved_users=valid_users)

	users_with_tutorial = set()
	for user in valid_users:
		if user_condition_dict[user] == 3:
			continue
		users_with_tutorial.add(user)
	if len(users_with_tutorial) == 0:
		print("no user valid for H1 analysis")
		sys.exit(-1)
	
	user_task_dict, user_confidence_dict, user_flip_dict = read_decisions(user_batch_order, reserved_users=valid_users)
	print(f"For H2, we have {len(users_with_tutorial)} participants with tutorial")
	compare_performance_debugging(users_with_tutorial, user_batch_order, user_task_dict, user_flip_dict, answer_dict)

	# user2performance = obtain_all_performance(valid_users, user_condition_dict, user_batch_order, user_task_dict, user_flip_dict, answer_dict)
	# outlier_agreement = get_agreement_filter(user2performance)

	# print(f"For H2, we have {len(users_with_tutorial - outlier_agreement)} participants with tutorial")
	# compare_performance_debugging(users_with_tutorial - outlier_agreement, user_batch_order, user_task_dict, user_flip_dict, answer_dict)

	# user_debugging_performance = load_debugging_performance(valid_users)
	# serious_users = set()
	# for user in users_with_tutorial:
	# 	if user_debugging_performance[user] < 0.6:
	# 		continue
	# 	serious_users.add(user)

	# print(f"For H2, we have {len(serious_users)} serious participants with tutorial")
	# compare_performance_debugging(serious_users, user_batch_order, user_task_dict, user_flip_dict, answer_dict)

	condition_users = get_condition_users(user_condition_dict)
	users_dict = {
		"Debugging-R": valid_users & condition_users[0],
		"Debugging-D": valid_users & condition_users[1],
		"Debugging-I": valid_users & condition_users[2]
	}

	for condition in users_dict:
		print("Condition {} have {} valid participants".format(condition, len(users_dict[condition])))
		print(condition)
		compare_performance_debugging(users_dict[condition], user_batch_order, user_task_dict, user_flip_dict, answer_dict)
	print("-" * 34)




