Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Depth nerfacto with visibility loss #2982

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def get_intrinsics_matrices(self) -> Float[Tensor, "*num_cameras 3 3"]:
Returns:
Pinhole camera intrinsics matrices
"""
K = torch.zeros((*self.shape, 3, 3), dtype=torch.float32)
K = torch.zeros((*self.shape, 3, 3), dtype=torch.float32, device=self.device)
K[..., 0, 0] = self.fx.squeeze(-1)
K[..., 1, 1] = self.fy.squeeze(-1)
K[..., 0, 2] = self.cx.squeeze(-1)
Expand Down
90 changes: 90 additions & 0 deletions nerfstudio/fields/visibility_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Visibility Field"""

import torch
from torch import nn

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.cameras.rays import RaySamples


class VisibilityField(nn.Module):
"""Visibility Field"""

def __init__(self, cameras: Cameras) -> None:
super().__init__()
# training camera tranforms
# TODO: use optimized cameras
self.c2ws = cameras.camera_to_worlds
self.c2whs = torch.cat([self.c2ws, torch.zeros_like(self.c2ws[:, :1, :])], dim=1)
self.c2whs[:, 3, 3] = 1.0
self.w2chs = torch.inverse(self.c2whs)
self.K = cameras.get_intrinsics_matrices()
self.image_height = cameras.height
self.image_width = cameras.width

@torch.no_grad()
def forward(self, ray_samples: RaySamples, camera_chunk_size=50, ray_chunk_size=4096) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it maybe make sense to take in positions (Nx3) directly instead of ray_samples? this would make it compatible with other methods which might want to use visibility like splatting.

"""
Args:
ray_samples: Ray samples.
camera_chunk_size: Number of cameras to process at once to avoid memory issues.
ray_chunk_size: Number of rays to process at once to avoid memory issues.
Returns:
"""
# get positions
positions = ray_samples.frustums.get_positions() # [N, S, 3]
# project positions into each camera
# move to homogeneous coordinates
positions = torch.cat([positions, torch.ones_like(positions[..., :1])], dim=-1)
N, S, _ = positions.shape
B = self.w2chs.shape[0] # num cameras
p = positions.view(N * S, 4).transpose(0, 1).unsqueeze(0) # [1, 4, N*S]
p = p.expand(B, *p.shape[1:]) # [B, 4, N*S]

num_views = torch.zeros([N, S, 1], device=positions.device)
for i in range(0, B, camera_chunk_size):
ccs = min(camera_chunk_size, B - i)
for j in range(0, N, ray_chunk_size):
rcs = min(ray_chunk_size, N - j)

ptemp = p.reshape(B, 4, N, S)[i : i + ccs, :, j : j + rcs, :].reshape(ccs, 4, rcs * S)
cam_coords = torch.bmm(self.w2chs[i : i + ccs, :], ptemp)

# flip y and z axes
cam_coords[:, 1, :] *= -1
cam_coords[:, 2, :] *= -1

z = cam_coords[:, 2:3, :].transpose(1, 2).view(ccs, rcs, S, 1) # [CS, RCS, S, 1]
mask_z = z[..., 0] > 0

# divide by z
cam_coords = cam_coords[:, :3, :] / cam_coords[:, 2:3, :]

cam_points = torch.bmm(self.K[i : i + ccs], cam_coords)

pixel_coords = cam_points[:, :2, :].transpose(1, 2).view(ccs, rcs, S, 2) # [CS, RCS, S, 2]
x = pixel_coords[..., 0]
y = pixel_coords[..., 1]
mask_x = (x >= 0) & (x < self.image_width.view(B, 1, 1)[i : i + ccs])
mask_y = (y >= 0) & (y < self.image_height.view(B, 1, 1)[i : i + ccs])
mask = mask_x & mask_y & mask_z
# sum over the batch dimension
nv = mask.sum(dim=0).unsqueeze(-1)
# nv is [N, S, 1] # this is the number of camera frustums that the point belongs to
# for this particular chunk of cameras
num_views[j : j + rcs] += nv
return num_views
6 changes: 4 additions & 2 deletions nerfstudio/model_components/ray_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,11 @@ def generate_ray_samples(
self,
ray_bundle: Optional[RayBundle] = None,
density_fns: Optional[List[Callable]] = None,
) -> Tuple[RaySamples, List, List]:
) -> Tuple[RaySamples, List, List, List]:
assert ray_bundle is not None
assert density_fns is not None

densities_list = []
weights_list = []
ray_samples_list = []

