mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-12-02 03:03:56 +01:00
compute_loss
This commit is contained in:
parent
a802503bee
commit
c1480958ac
@ -350,7 +350,7 @@ def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders,
|
|||||||
else:
|
else:
|
||||||
# forward + backward + optimize
|
# forward + backward + optimize
|
||||||
ds,_ = net(inputs_v)
|
ds,_ = net(inputs_v)
|
||||||
loss2, loss = muti_loss_fusion(ds, labels_v)
|
loss2, loss = net.compute_loss(ds, labels_v)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
Reference in New Issue
Block a user