import os
import argparse
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from src.helper_functions.helper_functions import mAP, CocoDetection, CocoDetectionM, CocoDetectionNo, COCO_missing_dataset, CutoutPIL, ModelEma, add_weight_decay
from src.helper_functions.helper_functions import Voc2007Classification
from src.models import create_model
from src.loss_functions.losses import AsymmetricLoss, Hill, SPLC, SPLNC, SPLNCF, SPLNCT, SPLNCFC
from randaugment import RandAugment
from torch.cuda.amp import GradScaler, autocast
from measure import compute_metrics
torch.multiprocessing.set_sharing_strategy('file_system')

parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
parser.add_argument('--dataset', help='select dataset', default='./dataset/coco_train_0.4left.txt')
parser.add_argument('--data', metavar='DIR', help='path to dataset', default='/home/MSCOCO_2014/')
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--model-name', default='resnet50',choices=['resnet50','resnet101'])
parser.add_argument('--num-classes', default=20)
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                    help='number of data loading workers (default: 16)')
parser.add_argument('--image-size', default=448, type=int,
                    metavar='N', help='input image size (default: 448)')
parser.add_argument('--thre', default=0.5, type=float,
                    metavar='N', help='threshold value')
parser.add_argument('-b', '--batch-size', default=32, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--print-freq', '-p', default=64, type=int,
                    metavar='N', help='print frequency (default: 64)')
parser.add_argument('--loss', default='Hill', type=str,
                    help='select loss function', choices=['BCE','Focal','Asl','Hill','SPLC', 'SPLNC', 'SPLNCF', 'SPLNCT', 'SPLNCFC'])
parser.add_argument('--missing-rate', '-mrate', type=float, default=0.4, help='The label missing rate.')
parser.add_argument('--corruption_prob', '-pnoise', type=float, default=0.4, help='The label noise rate.')
parser.add_argument('--corruption_type', '-ctype', type=str, default='unif',
                    help='Type of corruption ("unif" or "flip").')
parser.add_argument('--seed', type=int, default=1)

def main():
    args = parser.parse_args()
    args.do_bottleneck_head = False

    # Setup model
    print('creating model...')
    model = create_model(args).cuda()
    print('done\n')

    if torch.cuda.device_count() > 1:
        device_id = range(torch.cuda.device_count())      
        model = nn.DataParallel(model, device_ids = device_id)

    # COCO Data loading
    instances_path_val = os.path.join(args.data, 'annotations/instances_val2014.json')
    instances_path_train = os.path.join(args.data, 'annotations/instances_train2014.json')
    # instances_path_train = args.dataset
    
    data_path_val   = f'{args.data}/val2014'    # args.data
    data_path_train = f'{args.data}/train2014'  # args.data
    transform = transforms.Compose([
                                transforms.Resize((args.image_size, args.image_size)),
                                transforms.ToTensor(),
                            ])
    val_dataset = Voc2007Classification('test', transform=transform, train=False)
    # val_dataset = CocoDetection(data_path_val,
    #                             instances_path_val,
    #                             transforms.Compose([
    #                                 transforms.Resize((args.image_size, args.image_size)),
    #                                 transforms.ToTensor(),
    #                                 # normalize, # no need, toTensor does normalization
    #                             ]))
    # train_dataset = CocoDetectionM(data_path_train,
    #                             instances_path_train,
    #                             transforms.Compose([
    #                                 transforms.Resize((args.image_size, args.image_size)),
    #                                 CutoutPIL(cutout_factor=0.5),
    #                                 RandAugment(),
    #                                 transforms.ToTensor(),
    #                                 # normalize,
    #                             ]),
    #                             train=True,
    #                             missing_rate=args.missing_rate
    #                             )
    # train_dataset = CocoDetectionNo(root=data_path_train,
    #                                     annFile=instances_path_train,
    #                                     transform=transforms.Compose([
    #                                     transforms.Resize((args.image_size, args.image_size)),
    #                                     CutoutPIL(cutout_factor=0.5),
    #                                     RandAugment(),
    #                                     transforms.ToTensor(),
    #                                     # normalize,
    #                                     ]),
    #                                     train=True,
    #                                     corruption_prob=args.corruption_prob,
    #                                     corruption_type=args.corruption_type,
    #                                     seed=args.seed)

    train_dataset = Voc2007Classification('trainval',
                                        transform=transform,
                                        corruption_prob=args.corruption_prob,
                                        corruption_type=args.corruption_type,
                                        train=True)
    
    train_dataset_clean = Voc2007Classification('trainval',
                                        transform=transform,
                                        corruption_prob=0.0,
                                        corruption_type=args.corruption_type,
                                        train=True)

    # train_dataset_clean = CocoDetectionNo(root=data_path_train,
    #                                     annFile=instances_path_train,
    #                                     transform=transforms.Compose([
    #                                     transforms.Resize((args.image_size, args.image_size)),
    #                                     CutoutPIL(cutout_factor=0.5),
    #                                     RandAugment(),
    #                                     transforms.ToTensor(),
    #                                     # normalize,
    #                                     ]),
    #                                     train=True,
    #                                     corruption_prob=0.0,
    #                                     corruption_type=args.corruption_type,
    #                                     seed=args.seed)

    # train_dataset = COCO_missing_dataset(data_path_train,
    #                             instances_path_train,
    #                             transforms.Compose([
    #                                 transforms.Resize((args.image_size, args.image_size)),
    #                                 CutoutPIL(cutout_factor=0.5),
    #                                 RandAugment(),
    #                                 transforms.ToTensor(),
    #                                 # normalize,
    #                             ]),class_num=args.num_classes)
    print("len(val_dataset)): ", len(val_dataset))
    print("len(train_dataset)): ", len(train_dataset))

    # Pytorch Data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    train_loader_clean = torch.utils.data.DataLoader(
        train_dataset_clean, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    # Actuall Training
    train_multi_label_coco(args, model, train_loader, val_loader, args.lr)
    # train_multi_label_coco_wfpfn(args, model, train_loader, val_loader, train_loader_clean, args.lr)

