Skip to content

Commit

Permalink
fix loss logging issue
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 23, 2023
1 parent a993949 commit 4a85b17
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions lavis/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def build_datasets(self, cfg):
return datasets

def train_step(self, model, samples):
loss_dict = model(samples)
loss = loss_dict["loss"]
return loss, loss_dict
output = model(samples)
loss_dict = {}
for k,v in output.items():
if "loss" in k:
loss_dict[k] = v
return output["loss"], loss_dict

def valid_step(self, model, samples):
raise NotImplementedError
Expand Down

0 comments on commit 4a85b17

Please sign in to comment.