from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
import pickle

import torch.utils.data as data
from torchvision.datasets.utils import download_url, check_integrity
from numpy.testing import assert_array_almost_equal
import torch
import torch.nn.functional as F
from torch.autograd import Variable as V
import wideresnet as wrn
import torchvision.transforms as transforms


def uniform_mix_C(mixing_ratio, num_classes):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    #return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
    #    (1 - mixing_ratio) * np.eye(num_classes)

    C = (1 - mixing_ratio) * np.eye(num_classes)
    for i in range(num_classes):
        for j in range(num_classes):
            if(i != j):
                C[i][j] = mixing_ratio * float(1.0 / (num_classes - 1))
    
    return C
    
def flip_labels_C(corruption_prob, num_classes, seed=1):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    np.random.seed(seed)
    C = np.eye(num_classes) * (1 - corruption_prob)
    row_indices = np.arange(num_classes)
    for i in range(num_classes):
        C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
    return C

def flip_labels_C_t(dataset, corruption_prob, num_classes, seed=1):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    if dataset == "cifar10":
        source_class = [9, 2, 3, 5, 4]
        target_class = [1, 0, 5, 3, 7]
        np.random.seed(seed)
        C = np.eye(num_classes) #* (1.0 - corruption_prob)
        for i in range(len(source_class)):
            C[source_class[i]][target_class[i]] = corruption_prob
            C[source_class[i]][source_class[i]] = 1.0 - corruption_prob
    elif dataset == "cifar100":
        P = np.eye(num_classes)
        n = corruption_prob
        nb_superclasses = 20
        nb_subclasses = 5
        np.random.seed(seed)
        
        if n > 0.0:
            for i in np.arange(nb_superclasses):
                init, end = i * nb_subclasses, (i+1) * nb_subclasses
                P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)
                 
        C = P
    
    return C

