mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
gt.cuda() creating error for CPU. Change to .to(device)
This commit is contained in:
parent
a802503bee
commit
8836d5b4f4
@ -14,6 +14,8 @@ from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandom
|
||||
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
|
||||
from models import *
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
||||
|
||||
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
||||
@ -240,7 +242,7 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
with torch.no_grad():
|
||||
gt = torch.tensor(gt).cuda()
|
||||
gt = torch.tensor(gt).to(device)
|
||||
|
||||
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
||||
|
||||
@ -476,7 +478,7 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
with torch.no_grad():
|
||||
gt = torch.tensor(gt).cuda()
|
||||
gt = torch.tensor(gt).to(device)
|
||||
|
||||
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user