diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py index a377597..54648d6 100644 --- a/IS-Net/train_valid_inference_main.py +++ b/IS-Net/train_valid_inference_main.py @@ -14,6 +14,8 @@ from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandom from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch, from models import * +device = 'cuda' if torch.cuda.is_available() else 'cpu' + def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): # train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, @@ -240,7 +242,7 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0): gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 with torch.no_grad(): - gt = torch.tensor(gt).cuda() + gt = torch.tensor(gt).to(device) pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) @@ -476,7 +478,7 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 with torch.no_grad(): - gt = torch.tensor(gt).cuda() + gt = torch.tensor(gt).to(device) pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)