import torch
import config_Att
import time
from config_Att import args_setting
from dataset import RoadSequenceDataset, RoadSequenceDatasetList
from model_att import generate_model
from torchvision import transforms
from torch.optim import lr_scheduler
import torch.nn as nn
import os
from radam import RAdam
import UNet_TwoConvGRU
import numpy as np
import cv2

def train(args, epoch, model, train_loader, device, optimizer, criterion):
    since = time.time()
    model.train()
    for batch_idx,  sample_batched in enumerate(train_loader):
        data, target = sample_batched['data'].to(device), sample_batched['label'].type(torch.LongTensor).to(device) # LongTensor # modify
        #data, target = sample_batched['data'].cuda(non_blocking=True), sample_batched['label'].type(torch.LongTensor).cuda(non_blocking=True) # LongTensor # modify
        optimizer.zero_grad()
        output, aux = model(data) # output  = model(data) changed #change back to [output  = model(data) for Segnet] from [output, aux = model(data) for Unet-lstm; 
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        #scheduler.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    time_elapsed = time.time() - since
    #scheduler.step() ### Add here
    print('Train Epoch: {} complete in {:.0f}m {:.0f}s'.format(epoch,
        time_elapsed // 60, time_elapsed % 60))

def val(args, model, val_loader, device, criterion, best_acc, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for sample_batched in val_loader:
            data, target = sample_batched['data'].to(device), sample_batched['label'].type(torch.LongTensor).to(device)
            output, aux = model(data) # output,_ = model(data) for Unetlstm change to output = model(data) for gegnet
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= (len(val_loader.dataset)/args.test_batch_size)
    val_acc = 100. * int(correct) / (len(val_loader.dataset) * config_Att.label_height * config_Att.label_width)
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.5f}%)\n'.format(
        test_loss, int(correct), len(val_loader.dataset), val_acc))
    #torch.save(model.state_dict(), './model/%s.pth'%val_acc) #/scratch/ '"$TMPDIR"/LaneDetectionCode/save/%s.pth'
    torch.save(model.state_dict(), './model/model_Layer2/{}/{}_{}_lr{}_batch{}_epoch{}.pth'.format(args.model,val_acc,str(optimizer)[0:3],args.lr,args.batch_size,epoch))



def get_parameters(model, layer_name):
    import torch.nn as nn
    modules_skipped = (
        nn.ReLU,
        nn.MaxPool2d,
        nn.Dropout2d,
        nn.UpsamplingBilinear2d
    )
    for name, module in model.named_children():
        if name in layer_name:
            for layer in module.children():
                if isinstance(layer, modules_skipped):
                    continue
                else:
                    for parma in layer.parameters():
                        yield parma

