Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sample pixels from all data in variable-resolution batches #2772

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 11 additions & 10 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@
Code for sampling pixels.
"""

from collections import defaultdict
import random
import warnings
from dataclasses import dataclass, field
from typing import Dict, Optional, Type, Union

import torch
from jaxtyping import Int
from torch import Tensor

from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.data.utils.pixel_sampling_utils import erode_mask


@dataclass

Check failure on line 33 in nerfstudio/data/pixel_samplers.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/data/pixel_samplers.py:19:1: I001 Import block is un-sorted or un-formatted
class PixelSamplerConfig(InstantiateConfig):
"""Configuration for pixel sampler instantiation."""

Expand Down Expand Up @@ -271,7 +272,7 @@

# only sample within the mask, if the mask is in the batch
all_indices = []
all_images = []
all_images = defaultdict(list)

if "mask" in batch:
num_rays_in_batch = num_rays_per_batch // num_images
Expand All @@ -286,7 +287,10 @@
)
indices[:, 0] = i
all_indices.append(indices)
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
for key, value in batch.items():
if key in ["image_idx", "mask"]:
continue
all_images[key].append(batch[key][i][indices[:, 1], indices[:, 2]])

else:
num_rays_in_batch = num_rays_per_batch // num_images
Expand All @@ -302,18 +306,15 @@
indices = self.sample_method(num_rays_in_batch, 1, image_height, image_width, device=device)
indices[:, 0] = i
all_indices.append(indices)
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
for key, value in batch.items():
if key in ["image_idx", "mask"]:
continue
all_images[key].append(batch[key][i][indices[:, 1], indices[:, 2]])

indices = torch.cat(all_indices, dim=0)

c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
collated_batch = {
key: value[c, y, x]
for key, value in batch.items()
if key != "image_idx" and key != "image" and key != "mask" and value is not None
}

collated_batch["image"] = torch.cat(all_images, dim=0)
collated_batch = {key: torch.cat(all_images[key], dim=0) for key in all_images}

assert collated_batch["image"].shape[0] == num_rays_per_batch

Expand Down