From 8836d5b4f47ae74b4063befae685bd22a1131d48 Mon Sep 17 00:00:00 2001 From: Mahesh Deshwal Date: Sun, 17 Jul 2022 14:52:08 +0530 Subject: [PATCH] gt.cuda() creating error for CPU. Change to .to(device) --- IS-Net/train_valid_inference_main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py index a377597..54648d6 100644 --- a/IS-Net/train_valid_inference_main.py +++ b/IS-Net/train_valid_inference_main.py @@ -14,6 +14,8 @@ from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandom from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch, from models import * +device = 'cuda' if torch.cuda.is_available() else 'cpu' + def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): # train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, @@ -240,7 +242,7 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0): gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 with torch.no_grad(): - gt = torch.tensor(gt).cuda() + gt = torch.tensor(gt).to(device) pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) @@ -476,7 +478,7 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 with torch.no_grad(): - gt = torch.tensor(gt).cuda() + gt = torch.tensor(gt).to(device) pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)