mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-29 18:23:08 +01:00
correct the error that F1 always equal to 0.
This commit is contained in:
parent
7df2a36362
commit
10657f42e5
@ -243,6 +243,8 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
|||||||
# pred_val = normPRED(pred_val)
|
# pred_val = normPRED(pred_val)
|
||||||
|
|
||||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
if gt.max()==1:
|
||||||
|
gt=gt*255
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gt = torch.tensor(gt).to(device)
|
gt = torch.tensor(gt).to(device)
|
||||||
|
|
||||||
@ -480,6 +482,8 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
|||||||
|
|
||||||
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
||||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
if gt.max()==1:
|
||||||
|
gt=gt*255
|
||||||
else:
|
else:
|
||||||
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user