compute_loss

This commit is contained in:
Xuebin Qin 2022-07-17 11:00:12 -07:00
parent a802503bee
commit c1480958ac

View File

@ -350,7 +350,7 @@ def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders,
else: else:
# forward + backward + optimize # forward + backward + optimize
ds,_ = net(inputs_v) ds,_ = net(inputs_v)
loss2, loss = muti_loss_fusion(ds, labels_v) loss2, loss = net.compute_loss(ds, labels_v)
loss.backward() loss.backward()
optimizer.step() optimizer.step()