from operator import add
import torch
import os

from torch.utils import data
import _pickle as cPickle
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset

def uniform_mix_C(mixing_ratio, num_classes):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
        (1 - mixing_ratio) * np.eye(num_classes)


def flip_labels_C(corruption_prob, num_classes, seed=1):
    '''
    returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
    concentrated in only one other entry for each row
    '''
    np.random.seed(seed)
    C = np.eye(num_classes) * (1 - corruption_prob)
    row_indices = np.arange(num_classes)
    for i in range(num_classes):
        C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
    return C

def preprocess_flickr():
    label = []
    with open('../data/Flickr-25/test_label.txt','r') as f:
        for line in f:
            s = line.split()
            label.append(" ".join(str(y) for y in s))


class DatasetProcessingFlickr(Dataset):
    def __init__(self, data_path, split, gold=False, gold_fraction=0.1, 
                corruption_prob=0.0, corruption_type='unif', distinguish_gold=True, shuffle_indices=None, transform=None, seed=1):
        
        self.transform = transform
        self.split = split
        self.gold = gold
        self.gold_fraction = gold_fraction
        self.corruption_prob = corruption_prob
        self.num_classes = 38
        self.img_path = data_path + '/flickr/Flickr-25'
        if self.split == 'train':
            self.img_filepath =  data_path + '/flickr/database_image.txt'
            self.label_filepath = data_path + '/flickr/database_label.txt'
        elif self.split == 'test':
            self.img_filepath = data_path + '/flickr/test_image.txt'
            self.label_filepath = data_path + '/flickr/test_label.txt'

        if self.split ==  'train':
            self.train_data = []
            self.train_labels = []
            self.data, self.labels = self.pre_precossing()
            for idx, _ in enumerate(self.data):
                
                imgpath = self.data[idx]
                labels = self.labels[idx]
                if np.count_nonzero(labels) > 1:
                    self.train_data.append(imgpath)
                    self.train_labels.append(labels)
            num_samples = len(self.train_labels)
            self.num_samples = num_samples

            self.train_data = np.array(self.train_data)
            self.train_labels = np.array(self.train_labels)

            if gold is True:
                if shuffle_indices is None:
                    avg_labels_per_image = 2.9  # From ASL paper
                    indices = np.arange(num_samples)
                    
                    np.random.seed(seed)
                    np.random.shuffle(indices)

                    shuffled_train_labels = self.train_labels[indices]
                   
                    while np.count_nonzero(shuffled_train_labels[:int(gold_fraction * num_samples)].sum(axis=0)
                                           > avg_labels_per_image) < self.num_classes:
                        np.random.shuffle(indices)
                        shuffled_train_labels = self.train_labels[indices]
                else:
                    indices = shuffle_indices
                self.train_data = self.train_data[indices][:int(gold_fraction * num_samples)]
                if distinguish_gold:
                    self.train_labels = self.train_labels[indices][:int(gold_fraction * num_samples)] + self.num_classes
                else:
                    self.train_labels = self.train_labels[indices][:int(gold_fraction * num_samples)]
                self.shuffle_indices = indices
            else:
                indices = np.arange(len(self.train_data)) if shuffle_indices is None else shuffle_indices
                self.train_data = self.train_data[indices][int(gold_fraction * num_samples):]
                self.train_labels = self.train_labels[indices][int(gold_fraction * num_samples):]

                if corruption_type == 'flip':
                    raise Exception('Corruption type "flip" not implemeneted')
                elif corruption_type == 'unif':
                    C = uniform_mix_C(self.corruption_prob, self.num_classes)
                    self.C = C

                else:
                    assert False, "Invalid corruption type '{}' given. " \
                                  "Must be in ['unif', 'flip']".format(corruption_type)
                self.C_true = np.zeros((self.num_classes, self.num_classes), dtype=np.float64)
                np.random.seed(seed)
                tmp = 0
                if corruption_type == 'unif':
                    for i in range(len(self.train_labels)):
                        true_labels = np.nonzero(self.train_labels[i])[0]
                        for label in range(len(self.train_labels[i])):
                            if self.train_labels[i][label] == 1:
                                new_label = np.random.choice(self.num_classes, p=self.C[label])
                                # If the choice has been to corrupt this label
                                if new_label != label:
                                    if label == 1:
                                        tmp += 1
                                    # This ensures we are generating WRONG labels (not missing/weak labels)
                                    while self.train_labels[i][new_label] == 1 or new_label in true_labels:
                                        new_label = np.random.choice(self.num_classes, p=self.C[label])
                                    self.train_labels[i][label] = 0
                                    self.train_labels[i][new_label] = 1
                                    self.C_true[label][new_label] += 1
                                else:
                                    self.C_true[label][label] += 1
                    self.corruption_matrix = C
                    self.C_true /= np.sum(self.C_true, axis=1)
                elif corruption_type == 'flip':
                    raise RuntimeError("Not yet implemented")
                else:
                    raise RuntimeError("Not yet implemented")
        
        elif self.split == 'test':
            print("Test data ...")
            self.test_data = []
            self.test_labels = []
            self.data, self.labels = self.pre_precossing()
            
            for idx, _ in enumerate(self.data):
                imgpath = self.data[idx]
                labels = self.labels[idx] 
                if np.count_nonzero(labels) > 1:
                    self.test_data.append(imgpath)
                    self.test_labels.append(labels)

            num_samples = len(self.test_labels)
            self.num_samples = num_samples
            # Converting to numpy array
            self.test_data = np.array(self.test_data)
            self.test_labels = np.array(self.test_labels)


        
    def pre_precossing(self):
        fp = open(self.img_filepath, 'r')
        self.img_filename = [x.strip() for x in fp]
        fp.close()
 
        self.img_filename = np.array(self.img_filename)
        self.label = np.loadtxt(self.label_filepath, dtype=np.int64)

        return self.img_filename, self.label

    def __getitem__(self, index):

        if self.split == 'train':
            img_path, target = self.train_data[index], self.train_labels[index]
        else:
            img_path, target = self.test_data[index], self.test_labels[index]
        
        img = Image.open(img_path)
        
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        
        return img, target 

    def __len__(self):
        if self.split == 'train':
            return len(self.train_labels)
        elif self.split == 'test':
            return len(self.test_labels)
        

class DatasetProcessingFlickrSingle(Dataset):
    def __init__(self, data_path, transform=None, seed=1):
        
        self.transform = transform
        self.num_classes = 38
        self.img_path = data_path + '/flickr/Flickr-25'
        #if self.split == 'train':
        self.img_filepath = data_path + '/flickr/database_image.txt'
        self.label_filepath = data_path + '/flickr/database_label.txt'
        
        self.train_data = []
        self.train_labels = []
        self.data, self.labels = self.pre_precossing()
        for idx, _ in enumerate(self.data):
            imgpath = self.data[idx]
            labels = self.labels[idx]
            if np.count_nonzero(labels) == 1:
                self.train_data.append(imgpath)
                self.train_labels.append(labels)
        num_samples = len(self.train_labels)
        self.num_samples = num_samples

        self.train_data = np.array(self.train_data)
        self.train_labels = np.array(self.train_labels)

        
        
    def pre_precossing(self):
        fp = open(self.img_filepath, 'r')
        self.img_filename = [x.strip() for x in fp]
        fp.close()
        self.img_filename = np.array(self.img_filename)
        self.label = np.loadtxt(self.label_filepath, dtype=np.int64)

        return self.img_filename, self.label

    def __getitem__(self, index):

        img_path, target = self.train_data[index], self.train_labels[index]
        img = Image.open(img_path)
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_labels)