import numpy as np
import pandas as pd
import os
import sys
from util import UserPerformance
from util import load_answers, load_user_info, find_valid_users, read_decisions, read_user_question_order, get_condition_users
from util import calc_ATI_scale, calc_familiarity, calc_propensity_to_trust
from util import find_valid_users, load_answers, read_decisions, get_assessment, calc_user_reliance_measures
from scipy.stats import wilcoxon, kruskal
import seaborn as sns
import matplotlib.pyplot as plt

p1_tasks = ["p1-{}".format(i) for i in range(10)]
p2_tasks = ["p2-{}".format(i) for i in range(10)]

def draw_box_plot(data_dict, mean_dict, sd_dict, median_dict):
	df = pd.DataFrame(data_dict, dtype=float)
	# print(df.isnull().sum())
	# print(df)
	sns.set_theme(style="whitegrid")
	sns.set(font="Arial")
	# tips = sns.load_dataset("tips")
	# print(type(tips))
	# print(df)
	# print(type(df))
	# print(len(data[0]), len(data[1]), len(data[2]), len(data[3]))
	# ax = sns.boxplot(data=df, showmeans=True, meanprops={"marker":"o", "markerfacecolor":"white",  "markeredgecolor":"black", "markersize":"10"})
	ax = sns.boxplot(data=df)

	# Calculate number of obs per group & median to position labels
	# medians = df.groupby(['species'])['sepal_length'].median().values
	# nobs = df['species'].value_counts().values
	# nobs = [str(x) for x in nobs.tolist()]
	# nobs = ["n: " + i for i in nobs]
	key_list = list(data_dict.keys())
	mean_list = [mean_dict[tp_key] for tp_key in key_list]
	text_list = ["M: {:.2f}\nSD: {:.2f}".format(mean_dict[tp_key], sd_dict[tp_key]) for tp_key in key_list]
	 
	# Add it to the plot
	pos = range(len(text_list))
	for tick,label in zip(pos,ax.get_xticklabels()):
	    ax.text(pos[tick],
	            mean_list[tick] + 0.03,
	            text_list[tick],
	            horizontalalignment='center',
	            size=18,
	            color='w',
	            weight='semibold')

	# ax = sns.boxplot(x="day", y="total_bill", hue="smoker", data=tips, palette="Set3")
	ax.tick_params(labelsize=18)
	ax.set_xlabel("Dimension", fontsize = 24)
	ax.set_ylabel("Value", fontsize = 24)
	plt.margins(0.015, tight=True)
	plt.show()

def multi_line_plot(data_dict):
	df = pd.DataFrame(data_dict, dtype=float)
	# print(df)
	sns.set_theme(style="whitegrid")
	sns.set(font="Arial")
	palette = sns.color_palette("pastel")
	ax = sns.lineplot(data=df, dashes=False)
	ax.tick_params(labelsize=18)
	ax.set_xlabel("Tasks", fontsize = 24)
	ax.set_ylabel("Confidence", fontsize = 24)
	plt.setp(ax.get_legend().get_texts(), fontsize='14') # for legend text
	plt.setp(ax.get_legend().get_title(), fontsize='14') # for legend title
	plt.margins(0.015, tight=True)
	plt.show()

def multiple_plot_new(y_list):
	x = range(1, 21)
	y_control, y_R, y_D, y_I = y_list
	fig = plt.figure()
	gs = fig.add_gridspec(2, 2, hspace=0, wspace=0)
	(ax1, ax2), (ax3, ax4) = gs.subplots(sharex='col', sharey='row')
	# fig.suptitle('Confidence Dynamics')
	ax1.plot(x, y_control)
	ax1.set_title('Control', loc="left", y=1.0, pad=-14)
	ax2.plot(x, y_R, 'tab:orange')
	ax2.set_title('Debugging-R', loc="left", y=1.0, pad=-14)
	ax3.plot(x, y_D, 'tab:green')
	ax3.set_title('Debugging-D', loc="left", y=1.0, pad=-14)
	ax4.plot(x, y_I, 'tab:red')
	ax4.set_title('Debugging-I', loc="left", y=1.0, pad=-14)
	# So far so good. Now the trick:

	# lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
	# lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]

	# # Finally, we invoke the legend (that you probably would like to customize...)

	# fig.legend(lines, labels)

	plt.show()

