mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-29 18:23:08 +01:00
dataloader for inference w/o gt
This commit is contained in:
parent
c1480958ac
commit
7c61e1876d
@ -274,7 +274,7 @@ class GOSDatasetCache(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
im_id = cached_dataset["im_name"][i]
|
im_id = cached_dataset["im_name"][i]
|
||||||
|
print("im_path: ", im_path)
|
||||||
im = im_reader(im_path)
|
im = im_reader(im_path)
|
||||||
im, im_shp = im_preprocess(im,self.cache_size)
|
im, im_shp = im_preprocess(im,self.cache_size)
|
||||||
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
||||||
@ -285,12 +285,14 @@ class GOSDatasetCache(Dataset):
|
|||||||
ims_pt_list.append(torch.unsqueeze(im,0))
|
ims_pt_list.append(torch.unsqueeze(im,0))
|
||||||
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
||||||
|
|
||||||
|
gt = np.zeros(im.shape[0:2])
|
||||||
gt = im_reader(self.dataset["gt_path"][i])
|
if len(self.dataset["gt_path"])!=0:
|
||||||
|
gt = im_reader(self.dataset["gt_path"][i])
|
||||||
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
||||||
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
||||||
torch.save(gt,gt_cache_file)
|
torch.save(gt,gt_cache_file)
|
||||||
cached_dataset["gt_path"][i] = gt_cache_file
|
# cached_dataset["gt_path"][i] = gt_cache_file
|
||||||
|
cached_dataset["gt_path"].append(gt_cache_file)
|
||||||
if(self.cache_boost):
|
if(self.cache_boost):
|
||||||
gts_pt_list.append(torch.unsqueeze(gt,0))
|
gts_pt_list.append(torch.unsqueeze(gt,0))
|
||||||
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
||||||
|
@ -474,7 +474,10 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
|||||||
mi = torch.min(pred_val)
|
mi = torch.min(pred_val)
|
||||||
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
||||||
|
|
||||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
||||||
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
else:
|
||||||
|
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gt = torch.tensor(gt).cuda()
|
gt = torch.tensor(gt).cuda()
|
||||||
|
|
||||||
@ -648,10 +651,17 @@ if __name__ == "__main__":
|
|||||||
"im_ext": ".jpg",
|
"im_ext": ".jpg",
|
||||||
"gt_ext": ".png",
|
"gt_ext": ".png",
|
||||||
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
||||||
|
### test your own dataset
|
||||||
|
dataset_demo = {"name": "your-dataset",
|
||||||
|
"im_dir": "../your-dataset/im",
|
||||||
|
"gt_dir": "",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": "",
|
||||||
|
"cache_dir":"../your-dataset/cache"}
|
||||||
|
|
||||||
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
||||||
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
||||||
valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
valid_datasets = [dataset_vd] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
||||||
|
|
||||||
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
||||||
hypar = {}
|
hypar = {}
|
||||||
@ -662,7 +672,7 @@ if __name__ == "__main__":
|
|||||||
## "valid": for validation and inferening,
|
## "valid": for validation and inferening,
|
||||||
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
||||||
## otherwise only accuracy will be calculated and no predictions will be saved
|
## otherwise only accuracy will be calculated and no predictions will be saved
|
||||||
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
|
hypar["interm_sup"] = False ## in-dicate if activate intermediate feature supervision
|
||||||
|
|
||||||
if hypar["mode"] == "train":
|
if hypar["mode"] == "train":
|
||||||
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
||||||
@ -671,9 +681,9 @@ if __name__ == "__main__":
|
|||||||
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
||||||
hypar["gt_encoder_model"] = ""
|
hypar["gt_encoder_model"] = ""
|
||||||
else: ## configure the segmentation output path and the to-be-used model weights path
|
else: ## configure the segmentation output path and the to-be-used model weights path
|
||||||
hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
hypar["valid_out_dir"] = "../your-results/"##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
||||||
hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path
|
hypar["model_path"] = "../saved_models/IS-Net" ## load trained weights from this path
|
||||||
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
|
hypar["restore_model"] = "isnet.pth"##"isnet.pth" ## name of the to-be-loaded weights
|
||||||
|
|
||||||
# if hypar["restore_model"]!="":
|
# if hypar["restore_model"]!="":
|
||||||
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
||||||
|
BIN
saved_models/.DS_Store
vendored
BIN
saved_models/.DS_Store
vendored
Binary file not shown.
Loading…
Reference in New Issue
Block a user