import os
from copy import deepcopy
import random
import time
from copy import deepcopy

import numpy as np
from PIL import Image
from torchvision import datasets as datasets
import torch
from PIL import ImageDraw
from pycocotools.coco import COCO

import csv
import tarfile
from urllib.parse import urlparse
import torch.utils.data as data
from torch.utils.data.dataset import Dataset


def parse_args(parser):
    # parsing args
    args = parser.parse_args()
    if args.dataset_type == 'OpenImages':
        args.do_bottleneck_head = True
        if args.th == None:
            args.th = 0.995
    else:
        args.do_bottleneck_head = False
        if args.th == None:
            args.th = 0.7
    return args


def average_precision(output, target):
    epsilon = 1e-8

    # sort examples
    indices = output.argsort()[::-1]
    # Computes prec@i
    total_count_ = np.cumsum(np.ones((len(output), 1)))

    target_ = target[indices]
    ind = target_ == 1
    pos_count_ = np.cumsum(ind)
    total = pos_count_[-1]
    pos_count_[np.logical_not(ind)] = 0
    pp = pos_count_ / total_count_
    precision_at_i_ = np.sum(pp)
    precision_at_i = precision_at_i_ / (total + epsilon)

    return precision_at_i


def mAP(targs, preds):
    """Returns the model's average precision for each class
    Return:
        ap (FloatTensor): 1xK tensor, with avg precision for each class k
    """

    if np.size(preds) == 0:
        return 0
    ap = np.zeros((preds.shape[1]))
    # compute average precision for each class
    # print(targs)
    for k in range(preds.shape[1]):
        # sort scores
        scores = preds[:, k]
        targets = targs[:, k]
        # compute average precision
        ap[k] = average_precision(scores, targets)
    return 100 * ap.mean()


def uniform_mix_C(mixing_ratio, num_classes):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
        (1 - mixing_ratio) * np.eye(num_classes)

object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']

class AverageMeter(object):
    def __init__(self):
        self.val = None
        self.sum = None
        self.cnt = None
        self.avg = None
        self.ema = None
        self.initialized = False

    def update(self, val, n=1):
        if not self.initialized:
            self.initialize(val, n)
        else:
            self.add(val, n)

    def initialize(self, val, n):
        self.val = val
        self.sum = val * n
        self.cnt = n
        self.avg = val
        self.ema = val
        self.initialized = True

    def add(self, val, n):
        self.val = val
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt
        self.ema = self.ema * 0.99 + self.val * 0.01

def read_image_label(file):
    print('[dataset] read ' + file)
    data = dict()
    with open(file, 'r') as f:
        for line in f:
            tmp = line.split(' ')
            name = tmp[0]
            label = int(tmp[-1])
            data[name] = label
            # data.append([name, label])
            # print('%s  %d' % (name, label))
    return data

def read_object_labels(root, dataset, set):
    path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main')
    labeled_data = dict()
    num_classes = len(object_categories)

    for i in range(num_classes):
        file = os.path.join(path_labels, object_categories[i] + '_' + set + '.txt')
        data = read_image_label(file)

        if i == 0:
            for (name, label) in data.items():
                labels = np.zeros(num_classes)
                labels[i] = label
                labeled_data[name] = labels
        else:
            for (name, label) in data.items():
                labeled_data[name][i] = label

    return labeled_data

def write_object_labels_csv(file, labeled_data):
    # write a csv file
    print('[dataset] write file %s' % file)
    with open(file, 'w') as csvfile:
        fieldnames = ['name']
        fieldnames.extend(object_categories)
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for (name, labels) in labeled_data.items():
            example = {'name': name}
            for i in range(20):
                example[fieldnames[i + 1]] = int(labels[i])
            writer.writerow(example)

    csvfile.close()

def read_object_labels_csv(file, header=True):
    images = []
    num_categories = 0
    print('[dataset] read', file)
    with open(file, 'r') as f:
        reader = csv.reader(f)
        rownum = 0
        for row in reader:
            if header and rownum == 0:
                header = row
            else:
                if num_categories == 0:
                    num_categories = len(row) - 1
                name = row[0]
                labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32)
                
                for i in range(len(labels)):
                    # print(labels[i])
                    if labels[i] == -1.0:
                        labels[i] = 0.0
                labels = torch.from_numpy(labels)
                item = (name, labels)
                images.append(item)
            rownum += 1
    return images


