import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable as V
import torch.nn.functional as F
import pickle
import argparse
from helper_functions_twitter import *
import os 

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

parser = argparse.ArgumentParser(description='Twitter label corruption experiments')
parser.add_argument('--method', default='ours', type=str, choices=['ours', 'glc', 'forward_gold', 'TMatrix', 'sTMatrix', 'sgforward'])
parser.add_argument('--corruption_type', default='flip_labels', type=str, choices=['uniform_mix', 'flip_labels', 'bimodal'])
parser.add_argument('--alpha', type=float, default=1.0, help='alpha scale')
parser.add_argument('--beta', type=float, default=1.0, help='beta scale')
args = parser.parse_args()
print(args)

window_size = 1

# note that we encode the tags with numbers for later convenience
tag_to_number = {
    u'N': 0, u'O': 1, u'S': 2, u'^': 3, u'Z': 4, u'L': 5, u'M': 6,
    u'V': 7, u'A': 8, u'R': 9, u'!': 10, u'D': 11, u'P': 12, u'&': 13, u'T': 14,
    u'X': 15, u'Y': 16, u'#': 17, u'@': 18, u'~': 19, u'U': 20, u'E': 21, u'$': 22,
    u',': 23, u'G': 24
}

embeddings = embeddings_to_dict('./data/Tweets/embeddings-twitter.txt')
vocab = embeddings.keys()

# we replace <s> with </s> since it has no embedding, and </s> is a better embedding than UNK
X_train, Y_train = data_to_mat('./data/Tweets/tweets-train.txt', vocab, tag_to_number, window_size=window_size,
                     start_symbol=u'</s>')
X_dev, Y_dev = data_to_mat('./data/Tweets/tweets-dev.txt', vocab, tag_to_number, window_size=window_size,
                         start_symbol=u'</s>')
X_test, Y_test = data_to_mat('./data/Tweets/tweets-devtest.txt', vocab, tag_to_number, window_size=window_size,
                             start_symbol=u'</s>')


def prepare_data(corruption_matrix, gold_fraction=0.5, merge_valset=True):
    np.random.seed(1)

    twitter_tweets = np.copy(X_train)
    twitter_labels = np.copy(Y_train)
    if merge_valset:
        twitter_tweets = np.concatenate([twitter_tweets, np.copy(X_dev)], axis=0)
        twitter_labels = np.concatenate([twitter_labels, np.copy(Y_dev)])

    indices = np.arange(len(twitter_labels))
    np.random.shuffle(indices)

    twitter_tweets = twitter_tweets[indices]
    twitter_labels = twitter_labels[indices].astype(np.long)

    num_gold = int(len(twitter_labels)*gold_fraction)
    num_silver = len(twitter_labels) - num_gold

    for i in range(num_silver):
        twitter_labels[i] = np.random.choice(num_classes, p=corruption_matrix[twitter_labels[i]])

    dataset = {'x': twitter_tweets, 'y': twitter_labels}
    gold = {'x': dataset['x'][num_silver:], 'y': dataset['y'][num_silver:]}

    return dataset, gold, num_gold, num_silver


def uniform_mix_C(mixing_ratio):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    return mixing_ratio * np.full((num_classes, num_classes), 1.0 / num_classes) + \
        (1.0 - mixing_ratio) * np.eye(num_classes)

def flip_labels_C(corruption_prob):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    np.random.seed(1)

    C = np.eye(num_classes) * (1.0 - corruption_prob)
    row_indices = np.arange(num_classes)
    for i in range(num_classes):
        C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
    return C

