Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(artifacts): add save() method on ArtifactCollection to allow persisting changes #7555

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Please add to the relevant subsections under Unreleased below on every PR where
* `wandb-core` now supports Artifact file caching by @moredatarequired in https://github.com/wandb/wandb/pull/7364 and https://github.com/wandb/wandb/pull/7366
* Added artifact_exists() and artifact_collection_exists() methods to Api to check if an artifact or collection exists by @amusipatla-wandb in https://github.com/wandb/wandb/pull/7483
* `wandb launch -u <git-uri | local-path> ` creates and launches a job from the given source code by @bcsherma in https://github.com/wandb/wandb/pull/7485
* Added ArtifactCollection.save() by @amusipatla-wandb in https://github.com/wandb/wandb/pull/7555
amusipatla-wandb marked this conversation as resolved.
Show resolved Hide resolved

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,66 @@ def test_change_artifact_collection_type(monkeypatch, wandb_init):
assert artifact.type == "lucas_type"


def test_save_artifact_sequence(monkeypatch, wandb_init):
with wandb_init() as run:
artifact = wandb.Artifact("sequence_name", "data")
run.log_artifact(artifact)

with wandb_init() as run:
amusipatla-wandb marked this conversation as resolved.
Show resolved Hide resolved
artifact = run.use_artifact("sequence_name:latest")
collection = artifact.collection
collection.description = "new description"
collection.name = "new_name"
collection.type = "new_type"
collection.tags = ["tag"]
collection.save()

artifact = run.use_artifact("new_name:latest")
assert artifact.type == "new_type"
collection = artifact.collection
assert collection.type == "new_type"
assert collection.name == "new_name"
assert collection.description == "new description"
assert len(collection.tags) == 1 and collection.tags[0] == "tag"

collection.tags = ["new_tag"]
collection.save()

artifact = run.use_artifact("new_name:latest")
collection = artifact.collection
assert len(collection.tags) == 1 and collection.tags[0] == "new_tag"


def test_save_artifact_portfolio(monkeypatch, wandb_init):
with wandb_init() as run:
artifact = wandb.Artifact("image_data", "data")
run.log_artifact(artifact)
artifact.link("portfolio_name")

with wandb_init() as run:
port_artifact = run.use_artifact("portfolio_name:v0")
portfolio = port_artifact.collection
portfolio.description = "new description"
portfolio.name = "new_name"
with pytest.raises(ValueError):
portfolio.type = "new_type"
portfolio.tags = ["tag"]
portfolio.save()

port_artifact = run.use_artifact("new_name:v0")
portfolio = port_artifact.collection
assert portfolio.name == "new_name"
assert portfolio.description == "new description"
assert len(portfolio.tags) == 1 and portfolio.tags[0] == "tag"

portfolio.tags = ["new_tag"]
portfolio.save()

artifact = run.use_artifact("new_name:latest")
portfolio = artifact.collection
assert len(portfolio.tags) == 1 and portfolio.tags[0] == "new_tag"


def test_s3_storage_handler_load_path_missing_reference_allowed(
monkeypatch, wandb_init, capsys
):
Expand Down
197 changes: 196 additions & 1 deletion wandb/apis/public/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from copy import copy
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence

from wandb_gql import Client, gql

Expand Down Expand Up @@ -508,21 +508,216 @@
"""A description of the artifact collection."""
return self._description

@description.setter
def description(self, description: Optional[str]) -> None:
self._description = description

@property
def tags(self):
"""The tags associated with the artifact collection."""
return self._tags

@tags.setter
def tags(self, tags: List[str]) -> None:
if any(char in tag for tag in tags for char in ["/", ":"]):
raise ValueError(

Check warning on line 523 in wandb/apis/public/artifacts.py

View check run for this annotation

Codecov / codecov/patch

wandb/apis/public/artifacts.py#L523

Added line #L523 was not covered by tests
"Tags must not contain any of the following characters: /, :"
)
amusipatla-wandb marked this conversation as resolved.
Show resolved Hide resolved
self._tags = tags

@property
def name(self):
"""The name of the artifact collection."""
return self._name

@name.setter
def name(self, name: List[str]) -> None:
self._name = name

