mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
correct the error that F1 always equal to 0.
This commit is contained in:
parent
7df2a36362
commit
10657f42e5
@ -243,6 +243,8 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
# pred_val = normPRED(pred_val)
|
||||
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
if gt.max()==1:
|
||||
gt=gt*255
|
||||
with torch.no_grad():
|
||||
gt = torch.tensor(gt).to(device)
|
||||
|
||||
@ -480,6 +482,8 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
|
||||
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
if gt.max()==1:
|
||||
gt=gt*255
|
||||
else:
|
||||
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
||||
with torch.no_grad():
|
||||
@ -724,4 +728,4 @@ if __name__ == "__main__":
|
||||
|
||||
main(train_datasets,
|
||||
valid_datasets,
|
||||
hypar=hypar)
|
||||
hypar=hypar)
|
Loading…
Reference in New Issue
Block a user