correct the error that F1 always equal to 0.

This commit is contained in:
PiggyJerry 2022-11-04 12:10:53 +04:00
parent 7df2a36362
commit 10657f42e5

View File

@ -243,6 +243,8 @@ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
# pred_val = normPRED(pred_val) # pred_val = normPRED(pred_val)
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
if gt.max()==1:
gt=gt*255
with torch.no_grad(): with torch.no_grad():
gt = torch.tensor(gt).to(device) gt = torch.tensor(gt).to(device)
@ -480,6 +482,8 @@ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
if len(valid_dataset.dataset["ori_gt_path"]) != 0: if len(valid_dataset.dataset["ori_gt_path"]) != 0:
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
if gt.max()==1:
gt=gt*255
else: else:
gt = np.zeros((shapes_val[t][0],shapes_val[t][1])) gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
with torch.no_grad(): with torch.no_grad():
@ -724,4 +728,4 @@ if __name__ == "__main__":
main(train_datasets, main(train_datasets,
valid_datasets, valid_datasets,
hypar=hypar) hypar=hypar)