import numpy as np
import pandas as pd
import os
import sys
from util import UserPerformance
from util import load_answers, load_user_info, find_valid_users, read_decisions, get_condition_users
from util import calc_ATI_scale, calc_familiarity, calc_propensity_to_trust, load_debugging_performance
from util import find_valid_users, load_answers, read_decisions, get_assessment, calc_user_reliance_measures
from scipy.stats import wilcoxon, kruskal, mannwhitneyu


p1_tasks = ["p1-{}".format(i) for i in range(10)]
p2_tasks = ["p2-{}".format(i) for i in range(10)]

# def post_hoc_comparison(data_list_1, data_list_2, name1, name2):
# 	print("Use pots-hoc analysis")
# 	statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='less')
# 	if pvalue < (0.05 / 3):
# 		print("Alternative {} < {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
# 		# print("Alternative {} <> {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
# 		# statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='greater')
# 		# if pvalue < (0.05 / 3):
# 		# 	print("Alternative {} > {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
# 		# statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='less')
# 		# if pvalue < (0.05 / 3):
# 		# 	print("Alternative {} < {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
# 	else:
# 		print("No significant difference with post-hoc analysis")

def post_hoc_comparison(data_list_1, data_list_2, name1, name2):
	print("Use pots-hoc analysis")
	threshold = 0.05 / 3
	flag = False
	statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='greater')
	if pvalue < threshold:
		print("Alternative {} > {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
		flag = True
	statistic, pvalue = mannwhitneyu(data_list_1, data_list_2, alternative='less')
	if pvalue < threshold:
		print("Alternative {} < {},".format(name1, name2), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
		flag = True
	if not flag:
		print("No significant difference with post-hoc analysis")

def test_wilcoxon(list_1, list_2, abstract_value=False, hypo='greater'):
	assert len(list_1) == len(list_2)
	if not abstract_value:
		res = wilcoxon(x=list_1, y=list_2, alternative=hypo)
		print(res)
		abs_list_1 = list_1
		abs_list_2 = list_2
	else:
		print("With abstract value, the results are:")
		abs_list_1 = [abs(x) for x in list_1]
		abs_list_2 = [abs(x) for x in list_2]
		# print(len(abs_list_1), len(abs_list_2))
		res = wilcoxon(x=abs_list_1, y=abs_list_2, alternative=hypo)
		print("abstract_res", res)
	pvalue = res[1]
	if pvalue < 0.05  / 3:
		post_hoc_comparison(abs_list_1, abs_list_2, "before tutorial", "after tutorial")
	# print("%.3f\t%.3f"%(np.mean(abs_list_1), np.mean(abs_list_2)))
	print("Mean: M(first):{:.3f}, SD(first):{:.3f}".format(np.mean(abs_list_1), np.std(abs_list_1)))
	print("Mean: M(second):{:.3f}, SD(second):{:.3f}".format(np.mean(abs_list_2), np.std(abs_list_2)))
	print("-" * 17)

def compare_self_assessment(users_with_tutorial, user_batch_order, user_task_dict, user_flip_dict, answer_dict, self_assessment_first, self_assessment_second):
	miscalibration_list_first = []
	miscalibration_list_second = []
	for user in users_with_tutorial:
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance, \
			relative_positive_self_reliance = calc_user_reliance_measures(user, user_task_dict, user_flip_dict, answer_dict, first_batch)
		miscalibration_first = self_assessment_first[user] - tp_correct
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance, \
			relative_positive_self_reliance = calc_user_reliance_measures(user, user_task_dict, user_flip_dict, answer_dict, second_batch)
		miscalibration_second = self_assessment_second[user] - tp_correct
		miscalibration_list_first.append(miscalibration_first)
		miscalibration_list_second.append(miscalibration_second)
	print(f"For H1, we have {len(miscalibration_list_first)} participants with tutorial")
	# print("Self-assessment miscalibration - abstrct")
	print("MTP miscalibration, hypo: first batch > second batch")
	# According to assuption, the miscalibration will decrease, so the first batch is expected to be greater than the second batch
	test_wilcoxon(miscalibration_list_first, miscalibration_list_second, abstract_value=True, hypo='greater')
	# print("Self-assessment miscalibration")
	# test_wilcoxon(miscalibration_list_first, miscalibration_list_second)


def compare_ai_asssessment(users_with_tutorial, ai_assessment_first, ai_assessment_second):
	miscalibration_list_first = []
	miscalibration_list_second = []
	for user in users_with_tutorial:
		# tp_performance.print_information()
		miscalibration_first = ai_assessment_first[user] - 8
		miscalibration_second = ai_assessment_second[user] - 8
		miscalibration_list_first.append(miscalibration_first)
		miscalibration_list_second.append(miscalibration_second)

	print(f"For H1, we have {len(miscalibration_list_first)} participants with tutorial")
	# print("AI-assessment miscalibration - abstrct")
	print("MAP miscalibration, hypo: first batch > second batch")
	# According to assuption, the miscalibration will decrease, so the first batch is expected to be greater than the second batch
	test_wilcoxon(miscalibration_list_first, miscalibration_list_second, abstract_value=True, hypo='greater')
	# print("AI-assessment miscalibration")
	# test_wilcoxon(miscalibration_list_first, miscalibration_list_second)


def compare_confident_correct_dicisions(users_with_tutorial, user_batch_order, user_task_dict, user_confidence_dict, answer_dict):
	confident_correct_dicisions_list_first = []
	confident_correct_dicisions_list_second = []
	for user in users_with_tutorial:
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_correct_confident = 0
		for task in first_batch:
			if user_confidence_dict[user][(task, "advice")] > 1 and user_task_dict[user][(task, "advice")] == answer_dict[task]:
				tp_correct_confident += 1
		confident_correct_dicisions_list_first.append(tp_correct_confident)

		tp_correct_confident = 0
		for task in second_batch:
			# print(user_confidence_dict[user][(task, "advice")])
			# print(user_task_dict[user][(task, "advice")])
			# print(answer_dict[task])
			if user_confidence_dict[user][(task, "advice")] > 1 and user_task_dict[user][(task, "advice")] == answer_dict[task]:
				tp_correct_confident += 1
		confident_correct_dicisions_list_second.append(tp_correct_confident)

	print(f"For H1, we have {len(users_with_tutorial)} participants with tutorial")
	# print("Confident correct decisions")
	print("CCD, hypo: first batch < second batch")
	# print(confident_correct_dicisions_list_first)
	# print(confident_correct_dicisions_list_second)
	# According to assumption, the second batch performance are supposed to be better than the first batch
	test_wilcoxon(confident_correct_dicisions_list_first, confident_correct_dicisions_list_second, abstract_value=False, hypo='less')

def obtain_all_performance(users, user_condition_dict, user_batch_order, usertask_dict, user_flip_dict, answer_dict):
	user2performance = {}
	for user in users:
		tp_condition = user_condition_dict[user]
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_performance = UserPerformance(username=user)
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, first_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="first_batch")
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, second_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="second_batch")

		# user_performance_list.append(tp_performance)
		user2performance[user] = tp_performance
	return user2performance

