mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
correct the error that F1 always equal to 0.
This commit is contained in:
parent
7df2a36362
commit
10657f42e5
@ -243,6 +243,8 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
|||||||
# pred_val = normPRED(pred_val)
|
# pred_val = normPRED(pred_val)
|
||||||
|
|
||||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
if gt.max()==1:
|
||||||
|
gt=gt*255
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gt = torch.tensor(gt).to(device)
|
gt = torch.tensor(gt).to(device)
|
||||||
|
|
||||||
@ -480,6 +482,8 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
|||||||
|
|
||||||
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
||||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
if gt.max()==1:
|
||||||
|
gt=gt*255
|
||||||
else:
|
else:
|
||||||
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -724,4 +728,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
main(train_datasets,
|
main(train_datasets,
|
||||||
valid_datasets,
|
valid_datasets,
|
||||||
hypar=hypar)
|
hypar=hypar)
|
Loading…
Reference in New Issue
Block a user