Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Support signal-safe wandb.finish() #7590

Open
collinmccarthy opened this issue May 7, 2024 · 4 comments
Open

[Feature]: Support signal-safe wandb.finish() #7590

collinmccarthy opened this issue May 7, 2024 · 4 comments
Labels
a:cli Area: Client

Comments

@collinmccarthy
Copy link

Description

I'm trying to call wandb.finish(exit_code=0, quiet=True) from within a sigterm handler. This allows me to choose my exit code, and ensure checkpoints finish uploading, when the run is preempted/requeued from something like a Slurm timeout. Calling wandb.finish() from a sigterm handler works most of the time, unless one of the following happens (and possibly more things, I'm not an expert here):

(1) If SIGTERM was raised from within a print/logging function, calling logger.info() will raise a "reentrant" RuntimeError. The other print statements like wandb: Synced N W&B file(s) may also raise the same error, but I haven't verified that. Over the course of preempting many runs this can definitely happen.

(2) If PyTorch dataloader workers that are still running in the background, but were terminated by SIGTERM. This will raise something like the following when trying to call wandb.finish():

Traceback (most recent call last):
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 2313, in _atexit_cleanup
    self._on_finish()
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 2561, in _on_finish
    _ = exit_handle.wait(timeout=-1, on_progress=self._on_progress_exit)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/site-packages/wandb/sdk/lib/mailbox.py", line 283, in wait
    found, abandoned = self._slot._get_and_clear(timeout=wait_timeout)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/site-packages/wandb/sdk/lib/mailbox.py", line 130, in _get_and_clear
    if self._wait(timeout=timeout):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/site-packages/wandb/sdk/lib/mailbox.py", line 126, in _wait
    return self._event.wait(timeout=timeout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/threading.py", line 629, in wait
    signaled = self._cond.wait(timeout)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/threading.py", line 331, in wait
    gotit = waiter.acquire(True, timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cmccarth/.conda/envs/custom_env/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 3477323) is killed by signal: Terminated. 

This is using wandb, version 0.16.7.dev1 (commit eea39ec from PR 7322).

Suggested Solution

Issue 1 above can be solved by allowing wandb to be completely silenced after calling wandb.init(). I tried using os.environ["WANDB_SILENT"] = "true" but this doesn't appear to work after calling init.

Issue 2 above probably just needs to ignore the exception when calling self._event.wait(timeout=timeout), but I'm not entirely sure if that's possible because I don't entirely understand why this is getting raised from calling wait() like this. Maybe in conjunction with the right start_method setting this issue won't occur in the first place? I'm currently using the default settings (along with wandb.require('service') and I tried manually setting start_method to things like fork/spawn/forkserver and no luck.

Alternatives

My current workaround is to

  1. Override the default "wandb" logger by declaring my own "wandb" logger that doesn't use a stream handler. This may prevent logging to debug.log though, I'm not sure.
  2. Call PyTorch's DataLoader._iterator._shutdown_workers(), which works great.

Combined this seems to work, except it still allows the wandb: Synced N W&B file(s) output to be printed to the console. However I think at that point if we crash the run was successfully closed.

Regardless, I think most people won't go this far down the rabbit hole to implement these workarounds, so if we could do it all behind the scenes with something like a silent flag to wandb.finish() and a signal-safe clean-up for other processes, that would be amazing.

Additional Context

Also, if I'm totally off the mark here on how to force wandb to finish even when it gets SIGTERM, please let me know. Being able to get my checkpoints from wandb directly is amazingly convenient, so it's very important to me that they finish uploading in all scenarios (within reason).

@luisbergua
Copy link
Contributor

Hey @collinmccarthy, thanks for the detailed report and for sharing your insights and workarounds. We appreciate the time you've taken to outline the issues you've encountered with wandb.finish() during SIGTERM handling, especially when involving logging functions and PyTorch DataLoader workers.

  • Issue with Logging: The reentrant RuntimeError you've described is an interesting scenario that highlights a potential area for improvement in how wandb handles logging during termination. We'll investigate whether there's a possibility to fully silence wandb post-initialization to avoid such errors, or alternatively, to handle log operations safely during termination.
  • PyTorch DataLoader and Cleanup: Your suggestion to ignore exceptions during self._event.wait() is noted. We need to evaluate if this could lead to other unintended consequences or if there's a more robust way to ensure a clean shutdown.

If you're interested, we welcome contributions to our codebase, take a look here if you want to know more

@luisbergua
Copy link
Contributor

Hey @collinmccarthy! I checked internally and wondering if you would have any problems with sharing a repro code so we can test on our side? Thanks in advance!

@collinmccarthy
Copy link
Author

collinmccarthy commented May 9, 2024

Hey @luisbergua, the repo that I'm using is built on top of mmdetection with a lot of abstractions. I think it's too complicated for reproducing the issue.

I can work on a script to demonstrate but it may take a bit to get to it. Here's what I would do.

  • Take a CIFAR 10 PyTorch example
  • Run validation (or training if easier), with persistent threads and logging (to stdout) every iter for max verbosity
  • Add SIGTERM handler which calls wandb.finish()
  • Launch validation and either (a) run kill from a different terminal , or (b) have the same script signal SIGTERM after N iterations, from within the logger

If we invoke SIGTERM from within the logger (actually, maybe from within our own version of print() that we override) I think it will demonstrate both issues in a reproducible way. I can try to work on this script over the next couple of weeks, but no guarantees on my side. My "solution" I posted above seems to be working (so far) for the dozen or so preempted runs I launch per day. I'll also update this issue if it fails at some point.

@kptkin kptkin added the a:cli Area: Client label May 10, 2024
@luisbergua
Copy link
Contributor

Hey @collinmccarthy, thanks for sharing these details! We'll give it a try but would appreciate it if you could provide the code example once ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
a:cli Area: Client
Projects
None yet
Development

No branches or pull requests

3 participants