def evaluate_model(model, test_loader, device, criterion, epoch):
    model.eval()
    i = 0
    precision = 0.0
    recall = 0.0
    test_loss = 0
    correct = 0
    error=0
    with torch.no_grad():
        for sample_batched in test_loader:
            i+=1
            data, target = sample_batched['data'].to(device), sample_batched['label'].type(torch.LongTensor).to(device)
            output , aux = model(data) #
            pred = output.max(1, keepdim=True)[1]  # 返回两个，一个是最大值另一个是最大值索引
            img = torch.squeeze(pred).cpu().numpy()*255
            lab = torch.squeeze(target).cpu().numpy()*255
            img = img.astype(np.uint8)
            lab = lab.astype(np.uint8)
            kernel = np.uint8(np.ones((3, 3)))

            #accuracy
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

            #precision,recall,f1
            label_precision = cv2.dilate(lab, kernel)
            pred_recall = cv2.dilate(img, kernel)
            img = img.astype(np.int32)
            lab = lab.astype(np.int32)
            label_precision = label_precision.astype(np.int32)
            pred_recall = pred_recall.astype(np.int32)
            a = len(np.nonzero(img*label_precision)[1])
            b = len(np.nonzero(img)[1])
            if b==0:
                error=error+1
                continue
            else:
                precision += float(a/b)
            c = len(np.nonzero(pred_recall*lab)[1])
            d = len(np.nonzero(lab)[1])
            if d==0:
                error = error + 1
                continue
            else:
                recall += float(c / d)
            F1_measure=(2*precision*recall)/(precision+recall)

    test_loss /= (len(test_loader.dataset) / args.test_batch_size)
    test_acc = 100. * int(correct) / (len(test_loader.dataset) * config_Att.label_height * config_Att.label_width)
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.5f}%)'.format(
        test_loss, int(correct), len(test_loader.dataset), test_acc))
    #torch.save(model.state_dict(), './model/%s.pth'%test_acc) #/scratch/ '"$TMPDIR"/LaneDetectionCode/save/%s.pth'
    #torch.save(model.state_dict(), './model/model_Layer2/{}/{}_{}_lr{}_batch{}_epoch{}.pth'.format(args.model,test_acc,str(optimizer)[0:3],args.lr,args.batch_size,epoch))


    precision = precision / (len(test_loader.dataset) - error)
    recall = recall / (len(test_loader.dataset) - error)
    F1_measure = F1_measure / (len(test_loader.dataset) - error)
    print('Precision: {:.5f}, Recall: {:.5f}, F1_measure: {:.5f}\n'.format(precision,recall,F1_measure))
    evaluate_result = {'precision': precision, 'recall': recall, 'F1_measure': F1_measure, 'test_acc':test_acc}
    return evaluate_result

#from torch.utils.tensorboard import SummaryWriter

