mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
Update data_loader_cache.py
temporarily solution for training bugs
This commit is contained in:
parent
85a9a1be60
commit
362f4e9699
@ -291,8 +291,8 @@ class GOSDatasetCache(Dataset):
|
|||||||
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
||||||
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
||||||
torch.save(gt,gt_cache_file)
|
torch.save(gt,gt_cache_file)
|
||||||
# cached_dataset["gt_path"][i] = gt_cache_file
|
cached_dataset["gt_path"][i] = gt_cache_file
|
||||||
cached_dataset["gt_path"].append(gt_cache_file)
|
#cached_dataset["gt_path"].append(gt_cache_file)
|
||||||
if(self.cache_boost):
|
if(self.cache_boost):
|
||||||
gts_pt_list.append(torch.unsqueeze(gt,0))
|
gts_pt_list.append(torch.unsqueeze(gt,0))
|
||||||
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
||||||
|
Loading…
Reference in New Issue
Block a user