Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian Splats 0.1 #4991

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

216 changes: 205 additions & 11 deletions crates/re_renderer/shader/point_cloud.wgsl
Expand Up @@ -8,9 +8,19 @@

@group(1) @binding(0)
var position_data_texture: texture_2d<f32>;

@group(1) @binding(1)
var color_texture: texture_2d<f32>;

/// 3D scale of point splats.
@group(1) @binding(2)
var scale_texture: texture_2d<f32>;

/// XYZW quaternion of point splats.
@group(1) @binding(3)
var rotation_texture: texture_2d<f32>;

@group(1) @binding(4)
var picking_instance_id_texture: texture_2d<u32>;

struct DrawDataUniformBuffer {
Expand All @@ -21,7 +31,7 @@ struct DrawDataUniformBuffer {
// if we wouldn't add padding here, which isn't available on WebGL.
_padding: vec4f,
};
@group(1) @binding(3)
@group(1) @binding(5)
var<uniform> draw_data: DrawDataUniformBuffer;

struct BatchUniformBuffer {
Expand Down Expand Up @@ -64,12 +74,18 @@ struct VertexOut {

@location(4) @interpolate(flat)
picking_instance_id: vec2u,

// [-2, +2] coordinates on the point splat
@location(5) @interpolate(perspective)
vpos: vec2f,
};

struct PointData {
pos: vec3f,
unresolved_radius: f32,
color: vec4f,
scale: vec3f,
rotation_quat_xyzw: vec4f,
picking_instance_id: vec2u,
}

Expand All @@ -83,6 +99,14 @@ fn read_data(idx: u32) -> PointData {
let color = textureLoad(color_texture,
vec2u(idx % color_texture_size.x, idx / color_texture_size.x), 0);

let scale_texture_size = textureDimensions(scale_texture);
let scale = textureLoad(scale_texture,
vec2u(idx % scale_texture_size.x, idx / scale_texture_size.x), 0).xyz;

let rotation_texture_size = textureDimensions(rotation_texture);
let rotation = textureLoad(rotation_texture,
vec2u(idx % rotation_texture_size.x, idx / rotation_texture_size.x), 0);

let picking_instance_id_texture_size = textureDimensions(picking_instance_id_texture);
let picking_instance_id = textureLoad(picking_instance_id_texture,
vec2u(idx % picking_instance_id_texture_size.x, idx / picking_instance_id_texture_size.x), 0).xy;
Expand All @@ -92,6 +116,8 @@ fn read_data(idx: u32) -> PointData {
data.pos = pos_4d.xyz / pos_4d.w;
data.unresolved_radius = position_data.w;
data.color = color;
data.scale = scale;
data.rotation_quat_xyzw = rotation;
data.picking_instance_id = picking_instance_id;
return data;
}
Expand All @@ -103,24 +129,154 @@ fn vs_main(@builtin(vertex_index) vertex_idx: u32) -> VertexOut {
// Read point data (valid for the entire quad)
let point_data = read_data(quad_idx);

// Span quad
let camera_distance = distance(frame.camera_position, point_data.pos);
let world_scale_factor = average_scale_from_transform(batch.world_from_obj); // TODO(andreas): somewhat costly, should precompute this
let world_radius = unresolved_size_to_world(point_data.unresolved_radius, camera_distance,
frame.auto_size_points, world_scale_factor) +
world_size_from_point_size(draw_data.radius_boost_in_ui_points, camera_distance);
let quad = sphere_or_circle_quad_span(vertex_idx, point_data.pos, world_radius,
has_any_flag(batch.flags, FLAG_DRAW_AS_CIRCLES));

// Output, transform to projection space and done.
let splat = true; // TODO: use batch.flags

var out: VertexOut;
out.position = apply_depth_offset(frame.projection_from_world * vec4f(quad.pos_in_world, 1.0), batch.depth_offset);
out.color = point_data.color;
out.radius = quad.point_resolved_radius;
out.world_position = quad.pos_in_world;
out.point_center = point_data.pos;
out.picking_instance_id = point_data.picking_instance_id;

if splat {
// Gaussian splatting code based on several (sometimes similar) sources:
// * https://github.com/antimatter15/splat/blob/main/main.js
// * https://github.com/aras-p/UnityGaussianSplatting/blob/main/package/Shaders/GaussianSplatting.hlsl
// * https://github.com/BladeTransformerLLC/gauzilla/blob/cef36adf71835eb60d1c6e8b2a2b34af3790c828/src/gsplat.vert
// * https://github.com/cvlab-epfl/gaussian-splatting-web // TODO: read this one!
// * https://github.com/huggingface/gsplat.js/blob/4933ddfecf1a859da473013e63b9d883e298cd15/src/renderers/webgl/programs/RenderProgram.ts
//
// TODO: read https://towardsdatascience.com/a-comprehensive-overview-of-gaussian-splatting-e7d570081362
// How to create your own splats: https://docs.spline.design/e17b7c105ef0433f8c5d2b39d512614e
//
// with references to equations in https://www.cs.umd.edu/~zwicker/publications/EWASplatting-TVCG02.pdf
// and https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_low.pdf
//
// TODO(emilk):
// * View-dependent color based on spherical-harmonics
// * Tranparency
let splat_scale = 1.0; // TODO: let user control it

let pos2d = frame.projection_from_world * vec4f(out.point_center, 1.0);

let clip = 1.2 * pos2d.w;
if (pos2d.z < -clip || pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) {
// Discard
out.position = vec4(0.0, 0.0, 0.0, 0.0);
out.color = vec4f();
return out;
}

// Convert rotation to 3x3 rotation matrix:
let rot: vec4f = point_data.rotation_quat_xyzw;
let qx = rot.x;
let qy = rot.y;
let qz = rot.z;
let qw = rot.w;
let r = mat3x3f(
1.0 - 2.0 * (qy * qy + qz * qz),
2.0 * (qx * qy + qw * qz),
2.0 * (qx * qz - qw * qy),

2.0 * (qx * qy - qw * qz),
1.0 - 2.0 * (qx * qx + qz * qz),
2.0 * (qy * qz + qw * qx),

2.0 * (qx * qz + qw * qy),
2.0 * (qy * qz - qw * qx),
1.0 - 2.0 * (qx * qx + qy * qy),
);

// Scale matrix:
let s = mat3x3f(
point_data.scale.x, 0.0, 0.0,
0.0, point_data.scale.y, 0.0,
0.0, 0.0, point_data.scale.z,
);

// world-space covariance matrix (called "Vrk" in other sources).
let cov3d_in_world = r * s * transpose(s) * transpose(r);

let pos_in_cam: vec3f = frame.view_from_world * vec4f(point_data.pos, 1.0);

// Project to 2D screen space and clamp:
let limx: f32 = 1.3 * frame.tan_half_fov.x;
let limy: f32 = 1.3 * frame.tan_half_fov.y;
let pos_in_2d = vec3f(
clamp(pos_in_cam.x / pos_in_cam.z, -limx, limx) * pos_in_cam.z,
clamp(pos_in_cam.y / pos_in_cam.z, -limy, limy) * pos_in_cam.z,
pos_in_cam.z,
);

// Crate Jacobian for the Taylor approximation of the nonlinear view_from_camera transformation:
let z_sqr = pos_in_2d.z * pos_in_2d.z;
let t = pos_in_2d;
let l = length(t);
let aspect_ratio = frame.projection_from_view[1][1] / frame.projection_from_view[0][0];
// EWA Splatting eq.29, with some modifications.
// I'm not sure how correct this is. The transpose of it also seems to work okish.
let j: mat3x3f = mat3x3f(
1.0 / (aspect_ratio * t.z), 0.0, -t.x / (t.z * t.z),
0.0, 1.0 / t.z, -t.y / (t.z * t.z),
0.0, 0.0, 1.0,
);

let view3 = mat3x3f(frame.view_from_world.x, frame.view_from_world.y, frame.view_from_world.z);

// eq.5 in https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_low.pdf
let jw: mat3x3f = j * view3;

// covariance matrix in view space
let cov2d: mat3x3f = jw * cov3d_in_world * transpose(jw);

// Find eigen-values of the covariance matrix:
let mid: f32 = 0.5 * (cov2d[0][0] + cov2d[1][1]);
let radius: f32 = length(vec2(0.5 * (cov2d[0][0] - cov2d[1][1]), cov2d[0][1]));
let lambda1: f32 = mid + radius;
let lambda2: f32 = mid - radius;

if (lambda2 < 0.0) {
// Discard
out.position = vec4(0.0, 0.0, 0.0, 0.0);
out.color = vec4f();
return out;
}

// Find eigen-vectors of the covariance matrix (the major and minor axes of the ellipsis):
let diagonal_vector: vec2f = normalize(vec2f(cov2d[0][1], lambda1 - cov2d[0][0]));
let major_axis: vec2f = min(sqrt(2.0 * lambda1), 1024.0) * diagonal_vector;
let minor_axis: vec2f = min(sqrt(2.0 * lambda2), 1024.0) * vec2f(diagonal_vector.y, -diagonal_vector.x);

let local_idx = vertex_idx % 6u;
let top_bottom = f32(local_idx <= 1u || local_idx == 5u) * 2.0 - 1.0; // 1 for a top vertex, -1 for a bottom vertex.
let left_right = f32(vertex_idx % 2u) * 2.0 - 1.0; // 1 for a right vertex, -1 for a left vertex.
let vpos = 2.0 * vec2f(left_right, top_bottom);

let major: vec2f = (vpos.x * major_axis);
let minor: vec2f = (vpos.y * minor_axis);

// Use a correctish Z so we don't need depth-sorting for opaque splats:
let proj = frame.projection_from_world * vec4f(out.point_center, 1.0);
let z = proj.z / proj.w;

out.vpos = vpos;
out.position = vec4(pos2d.xy / pos2d.w + splat_scale * (major + minor), z, 1.0);
out.color *= clamp(pos2d.z / pos2d.w + 1.0, 0.0, 1.0);
out.radius = -666.0; // signal that this is a splat
} else {
let world_radius = unresolved_size_to_world(point_data.unresolved_radius, camera_distance,
frame.auto_size_points, world_scale_factor) +
world_size_from_point_size(draw_data.radius_boost_in_ui_points, camera_distance);
let quad = sphere_or_circle_quad_span(vertex_idx, point_data.pos, world_radius,
has_any_flag(batch.flags, FLAG_DRAW_AS_CIRCLES));

// Output, transform to projection space and done.
out.position = apply_depth_offset(frame.projection_from_world * vec4f(quad.pos_in_world, 1.0), batch.depth_offset);
out.radius = quad.point_resolved_radius;
out.world_position = quad.pos_in_world;

}
return out;
}

Expand Down Expand Up @@ -148,6 +304,44 @@ fn coverage(world_position: vec3f, radius: f32, point_center: vec3f) -> f32 {

@fragment
fn fs_main(in: VertexOut) -> @location(0) vec4f {
let splat = in.radius == -666.0;
if splat {
let A = -dot(in.vpos, in.vpos);
if (A < -4.0) {
discard; // Outside the ellipsis
}

var alpha = 1.0;

// The input color comes with premultiplied alpha.
// Despite not sorting the splats, this looks okish.
// (re_renderer has premul alpha-blending turned on,
// despite not sorting, or having any order-independent transparency).
var rgba = in.color.rgba;


if false {
// In some scenes the alpha seems off, and this seems to help.
var a = rgba.a;
rgba /= a; // undo premul
a = pow(a, 0.1); // change alpha
rgba *= a; // redo premul
}

// TODO(#1611): transparency in rerun
if false {
// Fade the gaussian at the edges.
// This makes the splats way too transparent atm,
// but I think that is an artifact of use not sorting the splats,
// but having the Z buffer on, so a transparent edge will occlude other splats.
// Maybe the splats are too small too?
// If you turn this on, increase splat_scale to 3-4 or so.
rgba *= exp(A);
}

return rgba;
}

let coverage = coverage(in.world_position, in.radius, in.point_center);
if coverage < 0.001 {
discard;
Expand Down
54 changes: 54 additions & 0 deletions crates/re_renderer/src/point_cloud_builder.rs
Expand Up @@ -19,6 +19,8 @@ pub struct PointCloudBuilder {
pub vertices: Vec<PositionRadius>,

pub(crate) color_buffer: CpuWriteGpuReadBuffer<Color32>,
pub(crate) scale_buffer: CpuWriteGpuReadBuffer<glam::Vec4>, // TODO: optional
pub(crate) rotation_buffer: CpuWriteGpuReadBuffer<glam::Quat>, // TODO: optional
pub(crate) picking_instance_ids_buffer: CpuWriteGpuReadBuffer<PickingLayerInstanceId>,

pub(crate) batches: Vec<PointCloudBatchInfo>,
Expand Down Expand Up @@ -46,6 +48,34 @@ impl PointCloudBuilder {
)
.expect("Failed to allocate color buffer"); // TODO(#3408): Should never happen but should propagate error anyways

let scale_buffer = ctx
.cpu_write_gpu_read_belt
.lock()
.allocate::<glam::Vec4>(
&ctx.device,
&ctx.gpu_resources.buffers,
data_texture_source_buffer_element_count(
PointCloudDrawData::SCALE_TEXTURE_FORMAT,
max_num_points,
max_texture_dimension_2d,
),
)
.expect("Failed to allocate scale buffer"); // TODO(#3408): Should never happen but should propagate error anyways

let rotation_buffer = ctx
.cpu_write_gpu_read_belt
.lock()
.allocate::<glam::Quat>(
&ctx.device,
&ctx.gpu_resources.buffers,
data_texture_source_buffer_element_count(
PointCloudDrawData::ROTATION_TEXTURE_FORMAT,
max_num_points,
max_texture_dimension_2d,
),
)
.expect("Failed to allocate rotation buffer"); // TODO(#3408): Should never happen but should propagate error anyways

let picking_instance_ids_buffer = ctx
.cpu_write_gpu_read_belt
.lock()
Expand All @@ -63,6 +93,8 @@ impl PointCloudBuilder {
Self {
vertices: Vec::with_capacity(max_num_points as usize),
color_buffer,
scale_buffer,
rotation_buffer,
picking_instance_ids_buffer,
batches: Vec::with_capacity(16),
radius_boost_in_ui_points_for_outlines: 0.0,
Expand Down Expand Up @@ -290,6 +322,28 @@ impl<'a> PointCloudBatchBuilder<'a> {
self
}

pub fn push_scales3(&mut self, scales: &[glam::Vec3]) {
// TODO: handle only some point clouds having scales
re_tracing::profile_function!();
let scales4 = scales
.iter()
.copied()
.map(|s| glam::Vec4::new(s.x, s.y, s.z, 1.0));
self.0
.scale_buffer
.extend(scales4.into_iter())
.unwrap_debug_or_log_error();
}

pub fn push_rotations(&mut self, rotations: &[glam::Quat]) {
// TODO: handle only some point clouds having rotations
re_tracing::profile_function!();
self.0
.rotation_buffer
.extend_from_slice(rotations)
.unwrap_debug_or_log_error();
}

/// Pushes additional outline mask ids for a specific range of points.
/// The range is relative to this batch.
///
Expand Down