Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support splat export in original dataset coordinates #2951

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ def viewmatrix(lookat: torch.Tensor, up: torch.Tensor, pos: torch.Tensor) -> Flo
"""
vec2 = normalize(lookat)
vec1_avg = normalize(up)
vec0 = normalize(torch.cross(vec1_avg, vec2))
vec1 = normalize(torch.cross(vec2, vec0))
vec0 = normalize(torch.linalg.cross(vec1_avg, vec2))
vec1 = normalize(torch.linalg.cross(vec2, vec0))
m = torch.stack([vec0, vec1, vec2, pos], 1)
return m

Expand Down
50 changes: 46 additions & 4 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
import tyro
from typing_extensions import Annotated, Literal

from nerfstudio.cameras.camera_utils import quaternion_from_matrix
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.exporter import texture_utils, tsdf_utils
Expand Down Expand Up @@ -121,7 +123,7 @@ class ExportPointCloud(Exporter):
"""Number of rays to evaluate per batch. Decrease if you run out of memory."""
std_ratio: float = 10.0
"""Threshold based on STD of the average distances across the point cloud to remove outliers."""
save_world_frame: bool = False
save_world_frame: bool = True
"""If set, saves the point cloud in the same frame as the original dataset. Otherwise, uses the
scaled and reoriented coordinate space expected by the NeRF models."""

Expand Down Expand Up @@ -482,6 +484,11 @@ class ExportGaussianSplat(Exporter):
Export 3D Gaussian Splatting model to a .ply
"""

save_world_frame: bool = True
"""If set, saves the splat in the same frame as the original dataset.
Otherwise, uses the scaled and reoriented coordinate space produced
internally by Nerfstudio."""

def main(self) -> None:
if not self.output_dir.exists():
self.output_dir.mkdir(parents=True)
Expand All @@ -497,7 +504,26 @@ def main(self) -> None:
map_to_tensors = {}

with torch.no_grad():
positions = model.means.cpu().numpy()
if self.save_world_frame:
assert isinstance(pipeline.datamanager, FullImageDatamanager)
dataparser_outputs = pipeline.datamanager.train_dataparser_outputs
dataparser_scale = dataparser_outputs.dataparser_scale
dataparser_transform = dataparser_outputs.dataparser_transform.numpy(force=True)

output_scale = 1 / dataparser_scale
output_transform = np.zeros((3, 4))
output_transform[:3, :3] = dataparser_transform[:3, :3].T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty please don't do transform math w/out at least comments, this sort of code is 110% likely to put a future reader in transform hell

also pretty please use pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space() because
(1) that's what's used elsewhere in this file
(2) using that function ensures future refactors won't break things, and most past nerfstudio refactors have indeed broken lots of things

output_transform[:3, 3] = -dataparser_transform[:3, :3].T @ dataparser_transform[:3, 3]
else:
output_scale = 1
output_transform = np.zeros((3, 4))
output_transform[:3, :3] = np.eye(3)
inv_dataparser_quat = quaternion_from_matrix(output_transform[:3, :3])

positions = (
np.einsum("ij,bj->bi", output_transform[:3, :3], model.means.cpu().numpy() * output_scale)
+ output_transform[None, :3, 3]
Comment on lines +524 to +525
Copy link
Contributor

@pwais pwais Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't do this, ESPECIALLY w/out comments. I have lost a lot of time reading nerfstudio code that's like this. instead consider:

positions = model.means.cpu().numpy()

poses = np.eye(4, dtype=np.float32)[None, ...].repeat(positions.shape[0], axis=0)[:, :3, :]
poses[:, :3, 3] = positions
poses = pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space(
    torch.from_numpy(poses)
)

)
n = positions.shape[0]
map_to_tensors["positions"] = positions
map_to_tensors["normals"] = np.zeros_like(positions, dtype=np.float32)
Expand All @@ -518,11 +544,27 @@ def main(self) -> None:

map_to_tensors["opacity"] = model.opacities.data.cpu().numpy()

