mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
dataloader for inference w/o gt
This commit is contained in:
parent
c1480958ac
commit
7c61e1876d
@ -274,7 +274,7 @@ class GOSDatasetCache(Dataset):
|
||||
|
||||
|
||||
im_id = cached_dataset["im_name"][i]
|
||||
|
||||
print("im_path: ", im_path)
|
||||
im = im_reader(im_path)
|
||||
im, im_shp = im_preprocess(im,self.cache_size)
|
||||
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
||||
@ -285,12 +285,14 @@ class GOSDatasetCache(Dataset):
|
||||
ims_pt_list.append(torch.unsqueeze(im,0))
|
||||
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
||||
|
||||
|
||||
gt = im_reader(self.dataset["gt_path"][i])
|
||||
gt = np.zeros(im.shape[0:2])
|
||||
if len(self.dataset["gt_path"])!=0:
|
||||
gt = im_reader(self.dataset["gt_path"][i])
|
||||
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
||||
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
||||
torch.save(gt,gt_cache_file)
|
||||
cached_dataset["gt_path"][i] = gt_cache_file
|
||||
# cached_dataset["gt_path"][i] = gt_cache_file
|
||||
cached_dataset["gt_path"].append(gt_cache_file)
|
||||
if(self.cache_boost):
|
||||
gts_pt_list.append(torch.unsqueeze(gt,0))
|
||||
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
||||
|
@ -474,7 +474,10 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
mi = torch.min(pred_val)
|
||||
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
||||
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
else:
|
||||
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
||||
with torch.no_grad():
|
||||
gt = torch.tensor(gt).cuda()
|
||||
|
||||
@ -648,10 +651,17 @@ if __name__ == "__main__":
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
||||
### test your own dataset
|
||||
dataset_demo = {"name": "your-dataset",
|
||||
"im_dir": "../your-dataset/im",
|
||||
"gt_dir": "",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": "",
|
||||
"cache_dir":"../your-dataset/cache"}
|
||||
|
||||
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
||||
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
||||
valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
||||
valid_datasets = [dataset_vd] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
||||
|
||||
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
||||
hypar = {}
|
||||
@ -662,7 +672,7 @@ if __name__ == "__main__":
|
||||
## "valid": for validation and inferening,
|
||||
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
||||
## otherwise only accuracy will be calculated and no predictions will be saved
|
||||
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
|
||||
hypar["interm_sup"] = False ## in-dicate if activate intermediate feature supervision
|
||||
|
||||
if hypar["mode"] == "train":
|
||||
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
||||
@ -671,9 +681,9 @@ if __name__ == "__main__":
|
||||
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
||||
hypar["gt_encoder_model"] = ""
|
||||
else: ## configure the segmentation output path and the to-be-used model weights path
|
||||
hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
||||
hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path
|
||||
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
|
||||
hypar["valid_out_dir"] = "../your-results/"##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
||||
hypar["model_path"] = "../saved_models/IS-Net" ## load trained weights from this path
|
||||
hypar["restore_model"] = "isnet.pth"##"isnet.pth" ## name of the to-be-loaded weights
|
||||
|
||||
# if hypar["restore_model"]!="":
|
||||
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
||||
|
BIN
saved_models/.DS_Store
vendored
BIN
saved_models/.DS_Store
vendored
Binary file not shown.
Loading…
Reference in New Issue
Block a user