from operator import mod
import os
import argparse
from matplotlib.pyplot import axis
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from src.helper_functions.helper_functions import mAP, CocoDetection, CocoDetectionM, CocoDetectionNo, COCO_missing_dataset, CutoutPIL, ModelEma, add_weight_decay
from src.helper_functions.helper_functions import Voc2007Classification
from src.helper_functions.helper_functions import DatasetProcessingFlickr
from src.models import create_model
from src.loss_functions.losses import SPLNCB, AsymmetricLoss, Hill, SPLC, SPLNC, SPLNCF, SPLNCT, SPLNCFC, SPCBCM, SPLNCFW, GLC, SPBCAG
from randaugment import RandAugment
from torch.cuda.amp import GradScaler, autocast
from measure import compute_metrics
import numpy as np
torch.multiprocessing.set_sharing_strategy('file_system')

parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
parser.add_argument('--dataset', help='select dataset', default='./dataset/coco_train_0.4left.txt')
parser.add_argument('--data', metavar='DIR', help='path to dataset', default='/home/MSCOCO_2014/')
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--model-name', default='resnet50',choices=['resnet50','resnet101'])
parser.add_argument('--num-classes', default=38)
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                    help='number of data loading workers (default: 16)')
parser.add_argument('--image-size', default=224, type=int,
                    metavar='N', help='input image size (default: 448)')
parser.add_argument('--thre', default=0.5, type=float,
                    metavar='N', help='threshold value')
parser.add_argument('-b', '--batch-size', default=32, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--print-freq', '-p', default=64, type=int,
                    metavar='N', help='print frequency (default: 64)')
parser.add_argument('--loss', default='Hill', type=str,
                    help='select loss function', choices=['BCE','Focal','Asl','Hill','SPLC', 'SPLNC', 'SPLNCF', 'SPLNCFW', 'SPLNCT', 'SPLNCFC','SPCBCM', 'SPLNCB', 'GLC', 'SPBCAG'])
parser.add_argument('--missing-rate', '-mrate', type=float, default=0.4, help='The label missing rate.')
parser.add_argument('--corruption_prob', '-pnoise', type=float, default=0.4, help='The label noise rate.')
parser.add_argument('--corruption_type', '-ctype', type=str, default='unif',
                    help='Type of corruption ("unif" or "flip").')
parser.add_argument('--seed', type=int, default=1)