def simple_plot(y_list):

	fig, ax = plt.subplots()

	ax.plot(y_list[0], label = 'Control', marker='+')
	ax.plot(y_list[1], color = 'tab:orange', label = 'Debugging-R', marker='o')
	ax.plot(y_list[2], color = 'tab:green', label = 'Debugging-D', marker='*')
	ax.plot(y_list[3], color = 'tab:red', label = 'Debugging-I', marker='d')
	ax.plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')
	ax.legend(loc = 'upper left')
	ax.tick_params(labelsize=18)
	ax.set_xlabel("Tasks", fontsize = 18)
	ax.set_ylabel("Confidence", fontsize = 18)
	plt.setp(ax.get_legend().get_texts(), fontsize='14') # for legend text
	plt.setp(ax.get_legend().get_title(), fontsize='14') # for legend title
	plt.margins(0.015, tight=True)
	plt.ylim([0.4, 1.2])
	plt.show()

def multiple_plot(y_list):
	x = range(1, 21)
	y_control, y_R, y_D, y_I = y_list
	# print(y_control)
	# print(y_R)
	# fig, axs = plt.subplots(1, 4)

	fig1, axs = plt.subplots(ncols=2, nrows=2, constrained_layout=True)

	# axs[0, 0].plot(x, y_control, marker='+')
	axs[0, 0].plot(x, y_control)
	axs[0, 0].set_title('Control')
	axs[0, 0].set_ylim([0.4, 1.2])
	# axs[0, 0].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')
	# axs[0, 0].set_marker("+")
	axs[0, 1].plot(x, y_R, 'tab:orange')
	axs[0, 1].set_title('Debugging-R')
	axs[0, 1].set_ylim([0.4, 1.2])
	axs[0, 1].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')

	axs[1, 0].plot(x, y_D, 'tab:green')
	axs[1, 0].set_title('Debugging-D')
	axs[1, 0].set_ylim([0.4, 1.2])
	axs[1, 0].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')

	axs[1, 1].plot(x, y_I, 'tab:red')
	axs[1, 1].set_title('Debugging-I')
	axs[1, 1].set_ylim([0.4, 1.2])
	axs[1, 1].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')

	for ax in axs.flat:
	    ax.set(xlabel='Tasks', ylabel='Confidence')

	# Hide x labels and tick labels for top plots and y ticks for right plots.
	# for ax in axs.flat:
	#     ax.label_outer()

	plt.show()

