Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Allow logging artifact without creating new versions #7413

Open
buoyancy99 opened this issue Apr 18, 2024 · 4 comments
Open

[Feature]: Allow logging artifact without creating new versions #7413

buoyancy99 opened this issue Apr 18, 2024 · 4 comments
Labels
c:artifacts Candidate for artifact branch ty:feature_request type of the issue is a feature request

Comments

@buoyancy99
Copy link

Description

Now wandb has tightened its storage policy, we found it's very space-inefficient to store all versions of artifacts. For example, to save space, if I am to store model weights I may only want to store the latest version. However, I found no good ways to overwrite artifact with the same name without creating versions.

Note: pytorch_lightning's ModelCheckpoint class has options like save_top_k or enable_version_counter, but these only affect the behavior for local files, not wandb. Wandb creates a new version anyway so it wastes a lot of storage.

Suggested Solution

Add an overwrite option to run.log_artifact such that when overwrite=True, log_artifact will not create new version but instead overwrite the artifact with same name if it exists

Alternatives

No response

Additional Context

No response

@buoyancy99
Copy link
Author

buoyancy99 commented Apr 18, 2024

Related: #6278
I don't think people took this seriously until they run out of storage!

@kptkin kptkin added c:artifacts Candidate for artifact branch ty:feature_request type of the issue is a feature request labels Apr 18, 2024
@ArtsiomWB
Copy link
Contributor

Hi @buoyancy99, thank you for writing in! I have submitted your feature request to our enginering team.

@sdascoli
Copy link

sdascoli commented May 3, 2024

+1, this is a huge problem !!!
I only realized a few days ago that I had accumulated over 50TB of data without even knowing about it.
I find it very weird that when one specifies a dirpath in ModelCheckpoint, wandb both saves to the dirpath AND creates an artifact silently in a obscure folder buried inside ~/.cache/wandb.

@buoyancy99
Copy link
Author

If you are using pytorch lightning, I have a work around here (could be useful for the team too)

from pathlib import Path
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union
from typing_extensions import override
from functools import wraps
import os
import time
from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH


if TYPE_CHECKING:
    from wandb.sdk.lib import RunDisabled
    from wandb.wandb_run import Run


class SpaceEfficientWandbLogger(WandbLogger):
    """
    A wandb logger that by default overrides artifacts to save space, instead of creating new version.
    A variable expiration_days can be set to control how long older versions of artifacts are kept.
    By default, the latest version is kept indefinitely, while older versions are kept for 5 days.
    """

    def __init__(
        self,
        name: Optional[str] = None,
        save_dir: _PATH = ".",
        version: Optional[str] = None,
        offline: bool = False,
        dir: Optional[_PATH] = None,
        id: Optional[str] = None,
        anonymous: Optional[bool] = None,
        project: Optional[str] = None,
        log_model: Union[Literal["all"], bool] = False,
        experiment: Union["Run", "RunDisabled", None] = None,
        prefix: str = "",
        checkpoint_name: Optional[str] = None,
        expiration_days: Optional[int] = 5,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            name=name,
            save_dir=save_dir,
            version=version,
            offline=False,
            dir=dir,
            id=id,
            anonymous=anonymous,
            project=project,
            log_model=log_model,
            experiment=experiment,
            prefix=prefix,
            checkpoint_name=checkpoint_name,
            **kwargs,
        )

        super().__init__(
            name=name,
            save_dir=save_dir,
            version=version,
            offline=offline,
            dir=dir,
            id=id,
            anonymous=anonymous,
            project=project,
            log_model=log_model,
            experiment=experiment,
            prefix=prefix,
            checkpoint_name=checkpoint_name,
            **kwargs,
        )
        self.expiration_days = expiration_days
        self._last_artifacts = []

    def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
        import wandb

        # get checkpoints to be saved with associated score
        checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)

        # log iteratively all new checkpoints
        artifacts = []
        for t, p, s, tag in checkpoints:
            metadata = {
                "score": s.item() if isinstance(s, Tensor) else s,
                "original_filename": Path(p).name,
                checkpoint_callback.__class__.__name__: {
                    k: getattr(checkpoint_callback, k)
                    for k in [
                        "monitor",
                        "mode",
                        "save_last",
                        "save_top_k",
                        "save_weights_only",
                        "_every_n_train_steps",
                    ]
                    # ensure it does not break if `ModelCheckpoint` args change
                    if hasattr(checkpoint_callback, k)
                },
            }
            if not self._checkpoint_name:
                self._checkpoint_name = f"model-{self.experiment.id}"

            artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
            artifact.add_file(p, name="model.ckpt")
            aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
            self.experiment.log_artifact(artifact, aliases=aliases)
            # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
            self._logged_model_time[p] = t
            artifacts.append(artifact)

        for artifact in self._last_artifacts:
            if not self._offline:
                artifact.wait()
            artifact.ttl = timedelta(days=self.expiration_days)
            artifact.save()
        self._last_artifacts = artifacts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
c:artifacts Candidate for artifact branch ty:feature_request type of the issue is a feature request
Projects
None yet
Development

No branches or pull requests

4 participants