dataloader for inference w/o gt

This commit is contained in:
Xuebin Qin 2022-07-17 13:17:27 -07:00
parent c1480958ac
commit 7c61e1876d
4 changed files with 22 additions and 10 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@ -274,7 +274,7 @@ class GOSDatasetCache(Dataset):
im_id = cached_dataset["im_name"][i]
print("im_path: ", im_path)
im = im_reader(im_path)
im, im_shp = im_preprocess(im,self.cache_size)
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
@ -285,12 +285,14 @@ class GOSDatasetCache(Dataset):
ims_pt_list.append(torch.unsqueeze(im,0))
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
gt = im_reader(self.dataset["gt_path"][i])
gt = np.zeros(im.shape[0:2])
if len(self.dataset["gt_path"])!=0:
gt = im_reader(self.dataset["gt_path"][i])
gt, gt_shp = gt_preprocess(gt,self.cache_size)
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
torch.save(gt,gt_cache_file)
cached_dataset["gt_path"][i] = gt_cache_file
# cached_dataset["gt_path"][i] = gt_cache_file
cached_dataset["gt_path"].append(gt_cache_file)
if(self.cache_boost):
gts_pt_list.append(torch.unsqueeze(gt,0))
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))

View File

@ -474,7 +474,10 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
mi = torch.min(pred_val)
pred_val = (pred_val-mi)/(ma-mi) # max = 1
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
else:
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
with torch.no_grad():
gt = torch.tensor(gt).cuda()
@ -648,10 +651,17 @@ if __name__ == "__main__":
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
### test your own dataset
dataset_demo = {"name": "your-dataset",
"im_dir": "../your-dataset/im",
"gt_dir": "",
"im_ext": ".jpg",
"gt_ext": "",
"cache_dir":"../your-dataset/cache"}
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
valid_datasets = [dataset_vd] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
hypar = {}
@ -662,7 +672,7 @@ if __name__ == "__main__":
## "valid": for validation and inferening,
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
## otherwise only accuracy will be calculated and no predictions will be saved
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
hypar["interm_sup"] = False ## in-dicate if activate intermediate feature supervision
if hypar["mode"] == "train":
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
@ -671,9 +681,9 @@ if __name__ == "__main__":
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
hypar["gt_encoder_model"] = ""
else: ## configure the segmentation output path and the to-be-used model weights path
hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
hypar["valid_out_dir"] = "../your-results/"##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
hypar["model_path"] = "../saved_models/IS-Net" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth"##"isnet.pth" ## name of the to-be-loaded weights
# if hypar["restore_model"]!="":
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])

BIN
saved_models/.DS_Store vendored

Binary file not shown.