def bimodal_mix_C(mixing_ratio, num_classes, mu1=3.0, sigma1=1.0, mu2=7.0, sigma2=0.5):
    '''
    returns a linear interpolation of a uniform matrix and two truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1.0 - mixing_ratio) * np.eye(num_classes)
    
    lower = 0
    upper = num_classes - 1
    mu1 = mu1 * num_classes / 10
    sigma1 = sigma1 * num_classes / 10
    mu2 = mu2 * num_classes / 10
    sigma2 = sigma2 * num_classes / 10
    tn1 = truncnorm((lower - mu1) / sigma1, (upper - mu1) / sigma1, loc=mu1, scale=sigma1)
    tn2 = truncnorm((lower - mu2) / sigma2, (upper - mu2) / sigma2, loc=mu2, scale=sigma2)
    r1 = tn1.pdf(range(num_classes))
    r1 = r1 / np.sum(r1) # ensure sum = 1
    r2 = tn2.pdf(range(num_classes))
    r2 = r2 / np.sum(r2) # ensure sum = 1

    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * (r1[j] + r2[j]) / 2.0
    return C


reg_str = 5e-5
num_epochs_1 = 15
num_epochs_2 = 25
num_classes = 25
hidden_size = 256
batch_size = 64
embedding_dimension = 50
example_size = (2*window_size + 1)*embedding_dimension
init_lr = 0.001
num_examples = Y_train.shape[0]
num_batches = num_examples//batch_size


# //////////////////////// defining graph ////////////////////////
class ThreeLayerNet(nn.Module):
    def __init__(self):
        super(ThreeLayerNet,self).__init__()
        self.main = nn.Sequential(
            nn.Linear(example_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
            )

        self.init_weights()

    def init_weights(self):
        self.main[0].weight.data.normal_(0, 1/np.sqrt(example_size))
        self.main[0].bias.data.zero_()
        self.main[2].weight.data.normal_(0, 1/np.sqrt(256))
        self.main[2].bias.data.zero_()
        self.main[4].weight.data.normal_(0, 1/np.sqrt(256))
        self.main[4].bias.data.zero_()


    def forward(self, x):
        return self.main(x)


to_embeds = lambda x: word_list_to_embedding(x, embeddings, embedding_dimension)

def alpha_beta_tune(corruption_level, C):
	if (corruption_level < 0.5):
		return	args.alpha, args.beta
	
	else:	
		rate_C = np.ones((num_classes-1)*(num_classes))
		avg_C = np.zeros(num_classes)
		std_C = np.zeros(num_classes)
		jain_C = np.zeros(num_classes)
		k = 0
		for j in range(num_classes):
			for i in range(num_classes):
				if i != j:
					if (C[j][j] != 0):
						rate_C[k] = C[i][j]/C[j][j]
						k += 1 
					else:
						rate_C[k] = C[i][j]/(C[j][j]+ np.finfo(float).eps)
						k += 1
						   
			avg_C[j] = np.mean(rate_C)
			std_C[j] = np.std(rate_C) 

		jain_fair_idx = np.sum(rate_C)**2/(len(rate_C) * np.sum(np.power(rate_C,2)) + np.finfo(float).eps)
		alpha = (jain_fair_idx*(1 - corruption_level))
		beta = (jain_fair_idx*(1 - corruption_level)) 

		return alpha, beta

def train_and_test(method='ours', corruption_level=0, gold_fraction=0.5, get_C=uniform_mix_C):
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    net = ThreeLayerNet().cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=reg_str)
    
    if (get_C == bimodal_mix_C):
        C = get_C(corruption_level, num_classes, mu1=6.0, sigma1=2.5, mu2=18.0, sigma2=1.25)
    else:
        C = get_C(corruption_level)

    dataset, gold, num_gold, num_silver = prepare_data(C, gold_fraction)


    # //////////////////////// train for estimation ////////////////////////

    if method == 'ours' or method == 'glc' or method == 'TMatrix' or method == 'sTMatrix' or method == 'sgforward' or method == 'forward_gold':
        num_examples = num_silver
    
    num_batches = num_examples//batch_size

    indices = np.arange(num_examples)
    for epoch in range(num_epochs_1):
        # shuffle data indices every epoch
        np.random.shuffle(indices)

        for i in range(num_batches):
            offset = i * batch_size

            x_batch = to_embeds(dataset['x'][indices[offset:offset + batch_size]])
            y_batch = dataset['y'][indices[offset:offset + batch_size]]
            data, target = V(torch.from_numpy(x_batch).cuda()), V(torch.from_numpy(y_batch).cuda())

            # forward
            output = net(data)

            # backward
            loss = F.cross_entropy(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    net.eval()
    data, target = V(torch.from_numpy(to_embeds(X_test)).cuda(), volatile=True),\
                   V(torch.from_numpy(Y_test.astype(np.long)).cuda(), volatile=True)

    output = net(data)
    pred = output.data.max(1)[1]
    correct = pred.eq(target.data).sum()

    baseline_acc = float(correct) / len(Y_test)


    # //////////////////////// estimate C ////////////////////////
    if method == 'ours' or method == 'glc':
        probs = F.softmax(net(V(torch.from_numpy(to_embeds(gold['x'])).cuda(), volatile=True))).data.cpu().numpy()

        C_hat = np.zeros((num_classes,num_classes))
        for label in range(num_classes):
            indices = np.arange(len(gold['y']))[gold['y'] == label]
            if indices.size == 0:
                C_hat[label] = np.ones(num_classes) / num_classes  # TODO: try a diagonal prior instead
            else:
                C_hat[label] = np.mean(probs[indices], axis=0, keepdims=True)

    elif method == 'sgforward' or method == 'forward_gold':
        probs = F.softmax(net(V(torch.from_numpy(to_embeds(dataset['x'])).cuda(), volatile=True))).data.cpu().numpy()

        C_hat = np.zeros((num_classes,num_classes))
        for label in range(num_classes):
            class_probs = probs[:,label]
            thresh = np.percentile(class_probs, 97, interpolation='higher')
            class_probs[class_probs >= thresh] = 0

            C_hat[label] = probs[np.argsort(class_probs)][-1]

    elif method == 'TMatrix' or method == 'sTMatrix' : C_hat = C


    C_hat = V(torch.from_numpy(C_hat.astype(np.float32))).cuda()


    # //////////////////////// retrain with correction ////////////////////////
    net.train()
    net.init_weights()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=reg_str)

    if method == 'glc' or method == 'TMatrix' or method == 'forward_gold':
        num_examples = dataset['y'].shape[0]
        num_batches = num_examples//batch_size

        indices = np.arange(num_examples)
        for epoch in range(num_epochs_2):
            np.random.shuffle(indices)

            for i in range(num_batches):
                offset = i * batch_size
                current_indices = indices[offset:offset + batch_size]

                data = to_embeds(dataset['x'][current_indices])
                target = dataset['y'][current_indices]

                gold_indices = current_indices >= num_silver
                silver_indices = current_indices < num_silver

                gold_len = np.sum(gold_indices)
                if gold_len > 0:
                    data_g, target_g = data[gold_indices], target[gold_indices]
                    data_g, target_g = V(torch.FloatTensor(data_g).cuda()),\
                                       V(torch.from_numpy(target_g).long().cuda())

                silver_len = np.sum(silver_indices)
                if silver_len > 0:
                    data_s, target_s = data[silver_indices], target[silver_indices]
                    data_s, target_s = V(torch.FloatTensor(data_s).cuda()),\
                                       V(torch.from_numpy(target_s).long().cuda())

                # forward
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                loss_s = 0
                if silver_len > 0:
					output_s = net(data_s)
					pre1 = C_hat.t()[torch.cuda.LongTensor(target_s.data)]
					pre2 = torch.mul(F.softmax(output_s), pre1)
					loss_s = -(torch.log(pre2.sum(1))).sum(0)
                loss_g = 0
                if gold_len > 0:
                    output_g = net(data_g)
                    loss_g = F.cross_entropy(output_g, target_g, size_average=False)

                # backward
                loss = (loss_g + loss_s)/batch_size
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
	
    elif method == 'sTMatrix' or method == 'sgforward':
	
        num_examples = dataset['y'].shape[0]
        num_batches = num_examples//batch_size

        indices = np.arange(num_examples)
        for epoch in range(num_epochs_2):
            np.random.shuffle(indices)

            for i in range(num_batches):
                offset = i * batch_size
                current_indices = indices[offset:offset + batch_size]

                data = to_embeds(dataset['x'][current_indices])
                target = dataset['y'][current_indices]

                gold_indices = current_indices >= num_silver
                silver_indices = current_indices < num_silver

                gold_len = np.sum(gold_indices)
                if gold_len > 0:
                    data_g, target_g = data[gold_indices], target[gold_indices]
                    data_g, target_g = V(torch.FloatTensor(data_g).cuda()),\
                                       V(torch.from_numpy(target_g).long().cuda())

                silver_len = np.sum(silver_indices)
                if silver_len > 0:
                    data_s, target_s = data[silver_indices], target[silver_indices]
                    data_s, target_s = V(torch.FloatTensor(data_s).cuda()),\
                                       V(torch.from_numpy(target_s).long().cuda())

                # forward
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                alpha, beta = alpha_beta_tune(corruption_level, C)
                loss_s = 0
                if silver_len > 0:
                    
                    output_s = net(data_s)
                    pre1 = C_hat.t()[torch.cuda.LongTensor(target_s.data)]
                    pre2 = torch.mul(F.softmax(output_s), pre1)
                    
                    ce = F.cross_entropy(pre2, target_s)
                    
                    pred = F.softmax(output_s, dim=1)
                    pred = torch.clamp(pred, min=1e-7, max=1.0)
                    label_one_hot = torch.nn.functional.one_hot(target_s, num_classes).float().to(device)
                    label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
                    rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
                    
                    # Loss
                    loss_s = alpha * ce + beta * rce.mean()
                    
                loss_g = 0
                if gold_len > 0:
                    output_g = net(data_g)
                    loss_g = F.cross_entropy(output_g, target_g, size_average=False)

                # backward
                loss = (loss_g + loss_s)/batch_size
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
	
    elif method == 'ours':
        num_examples = dataset['y'].shape[0]
        num_batches = num_examples//batch_size

        indices = np.arange(num_examples)
        for epoch in range(num_epochs_2):
            np.random.shuffle(indices)

            for i in range(num_batches):
                offset = i * batch_size
                current_indices = indices[offset:offset + batch_size]

                data = to_embeds(dataset['x'][current_indices])
                target = dataset['y'][current_indices]

                gold_indices = current_indices >= num_silver
                silver_indices = current_indices < num_silver

                gold_len = np.sum(gold_indices)
                if gold_len > 0:
                    data_g, target_g = data[gold_indices], target[gold_indices]
                    data_g, target_g = V(torch.FloatTensor(data_g).cuda()),\
                                       V(torch.from_numpy(target_g).long().cuda())

                silver_len = np.sum(silver_indices)
                if silver_len > 0:
                    data_s, target_s = data[silver_indices], target[silver_indices]
                    data_s, target_s = V(torch.FloatTensor(data_s).cuda()),\
                                       V(torch.from_numpy(target_s).long().cuda())

                # forward
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                alpha, beta = alpha_beta_tune(corruption_level, C)
                loss_s = 0
                if silver_len > 0:
                   
                    output_s = net(data_s)
                    pre1 = C_hat.t()[torch.cuda.LongTensor(target_s.data)]
                    pre2 = torch.mul(F.softmax(output_s), pre1)
                    
                    ce = F.cross_entropy(pre2, target_s)
                    
                    pred = F.softmax(output_s, dim=1)
                    pred = torch.clamp(pred, min=1e-7, max=1.0)
                    label_one_hot = torch.nn.functional.one_hot(target_s, num_classes).float().to(device)
                    label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
                    rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
                    
                    # Loss
                    loss_s = alpha * ce + beta * rce.mean()
                   
                loss_g = 0
                if gold_len > 0:
                    output_g = net(data_g)
                    loss_g = F.cross_entropy(output_g, target_g, size_average=False)

                # backward
                loss = (loss_g + loss_s)/batch_size
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    

    


    # //////////////////////// evaluate method ////////////////////////
    net.eval()
    data, target = V(torch.from_numpy(to_embeds(X_test)).cuda(), volatile=True), \
                   V(torch.from_numpy(Y_test.astype(np.long)).cuda(), volatile=True)

    output = net(data)
    pred = output.data.max(1)[1]
    correct = pred.eq(target.data).sum()

    test_acc = float(correct) / len(Y_test)

    # nudge garbage collector
    del dataset; del gold

    return test_acc, baseline_acc


# //////////////////////// run experiments ////////////////////////

if args.corruption_type == 'uniform_mix':
	corruption_fnctn = uniform_mix_C
elif args.corruption_type == 'bimodal':
	corruption_fnctn = bimodal_mix_C
else:
	corruption_fnctn = flip_labels_C

filename = './' + args.method + '_' + args.corruption_type
results = {}
for gold_fraction in [0.01, 0.05, 0.25]:
    results[gold_fraction] = {}
    for corruption_level in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
        test_acc, baseline_acc = train_and_test(args.method, corruption_level, gold_fraction, corruption_fnctn)
        results[gold_fraction][corruption_level] = {}
        results[gold_fraction][corruption_level]['method'] = test_acc
        results[gold_fraction][corruption_level]['baseline'] = baseline_acc
        print('Gold fraction:', gold_fraction, '| Corruption level:', corruption_level,
              '| Method acc:', results[gold_fraction][corruption_level]['method'],
              '| Baseline acc:', results[gold_fraction][corruption_level]['baseline'])
    print()
with open(filename, 'wb') as file:
    pickle.dump(results, file)
    print("Dumped results_ours in file: " + filename)