def multiple_plot_14(y_list):
	x = range(1, 21)
	y_control, y_R, y_D, y_I = y_list

	size=24
	params = {'axes.labelsize': size,
	          'axes.titlesize': size,
	          'xtick.labelsize': size*0.75,
	          'ytick.labelsize': size*0.75}
	plt.rcParams.update(params)

	# ax = plt.plot(x, y_control)
	# ax = sns.lineplot(x=x, y=y_control)
	# figure = plt.figure()
	ax = plt.gca()
	ax.plot(x, y_control)
	ax.set_ylim([0.4, 1.1])
	ax.set(xlabel='Tasks', ylabel='Confidence')
	plt.show()
	# plt.savefig("Confidence_control.png")
	plt.clf()

	ax = plt.gca()
	ax.plot(x, y_R, 'tab:orange')
	ax.set_ylim([0.4, 1.1])
	ax.set(xlabel='Tasks', ylabel='Confidence')
	ax.plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')
	plt.show()
	# plt.savefig("Confidence_random.png")
	plt.clf()

	ax = plt.gca()
	# ax = sns.lineplot(x=x, y=y_D, 'tab:green')
	ax.plot(x, y_D, 'tab:green')
	ax.set_ylim([0.4, 1.1])
	ax.set(xlabel='Tasks', ylabel='Confidence')
	ax.plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')
	plt.show()
	# plt.savefig("Confidence_dec.png")
	plt.clf()

	# ax = sns.lineplot(x=x, y=y_I, 'tab:red')
	ax = plt.gca()
	ax.plot(x, y_I, 'tab:red')
	ax.set_ylim([0.4, 1.1])
	ax.set(xlabel='Tasks', ylabel='Confidence')
	ax.plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')
	plt.show()
	# plt.savefig("Confidence_inc.png")

	# # print(y_control)
	# # print(y_R)
	# fig, axs = plt.subplots(1, 4)
	# # fig1, axs = plt.subplots(ncols=4, nrows=4, constrained_layout=True)

	# # axs[0, 0].plot(x, y_control, marker='+')
	# axs[0].plot(x, y_control)
	# axs[0].set_title('Control')
	# axs[0].set_ylim([0.4, 1.2])
	# # axs[0, 0].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')
	# # axs[0, 0].set_marker("+")
	# axs[1].plot(x, y_R, 'tab:orange')
	# axs[1].set_title('Debugging-R')
	# axs[1].set_ylim([0.4, 1.2])
	# axs[1].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')

	# axs[2].plot(x, y_D, 'tab:green')
	# axs[2].set_title('Debugging-D')
	# axs[2].set_ylim([0.4, 1.2])
	# axs[2].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')

	# axs[3].plot(x, y_I, 'tab:red')
	# axs[3].set_title('Debugging-I')
	# axs[3].set_ylim([0.4, 1.2])
	# axs[3].plot([10] * 21, np.arange(0.4, 1.24, 0.04), 'tab:brown', linestyle='--')

	# for ax in axs.flat:
	#     ax.set(xlabel='Tasks', ylabel='Confidence')

	# Hide x labels and tick labels for top plots and y ticks for right plots.
	# for ax in axs.flat:
	#     ax.label_outer()

	# plt.show()

def draw_line_plot(data_dict):
	df = pd.DataFrame(data_dict, dtype=float)
	sns.set_theme(style="whitegrid")
	sns.set(font="Arial")
	# ax = sns.boxplot(data=df)
	ax = sns.lineplot(x="Task", y="Confidence", data=df)
	ax.tick_params(labelsize=18)
	ax.set_xlabel("Tasks", fontsize = 24)
	ax.set_ylabel("Confidence", fontsize = 24)
	# plt.setp(ax.get_legend().get_texts(), fontsize='16') # for legend text
	# plt.setp(ax.get_legend().get_title(), fontsize='16') # for legend title
	plt.margins(0.015, tight=True)
	plt.show()

def show_confidence_dynamics(users_with_tutorial, user_question_order, user_confidence_dict, plot=False):
	data = {}
	for i in range(20):
		data[i] = []
	for user in users_with_tutorial:
		question_order = user_question_order[user]
		for index, task in enumerate(question_order):
			data[index].append(user_confidence_dict[user][(task, "advice")])
	confidence_dynamic_avg = []
	data_long_format = {}
	data_long_format["Task"] = []
	data_long_format["Confidence"] = []
	for i in range(20):
		confidence_dynamic_avg.append(np.mean(data[i]))
		data_long_format["Task"].append(i)
		data_long_format["Confidence"].append(np.mean(data[i]))
	# print(confidence_dynamic_avg)
	if plot:
		draw_line_plot(data_long_format)
	return [np.mean(data[i]) for i in range(20)]

