Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sdk): wandb.Image breaks plt.imshow #7279

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 24 additions & 18 deletions wandb/sdk/data_types/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,14 @@
# self._masks = wbimage._masks

def _initialize_from_path(self, path: str) -> None:
pil_image = util.get_module(
"PIL.Image",
required='wandb.Image needs the PIL package. To get it, run "pip install pillow".',
)
try:
import PIL.Image as PILImage
except ImportError:
raise wandb.Error(

Check warning on line 274 in wandb/sdk/data_types/image.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/data_types/image.py#L273-L274

Added lines #L273 - L274 were not covered by tests
'wandb.Image needs the PIL package. To get it, run "pip install pillow".'
)
self._set_file(path, is_tmp=False)
self._image = pil_image.open(path)
self._image = PILImage.open(path)
assert self._image is not None
self._image.load()
ext = os.path.splitext(path)[1][1:]
Expand All @@ -293,15 +295,17 @@
mode: Optional[str] = None,
file_type: Optional[str] = None,
) -> None:
pil_image = util.get_module(
"PIL.Image",
required='wandb.Image needs the PIL package. To get it, run "pip install pillow".',
)
try:
import PIL.Image as PILImage
except ImportError:
raise wandb.Error(

Check warning on line 301 in wandb/sdk/data_types/image.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/data_types/image.py#L300-L301

Added lines #L300 - L301 were not covered by tests
'wandb.Image needs the PIL package. To get it, run "pip install pillow".'
)
if util.is_matplotlib_typename(util.get_full_typename(data)):
buf = BytesIO()
util.ensure_matplotlib_figure(data).savefig(buf, format="png")
self._image = pil_image.open(buf, formats=["PNG"])
elif isinstance(data, pil_image.Image):
self._image = PILImage.open(buf, formats=["PNG"])
elif isinstance(data, PILImage.Image):
self._image = data
elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
vis_util = util.get_module(
Expand All @@ -312,15 +316,15 @@
if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
data = data.to(float)
data = vis_util.make_grid(data, normalize=True)
self._image = pil_image.fromarray(
self._image = PILImage.fromarray(
data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
)
else:
if hasattr(data, "numpy"): # TF data eager tensors
data = data.numpy()
if data.ndim > 2:
data = data.squeeze() # get rid of trivial dimensions as a convenience
self._image = pil_image.fromarray(
self._image = PILImage.fromarray(
self.to_uint8(data), mode=mode or self.guess_mode(data)
)
accepted_formats = ["png", "jpg", "jpeg", "bmp"]
Expand Down Expand Up @@ -678,10 +682,12 @@
def image(self) -> Optional["PILImage"]:
if self._image is None:
if self._path is not None and not self.path_is_reference(self._path):
pil_image = util.get_module(
"PIL.Image",
required='wandb.Image needs the PIL package. To get it, run "pip install pillow".',
)
self._image = pil_image.open(self._path)
try:
import PIL.Image as PILImage
except ImportError:
raise wandb.Error(

Check warning on line 688 in wandb/sdk/data_types/image.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/data_types/image.py#L687-L688

Added lines #L687 - L688 were not covered by tests
'wandb.Image needs the PIL package. To get it, run "pip install pillow".'
)
self._image = PILImage.open(self._path)
self._image.load()
return self._image