gt.cuda() creating error for CPU. Change to .to(device)

This commit is contained in:
Mahesh Deshwal 2022-07-17 14:52:08 +05:30
parent a802503bee
commit 8836d5b4f4

View File

@ -14,6 +14,8 @@ from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandom
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch, from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
from models import * from models import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, # train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
@ -240,7 +242,7 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
with torch.no_grad(): with torch.no_grad():
gt = torch.tensor(gt).cuda() gt = torch.tensor(gt).to(device)
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
@ -476,7 +478,7 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
with torch.no_grad(): with torch.no_grad():
gt = torch.tensor(gt).cuda() gt = torch.tensor(gt).to(device)
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)