Merge pull request #6 from deshwalmahesh/main

gt.cuda() creating error for CPU. Change to .to(device)
This commit is contained in:
Xuebin Qin 2022-07-30 09:27:04 -07:00 committed by GitHub
commit 7bf8a3be2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 395 additions and 2 deletions

391
Colab_Demo.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,8 @@ from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandom
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
from models import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
@ -240,7 +242,7 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
with torch.no_grad():
gt = torch.tensor(gt).cuda()
gt = torch.tensor(gt).to(device)
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
@ -479,7 +481,7 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
else:
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
with torch.no_grad():
gt = torch.tensor(gt).cuda()
gt = torch.tensor(gt).to(device)
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)