From 6e807f8d109fb9477faf2ead72ffeedbf65bca20 Mon Sep 17 00:00:00 2001 From: PiggyJerry <1261686183@qq.com> Date: Wed, 27 Jul 2022 23:42:02 +0400 Subject: [PATCH] update data_loader_cache.py --- IS-Net/data_loader_cache.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/IS-Net/data_loader_cache.py b/IS-Net/data_loader_cache.py index e7e3a13..6ba9dd5 100644 --- a/IS-Net/data_loader_cache.py +++ b/IS-Net/data_loader_cache.py @@ -34,8 +34,10 @@ def get_im_gt_name_dict(datasets, flag='valid'): if(datasets[i]["gt_dir"]==""): print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') + tmp_gt_list = [] else: tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] + # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) @@ -105,7 +107,7 @@ def im_preprocess(im,size): im = im[:, :, np.newaxis] if im.shape[2] == 1: im = np.repeat(im, 3, axis=2) - im_tensor = torch.tensor(im, dtype=torch.float32) + im_tensor = torch.tensor(im.copy(), dtype=torch.float32) im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1) if(len(size)<2): return im_tensor, im.shape[0:2] @@ -256,7 +258,7 @@ class GOSDatasetCache(Dataset): def manage_cache(self,dataset_names): if not os.path.exists(self.cache_path): # create the folder for cache - os.mkdir(self.cache_path) + os.makedirs(self.cache_path) cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size])) if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache return self.cache(cache_folder) @@ -272,7 +274,6 @@ class GOSDatasetCache(Dataset): gts_pt_list = [] for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])): - im_id = cached_dataset["im_name"][i] print("im_path: ", im_path) im = im_reader(im_path) @@ -291,8 +292,10 @@ class GOSDatasetCache(Dataset): gt, gt_shp = gt_preprocess(gt,self.cache_size) gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt") torch.save(gt,gt_cache_file) - cached_dataset["gt_path"][i] = gt_cache_file - #cached_dataset["gt_path"].append(gt_cache_file) + if len(self.dataset["gt_path"])>0: + cached_dataset["gt_path"][i] = gt_cache_file + else: + cached_dataset["gt_path"].append(gt_cache_file) if(self.cache_boost): gts_pt_list.append(torch.unsqueeze(gt,0)) # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))