def agreement_confidence_analysis(users, user_question_order, user_task_dict, user_confidence_dict, user_flip_dict, answer_dict):
	system_advice_dict = {}
	initial_agreement_change_list = []
	disagreement_change_list = []
	positive_ai_reliance_list = []
	negative_ai_reliance_list = []
	positive_self_reliance_list = []
	negative_self_reliance_list = []
	switch_list = []
	agreement_list = []
	disagreement_list = []
	for user in users:
		question_order = user_question_order[user]
		for task in question_order:
			correct_answer = answer_dict[task]
			first_choice = user_task_dict[user][(task, "base")]
			second_choice = user_task_dict[user][(task, "advice")]
			first_confidence = user_confidence_dict[user][(task, "base")]
			second_confidence = user_confidence_dict[user][(task, "advice")]
			if task in user_flip_dict[user]:
				system_advice = 0 if correct_answer == 1 else 1
			else:
				system_advice = correct_answer
			confidence_change = second_confidence - first_confidence
			if second_choice == system_advice:
				agreement_list.append(confidence_change)
			else:
				disagreement_list.append(confidence_change)
			if first_choice == system_advice:
				# user initially agree with AI advice
				initial_agreement_change_list.append(confidence_change)
			else:
				disagreement_change_list.append(confidence_change)
				if first_choice != second_choice:
					switch_list.append(confidence_change)
				if system_advice == correct_answer:
					if second_choice == correct_answer:
						# Positive AI reliance
						positive_ai_reliance_list.append(confidence_change)
					else:
						# Negative self reliance
						negative_self_reliance_list.append(confidence_change)
				else:
					if second_choice == correct_answer:
						# Positive Self reliance
						positive_self_reliance_list.append(confidence_change)
					else:
						# Negative AI reliance
						negative_ai_reliance_list.append(confidence_change)
	print("-" * 17)
	print("Confidence change in initial agreement:", np.mean(initial_agreement_change_list), np.std(initial_agreement_change_list))
	print("Confidence change in initial disagreement:", np.mean(disagreement_change_list), np.std(disagreement_change_list))
	res = kruskal(initial_agreement_change_list, disagreement_change_list)
	print("Kruskal results for agreement vs disagreement:", res)
	print("-" * 17)
	print("Confidence change in switch behavior:", np.mean(switch_list), np.std(switch_list))
	print("Confidence change in agreement behavior:", np.mean(agreement_list), np.std(agreement_list))
	res = kruskal(agreement_list, switch_list)
	print("Kruskal results for agreement vs switch:", res)
	print("-" * 17)
	print("Confidence change in final disagreement behavior:", np.mean(disagreement_list), np.std(disagreement_list))
	print("-" * 17)
	# print("Confidence change in disagreement:", np.mean(disagreement_change_list))
	print("Confidence change in positive AI reliance:", np.mean(positive_ai_reliance_list), np.std(positive_ai_reliance_list))
	print("Confidence change in positive self reliance:", np.mean(positive_self_reliance_list), np.std(positive_self_reliance_list))
	print("Confidence change in negative AI reliance:", np.mean(negative_ai_reliance_list), np.std(negative_ai_reliance_list))
	print("Confidence change in negative self reliance:", np.mean(negative_self_reliance_list), np.std(negative_self_reliance_list))
	print("-" * 17)
	res = kruskal(positive_ai_reliance_list + negative_ai_reliance_list, positive_self_reliance_list + negative_self_reliance_list)
	print("Kruskal results for AI reliance vs Self reliance under initial disagreement:", res)
	print("-" * 17)

	# user_data = {
	# 	"initial agreement": initial_agreement_change_list,
	# 	"initial disagreement": disagreement_change_list,
	# 	"final agreement": agreement_list,
	# 	"final disagreement": disagreement_list,
	# 	"switch behavior": switch_list,
	# 	"positive AI reliance": positive_ai_reliance_list,
	# 	"negative AI reliance": negative_ai_reliance_list,
	# 	"positive self-reliance": positive_self_reliance_list,
	# 	"negative self-reliance": negative_self_reliance_list
	# }

	# mean_dict, sd_dict, median_dict = {}, {}, {}
	# for scale in user_data:
	# 	mean_dict[scale] = np.mean(user_data[scale])
	# 	sd_dict[scale] = np.std(user_data[scale], ddof=1)
	# 	median_dict[scale] = np.median(user_data[scale])
	# draw_box_plot(user_data, mean_dict, sd_dict, median_dict)