def truncnorm_mix_C(mixing_ratio, num_classes, mu=3.0, sigma=1.0):
    '''
    returns a linear interpolation of a uniform matrix and an truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1 - mixing_ratio) * np.eye(num_classes)
    
    lower = 0
    upper = num_classes - 1
    mu = mu * num_classes / 10 
    sigma = sigma * num_classes / 10
    tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    r = tn.pdf(range(num_classes))
    r = r / np.sum(r) # ensure sum = 1
    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * r[j]
    
    return C

def truncnorm_alt_mix_C(mixing_ratio, num_classes, sigma=1.0):
    '''
    returns a matrix with a truncnorm centered on the diagonal for each row
    and extending to the adjacent classes with roll-around at the borders
    '''
    from scipy.stats import truncnorm

    C = np.zeros(num_classes, num_classes)
    mu = int(num_classes/2)
    sigma = sigma * num_classes / 10
    lower = 0
    upper = num_classes - 1
    tn = truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    r = tn.pdf(range(num_classes))
    r = r / np.sum(r) # ensure sum = 1
    for i in range(num_classes):
        C[i] = np.roll(r, i - mu)
        
    return C

def bimodal_mix_C(mixing_ratio, num_classes, mu1=3.0, sigma1=1.0, mu2=7.0, sigma2=0.5):
    '''
    returns a linear interpolation of a uniform matrix and two truncnorm
    '''
    from scipy.stats import truncnorm
    
    C = (1 - mixing_ratio) * np.eye(num_classes)
    
    lower = 0
    upper = num_classes - 1
    mu1 = mu1 * num_classes / 10
    sigma1 = sigma1 * num_classes / 10
    mu2 = mu2 * num_classes / 10
    sigma2 = sigma2 * num_classes / 10
    tn1 = truncnorm((lower - mu1) / sigma1, (upper - mu1) / sigma1, loc=mu1, scale=sigma1)
    tn2 = truncnorm((lower - mu2) / sigma2, (upper - mu2) / sigma2, loc=mu2, scale=sigma2)
    r1 = tn1.pdf(range(num_classes))
    r1 = r1 / np.sum(r1) # ensure sum = 1
    r2 = tn2.pdf(range(num_classes))
    r2 = r2 / np.sum(r2) # ensure sum = 1

    for i in range(num_classes):
        for j in range(num_classes):
            C[i][j] += mixing_ratio * (r1[j] + r2[j]) / 2.0
    return C

def build_for_cifar100(size, noise):
    """ random flip between two random classes.
    """
    assert(noise >= 0.) and (noise <= 1.)

    P = np.eye(size)
    cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
    P[cls1, cls2] = noise
    P[cls2, cls1] = noise
    P[cls1, cls1] = 1.0 - noise
    P[cls2, cls2] = 1.0 - noise

    assert_array_almost_equal(P.sum(axis=1), 1, 1)
    return P


class CIFAR10(data.Dataset):
    dataset_name = "cifar10"
    base_folder = 'cifar-10-batches-py'
    url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]

    def __init__(self, root='~/home-nfs/dan/cifar_data', train=True, gold=True, gold_fraction=0.1,
                 corruption_prob=0, corruption_type='unif', transform=None, target_transform=None,
                 dataset="cifar10" ,download=False, shuffle_indices=None, distinguish_gold=True, seed=1):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.gold = gold
        self.gold_fraction = gold_fraction
        self.corruption_prob = corruption_prob
        self.dataset = dataset
        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            self.train_coarse_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                    num_classes = 10
                else:
                    self.train_labels += entry['fine_labels']
                    self.train_coarse_labels += entry['coarse_labels']
                    num_classes = 100
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC

            if gold is True:
                if shuffle_indices is None:
                    indices = np.arange(50000)
                    shuffled_train_labels = self.train_labels
                    while len(set(shuffled_train_labels[:int(gold_fraction * 50000)])) < num_classes:
                        np.random.shuffle(indices)
                        shuffled_train_labels = list(np.array(self.train_labels)[indices])
                else:
                    indices = shuffle_indices

                self.train_data = self.train_data[indices][:int(gold_fraction * 50000)]
                if distinguish_gold:
                    # this ad-hoc move is done so we can identify which examples are
                    # gold/trusted and which are silver/unstrusted
                    self.train_labels = list(np.array(self.train_labels)[indices][:int(gold_fraction * 50000)] + num_classes)
                else:
                    self.train_labels = list(np.array(self.train_labels)[indices][:int(gold_fraction * 50000)])
                self.shuffle_indices = indices
            else:
                indices = np.arange(len(self.train_data)) if shuffle_indices is None else shuffle_indices
                self.train_data = self.train_data[indices][int(gold_fraction * 50000):]
                self.train_labels = list(np.array(self.train_labels)[indices][int(gold_fraction * 50000):])
                if corruption_type == 'hierarchical':
                    self.train_coarse_labels = list(np.array(self.train_coarse_labels)[indices][int(gold_fraction * 50000):])

                if corruption_type == 'unif':
                    C = uniform_mix_C(self.corruption_prob, num_classes)
                    self.C = C
                elif corruption_type == 'flip':
                    #C = flip_labels_C(self.corruption_prob, num_classes)
                    C = flip_labels_C_t(self.dataset, self.corruption_prob, num_classes)
                    self.C = C
                elif corruption_type == 'truncnorm':
                    C = truncnorm_mix_C(self.corruption_prob, num_classes)
                    self.C = C
                elif corruption_type == 'truncnorm_alt':
                    C = truncnorm_alt_mix_C(self.corruption_prob, num_classes)
                    self.C = C
                elif corruption_type == 'bimodal':
                    C = bimodal_mix_C(self.corruption_prob, num_classes)
                    self.C = C
                elif corruption_type == 'hierarchical':
                    assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.'
                    coarse_fine = []
                    for i in range(20):
                        coarse_fine.append(set())
                    for i in range(len(self.train_labels)):
                        coarse_fine[self.train_coarse_labels[i]].add(self.train_labels[i])
                    for i in range(20):
                        coarse_fine[i] = list(coarse_fine[i])

                    C = np.eye(num_classes) * (1 - corruption_prob)

                    for i in range(20):
                        tmp = np.copy(coarse_fine[i])
                        for j in range(len(tmp)):
                            tmp2 = np.delete(np.copy(tmp), j)
                            C[tmp[j], tmp2] += corruption_prob * 1/len(tmp2)
                    self.C = C
                elif corruption_type == 'clabels':
                    net = wrn.WideResNet(40, num_classes, 2, dropRate=0.3).cuda()
                    model_name = './cifar{}_labeler'.format(num_classes)
                    net.load_state_dict(torch.load(model_name))
                    net.eval()
                else:
                    assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical', 'truncnorm'}".format(corruption_type)

                np.random.seed(seed)
                if corruption_type == 'clabels':
                    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
                    std = [x / 255 for x in [63.0, 62.1, 66.7]]

                    test_transform = transforms.Compose(
                        [transforms.ToTensor(), transforms.Normalize(mean, std)])

                    # obtain sampling probabilities
                    sampling_probs = []
                    print('Starting labeling')

                    for i in range((len(self.train_labels) // 64) + 1):
                        current = self.train_data[i*64:(i+1)*64]
                        current = [Image.fromarray(current[i]) for i in range(len(current))]
                        current = torch.cat([test_transform(current[i]).unsqueeze(0) for i in range(len(current))], dim=0)

                        data = V(current).cuda()
                        logits = net(data)
                        smax = F.softmax(logits / 5)  # temperature of 1
                        sampling_probs.append(smax.data.cpu().numpy())


                    sampling_probs = np.concatenate(sampling_probs, 0)
                    print('Finished labeling 1')

                    new_labeling_correct = 0
                    argmax_labeling_correct = 0
                    for i in range(len(self.train_labels)):
                        old_label = self.train_labels[i]
                        new_label = np.random.choice(num_classes, p=sampling_probs[i])
                        self.train_labels[i] = new_label
                        if old_label == new_label:
                            new_labeling_correct += 1
                        if old_label == np.argmax(sampling_probs[i]):
                            argmax_labeling_correct += 1
                    print('Finished labeling 2')
                    print('New labeling accuracy:', new_labeling_correct / len(self.train_labels))
                    print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels))
                else: 
                    #print(C)
                    #input()   
                    for i in range(len(self.train_labels)):
                        #C[self.train_labels[i]] /= C[self.train_labels[i]].sum()
                        #print(C)
                        self.train_labels[i] = np.random.choice(num_classes, p=C[self.train_labels[i]], replace=False)
                    self.corruption_matrix = C
                    print("C", self.corruption_matrix)
                    
        else:
            f = self.test_list[0][0]
            file = os.path.join(root, self.base_folder, f)
            fo = open(file, 'rb')
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
            self.test_data = entry['data']
            if 'labels' in entry:
                self.test_labels = entry['labels']
            else:
                self.test_labels = entry['fine_labels']
            fo.close()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            if self.gold is True:
                return int(self.gold_fraction * 50000)
            else:
                return 50000 - int(self.gold_fraction * 50000)
        else:
            return 10000

    def _check_integrity(self):
        root = self.root
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        root = self.root
        download_url(self.url, root, self.filename, self.tgz_md5)

        # extract file
        cwd = os.getcwd()
        tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)


class CIFAR100(CIFAR10):
    dataset_name = "cifar100"
    base_folder = 'cifar-100-python'
    url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