Expand All @@ -608,14 +609,15 @@ def generate_ray_samples(
else:
with torch.no_grad():
density = density_fns[i_level](ray_samples.frustums.get_positions())
densities_list.append(density)
weights = ray_samples.get_weights(density)
weights_list.append(weights) # (num_rays, num_samples)
ray_samples_list.append(ray_samples)
if updated:
self._steps_since_update = 0

assert ray_samples is not None
return ray_samples, weights_list, ray_samples_list
return ray_samples, densities_list, weights_list, ray_samples_list


class NeuSSampler(Sampler):
Expand Down
76 changes: 66 additions & 10 deletions nerfstudio/model_components/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ def background_color_override_context(mode: Float[Tensor, "3"]) -> Generator[Non
BACKGROUND_COLOR_OVERRIDE = old_background_color


def compute_median(quantity: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""
Compute the median of a quantity along rays.

Args:
quantity: Quantity to compute median of.
weights: Weights along the rays
"""

cumulative_weights = torch.cumsum(weights[..., 0], dim=-1) # [..., num_samples]
split = torch.ones((*weights.shape[:-2], 1), device=weights.device) * 0.5 # [..., 1]
median_index = torch.searchsorted(cumulative_weights, split, side="left") # [..., 1]
median_index = torch.clamp(median_index, 0, quantity.shape[-2] - 1) # [..., 1]
median_quantity = torch.gather(quantity[..., 0], dim=-1, index=median_index) # [..., 1]

return median_quantity


class RGBRenderer(nn.Module):
"""Standard volumetric rendering.

Expand All @@ -74,7 +92,7 @@ def combine_rgb(
background_color: BackgroundColor = "random",
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*bs 3"]:
) -> Float[Tensor, "*batch 3"]:
"""Composite samples along ray and render color image.
If background color is random, no BG color is added - as if the background was black!

Expand Down Expand Up @@ -336,7 +354,7 @@ def forward(
ray_samples: RaySamples,
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*batch 1"]:
) -> Float[Tensor, "*bs 1"]:
"""Composite samples along ray and calculate depths.

Args:
Expand All @@ -349,20 +367,15 @@ def forward(
Outputs of depth values.
"""

if self.method == "median":
steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2
steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2

if self.method == "median":
if ray_indices is not None and num_rays is not None:
raise NotImplementedError("Median depth calculation is not implemented for packed samples.")
cumulative_weights = torch.cumsum(weights[..., 0], dim=-1) # [..., num_samples]
split = torch.ones((*weights.shape[:-2], 1), device=weights.device) * 0.5 # [..., 1]
median_index = torch.searchsorted(cumulative_weights, split, side="left") # [..., 1]
median_index = torch.clamp(median_index, 0, steps.shape[-2] - 1) # [..., 1]
median_depth = torch.gather(steps[..., 0], dim=-1, index=median_index) # [..., 1]
median_depth = compute_median(steps, weights)
return median_depth
if self.method == "expected":
eps = 1e-10
steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2

if ray_indices is not None and num_rays is not None:
# Necessary for packed samples from volumetric ray sampler
Expand All @@ -383,6 +396,49 @@ def forward(
raise NotImplementedError(f"Method {self.method} not implemented")


class VisibilityRenderer(nn.Module):
"""Calculate visibility along ray.

Visibity Method:
- median: Visibility is set to the visibility_samples where the accumulated weight reaches 0.5.

Args:
method: Visibility calculation method.
"""

def __init__(self, method: Literal["median"] = "median") -> None:
super().__init__()
self.method = method

def forward(
self,
visibility_samples: Float[Tensor, "*batch num_samples 1"],
weights: Float[Tensor, "*batch num_samples 1"],
ray_indices: Optional[Int[Tensor, "num_samples"]] = None,
num_rays: Optional[int] = None,
) -> Float[Tensor, "*bs 1"]:
"""Composite samples along ray and calculate depths.

Args:
visibility_samples: Number of views each point is visible in.
weights: Weights for each sample.
ray_samples: Set of ray samples.
ray_indices: Ray index for each sample, used when samples are packed.
num_rays: Number of rays, used when samples are packed.

Returns:
Outputs of visibility values.
"""

if self.method == "median":
if ray_indices is not None and num_rays is not None:
raise NotImplementedError("Median visibility calculation is not implemented for packed samples.")
median_visibility = compute_median(visibility_samples, weights)
return median_visibility
else:
raise NotImplementedError(f"Method {self.method} not implemented")


class UncertaintyRenderer(nn.Module):
"""Calculate uncertainty along the ray."""

Expand Down
5 changes: 5 additions & 0 deletions nerfstudio/models/depth_nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@
Nerfacto augmented with depth supervision.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, Tuple, Type

import numpy as np
import torch

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.model_components import losses
from nerfstudio.model_components.losses import DepthLossType, depth_loss, depth_ranking_loss
from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig
from nerfstudio.utils import colormaps
from nerfstudio.model_components.scene_colliders import AABBBoxCollider


@dataclass

Check failure on line 35 in nerfstudio/models/depth_nerfacto.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/models/depth_nerfacto.py:19:1: I001 Import block is un-sorted or un-formatted
class DepthNerfactoModelConfig(NerfactoModelConfig):
"""Additional parameters for depth supervision."""

Expand Down Expand Up @@ -70,6 +71,10 @@
else:
self.depth_sigma = torch.tensor([self.config.depth_sigma])

# self.collider = NearFarCollider(near_plane=self.config.near_plane, far_plane=self.config.far_plane)
# import pdb; pdb.set_trace();
self.collider = AABBBoxCollider(scene_box=self.scene_box)

def get_outputs(self, ray_bundle: RayBundle):
outputs = super().get_outputs(ray_bundle)
if ray_bundle.metadata is not None and "directions_norm" in ray_bundle.metadata:
Expand Down
32 changes: 27 additions & 5 deletions nerfstudio/models/nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Literal, Tuple, Type
from typing import Dict, List, Literal, Optional, Tuple, Type

import numpy as np
import torch
Expand All @@ -32,6 +32,7 @@
from nerfstudio.field_components.spatial_distortions import SceneContraction
from nerfstudio.fields.density_fields import HashMLPDensityField
from nerfstudio.fields.nerfacto_field import NerfactoField
from nerfstudio.fields.visibility_field import VisibilityField
from nerfstudio.model_components.losses import (
MSELoss,
distortion_loss,
Expand All @@ -41,7 +42,13 @@
scale_gradients_by_distance_squared,
)
from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler, UniformSampler
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, NormalsRenderer, RGBRenderer
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
DepthRenderer,
NormalsRenderer,
RGBRenderer,
VisibilityRenderer,
)
from nerfstudio.model_components.scene_colliders import NearFarCollider
from nerfstudio.model_components.shaders import NormalsShader
from nerfstudio.models.base_model import Model, ModelConfig
Expand Down Expand Up @@ -171,10 +178,11 @@
average_init_density=self.config.average_init_density,
implementation=self.config.implementation,
)

self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
num_cameras=self.num_train_data, device="cpu"
)
# this can be set by the pipeline
# if set, then we create a visibility outputs
self.visibility_field: Optional[VisibilityField] = None

Check failure on line 185 in nerfstudio/models/nerfacto.py

View workflow job for this annotation

GitHub Actions / build

Ruff (E999)

nerfstudio/models/nerfacto.py:185:9: E999 SyntaxError: Unexpected token 'self'
self.density_fns = []
num_prop_nets = self.config.num_proposal_iterations
# Build the proposal network(s)
Expand Down Expand Up @@ -234,6 +242,7 @@
self.renderer_accumulation = AccumulationRenderer()
self.renderer_depth = DepthRenderer(method="median")
self.renderer_expected_depth = DepthRenderer(method="expected")
self.renderer_visibility = VisibilityRenderer(method="median")
self.renderer_normals = NormalsRenderer()

# shaders
Expand Down Expand Up @@ -300,11 +309,14 @@
if self.training:
self.camera_optimizer.apply_to_raybundle(ray_bundle)
ray_samples: RaySamples
ray_samples, weights_list, ray_samples_list = self.proposal_sampler(ray_bundle, density_fns=self.density_fns)
ray_samples, densities_list, weights_list, ray_samples_list = self.proposal_sampler(
ray_bundle, density_fns=self.density_fns
)
field_outputs = self.field.forward(ray_samples, compute_normals=self.config.predict_normals)
if self.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)

densities_list.append(field_outputs[FieldHeadNames.DENSITY])
weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
weights_list.append(weights)
ray_samples_list.append(ray_samples)
Expand All @@ -322,13 +334,23 @@
"expected_depth": expected_depth,
}

if self.visibility_field is not None:
assert isinstance(self.visibility_field, VisibilityField), "self.visibility_field must be a VisibilityField"
visibility_samples = self.visibility_field.forward(ray_samples_list[-1])
visibility = self.renderer_visibility(weights=weights, visibility_samples=visibility_samples)
visibility = visibility.float() / len(
self.visibility_field.c2ws
) # range [0, 1] where 1 is seen by all cameras
outputs["visibility"] = visibility

if self.config.predict_normals:
normals = self.renderer_normals(normals=field_outputs[FieldHeadNames.NORMALS], weights=weights)
pred_normals = self.renderer_normals(field_outputs[FieldHeadNames.PRED_NORMALS], weights=weights)
outputs["normals"] = self.normals_shader(normals)
outputs["pred_normals"] = self.normals_shader(pred_normals)
# These use a lot of GPU memory, so we avoid storing them for eval.
if self.training:
outputs["densities_list"] = densities_list
outputs["weights_list"] = weights_list
outputs["ray_samples_list"] = ray_samples_list

Expand Down