def train_multi_label_coco(args, model, train_loader, val_loader, lr):
    ema = ModelEma(model, 0.9997)  # 0.9997^641=0.82

    # set optimizer
    Epochs = 80
    Stop_epoch = 80
    weight_decay = 1e-4
    if args.loss == 'BCE':
        criterion = AsymmetricLoss(gamma_neg=0, gamma_pos=0, clip=0)
    elif args.loss == 'Focal':
        criterion = AsymmetricLoss(gamma_neg=2, gamma_pos=2, clip=0)
    elif args.loss == 'Asl':
        criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05)
    elif args.loss == 'Hill':
        criterion = Hill()
    elif args.loss == 'SPLC':
        criterion = SPLC()
    elif args.loss == 'SPLNC':
        criterion = SPLNC()
    elif args.loss == 'SPLNCF':
        criterion = SPLNCF()
    elif args.loss == 'SPLNCT':
        criterion = SPLNCT()
    else:
        raise ValueError("Loss function dose not exist.")
    parameters = add_weight_decay(model, weight_decay)
    optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0)  # true wd, filter_bias_and_bn
    steps_per_epoch = len(train_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs,
                                        pct_start=0.2)

    highest_mAP = 0
    trainInfoList = []
    scaler = GradScaler()
    for epoch in range(Epochs):
        if epoch > Stop_epoch:
            break
        for i, (inputData, target ) in enumerate(train_loader):
            inputData = inputData.cuda()
            target = target.cuda()  # (batch,3,num_classes)
            # target = target.max(dim=1)[0]
            with autocast():  # mixed precision
                output = model(inputData).float()  # sigmoid will be done in loss !
            if args.loss == 'SPLC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCF' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCT' :
                loss = criterion(output, target, epoch)
            else:
                loss = criterion(output, target)
            model.zero_grad()

            scaler.scale(loss).backward()
            # loss.backward()

            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()

            scheduler.step()

            ema.update(model)
            # store information
            if i % 100 == 0:
                trainInfoList.append([epoch, i, loss.item()])
                print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}'
                      .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3),
                              scheduler.get_last_lr()[0], \
                              loss.item()))

        try:
            torch.save(model.state_dict(), os.path.join(
                'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
        except:
            pass

        model.eval()
        mAP_score, if_ema_better = validate_multi(val_loader, model, ema)
        dict_res = validate_multi_metrics(val_loader, model)
        model.train()
        if mAP_score > highest_mAP:
            highest_mAP = mAP_score
            best_epoch = epoch
            try:
                if if_ema_better:
                    torch.save(ema.module.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
                else:
                    torch.save(model.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
            except:
                print('store failed')
                pass
        print('current_mAP = {:.2f}, highest_mAP = {:.2f}, best_epoch={}\n'.format(mAP_score, highest_mAP, best_epoch))
        print('current metrics ===> ', dict_res)

def train_multi_label_coco_wfpfn(args, model, train_loader, val_loader, train_loader_clean, lr):
    ema = ModelEma(model, 0.9997)  # 0.9997^641=0.82

    # set optimizer
    Epochs = 80
    Stop_epoch = 80
    weight_decay = 1e-4
    if args.loss == 'BCE':
        criterion = AsymmetricLoss(gamma_neg=0, gamma_pos=0, clip=0)
    elif args.loss == 'Focal':
        criterion = AsymmetricLoss(gamma_neg=2, gamma_pos=2, clip=0)
    elif args.loss == 'Asl':
        criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05)
    elif args.loss == 'Hill':
        criterion = Hill()
    elif args.loss == 'SPLC':
        criterion = SPLC()
    elif args.loss == 'SPLNC':
        criterion = SPLNC()
    elif args.loss == 'SPLNCF':
        criterion = SPLNCF()
    elif args.loss == 'SPLNCT':
        criterion = SPLNCT()
    elif args.loss == 'SPLNCFC':
        criterion = SPLNCFC()
    else:
        raise ValueError("Loss function dose not exist.")
    parameters = add_weight_decay(model, weight_decay)
    optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0)  # true wd, filter_bias_and_bn
    steps_per_epoch = len(train_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs,
                                        pct_start=0.2)

    highest_mAP = 0
    trainInfoList = []
    scaler = GradScaler()
    for epoch in range(Epochs):
        if epoch > Stop_epoch:
            break
        data_clean_iter = iter(train_loader_clean)

        for i, (inputData, target ) in enumerate(train_loader):
            try:
                inputData_c, target_c = next(data_clean_iter)
            except StopIteration:
                data_clean_iter = iter(train_loader_clean)
                inputData_c, target_c = next(data_clean_iter)

            inputData = inputData.cuda()
            target = target.cuda()  # (batch,3,num_classes)
            target_c = target_c.cuda()
            
            # target = target.max(dim=1)[0]
            with autocast():  # mixed precision
                output = model(inputData).float()  # sigmoid will be done in loss !
            if args.loss == 'SPLC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNC' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCF' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCT' :
                loss = criterion(output, target, epoch)
            elif args.loss == 'SPLNCFC' :
                loss = criterion(output, target, target_c, epoch)
            else:
                loss = criterion(output, target)
            model.zero_grad()

            scaler.scale(loss).backward()
            # loss.backward()

            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()

            scheduler.step()

            ema.update(model)
            # store information
            if i % 100 == 0:
                trainInfoList.append([epoch, i, loss.item()])
                print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}'
                      .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3),
                              scheduler.get_last_lr()[0], \
                              loss.item()))

        try:
            torch.save(model.state_dict(), os.path.join(
                'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
        except:
            pass

        model.eval()
        mAP_score, if_ema_better = validate_multi(val_loader, model, ema)
        dict_res = validate_multi_metrics(val_loader, model)
        model.train()
        if mAP_score > highest_mAP:
            highest_mAP = mAP_score
            best_epoch = epoch
            try:
                if if_ema_better:
                    torch.save(ema.module.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
                else:
                    torch.save(model.state_dict(), os.path.join(
                            'models/', 'model-highest.ckpt'))
            except:
                print('store failed')
                pass
        print('current_mAP = {:.2f}, highest_mAP = {:.2f}, best_epoch={}\n'.format(mAP_score, highest_mAP, best_epoch))
        print('current metrics ===> ', dict_res)

def validate_multi_metrics(val_loader, model):
    print("starting validation for different metrics")
    Sig = torch.nn.Sigmoid()
    preds_regular = []
    targets = []
    for i, (input_data, target) in enumerate(val_loader):
        # target = target
        # target = target.max(dim=1)[0]
        with torch.no_grad():
            with autocast():
                output_regular = Sig(model(input_data.cuda())).cpu()

        preds_regular.append(output_regular.cpu().detach())
        targets.append(target.cpu().detach())

    mAP_score_regular = mAP(torch.cat(targets).numpy(), torch.cat(preds_regular).numpy())
    dict_results = compute_metrics(torch.cat(preds_regular).numpy(), torch.cat(targets).numpy(), 0.5)
    
    return dict_results

def validate_multi(val_loader, model, ema_model):
    print("starting validation")
    Sig = torch.nn.Sigmoid()
    preds_regular = []
    preds_ema = []
    targets = []
    for i, (input, target) in enumerate(val_loader):
        # target = target
        # target = target.max(dim=1)[0]
        # compute output
        with torch.no_grad():
            with autocast():
                output_regular = Sig(model(input.cuda())).cpu()
                output_ema = Sig(ema_model.module(input.cuda())).cpu()

        # for mAP calculation
        preds_regular.append(output_regular.cpu().detach())
        preds_ema.append(output_ema.cpu().detach())
        targets.append(target.cpu().detach())

    mAP_score_regular = mAP(torch.cat(targets).numpy(), torch.cat(preds_regular).numpy())
    mAP_score_ema = mAP(torch.cat(targets).numpy(), torch.cat(preds_ema).numpy())
    print("mAP score regular {:.2f}, mAP score EMA {:.2f}".format(mAP_score_regular, mAP_score_ema))
    mAP_max = max(mAP_score_regular, mAP_score_ema)
    if mAP_score_ema >= mAP_score_regular:
        if_ema_better = True
    else:
        if_ema_better = False

    return mAP_max, if_ema_better


if __name__ == '__main__':
    main()