@property
def type(self):
"""The type of the artifact collection."""
return self._type

@type.setter
def type(self, type: List[str]) -> None:
amusipatla-wandb marked this conversation as resolved.
Show resolved Hide resolved
if not self.is_sequence():
raise ValueError(
"Type can only be changed if the artifact collection is a sequence."
)
self._type = type

def save(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a mild preference for breaking this up somehow, probably by making smaller private functions for each of the GQL operations. It's up to you though, feel free to merge as-is

"""Persist any changes made to the artifact collection."""
if self.is_sequence():
mutation = gql("""
mutation UpdateArtifactCollection(
$artifactSequenceID: ID!
$name: String
$description: String
) {
updateArtifactSequence(
input: {
artifactSequenceID: $artifactSequenceID
name: $name
description: $description
}
) {
artifactCollection {
id
name
description
}
}
}
""")

variable_values = {
"artifactSequenceID": self.id,
"name": self._name,
"description": self.description,
}
self.client.execute(mutation, variable_values=variable_values)
self._saved_name = self._name

if self._saved_type != self._type:
type_mutation = gql("""
mutation MoveArtifactCollection(
$artifactSequenceID: ID!
$destinationArtifactTypeName: String!
) {
moveArtifactSequence(
input: {
artifactSequenceID: $artifactSequenceID
destinationArtifactTypeName: $destinationArtifactTypeName
}
) {
artifactCollection {
id
name
description
__typename
}
}
}
""")

variable_values = {
"artifactSequenceID": self.id,
"destinationArtifactTypeName": self._type,
}
self.client.execute(type_mutation, variable_values=variable_values)
self._saved_type = self._type
else:
mutation = gql("""
mutation UpdateArtifactPortfolio(
$artifactPortfolioID: ID!
$name: String
$description: String
) {
updateArtifactPortfolio(
input: {
artifactPortfolioID: $artifactPortfolioID
name: $name
description: $description
}
) {
artifactCollection {
id
name
description
}
}
}
""")
variable_values = {
"artifactPortfolioID": self.id,
"name": self._name,
"description": self.description,
}
self.client.execute(mutation, variable_values=variable_values)
self._saved_name = self._name
moredatarequired marked this conversation as resolved.
Show resolved Hide resolved

tags_to_add = set(self._tags) - set(self._saved_tags)
tags_to_delete = set(self._saved_tags) - set(self._tags)
amusipatla-wandb marked this conversation as resolved.
Show resolved Hide resolved
if len(tags_to_add) > 0:
add_mutation = gql(
"""
mutation CreateArtifactCollectionTagAssignments(
$entityName: String!
$projectName: String!
$artifactCollectionName: String!
$tags: [TagInput!]!
) {
createArtifactCollectionTagAssignments(
input: {
entityName: $entityName
projectName: $projectName
artifactCollectionName: $artifactCollectionName
tags: $tags
}
) {
tags {
id
name
tagCategoryName
}
}
}
"""
)
self.client.execute(
add_mutation,
variable_values={
"entityName": self.entity,
"projectName": self.project,
"artifactCollectionName": self._saved_name,
"tags": [
{
"tagName": tag,
}
for tag in tags_to_add
amusipatla-wandb marked this conversation as resolved.
Show resolved Hide resolved
],
},
)
if len(tags_to_delete) > 0:
delete_mutation = gql(
"""
mutation DeleteArtifactCollectionTagAssignments(
$entityName: String!
$projectName: String!
$artifactCollectionName: String!
$tags: [TagInput!]!
) {
deleteArtifactCollectionTagAssignments(
input: {
entityName: $entityName
projectName: $projectName
artifactCollectionName: $artifactCollectionName
tags: $tags
}
) {
success
}
}
"""
)
self.client.execute(
delete_mutation,
variable_values={
"entityName": self.entity,
"projectName": self.project,
"artifactCollectionName": self._saved_name,
"tags": [
{
"tagName": tag,
}
for tag in tags_to_delete
],
},
)
self._saved_tags = copy(self._tags)

def __repr__(self):
return f"<ArtifactCollection {self._name} ({self._type})>"

Expand Down