def find_images_classification(root, dataset, set):
    path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main')
    images = []
    file = os.path.join(path_labels, set + '.txt')
    with open(file, 'r') as f:
        for line in f:
            images.append(line)
    return images


class Voc2007Classification(data.Dataset):
    def __init__(self, set, transform=None, target_transform=None, adj=None, corruption_prob=0.4, corruption_type='unif',
                 shuffle_indices=None, train=True):
        
        self.train = train
        self.corruption_prob = corruption_prob
        self.corruption_type = corruption_type
        
        if train == True:
            self.root = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/pascalvoc/VOCtrainval/'
        else:
            self.root = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/pascalvoc/VOCtest/'
        
        self.set = set
        self.path_devkit = os.path.join(self.root, 'VOCdevkit')
        self.path_images = os.path.join(self.root, 'VOCdevkit', 'VOC2007', 'JPEGImages')
        self.transform = transform
        self.target_transform = target_transform
        self.num_classes = len(object_categories)
        seed = 1234
        # download dataset
        # download_voc2007(self.root)

        # define path of csv file
        path_csv = os.path.join(self.root, 'files', 'VOC2007', self.corruption_type, str(self.corruption_prob))
        # define filename of csv file
        file_csv = os.path.join(path_csv, 'classification_' + set + '.csv')

        # create the csv file if necessary
        if not os.path.exists(file_csv):
            if not os.path.exists(path_csv):  # create dir if necessary
                os.makedirs(path_csv)
            # generate csv file
            labeled_data = read_object_labels(self.root, 'VOC2007', self.set)
            # write csv file
            write_object_labels_csv(file_csv, labeled_data)

        self.classes = object_categories
        self.images = read_object_labels_csv(file_csv)
        num_samples = len(self.images)
        # for index, _ in enumerate(self.ids):
        #     img_path, target = self.get_item_coco_numpy(index)
        #     # If we have multiple labels (we don't care about single-label samples)
        #     if np.count_nonzero(target) > 1:
        #         self.train_data.append(img_path)
        #         self.train_labels.append(target)
        
        # print("IMAGE:", self.images[51][1])
        # with open(inp_name, 'rb') as f:
        #     self.inp = pickle.load(f)
        # self.inp_name = inp_name

        print('[dataset] VOC 2007 classification set=%s number of classes=%d  number of images=%d' % (
            set, len(self.classes), len(self.images)))
        

        if self.train:

            self.train_data = []
            self.train_labels = []

            for i in range(len(self.images)):
                if np.count_nonzero(self.images[i][1].cpu().numpy()) > 0:
                    self.train_data.append(self.images[i][0])
                    self.train_labels.append(self.images[i][1].cpu().numpy())
            # print(self.train_labels[1])
            self.train_data = np.array(self.train_data)
            self.train_labels = np.array(self.train_labels)
            num_samples = len(self.train_data)
            
            if corruption_type == 'flip':
                raise Exception('Corruption type "flip" not implemeneted')
                # C = flip_labels_C(self.corruption_prob, self.num_classes)
                # self.C = C
            elif corruption_type == 'unif':
                C = uniform_mix_C(self.corruption_prob, self.num_classes)
                self.C = C

            else:
                assert False, "Invalid corruption type '{}' given. " \
                            "Must be in ['unif', 'flip']".format(corruption_type)
            self.C_true = np.zeros((self.num_classes, self.num_classes), dtype=np.float64)
            np.random.seed(seed)
            tmp = 0
            if corruption_type == 'unif':
                for i in range(len(self.train_labels)):
                    true_labels = np.nonzero(self.train_labels[i])[0]
                    for label in range(len(self.train_labels[i])):
                        if self.train_labels[i][label] == 1:
                            new_label = np.random.choice(self.num_classes, p=self.C[label])
                            # If the choice has been to corrupt this label
                            if new_label != label:
                                if label == 1:
                                    tmp += 1
                                # This ensures we are generating WRONG labels (not missing/weak labels)
                                while self.train_labels[i][new_label] == 1 or new_label in true_labels:
                                    new_label = np.random.choice(self.num_classes, p=self.C[label])
                                self.train_labels[i][label] = 0
                                self.train_labels[i][new_label] = 1
                                self.C_true[label][new_label] += 1
                            else:
                                self.C_true[label][label] += 1
                self.corruption_matrix = C
                self.C_true /= np.sum(self.C_true, axis=1)
            elif corruption_type == 'flip':
                raise RuntimeError("Not yet implemented")
            else:
                raise RuntimeError("Not yet implemented")
            
        else:

            self.test_data = []
            self.test_labels = []
            for i in range(len(self.images)):
                self.test_data.append(self.images[i][0])
                self.test_labels.append(self.images[i][1].cpu().numpy())
            
            self.test_data = np.array(self.test_data)
            self.test_labels = np.array(self.test_labels)
            

                

    def __getitem__(self, index):

        
        if self.train is True:
            path, target = self.train_data[index], self.train_labels[index]
        else:
            path, target = self.test_data[index], self.test_labels[index]
        # path, target = self.images[index]
        img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB')
        # print(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        # return img, target
        return img, torch.from_numpy(target)
        # return (img, path), target

    def __len__(self):
        if self.train is True:
            return len(self.train_labels)
        else:
            return len(self.test_labels)
        # return len(self.images)

    def get_number_classes(self):
        return len(self.classes)

