import numpy as np
import pandas as pd
import os
import sys
from util import UserPerformance
from util import load_answers, load_user_info, find_valid_users, read_decisions, get_condition_users, read_nasa_tlx
from util import calc_ATI_scale, calc_familiarity, calc_propensity_to_trust
from util import find_valid_users, load_answers, read_decisions, get_assessment, calc_user_reliance_measures
from scipy.stats import wilcoxon, kruskal, spearmanr, mannwhitneyu

p1_tasks = ["p1-{}".format(i) for i in range(10)]
p2_tasks = ["p2-{}".format(i) for i in range(10)]

def compare_performance_across_conditions(users, user_condition_dict, user_batch_order, usertask_dict, user_flip_dict, answer_dict):
	user_performance_list = []
	condition_performance_dict = {
		0: [],
		1: [],
		2: [],
		3: []
	}
	for user in users:
		tp_condition = user_condition_dict[user]
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_performance = UserPerformance(username=user)
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, first_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="first_batch")
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, second_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="second_batch")

		# user_performance_list.append(tp_performance)
		condition_performance_dict[tp_condition].append(tp_performance)

	def compare_performance(condition_performance_dict):
		keys = ["accuracy", "agreement_fraction", "switching_fraction", "relative_positive_ai_reliance", "relative_positive_self_reliance"]
		for var_name in keys:
			performance_dict = {
				0: [], # "Debugging-R"
				1: [], # "Debugging-D"
				2: [], # "Debugging-I"
				3: []  # "Control"
			}
			for condition in range(4):
				for tp_performance in condition_performance_dict[condition]:
					performance_improvement = tp_performance.performance["second_batch"][var_name] - tp_performance.performance["first_batch"][var_name]
					performance_dict[condition].append(performance_improvement)
			
			# statistic, pvalue = kruskal(performance_list_1, performance_list_2)
			print(var_name)
			for condition, condition_name in enumerate(["Debugging-R", "Debugging-D", "Debugging-I", "Control"]):
				performance_list = performance_dict[condition]
				print("{} Mean: M(first):{:.2f}, SD(first):{:.2f}".format(condition_name, np.mean(performance_list), np.std(performance_list)))

			res = kruskal(performance_dict[0], performance_dict[1], performance_dict[2], performance_dict[3])
			print("Kruskal results:", res)
			print("-" * 17)

	def compare_performance_second_batch(condition_performance_dict):
		keys = ["accuracy", "agreement_fraction", "switching_fraction", "relative_positive_ai_reliance", "relative_positive_self_reliance"]
		for var_name in keys:
			performance_dict = {
				0: [], # "Debugging-R"
				1: [], # "Debugging-D"
				2: [], # "Debugging-I"
				3: []  # "Control"
			}
			for condition in range(4):
				for tp_performance in condition_performance_dict[condition]:
					performance_dict[condition].append(tp_performance.performance["second_batch"][var_name])
			
			# statistic, pvalue = kruskal(performance_list_1, performance_list_2)
			print(var_name)
			for condition, condition_name in enumerate(["Debugging-R", "Debugging-D", "Debugging-I", "Control"]):
				performance_list = performance_dict[condition]
				print("{} Mean: M(first):{:.2f}, SD(first):{:.2f}".format(condition_name, np.mean(performance_list), np.std(performance_list)))

			res = kruskal(performance_dict[0], performance_dict[1], performance_dict[2], performance_dict[3])
			print("Kruskal results for performance on the second batch:", res)
			print("-" * 17)

	print("Compare participants's performance improvement with different intervention (second - first):")
	compare_performance(condition_performance_dict)
	compare_performance_second_batch(condition_performance_dict)
	# print("-" * 34)

def obtain_all_performance(users, user_condition_dict, user_batch_order, usertask_dict, user_flip_dict, answer_dict):
	user2performance = {}
	for user in users:
		tp_condition = user_condition_dict[user]
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_performance = UserPerformance(username=user)
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, first_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="first_batch")
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, second_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="second_batch")

		# user_performance_list.append(tp_performance)
		user2performance[user] = tp_performance
	return user2performance

if __name__ == "__main__":
	# load answer dict
	answer_dict_p1 = load_answers("reviews_p1.json")
	answer_dict_p2 = load_answers("reviews_p2.json")
	answer_dict = {}
	for task_id in answer_dict_p1:
		answer_dict["p1-{}".format(task_id)] = answer_dict_p1[task_id]
	for task_id in answer_dict_p2:
		answer_dict["p2-{}".format(task_id)] = answer_dict_p2[task_id]

	valid_users, approved_users = find_valid_users()
	user_condition_dict, user_batch_order = load_user_info(reserved_users=valid_users)

	print(f"For H3, we have {len(valid_users)} participants")
	condition_users = get_condition_users(user_condition_dict)
	users_dict = {
		"Debugging-R": valid_users & condition_users[0],
		"Debugging-D": valid_users & condition_users[1],
		"Debugging-I": valid_users & condition_users[2],
		"Control"    : valid_users & condition_users[3]
	}

	for condition in users_dict:
		print("Condition {} have {} valid participants".format(condition, len(users_dict[condition])))
	
	user_task_dict, user_confidence_dict, user_flip_dict = read_decisions(user_batch_order, reserved_users=valid_users)
	compare_performance_across_conditions(valid_users, user_condition_dict, user_batch_order, user_task_dict, user_flip_dict, answer_dict)

	# user2performance = obtain_all_performance(valid_users, user_condition_dict, user_batch_order, user_task_dict, user_flip_dict, answer_dict)

	user_nasa_tlx_dict = read_nasa_tlx(valid_users - users_dict["Control"])
	for scale in ["Mental Demand", "Physical Demand", "Temporal Demand", "Performance", "Effort", "Frustration"]:
		temp_dict = {}
		for condition in ["Debugging-R", "Debugging-D", "Debugging-I"]:
			user_set = users_dict[condition]
			temp_dict[condition] = []
			for user in user_set:
				temp_dict[condition].append(user_nasa_tlx_dict[user][scale])
		static, pvalue = kruskal(temp_dict["Debugging-R"], temp_dict["Debugging-D"], temp_dict["Debugging-I"])
		print("Kruskal results:", static, pvalue)
		if pvalue < (0.05 / 3):
			conditions = ["Debugging-R", "Debugging-D", "Debugging-I"]
			for c1 in range(3):
				for c2 in range(c1, 3):
					statistic, pvalue = mannwhitneyu(time_dict[c1], time_dict[c2], alternative='two-sided')
					if pvalue < (0.05 / 3):
						print("Alternative {} <> {},".format(conditions[c1], conditions[c2]), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
						statistic, pvalue = mannwhitneyu(time_dict[c1], time_dict[c2], alternative='greater')
						if pvalue < (0.05 / 3):
							print("Alternative {} > {},".format(conditions[c1], conditions[c2]), "pvalue %.4f"%pvalue, "statistic %.4f"%statistic)
		print("-" * 17)




