Update data_loader_cache.py

temporarily solution for training bugs
This commit is contained in:
Xuebin Qin 2022-07-26 19:55:35 -07:00 committed by GitHub
parent 85a9a1be60
commit 362f4e9699
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -291,8 +291,8 @@ class GOSDatasetCache(Dataset):
gt, gt_shp = gt_preprocess(gt,self.cache_size) gt, gt_shp = gt_preprocess(gt,self.cache_size)
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt") gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
torch.save(gt,gt_cache_file) torch.save(gt,gt_cache_file)
# cached_dataset["gt_path"][i] = gt_cache_file cached_dataset["gt_path"][i] = gt_cache_file
cached_dataset["gt_path"].append(gt_cache_file) #cached_dataset["gt_path"].append(gt_cache_file)
if(self.cache_boost): if(self.cache_boost):
gts_pt_list.append(torch.unsqueeze(gt,0)) gts_pt_list.append(torch.unsqueeze(gt,0))
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8)) # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))