def address_outlier(users_with_tutorial, ai_assessment_first, ai_assessment_second, self_assessment_first, self_assessment_second):
	outlier_set_first = set()
	outlier_set_second = set()
	for user in users_with_tutorial:
		if ai_assessment_first[user] < 5 and self_assessment_first[user] < 5:
			outlier_set_first.add(user)
		if ai_assessment_second[user] < 5 and self_assessment_second[user] < 5:
			outlier_set_second.add(user)
	print(len(outlier_set_first), len(outlier_set_second))
	print(len(outlier_set_first | outlier_set_second))
	return outlier_set_first | outlier_set_second

def get_agreement_filter(user2performance):
	user2agreement = {}
	user2agreement_1= {}
	outlier_agreement = set()
	for user in user2performance:
		agg_1 = user2performance[user].performance["first_batch"]["agreement_fraction"]
		agg_2 = user2performance[user].performance["second_batch"]["agreement_fraction"]
		if agg_1 == 1.0 or agg_2 == 1.0:
			outlier_agreement.add(user)
	return outlier_agreement

if __name__ == "__main__":
	# load answer dict
	answer_dict_p1 = load_answers("reviews_p1.json")
	answer_dict_p2 = load_answers("reviews_p2.json")
	answer_dict = {}
	for task_id in answer_dict_p1:
		answer_dict["p1-{}".format(task_id)] = answer_dict_p1[task_id]
	for task_id in answer_dict_p2:
		answer_dict["p2-{}".format(task_id)] = answer_dict_p2[task_id]

	valid_users, approved_users = find_valid_users()
	user_condition_dict, user_batch_order = load_user_info(reserved_users=valid_users)
	
	user_task_dict, user_confidence_dict, user_flip_dict = read_decisions(user_batch_order, reserved_users=valid_users)
	condition_users = get_condition_users(user_condition_dict)
	users_dict = {
		"Debugging-R": valid_users & condition_users[0],
		"Debugging-D": valid_users & condition_users[1],
		"Debugging-I": valid_users & condition_users[2],
		"Control"    : valid_users & condition_users[3]
	}

	for condition in users_dict:
		print("Condition {} have {} valid participants".format(condition, len(valid_users & users_dict[condition])))
	print("-" * 34)
	
	users_with_tutorial = set()
	for user in valid_users:
		if user_condition_dict[user] == 3:
			continue
		users_with_tutorial.add(user)
	if len(users_with_tutorial) == 0:
		print("no user valid for H1 analysis")
		sys.exit(-1)
	ai_assessment_first, ai_assessment_second, self_assessment_first, self_assessment_second = get_assessment(reserved_users=valid_users)

	# outlier_set = address_outlier(users_with_tutorial, ai_assessment_first, ai_assessment_second, self_assessment_first, self_assessment_second)
	# sys.exit(-1)

	# users_with_tutorial = users_with_tutorial - outlier_set # remove participants with unserious estimation, results are similar

	# user_debugging_performance = load_debugging_performance(valid_users)
	# serious_users = set()
	# for user in users_with_tutorial:
	# 	if user_debugging_performance[user] < 0.7:
	# 		continue
	# 	serious_users.add(user)
	# print(len(serious_users))
	# user2performance = obtain_all_performance(valid_users, user_condition_dict, user_batch_order, user_task_dict, user_flip_dict, answer_dict)
	# outlier_agreement = get_agreement_filter(user2performance)

	compare_ai_asssessment(users_with_tutorial, ai_assessment_first, ai_assessment_second)
	# global level, whether Assessment of AI performance get improved

	compare_self_assessment(users_with_tutorial, user_batch_order, user_task_dict, user_flip_dict, answer_dict, self_assessment_first, self_assessment_second)
	# global level, whether Assessment of team performance get improved

	compare_confident_correct_dicisions(users_with_tutorial, user_batch_order, user_task_dict, user_confidence_dict, answer_dict)
	# instance level, whether users can make more correct decisions with confidence

	print("-" * 34)

	for condition in ["Debugging-R", "Debugging-D", "Debugging-I"]:
		tp_users = users_dict[condition]
		print(condition)
		compare_ai_asssessment(tp_users, ai_assessment_first, ai_assessment_second)
		# global level, whether Assessment of AI performance get improved

		compare_self_assessment(tp_users, user_batch_order, user_task_dict, user_flip_dict, answer_dict, self_assessment_first, self_assessment_second)
		# global level, whether Assessment of team performance get improved

		compare_confident_correct_dicisions(tp_users, user_batch_order, user_task_dict, user_confidence_dict, answer_dict)
		# instance level, whether users can make more correct decisions with confidence

	# 	print("-" * 34)