def compare_performance_debugging(users_with_tutorial, user_batch_order, usertask_dict, user_flip_dict, answer_dict):
	user_performance_list = []
	for user in users_with_tutorial:
		if user_batch_order[user][0] == 1:
			first_batch = p1_tasks
			second_batch = p2_tasks
		else:
			assert user_batch_order[user][0] == 2
			first_batch = p2_tasks
			second_batch = p1_tasks
		tp_performance = UserPerformance(username=user)
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_1, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, first_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="first_batch")
		
		tp_correct, tp_agreement_fraction, tp_switching_fraction, initial_disagreement_2, relative_positive_ai_reliance,\
			relative_positive_self_reliance = calc_user_reliance_measures(user, usertask_dict, user_flip_dict, answer_dict, second_batch)
		tp_accuracy = tp_correct / 10.0
		tp_performance.add_performance(accuracy=tp_accuracy, agreement_fraction=tp_agreement_fraction, switching_fraction=tp_switching_fraction, 
			relative_positive_ai_reliance=relative_positive_ai_reliance, relative_positive_self_reliance=relative_positive_self_reliance, group="second_batch")

		user_performance_list.append(tp_performance)

	print(f"For H2, we have {len(user_performance_list)} participants with tutorial")
	print("Compare participants's performance before debugging and after debugging:")
	compare_performance(user_performance_list)
	# print("-" * 34)


if __name__ == "__main__":
	# load answer dict
	answer_dict_p1 = load_answers("reviews_p1.json")
	answer_dict_p2 = load_answers("reviews_p2.json")
	answer_dict = {}
	for task_id in answer_dict_p1:
		answer_dict["p1-{}".format(task_id)] = answer_dict_p1[task_id]
	for task_id in answer_dict_p2:
		answer_dict["p2-{}".format(task_id)] = answer_dict_p2[task_id]

	valid_users, approved_users = find_valid_users()
	user_condition_dict, user_batch_order = load_user_info(reserved_users=valid_users)
	user_question_order = read_user_question_order(user_batch_order, reserved_users=valid_users)

	condition_users = get_condition_users(user_condition_dict)
	users_dict = {
		"Debugging-R": valid_users & condition_users[0],
		"Debugging-D": valid_users & condition_users[1],
		"Debugging-I": valid_users & condition_users[2],
		"Control"    : valid_users & condition_users[3]
	}

	for condition in range(4):
		print("Condition {} have {} valid participants".format(condition, len(valid_users & condition_users[condition])))

	users_with_tutorial = set()
	users_control = set()
	for user in valid_users:
		if user_condition_dict[user] == 3:
			users_control.add(user)
			continue
		users_with_tutorial.add(user)
	if len(users_with_tutorial) == 0:
		print("no user valid for H1 analysis")
		sys.exit(-1)
	
	user_task_dict, user_confidence_dict, user_flip_dict = read_decisions(user_batch_order, reserved_users=valid_users)
	# show_confidence_dynamics(users_with_tutorial, user_question_order, user_confidence_dict, plot=True)
	# show_confidence_dynamics(users_control, user_question_order, user_confidence_dict)
	# show_confidence_dynamics(valid_users, user_question_order, user_confidence_dict)
	conditions = ["Control", "Debugging-R", "Debugging-D", "Debugging-I"]
	y_list = []
	# placeholder, wait for the last two conditions
	for condition in conditions:
		tp_y = show_confidence_dynamics(users_dict[condition], user_question_order, user_confidence_dict)
		y_list.append(tp_y)
	data_wide_format = {}
	data_wide_format["Control"] = []
	data_wide_format["Debugging-R"] = []
	data_wide_format["Debugging-D"] = []
	data_wide_format["Debugging-I"] = []
	for i in range(20):
		for index, condition in enumerate(conditions):
			data_wide_format[condition].append(y_list[index][i])

	# multi_line_plot(data_wide_format)
	# multiple_plot(y_list) 
	# multiple_plot_new(y_list)
	# simple_plot(y_list)
	# multiple_plot_14(y_list) # we used this in paper

	agreement_confidence_analysis(valid_users, user_question_order, user_task_dict, user_confidence_dict, user_flip_dict, answer_dict)