scales = model.scales.data.cpu().numpy()
# Note that scales are in log space!
scales = model.scales.data.cpu().numpy() + np.log(output_scale)
for i in range(3):
map_to_tensors[f"scale_{i}"] = scales[:, i, None]

quats = model.quats.data.cpu().numpy()
def quaternion_multiply(wxyz0: np.ndarray, wxyz1: np.ndarray) -> np.ndarray:
Copy link
Contributor

@pwais pwais Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??? First of all, scipy.spatial.transform already has quaternion multiply... at least this code is clear about scalar-first versus scalar-last.

this could be made more concise, but consider instead:

from scipy.spatial.transform import Rotation as ScR
# ns gplat says quaternions are [w,x,y,z] scalar-first format
# scipy is [x, y, z, w] scalar-last format

raw_quats = model.quats.data.cpu().numpy().squeeze()
R_quats = ScR.from_quat(raw_quats[:, [1, 2, 3, 0]])

# apply the inverse dataparser transform to the splat rotations
poses = np.eye(4, dtype=np.float32)[None, ...].repeat(raw_quats.shape[0], axis=0)[:, :3, :]
poses[:, :3, :3] = R_quats.as_matrix()
poses = pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space(
    torch.from_numpy(poses)
)
rots_in_input = poses[:, :3, :3].numpy()
quat_in_input = ScR.from_matrix(rots_in_input)
quats = quat_in_input.as_quat()[:, [3, 0, 1, 2], None]

Again, this uses transform_poses_to_original_space(), which might amalgamate several different transforms and scales, who knows? Instead of trying to re-derive the transform as the current PR does. And hopefully transform_poses_to_original_space() gets maintained. But it's really really important to be clear about frames etc, and transform_poses_to_original_space() helps with that a ton.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(note for when we revive this PR, which is planned) for the quaternion multiply if we don't want to deal with the xyzw/wxyz conversion of scipy we can also use (vtf.SO3(wxyz0) @ vtf.SO3(wxyz1)).wxyz with import viser.transforms as vtf where viser>=0.1.30

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

voicing a preference for the use of standard scipy / numpy / torch wherever possible

yes it's unfortunate that there are different quaternion encodings, different camera conventions, different euler angle conventions ....

assert wxyz0.shape[-1] == 4
assert wxyz1.shape[-1] == 4
w0, x0, y0, z0 = np.moveaxis(wxyz0, -1, 0)
w1, x1, y1, z1 = np.moveaxis(wxyz1, -1, 0)
return np.stack(
[
-x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
-x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
],
axis=-1,
)

quats = quaternion_multiply(inv_dataparser_quat, model.quats.data.cpu().numpy())
for i in range(4):
map_to_tensors[f"rot_{i}"] = quats[:, i, None]

Expand Down
11 changes: 10 additions & 1 deletion nerfstudio/viewer/export_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def populate_point_cloud_tab(
num_points = server.add_gui_number("# Points", initial_value=1_000_000, min=1, max=None, step=1)
world_frame = server.add_gui_checkbox(
"Save in world frame",
False,
True,
hint=(
"If checked, saves the point cloud in the same frame as the original dataset. Otherwise, uses the "
"scaled and reoriented coordinate space expected by the NeRF models."
Expand Down Expand Up @@ -194,6 +194,14 @@ def populate_splat_tab(
server.add_gui_markdown("<small>Generate ply export of Gaussian Splat</small>")

output_directory = server.add_gui_text("Output Directory", initial_value="exports/splat/")
world_frame = server.add_gui_checkbox(
"Save in world frame",
True,
hint=(
"If checked, saves the splat file in the same frame as the original dataset. Otherwise, uses the "
"scaled and reoriented coordinate space generated by Nerfstudio."
),
)
generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2)

@generate_command.on_click
Expand All @@ -204,6 +212,7 @@ def _(event: viser.GuiEvent) -> None:
"ns-export gaussian-splat",
f"--load-config {config_path}",
f"--output-dir {output_directory.value}",
f"--save-world-frame {world_frame.value}",
]
)
show_command_modal(event.client, "splat", command)
Expand Down