if __name__ == '__main__':
    #writer = SummaryWriter()
    print(os.getcwd())
    args = args_setting()
    torch.manual_seed(args.seed)
    use_cuda = args.cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    print(torch.cuda.is_available())

    # turn image into floatTensor
    op_tranforms = transforms.Compose([transforms.ToTensor()])

    # load data for batches, num_workers for multiprocess
    if args.model == 'SegNet-ConvLSTM' or args.model == 'UNet-ConvLSTM' or args.model =='SCNN_SegNet_ConvLSTM' or args.model =='SCNNV_UNetLight_ConvGRU' or args.model =='SCNN_UNet_ConvLSTM' or args.model == 'UNet_TwoConvGRU' or args.model == 'SCNN_UNet_ConvGRU' or args.model == 'SCNN_SegNet_ConvGRU' or args.model == 'SCNN_UNetLight_ConvGRU' or args.model == 'SCNN_UNetLight_ConvLSTM' or args.model == 'SCNNV_UNetLight_ConvLSTM' or args.model == 'SCNNV_UNet_ConvLSTM' or args.model == 'SCNNV2_UNetLight_ConvLSTM':
        train_loader = torch.utils.data.DataLoader(
            RoadSequenceDatasetList(file_path=config_Att.train_path, transforms=op_tranforms),
            batch_size=args.batch_size,shuffle=True,num_workers=config_Att.data_loader_numworkers)
        val_loader = torch.utils.data.DataLoader(
            RoadSequenceDatasetList(file_path=config_Att.val_path, transforms=op_tranforms),
            batch_size=args.test_batch_size,shuffle=True,num_workers=config_Att.data_loader_numworkers)
        test_loader=torch.utils.data.DataLoader(
            RoadSequenceDatasetList(file_path=config_Att.test_path, transforms=op_tranforms),
            batch_size=1, shuffle=False, num_workers=config_Att.data_loader_numworkers) #num_workers change to 0
    else:
        train_loader = torch.utils.data.DataLoader(
            RoadSequenceDataset(file_path=config_Att.train_path, transforms=op_tranforms),
            batch_size=args.batch_size, shuffle=True, num_workers=config_Att.data_loader_numworkers)
        val_loader = torch.utils.data.DataLoader(
            RoadSequenceDataset(file_path=config_Att.val_path, transforms=op_tranforms),
            batch_size=args.test_batch_size, shuffle=True, num_workers=config_Att.data_loader_numworkers)
        test_loader = torch.utils.data.DataLoader(
            RoadSequenceDataset(file_path=config_Att.test_path, transforms=op_tranforms),
            batch_size=1, shuffle=False, num_workers=config_Att.data_loader_numworkers)

    #load model
    model = generate_model(args)
    model.cuda() #Add for parallel
    model = nn.DataParallel(model) #Add for parallel
    
    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) #change optimizer note
    # optimizer = torch.optim.Adam([
    #     {'params': get_parameters(model, layer_name=["inc", "down1", "down2", "down3", "down4"]), 'lr': args.lr * 0.0},
    #     {'params': get_parameters(model, layer_name=["outc", "up1", "up2", "up3", "up4"]), 'lr': args.lr * 0.1},
    #     {'params': get_parameters(model, layer_name=["convlstm"]), 'lr': args.lr * 1},
    # ], lr=args.lr)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    # optimizer = torch.optim.SGD([
    #     {'params': get_parameters(model, layer_name=["conv1_block", "conv2_block", "conv3_block", "conv4_block", "conv5_block"]), 'lr': args.lr * 0.5},
    #     {'params': get_parameters(model, layer_name=["upconv5_block", "upconv4_block", "upconv3_block", "upconv2_block", "upconv1_block"]), 'lr': args.lr * 0.33},
    #     {'params': get_parameters(model, layer_name=["Conv3D_block"]), 'lr': args.lr * 0.5},
    # ], lr=args.lr,momentum=0.9)

    scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9) #0.5 -》0.9 step_size=1 -->2
    class_weight = torch.Tensor(config_Att.class_weight)
    criterion = torch.nn.CrossEntropyLoss(weight=class_weight).to(device)
    best_acc = 0
    
    
    pretrained_dict = torch.load('./model/model_Layer2/UNet-ConvLSTM/98.00694539380919_SGD_lr0.047239200000000016_batch72_epoch9.pth')  #[]#add map_location ='cpu'  , map_location ='cpu'
    model_dict = model.state_dict()

    pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if (k in model_dict)}
    model_dict.update(pretrained_dict_1) #
    model.load_state_dict(model_dict)
    

    # train
    for epoch in range(1, args.epochs+1):
        #scheduler.step() # change order
        print('lr---------', scheduler.get_last_lr())
        #scheduler.step() # change order
        train(args, epoch, model, train_loader, device, optimizer, criterion)
        if scheduler.get_lr()[0] > 0.000000001:#0.0000001
            scheduler.step()
        else:
            print('lr----no--change--------')
        print('lr---------', scheduler.get_last_lr())
        val(args, model, val_loader, device, criterion, best_acc,epoch)

        result = evaluate_model(model, test_loader, device, criterion, epoch)
        #writer.add_scalar("Loss/train", result['test_acc'], epoch)#
        if result['F1_measure'] > best_acc:
            best_acc = result['F1_measure']
            best_name='__test_acc=%s'%result['test_acc']  + '__precision=%s'%result['precision']  + '__recall=%s'%result['recall']  + '__F1_measure=%s'%result['F1_measure'] + '_epoch=%s'%epoch + '_'
            #copyfile(val_pth_name, best_name)
            print('best testing-------------', best_name)
            print('test acc-------------',  result['test_acc'])
            print('precision-----------', result['precision'])
            print('recall-----------', result['recall'])
            print('F1_measure-----------', result['F1_measure'])
        elif result['F1_measure'] > 0.905:
            current_name='__test_acc=%s'%result['test_acc']  + '__precision=%s'%result['precision']  + '__recall=%s'%result['recall']  + '__F1_measure=%s'%result['F1_measure'] + '_epoch=%s'%epoch + '_'
            #copyfile(val_pth_name, current_name)
            #print('current testing-------------', val_pth_name)
            print('current testing evaluation F1 beat 0.905-------------', current_name)
    #writer.flush()
