diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py index 4448809..375bb26 100644 --- a/IS-Net/train_valid_inference_main.py +++ b/IS-Net/train_valid_inference_main.py @@ -243,6 +243,8 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0): # pred_val = normPRED(pred_val) gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 + if gt.max()==1: + gt=gt*255 with torch.no_grad(): gt = torch.tensor(gt).to(device) @@ -480,6 +482,8 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): if len(valid_dataset.dataset["ori_gt_path"]) != 0: gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 + if gt.max()==1: + gt=gt*255 else: gt = np.zeros((shapes_val[t][0],shapes_val[t][1])) with torch.no_grad(): @@ -724,4 +728,4 @@ if __name__ == "__main__": main(train_datasets, valid_datasets, - hypar=hypar) + hypar=hypar) \ No newline at end of file