class CocoDetection(datasets.coco.CocoDetection):
    def __init__(self, root, annFile, transform=None, target_transform=None):
        self.root = root
        self.coco = COCO(annFile)

        self.ids = list(self.coco.imgToAnns.keys())
        self.transform = transform
        self.target_transform = target_transform
        self.cat2cat = dict()
        for cat in self.coco.cats.keys():
            self.cat2cat[cat] = len(self.cat2cat)
        # print(self.cat2cat)

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        output = torch.zeros((3, 80), dtype=torch.long)
        for obj in target:
            if obj['area'] < 32 * 32:
                output[0][self.cat2cat[obj['category_id']]] = 1
            elif obj['area'] < 96 * 96:
                output[1][self.cat2cat[obj['category_id']]] = 1
            else:
                output[2][self.cat2cat[obj['category_id']]] = 1
        target = output

        path = coco.loadImgs(img_id)[0]['file_name']
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

class CocoDetectionNo(datasets.coco.CocoDetection):
    # THIS IS WHERE WE SHOULD INJECT NOISE
    # Equivalent to COCO class from load_corrupted_data from GLC/CIFAR repo
    def __init__(self, root, annFile, transform=None, target_transform=None,
                train=True, corruption_prob=0, corruption_type='unif', shuffle_indices=None, seed=1):
        self.root = root
        self.coco = COCO(annFile)

        self.ids = list(self.coco.imgToAnns.keys())
        self.transform = transform
        self.target_transform = target_transform
        self.cat2cat = dict()
        for cat in self.coco.cats.keys():
            self.cat2cat[cat] = len(self.cat2cat)

        self.train = train
        # self.gold = gold
        # self.gold_fraction = gold_fraction
        self.corruption_prob = corruption_prob
        self.num_classes = 80

        if self.train:
            self.train_data = []
            self.train_labels = []
            # train_coarse_labels were the superclass labels from CIFAR100
            # They were used for hierarchical noise injection
            # For the purpose of this project we ignore them for now
            # However, we strongly suggest this would be a very interesting experiment
            # self.train_coarse_labels = []

            for index, _ in enumerate(self.ids):
                img_path, target = self.get_item_coco_numpy(index)
                
                # If we have multiple labels (we don't care about single-label samples)
                if np.count_nonzero(target) > 0:
                    self.train_data.append(img_path)
                    self.train_labels.append(target)
            
            num_samples = len(self.train_labels)
            self.num_samples = num_samples

            # TODO: include these (if necessary) when iterating through DataLoader
            # self.train_data = np.concatenate(self.train_data)
            # self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            # self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC

            # Converting to numpy array
            self.train_data = np.array(self.train_data)
            self.train_labels = np.array(self.train_labels)


            if corruption_type == 'flip':
                raise Exception('Corruption type "flip" not implemeneted')
                # C = flip_labels_C(self.corruption_prob, self.num_classes)
                # self.C = C
            elif corruption_type == 'unif':
                C = uniform_mix_C(self.corruption_prob, self.num_classes)
                self.C = C

            else:
                assert False, "Invalid corruption type '{}' given. " \
                                "Must be in ['unif', 'flip']".format(corruption_type)
            self.C_true = np.zeros((self.num_classes, self.num_classes), dtype=np.float64)
            np.random.seed(seed)
            tmp = 0
            if corruption_type == 'unif' and corruption_prob > 0:
                for i in range(len(self.train_labels)):
                    true_labels = np.nonzero(self.train_labels[i])[0]
                    for label in range(len(self.train_labels[i])):
                        if self.train_labels[i][label] == 1:
                            new_label = np.random.choice(self.num_classes, p=self.C[label])
                            # If the choice has been to corrupt this label
                            if new_label != label:
                                if label == 1:
                                    tmp += 1
                                # This ensures we are generating WRONG labels (not missing/weak labels)
                                while self.train_labels[i][new_label] == 1 or new_label in true_labels:
                                    new_label = np.random.choice(self.num_classes, p=self.C[label])
                                self.train_labels[i][label] = 0
                                self.train_labels[i][new_label] = 1
                                self.C_true[label][new_label] += 1
                            else:
                                self.C_true[label][label] += 1
                self.corruption_matrix = C
                self.C_true /= np.sum(self.C_true, axis=1)
            elif corruption_type == 'flip':
                raise RuntimeError("Not yet implemented")
            

        else:

            self.test_data = []
            self.test_labels = []

            for index, _ in enumerate(self.ids):
                img_path, target = self.get_item_coco_numpy(index)
                # If we have multiple labels (we don't care about single-label samples)
                if np.count_nonzero(target) > 0:
                    self.test_data.append(img_path)
                    self.test_labels.append(target)

            num_samples = len(self.test_labels)
            self.num_samples = num_samples

            # Converting to numpy array
            self.test_data = np.array(self.test_data)
            self.test_labels = np.array(self.test_labels)

    def __getitem__(self, index):
        # Don't forget train_data only keeps the path of the image
        # Loading the image will occur now (like in the original ASL repo)

        if self.train:
            img_path, target = self.train_data[index], self.train_labels[index]
        else:
            img_path, target = self.test_data[index], self.test_labels[index]

        # Here we load the image
        img = Image.open(os.path.join(self.root, img_path)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, torch.from_numpy(target)

    def __len__(self):
        return self.num_samples

    # This method gives us numpy representation of img path and the corresponding label
    def get_item_coco_numpy(self, index):
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        output = np.zeros(self.num_classes)
        for obj in target:
            output[self.cat2cat[obj['category_id']]] = 1
        target = output

        path = coco.loadImgs(img_id)[0]['file_name']

        return path, target


class CocoDetectionM(datasets.coco.CocoDetection):
    # THIS IS WHERE WE SHOULD INJECT NOISE
    # Equivalent to COCO class from load_corrupted_data from GLC/CIFAR repo
    def __init__(self, root, annFile, transform=None, target_transform=None, train=True, missing_rate=0.30, shuffle_indices=None, seed=1):
        self.root = root
        self.coco = COCO(annFile)

        self.ids = list(self.coco.imgToAnns.keys())
        self.transform = transform
        self.target_transform = target_transform
        self.cat2cat = dict()
        for cat in self.coco.cats.keys():
            self.cat2cat[cat] = len(self.cat2cat)

        self.train = train

        self.missing_rate = missing_rate
        self.num_classes = 80

        if self.train:
            self.train_data = []
            self.train_labels = []

            for index, _ in enumerate(self.ids):
                img_path, target = self.get_item_coco_numpy(index)
                # If we have multiple labels (we don't care about single-label samples)
                if np.count_nonzero(target) > 0:
                    self.train_data.append(img_path)
                    self.train_labels.append(target)

            num_samples = len(self.train_labels)
            self.num_samples = num_samples

            # TODO: include these (if necessary) when iterating through DataLoader
            # self.train_data = np.concatenate(self.train_data)
            # self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            # self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC

            # Converting to numpy array
            self.train_data = np.array(self.train_data)
            self.train_labels = np.array(self.train_labels)

            indices = np.arange(len(self.train_data)) if shuffle_indices is None else shuffle_indices


            C = uniform_mix_C(self.missing_rate, self.num_classes)
            self.C = C
                #C_hundred = over_hundred_C(int(self.num_classes))
                #self.C_hundred = C_hundred
            
            self.C_true = np.zeros((self.num_classes, self.num_classes), dtype=np.float64)
            np.random.seed(seed)
            tmp = 0
            count_missing = 0
            count_positive = 0
            print("missing_rate_given:", self.missing_rate)
            if self.missing_rate < 1.0:
                for i in range(len(self.train_labels)):
                    true_labels = np.nonzero(self.train_labels[i])[0]
                    count_positive = count_positive + len(true_labels)
                    # positive_no = len(true_labels)
                    # number_missing = int(positive_no*self.missing_rate)
                    for label in range(len(self.train_labels[i])):
                        if self.train_labels[i][label] == 1:
                            new_label = np.random.choice(self.num_classes, p=self.C[label])
                            # If the choice has been to corrupt this label
                            if new_label != label:
                                # if label == 1:
                                #     tmp += 1
                                # # This ensures we are generating WRONG labels (not missing/weak labels)
                                # while self.train_labels[i][new_label] == 1 or new_label in true_labels:
                                #     new_label = np.random.choice(self.num_classes, p=self.C[label])
                                self.train_labels[i][label] = 0
                                # self.train_labels[i][new_label] = 1
                                self.C_true[label][new_label] += 1
                                count_missing += 1
                            else:
                                self.C_true[label][label] += 1
                self.corruption_matrix = C
                self.C_true /= np.sum(self.C_true, axis=1)
                print("Missing Rate: ", count_missing/count_positive)
            
            elif self.missing_rate == 1.0:
                for i in range(len(self.train_labels)):
                    true_labels = np.nonzero(self.train_labels[i])[0]
                    count_positive = count_positive + len(true_labels)
                    candidate_label = np.random.choice(true_labels)
                    for label in true_labels:
                        if label != candidate_label:
                            self.train_labels[i][label] = 0
                            count_missing += 1
                # d = 0
                # for i in range(len(self.train_labels)):
                #     true_labels = len(np.nonzero(self.train_labels[i])[0])
                #     d  += true_labels
                # print("compare: ", d, len(self.train_labels))
                print("Missing Rate: ", count_missing/count_positive)
            
            elif self.missing_rate == 2.0:
                for i in range(len(self.train_labels)):
                    true_labels = np.nonzero(self.train_labels[i])[0]
                    zero_labels = np.where(self.train_labels[i]==0)[0]
                    candidate_zero_label = np.random.choice(zero_labels)
                    count_positive = count_positive + len(true_labels)
                    candidate_label = np.random.choice(true_labels)
                    for label in true_labels:
                        self.train_labels[i][label] = 0
                        count_missing += 1
                    
                    self.train_labels[i][candidate_zero_label] = 1
                
                print("Missing Rate: ", count_missing/count_positive)

            

        else:

            self.test_data = []
            self.test_labels = []

            for index, _ in enumerate(self.ids):
                img_path, target = self.get_item_coco_numpy(index)
                # If we have multiple labels (we don't care about single-label samples)
                if np.count_nonzero(target) > 1:
                    self.test_data.append(img_path)
                    self.test_labels.append(target)

            num_samples = len(self.test_labels)
            self.num_samples = num_samples

            # Converting to numpy array
            self.test_data = np.array(self.test_data)
            self.test_labels = np.array(self.test_labels)

    # def corrupt_uniform(self, integer_part):
    #     counter=0
    #     while counter<integer_part:
    #         for example in range(len(self.train_labels)):
    #             number_true_label = np.count_nonzero(self.train_labels[example] == 1)
    #             indices = np.where(self.train_labels[example] == 0)[0]
    #             for j in range(number_true_label):
    #                 new_label = random.randint(0, len(indices)-1)
    #                 self.train_labels[example][indices[new_label]]=1
    #                 np.delete(indices, new_label)
    #         counter=counter+1
    #     float_part = self.corruption_prob - integer_part
    #     if float_part>0:
    #         for example in range(len(self.train_labels)):
    #             number_true_label = np.count_nonzero(self.train_labels[example]==1)
    #             indices = np.where(self.train_labels[example]==0)[0]
    #             for j in range(number_true_label):
    #                 rand = random.uniform(0, 1)
    #                 if rand< float_part:
    #                     new_label = random.randint(0, len(indices)-1)
    #                     self.train_labels[example][indices[new_label]] = 1
    #                     np.delete(indices, new_label)

    def __getitem__(self, index):
        # Don't forget train_data only keeps the path of the image
        # Loading the image will occur now (like in the original ASL repo)

        if self.train:
            img_path, target = self.train_data[index], self.train_labels[index]
        else:
            img_path, target = self.test_data[index], self.test_labels[index]

        # Here we load the image
        img = Image.open(os.path.join(self.root, img_path)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, torch.from_numpy(target)

    def __len__(self):
        return self.num_samples

    # This method gives us numpy representation of img path and the corresponding label
    def get_item_coco_numpy(self, index):
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        output = np.zeros(self.num_classes)
        for obj in target:
            output[self.cat2cat[obj['category_id']]] = 1
        target = output

        path = coco.loadImgs(img_id)[0]['file_name']

        return path, target

class COCO_missing_dataset(torch.utils.data.Dataset):
    def __init__(self, root, annFile, transform=None, target_transform=None,class_num:int = None):
        self.root = root
        with open(annFile, 'r') as f:
            names = f.readlines()
        #name = names.strip('\n').split(' ')
        self.name = names
        #self.label = name[:,1]
        self.transform = transform
        self.class_num = class_num
        self.target_transform = target_transform
        print('load class_nums = ',self.class_num)

    def __getitem__(self, index):
        name = self.name[index]
        path = name.strip('\n').split(',')[0]
        num = name.strip('\n').split(',')[1]
        num = num.strip(' ').split(' ')
        num = np.array([int(i) for i in num])
    #    print('load class_nums = ',self.class_num)
        label = np.zeros([self.class_num])
        label[num] = 1
        label = torch.tensor(label, dtype=torch.long)
        if os.path.exists(os.path.join(self.root, path))==False:
            label = np.zeros([self.class_num])
            label = torch.tensor(label, dtype=torch.long)
            img = np.zeros((448,448,3))
            img = Image.fromarray(np.uint8(img))
        else:
            img = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, label

    def __len__(self):
        return len(self.name)


class ModelEma(torch.nn.Module):
    def __init__(self, model, decay=0.9997, device=None):
        super(ModelEma, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)


class CutoutPIL(object):
    def __init__(self, cutout_factor=0.5):
        self.cutout_factor = cutout_factor

    def __call__(self, x):
        img_draw = ImageDraw.Draw(x)
        h, w = x.size[0], x.size[1]  # HWC
        h_cutout = int(self.cutout_factor * h + 0.5)
        w_cutout = int(self.cutout_factor * w + 0.5)
        y_c = np.random.randint(h)
        x_c = np.random.randint(w)

        y1 = np.clip(y_c - h_cutout // 2, 0, h)
        y2 = np.clip(y_c + h_cutout // 2, 0, h)
        x1 = np.clip(x_c - w_cutout // 2, 0, w)
        x2 = np.clip(x_c + w_cutout // 2, 0, w)
        fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        img_draw.rectangle([x1, y1, x2, y2], fill=fill_color)

        return x


def add_weight_decay(model, weight_decay=1e-4, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

def preprocess_flickr():
    label = []
    with open('../data/Flickr-25/test_label.txt','r') as f:
        for line in f:
            s = line.split()
            label.append(" ".join(str(y) for y in s))


class DatasetProcessingFlickr(Dataset):
    def __init__(self, split, 
                corruption_prob=0.0, corruption_type='unif', shuffle_indices=None, transform=None, seed=1):
        
        self.transform = transform
        self.split = split
        self.corruption_prob = corruption_prob
        self.num_classes = 38
        self.img_path = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/flickr/Flickr-25'
        if self.split == 'train':
            self.img_filepath = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/flickr/database_image.txt'
            self.label_filepath = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/flickr/database_label.txt'
        elif self.split == 'test':
            self.img_filepath = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/flickr/test_image.txt'
            self.label_filepath = '/media/masoud/DATA/masoud_data/nus_wide/robust_multi_label-main/flickr/test_label.txt'

        if self.split ==  'train':
            self.train_data = []
            self.train_labels = []
            self.data, self.labels = self.pre_precossing()
            for idx, _ in enumerate(self.data):
                
                imgpath = self.data[idx]
                labels = self.labels[idx]
                if np.count_nonzero(labels) > 1:
                    self.train_data.append(imgpath)
                    self.train_labels.append(labels)
            num_samples = len(self.train_labels)
            self.num_samples = num_samples

            self.train_data = np.array(self.train_data)
            self.train_labels = np.array(self.train_labels)

            
            if corruption_type == 'flip':
                raise Exception('Corruption type "flip" not implemeneted')
                # C = flip_labels_C(self.corruption_prob, self.num_classes)
                # self.C = C
            elif corruption_type == 'unif':
                C = uniform_mix_C(self.corruption_prob, self.num_classes)
                self.C = C

            else:
                assert False, "Invalid corruption type '{}' given. " \
                                "Must be in ['unif', 'flip']".format(corruption_type)
            self.C_true = np.zeros((self.num_classes, self.num_classes), dtype=np.float64)
            np.random.seed(seed)
            tmp = 0
            if corruption_type == 'unif':
                for i in range(len(self.train_labels)):
                    true_labels = np.nonzero(self.train_labels[i])[0]
                    for label in range(len(self.train_labels[i])):
                        if self.train_labels[i][label] == 1:
                            new_label = np.random.choice(self.num_classes, p=self.C[label])
                            # If the choice has been to corrupt this label
                            if new_label != label:
                                if label == 1:
                                    tmp += 1
                                # This ensures we are generating WRONG labels (not missing/weak labels)
                                while self.train_labels[i][new_label] == 1 or new_label in true_labels:
                                    new_label = np.random.choice(self.num_classes, p=self.C[label])
                                self.train_labels[i][label] = 0
                                self.train_labels[i][new_label] = 1
                                self.C_true[label][new_label] += 1
                            else:
                                self.C_true[label][label] += 1
                self.corruption_matrix = C
                self.C_true /= np.sum(self.C_true, axis=1)
            elif corruption_type == 'flip':
                raise RuntimeError("Not yet implemented")
            else:
                raise RuntimeError("Not yet implemented")
    
        elif self.split == 'test':
            print("Test data ...")
            self.test_data = []
            self.test_labels = []
            self.data, self.labels = self.pre_precossing()
            
            for idx, _ in enumerate(self.data):
                imgpath = self.data[idx]
                labels = self.labels[idx] 
                if np.count_nonzero(labels) > 0:
                    self.test_data.append(imgpath)
                    self.test_labels.append(labels)

            num_samples = len(self.test_labels)
            self.num_samples = num_samples
            # Converting to numpy array
            self.test_data = np.array(self.test_data)
            self.test_labels = np.array(self.test_labels)


        
    def pre_precossing(self):
        fp = open(self.img_filepath, 'r')
        self.img_filename = [x.strip() for x in fp]
        fp.close()
        #indices = np.arange(int(10))
        #print(np.array(self.img_filename)[indices])
        self.img_filename = np.array(self.img_filename)
        #print(self.img_filename[0:5])
        self.label = np.loadtxt(self.label_filepath, dtype=np.int64)

        return self.img_filename, self.label

    def __getitem__(self, index):
        # addr = self.img_path + self.img_filename[index]
        #a = os.path.join(self.img_path, self.img_filename[index])
        if self.split == 'train':
            img_path, target = self.train_data[index], self.train_labels[index]
        else:
            img_path, target = self.test_data[index], self.test_labels[index]
        # img = Image.open(os.path.join(self.img_path, self.img_filename[index]))
        img = Image.open(img_path)
        # img = Image.open(addr)
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        #label = torch.from_numpy(self.label[index])
        return img, target #, label, index

    def __len__(self):
        if self.split == 'train':
            return len(self.train_labels)
        elif self.split == 'test':
            return len(self.test_labels)
        # return len(self.img_filename)