def main():
    args = parser.parse_args()
    args.do_bottleneck_head = False

    # Setup model
    print('creating model...')
    model = create_model(args).cuda()
    print('done\n')

    if torch.cuda.device_count() > 1:
        device_id = range(torch.cuda.device_count())      
        model = nn.DataParallel(model, device_ids = device_id)

    normalize = transforms.Normalize(mean=[0.38933693 , 0.42026136,  0.43608778], std=[0.23933473,  0.23539561,  0.24641008])
    transformations = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    augment1 = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

    augment2 = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])
    num_classes = 38
    avg_labels_per_image = 2.9

    val_dataset = DatasetProcessingFlickr('test', transform=transformations)

    # train_dataset = CocoDetectionM(data_path_train,
    #                             instances_path_train,
    #                             transforms.Compose([
    #                                 transforms.Resize((args.image_size, args.image_size)),
    #                                 CutoutPIL(cutout_factor=0.5),
    #                                 RandAugment(),
    #                                 transforms.ToTensor(),
    #                                 # normalize,
    #                             ]),
    #                             train=True,
    #                             missing_rate=args.missing_rate
    #                             )
    train_dataset = DatasetProcessingFlickr('train',
                                        corruption_prob=args.corruption_prob,
                                        corruption_type=args.corruption_type,
                                        seed=args.seed,
                                        transform=transformations)
    train_aug_dataset = DatasetProcessingFlickr('train',
                                        corruption_prob=args.corruption_prob,
                                        corruption_type=args.corruption_type,
                                        seed=args.seed,
                                        transform=augment1)

    train_dataset_clean = DatasetProcessingFlickr('train',
                                        corruption_prob=0.0,
                                        corruption_type=args.corruption_type,
                                        seed=args.seed,
                                        transform=transformations)

    # train_dataset = COCO_missing_dataset(data_path_train,
    #                             instances_path_train,
    #                             transforms.Compose([
    #                                 transforms.Resize((args.image_size, args.image_size)),
    #                                 CutoutPIL(cutout_factor=0.5),
    #                                 RandAugment(),
    #                                 transforms.ToTensor(),
    #                                 # normalize,
    #                             ]),class_num=args.num_classes)
    print("len(val_dataset)): ", len(val_dataset))
    print("len(train_dataset)): ", len(train_dataset))

    # Pytorch Data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    train_loader_clean = torch.utils.data.DataLoader(
        train_dataset_clean, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    train_loader_aug = torch.utils.data.DataLoader(
        train_aug_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    
    # fp_fn_calc(train_loader, train_loader_clean, num_classes)
    # exit()
    # Actuall Training
    # train_multi_label_coco(args, model, train_loader, val_loader, args.lr)
    # train_multi_label_coco_wfpfn(args, model, train_loader, val_loader, train_loader_clean, args.lr)
    train_multi_label_coco_aug(args, model, train_loader, val_loader, train_loader_clean, train_loader_aug, args.lr)

def train_multi_label_coco(args, model, train_loader, val_loader, lr):
    ema = ModelEma(model, 0.9997)  # 0.9997^641=0.82

    # set optimizer
    Epochs = 80
    Stop_epoch = 80
    weight_decay = 1e-4
    if args.loss == 'BCE':
        criterion = AsymmetricLoss(gamma_neg=0, gamma_pos=0, clip=0)
    elif args.loss == 'Focal':
        criterion = AsymmetricLoss(gamma_neg=2, gamma_pos=2, clip=0)
    elif args.loss == 'Asl':
        criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05)
    elif args.loss == 'Hill':
        criterion = Hill()
    elif args.loss == 'SPLC':
        criterion = SPLC()
    elif args.loss == 'SPLNC':
        criterion = SPLNC()
    elif args.loss == 'SPLNCF':
        criterion = SPLNCF()
    elif args.loss == 'SPLNCT':
        criterion = SPLNCT()
    elif args.loss == 'SPLNCB':
        criterion = SPLNCB()
    elif args.loss == 'SPCBCM':
        criterion = SPCBCM()
    elif args.loss == 'SPLNCFW':
        criterion = SPLNCFW()
    elif args.loss == 'GLC':
        criterion = GLC()
    else:
        raise ValueError("Loss function dose not exist.")
    parameters = add_weight_decay(model, weight_decay)
    optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0)  # true wd, filter_bias_and_bn
    steps_per_epoch = len(train_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs,
                                        pct_start=0.2)

    highest_mAP = 0
    trainInfoList = []
    scaler = GradScaler()
    for epoch in range(Epochs):
        if epoch > Stop_epoch:
            break
        for i, (inputData, target ) in enumerate(train_loader):
            inputData = inputData.cuda()
            target = target.cuda()  # (batch,3,num_classes)
            # target = target.max(dim=1)[0]
            with autocast():  # mixed precision
                output = model(inputData).float()  # sigmoid will be done in loss !

            if args.loss == 'SPLC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCF' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCT' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPCBCM' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCB' :
                loss = criterion(output, target, epoch)
            elif args.loss =='SPLNCFW':
                loss = criterion(output, target, epoch)

            else:
                loss = criterion(output, target)
            model.zero_grad()

            scaler.scale(loss).backward()
            # loss.backward()

            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()

            scheduler.step()

            ema.update(model)
            # store information
            if i % 100 == 0:
                trainInfoList.append([epoch, i, loss.item()])
                print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}'
                      .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3),
                              scheduler.get_last_lr()[0], \
                              loss.item()))

        try:
            torch.save(model.state_dict(), os.path.join(
                'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
        except:
            pass

        model.eval()
        mAP_score, if_ema_better = validate_multi(val_loader, model, ema)
        dict_res = validate_multi_metrics(val_loader, model)
        model.train()
        if mAP_score > highest_mAP:
            highest_mAP = mAP_score
            best_epoch = epoch
            try:
                if if_ema_better:
                    torch.save(ema.module.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
                else:
                    torch.save(model.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
            except:
                print('store failed')
                pass
        print('current_mAP = {:.2f}, highest_mAP = {:.2f}, best_epoch={}\n'.format(mAP_score, highest_mAP, best_epoch))
        print('current metrics ===> ', dict_res)


def train_multi_label_coco_wfpfn(args, model, train_loader, val_loader, train_loader_clean, lr):
    ema = ModelEma(model, 0.9997)  # 0.9997^641=0.82

    # set optimizer
    Epochs = 80
    Stop_epoch = 80
    weight_decay = 1e-4
    if args.loss == 'BCE':
        criterion = AsymmetricLoss(gamma_neg=0, gamma_pos=0, clip=0)
    elif args.loss == 'Focal':
        criterion = AsymmetricLoss(gamma_neg=2, gamma_pos=2, clip=0)
    elif args.loss == 'Asl':
        criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05)
    elif args.loss == 'Hill':
        criterion = Hill()
    elif args.loss == 'SPLC':
        criterion = SPLC()
    elif args.loss == 'SPLNC':
        criterion = SPLNC()
    elif args.loss == 'SPLNCF':
        criterion = SPLNCF()
    elif args.loss == 'SPLNCT':
        criterion = SPLNCT()
    elif args.loss == 'SPLNCB':
        criterion = SPLNCB()
    elif args.loss == 'SPLNCFC':
        criterion = SPLNCFC()
    elif args.loss == 'GLC':
        criterion = GLC(gamma_neg=0, gamma_pos=0, clip=0)
    else:
        raise ValueError("Loss function dose not exist.")
    parameters = add_weight_decay(model, weight_decay)
    optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0)  # true wd, filter_bias_and_bn
    steps_per_epoch = len(train_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs,
                                        pct_start=0.2)

    highest_mAP = 0
    trainInfoList = []
    scaler = GradScaler()
    list_fp = []
    list_fn = []
    for epoch in range(Epochs):
        if epoch > Stop_epoch:
            break
        data_clean_iter = iter(train_loader_clean)
        fp_total_arr = np.zeros(args.num_classes)
        fn_total_arr = np.zeros(args.num_classes)
        for i, (inputData, target ) in enumerate(train_loader):
            try:
                inputData_c, target_c = next(data_clean_iter)
            except StopIteration:
                data_clean_iter = iter(train_loader_clean)
                inputData_c, target_c = next(data_clean_iter)

            inputData = inputData.cuda()
            target = target.cuda()  # (batch,3,num_classes)
            target_c = target_c.cuda()
            
            # target = target.max(dim=1)[0]
            with autocast():  # mixed precision
                output = model(inputData).float()  # sigmoid will be done in loss !
            if args.loss == 'SPLC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCF' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCT' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCB' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCFC' :
                loss = criterion(output, target, target_c, epoch)
            else:
                loss = criterion(output, target)
            
            logits_f = torch.where(target == 1, output-1.0, output)
            logits_f = logits_f.float()
            tau = 0.6
            tau_p = 0.6
            predict = torch.where(torch.sigmoid(logits_f) > tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= tau_p, torch.tensor(1.).cuda(), target.float())
            
            fp_n, fn_n = fnfp_calc(predict, target_c, args.num_classes)
            fp_total_arr += fp_n
            fn_total_arr += fn_n

            model.zero_grad()

            scaler.scale(loss).backward()
            # loss.backward()

            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()

            scheduler.step()

            ema.update(model)
            # store information
            if i % 100 == 0:
                trainInfoList.append([epoch, i, loss.item()])
                print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}'
                      .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3),
                              scheduler.get_last_lr()[0], \
                              loss.item()))
        list_fp.append(fp_total_arr)
        list_fn.append(fn_total_arr)

        
        
        try:
            torch.save(model.state_dict(), os.path.join(
                'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
        except:
            pass

        model.eval()
        mAP_score, if_ema_better = validate_multi(val_loader, model, ema)
        dict_res = validate_multi_metrics(val_loader, model)
        model.train()
        if mAP_score > highest_mAP:
            highest_mAP = mAP_score
            best_epoch = epoch
            try:
                if if_ema_better:
                    torch.save(ema.module.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
                else:
                    torch.save(model.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
            except:
                print('store failed')
                pass
        print('current_mAP = {:.2f}, highest_mAP = {:.2f}, best_epoch={}\n'.format(mAP_score, highest_mAP, best_epoch))
        print('current metrics ===> ', dict_res)

    list_fp = np.array(list_fp)
    list_fn = np.array(list_fn)
    np.save('flickr_list_fp_'+ args.loss+'_noise'+ str(args.corruption_prob)+'_tau'+str(tau)+'_taup' + str(tau_p) +'.npy', list_fp)
    np.save('flickr_list_fn_'+ args.loss+'_noise'+ str(args.corruption_prob)+'_tau'+str(tau)+'_taup' + str(tau_p) +'.npy', list_fn)
    
    #total_fp_per_class = np.sum(list_fp, axis=1)


    total_fp = np.sum(list_fp, axis=1)
    total_fn = np.sum(list_fn, axis=1)
    print(total_fp, total_fn) #107018

def train_multi_label_coco_aug(args, model, train_loader, val_loader, train_loader_clean, train_aug_loader, lr):
    ema = ModelEma(model, 0.9997)  # 0.9997^641=0.82

    # set optimizer
    Epochs = 80
    Stop_epoch = 80
    weight_decay = 1e-4
    if args.loss == 'BCE':
        criterion = AsymmetricLoss(gamma_neg=0, gamma_pos=0, clip=0)
    elif args.loss == 'Focal':
        criterion = AsymmetricLoss(gamma_neg=2, gamma_pos=2, clip=0)
    elif args.loss == 'Asl':
        criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05)
    elif args.loss == 'Hill':
        criterion = Hill()
    elif args.loss == 'SPLC':
        criterion = SPLC()
    elif args.loss == 'SPLNC':
        criterion = SPLNC()
    elif args.loss == 'SPLNCF':
        criterion = SPLNCF()
    elif args.loss == 'SPLNCT':
        criterion = SPLNCT()
    elif args.loss == 'SPLNCB':
        criterion = SPLNCB()
    elif args.loss == 'SPLNCFC':
        criterion = SPLNCFC()
    elif args.loss == 'SPBCAG':
        criterion = SPBCAG()
    elif args.loss == 'GLC':
        criterion = GLC(gamma_neg=0, gamma_pos=0, clip=0)
    else:
        raise ValueError("Loss function dose not exist.")
    parameters = add_weight_decay(model, weight_decay)
    optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0)  # true wd, filter_bias_and_bn
    steps_per_epoch = len(train_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs,
                                        pct_start=0.2)

    highest_mAP = 0
    trainInfoList = []
    scaler = GradScaler()
    list_fp = []
    list_fn = []
    for epoch in range(Epochs):
        if epoch > Stop_epoch:
            break
        data_clean_iter = iter(train_loader_clean)
        data_aug_iter = iter(train_aug_loader)

        fp_total_arr = np.zeros(args.num_classes)
        fn_total_arr = np.zeros(args.num_classes)
        for i, (inputData, target) in enumerate(train_loader):
            try:
                inputData_c, target_c = next(data_clean_iter)
                inputData_aug, target_aug = next(data_aug_iter)
            except StopIteration:
                data_clean_iter = iter(train_loader_clean)
                data_aug_iter = iter(train_aug_loader)
                inputData_c, target_c = next(data_clean_iter)
                inputData_aug, target_aug = next(data_aug_iter)

            inputData = inputData.cuda()
            inputData_aug = inputData_aug.cuda()
            target = target.cuda()  # (batch,3,num_classes)
            target_c = target_c.cuda()
            
            # target = target.max(dim=1)[0]
            with autocast():  # mixed precision
                output = model(inputData).float()  # sigmoid will be done in loss !
                output_aug = model(inputData_aug).float()
            
            if args.loss == 'SPLC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCF' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCT' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCB' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCFC' :
                loss = criterion(output, target, target_c, epoch)
            elif args.loss == 'SPBCAG':
                loss = criterion(output, output_aug, target, epoch)
            else:
                loss = criterion(output, target)
            
            logits_f = torch.where(target == 1, output-1.0, output)
            logits_f = logits_f.float()
            tau = 0.6
            tau_p = 0.6
            predict = torch.where(torch.sigmoid(logits_f) > tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= tau_p, torch.tensor(1.).cuda(), target.float())
            
            fp_n, fn_n = fnfp_calc(predict, target_c, args.num_classes)
            fp_total_arr += fp_n
            fn_total_arr += fn_n

            model.zero_grad()

            scaler.scale(loss).backward()
            # loss.backward()

            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()

            scheduler.step()

            ema.update(model)
            # store information
            if i % 100 == 0:
                trainInfoList.append([epoch, i, loss.item()])
                print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}'
                      .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3),
                              scheduler.get_last_lr()[0], \
                              loss.item()))
        list_fp.append(fp_total_arr)
        list_fn.append(fn_total_arr)

        
        
        try:
            torch.save(model.state_dict(), os.path.join(
                'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
        except:
            pass

        model.eval()
        mAP_score, if_ema_better = validate_multi(val_loader, model, ema)
        dict_res = validate_multi_metrics(val_loader, model)
        model.train()
        if mAP_score > highest_mAP:
            highest_mAP = mAP_score
            best_epoch = epoch
            try:
                if if_ema_better:
                    torch.save(ema.module.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
                else:
                    torch.save(model.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
            except:
                print('store failed')
                pass
        print('current_mAP = {:.2f}, highest_mAP = {:.2f}, best_epoch={}\n'.format(mAP_score, highest_mAP, best_epoch))
        print('current metrics ===> ', dict_res)

    list_fp = np.array(list_fp)
    list_fn = np.array(list_fn)
    np.save('flickr_list_fp_'+ args.loss+'_noise'+ str(args.corruption_prob)+'_tau'+str(tau)+'_taup' + str(tau_p) +'.npy', list_fp)
    np.save('flickr_list_fn_'+ args.loss+'_noise'+ str(args.corruption_prob)+'_tau'+str(tau)+'_taup' + str(tau_p) +'.npy', list_fn)
    
    #total_fp_per_class = np.sum(list_fp, axis=1)


    total_fp = np.sum(list_fp, axis=1)
    total_fn = np.sum(list_fn, axis=1)
    print(total_fp, total_fn) #107018


def validate_multi_metrics(val_loader, model):
    print("starting validation for different metrics")
    Sig = torch.nn.Sigmoid()
    preds_regular = []
    targets = []
    for i, (input_data, target) in enumerate(val_loader):
        # target = target
        # target = target.max(dim=1)[0]
        with torch.no_grad():
            with autocast():
                output_regular = Sig(model(input_data.cuda())).cpu()

        preds_regular.append(output_regular.cpu().detach())
        targets.append(target.cpu().detach())

    mAP_score_regular = mAP(torch.cat(targets).numpy(), torch.cat(preds_regular).numpy())
    dict_results = compute_metrics(torch.cat(preds_regular).numpy(), torch.cat(targets).numpy(), 0.5)
    
    return dict_results

def validate_multi(val_loader, model, ema_model):
    print("starting validation")
    Sig = torch.nn.Sigmoid()
    preds_regular = []
    preds_ema = []
    targets = []
    for i, (input, target) in enumerate(val_loader):
        # target = target
        # target = target.max(dim=1)[0]
        # compute output
        with torch.no_grad():
            with autocast():
                output_regular = Sig(model(input.cuda())).cpu()
                output_ema = Sig(ema_model.module(input.cuda())).cpu()

        # for mAP calculation
        preds_regular.append(output_regular.cpu().detach())
        preds_ema.append(output_ema.cpu().detach())
        targets.append(target.cpu().detach())

    mAP_score_regular = mAP(torch.cat(targets).numpy(), torch.cat(preds_regular).numpy())
    mAP_score_ema = mAP(torch.cat(targets).numpy(), torch.cat(preds_ema).numpy())
    print("mAP score regular {:.2f}, mAP score EMA {:.2f}".format(mAP_score_regular, mAP_score_ema))
    mAP_max = max(mAP_score_regular, mAP_score_ema)
    if mAP_score_ema >= mAP_score_regular:
        if_ema_better = True
    else:
        if_ema_better = False

    return mAP_max, if_ema_better

def fp_fn_calc(data, clean_data, num_classes):
    fp = np.zeros(num_classes)
    fn = np.zeros(num_classes)
    data_clean_iter = iter(clean_data)
    l = 0
    number_labels = 0
    result = {}
    for i, (inputData, target ) in enumerate(data):
        try:
            inputData_c, target_c = next(data_clean_iter)
        except StopIteration:
            data_clean_iter = iter(clean_data)
            inputData_c, target_c = next(data_clean_iter)

        # target = target.cuda()  # (batch,3,num_classes)
        # target_c = target_c.cuda()
        # print(target)
        # print(target_c)
        target = target.cpu().numpy()
        target_c = target_c.cpu().numpy()
        l += len(target)
        number_labels += np.count_nonzero(target)
        fp = fp + np.sum(np.logical_not(target_c) * target, axis=0).astype('float32')
        fn = fn + np.sum(target_c * np.logical_not(target), axis=0).astype('float32')

    total_fp = np.sum(fp)
    total_fn = np.sum(fn)

    fp_ratio = total_fp / (number_labels)
    fn_ratio = total_fn / (number_labels)

    result['fp'] = fp
    result['fn'] = fn
    result['total_fp'] = total_fp
    result['total_fn'] = total_fn
    result['fp_ratio'] = fp_ratio
    result['fn_ratio'] = fn_ratio
    result['one_labels'] = number_labels

    print(result)
    return result
    # print(fp, fn)
    # print(total_fn, total_fp)
    # print(fp_ratio, fn_ratio)

def fnfp_calc(target, target_c, num_classes):
    fp = np.zeros(num_classes)
    fn = np.zeros(num_classes)
    # result = {}

    target = target.cpu().numpy()
    target_c = target_c.cpu().numpy()
    
    fp = fp + np.sum(np.logical_not(target_c) * target, axis=0).astype('float32')
    fn = fn + np.sum(target_c * np.logical_not(target), axis=0).astype('float32')

    # total_fp = np.sum(fp)
    # total_fn = np.sum(fn)

    
    return fp, fn


if __name__ == '__main__':
    main()
