mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
update data_loader_cache.py
This commit is contained in:
parent
362f4e9699
commit
6e807f8d10
@ -34,8 +34,10 @@ def get_im_gt_name_dict(datasets, flag='valid'):
|
|||||||
|
|
||||||
if(datasets[i]["gt_dir"]==""):
|
if(datasets[i]["gt_dir"]==""):
|
||||||
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
||||||
|
tmp_gt_list = []
|
||||||
else:
|
else:
|
||||||
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
||||||
|
|
||||||
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
||||||
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
||||||
|
|
||||||
@ -105,7 +107,7 @@ def im_preprocess(im,size):
|
|||||||
im = im[:, :, np.newaxis]
|
im = im[:, :, np.newaxis]
|
||||||
if im.shape[2] == 1:
|
if im.shape[2] == 1:
|
||||||
im = np.repeat(im, 3, axis=2)
|
im = np.repeat(im, 3, axis=2)
|
||||||
im_tensor = torch.tensor(im, dtype=torch.float32)
|
im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
|
||||||
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
||||||
if(len(size)<2):
|
if(len(size)<2):
|
||||||
return im_tensor, im.shape[0:2]
|
return im_tensor, im.shape[0:2]
|
||||||
@ -256,7 +258,7 @@ class GOSDatasetCache(Dataset):
|
|||||||
|
|
||||||
def manage_cache(self,dataset_names):
|
def manage_cache(self,dataset_names):
|
||||||
if not os.path.exists(self.cache_path): # create the folder for cache
|
if not os.path.exists(self.cache_path): # create the folder for cache
|
||||||
os.mkdir(self.cache_path)
|
os.makedirs(self.cache_path)
|
||||||
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
||||||
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
||||||
return self.cache(cache_folder)
|
return self.cache(cache_folder)
|
||||||
@ -272,7 +274,6 @@ class GOSDatasetCache(Dataset):
|
|||||||
gts_pt_list = []
|
gts_pt_list = []
|
||||||
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
||||||
|
|
||||||
|
|
||||||
im_id = cached_dataset["im_name"][i]
|
im_id = cached_dataset["im_name"][i]
|
||||||
print("im_path: ", im_path)
|
print("im_path: ", im_path)
|
||||||
im = im_reader(im_path)
|
im = im_reader(im_path)
|
||||||
@ -291,8 +292,10 @@ class GOSDatasetCache(Dataset):
|
|||||||
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
||||||
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
||||||
torch.save(gt,gt_cache_file)
|
torch.save(gt,gt_cache_file)
|
||||||
cached_dataset["gt_path"][i] = gt_cache_file
|
if len(self.dataset["gt_path"])>0:
|
||||||
#cached_dataset["gt_path"].append(gt_cache_file)
|
cached_dataset["gt_path"][i] = gt_cache_file
|
||||||
|
else:
|
||||||
|
cached_dataset["gt_path"].append(gt_cache_file)
|
||||||
if(self.cache_boost):
|
if(self.cache_boost):
|
||||||
gts_pt_list.append(torch.unsqueeze(gt,0))
|
gts_pt_list.append(torch.unsqueeze(gt,0))
|
||||||
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
||||||
|
Loading…
Reference in New Issue
Block a user