Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework metrics logic to support states #2391

Merged

Conversation

tRosenflanz
Copy link
Contributor

@tRosenflanz tRosenflanz commented May 21, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2390, fixes #2389

Summary

Instead of doing forward/log on metrics on each step it only updates the state at every step and does the computation/log/reset at the end of the epoch.

Copy link

codecov bot commented May 22, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.74%. Comparing base (000d29d) to head (43a4f24).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2391      +/-   ##
==========================================
- Coverage   93.75%   93.74%   -0.01%     
==========================================
  Files         138      138              
  Lines       14343    14338       -5     
==========================================
- Hits        13447    13441       -6     
- Misses        896      897       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


def update(self, preds, target) -> None:
if preds.shape != target.shape:
raise ValueError("preds and target must have the same shape")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("preds and target must have the same shape")
raise ValueError(f"preds and target must have the same shape, but got {preds.shape} for preds and {target.shape} for target.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated. Note that this is just a test class, this error isn't raised in the training loop itself

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great PR @tRosenflanz 🚀

It looks really good, just had some minor suggestions.
After those have been addressed, we can merge :)

CHANGELOG.md Outdated Show resolved Hide resolved
return loss

def _compute_metrics(self, metrics):
res = metrics.compute()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this method to where method _update_metrics() is defined?

Also, we can skip as done in _update_metrics()

Suggested change
res = metrics.compute()
if not len(metrics):
return
res = metrics.compute()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved and "skipped" when needed. For some reason linter doesn't add empty line by default btw

Comment on lines 88 to 92
if preds.shape != target.shape:
raise ValueError(
"preds and target must have the same shape "
f"but got {preds.shape} for preds and {target.shape} for target."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required for this test

Suggested change
if preds.shape != target.shape:
raise ValueError(
"preds and target must have the same shape "
f"but got {preds.shape} for preds and {target.shape} for target."
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines 83 to 84
def __init__(self, **kwargs):
super().__init__(**kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required for this test

Suggested change
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self):
super().__init__()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Note that this might fail in future if you start adding some args for devices or such

**tfm_kwargs,
)
model.fit(self.series)
assert model.model.trainer.logged_metrics["train_NumsCalled"] != 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just saw this now. We should probably check that it is > 1 since 0 would be incorrect.

Suggested change
assert model.model.trainer.logged_metrics["train_NumsCalled"] != 1
assert model.model.trainer.logged_metrics["train_NumsCalled"] > 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, fixed. This tests that update_metrics works correctly as well

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot @tRosenflanz, read to merge 🚀 💯

@dennisbader dennisbader merged commit a0cc279 into unit8co:master May 27, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants