Skip to content

Commit

Permalink
Hack in a splat renderer
Browse files Browse the repository at this point in the history
  • Loading branch information
emilk committed Feb 15, 2024
1 parent 5cfec60 commit ccd7450
Show file tree
Hide file tree
Showing 8 changed files with 511 additions and 46 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

211 changes: 198 additions & 13 deletions crates/re_renderer/shader/point_cloud.wgsl
Expand Up @@ -8,9 +8,19 @@

@group(1) @binding(0)
var position_data_texture: texture_2d<f32>;

@group(1) @binding(1)
var color_texture: texture_2d<f32>;

/// 3D scale of point splats.
@group(1) @binding(2)
var scale_texture: texture_2d<f32>;

/// XYZW quaternion of point splats.
@group(1) @binding(3)
var rotation_texture: texture_2d<f32>;

@group(1) @binding(4)
var picking_instance_id_texture: texture_2d<u32>;

struct DrawDataUniformBuffer {
Expand All @@ -21,7 +31,7 @@ struct DrawDataUniformBuffer {
// if we wouldn't add padding here, which isn't available on WebGL.
_padding: vec4f,
};
@group(1) @binding(3)
@group(1) @binding(5)
var<uniform> draw_data: DrawDataUniformBuffer;

struct BatchUniformBuffer {
Expand Down Expand Up @@ -66,26 +76,33 @@ struct VertexOut {

@location(4) @interpolate(flat)
picking_instance_id: vec2u,

// [-2, +2] coordinates on the point splat
@location(5) @interpolate(perspective)
vpos: vec2f,
};

struct PointData {
pos: vec3f,
unresolved_radius: f32,
color: vec4f,
scale: vec3f,
rotation_quat_xyzw: vec4f,
picking_instance_id: vec2u,
}

// Read and unpack data at a given location
fn read_data(idx: u32) -> PointData {
let coord = vec2u(idx % TEXTURE_SIZE, idx / TEXTURE_SIZE);
let position_data = textureLoad(position_data_texture, coord, 0);
let color = textureLoad(color_texture, coord, 0);

var data: PointData;
let pos_4d = batch.world_from_obj * vec4f(position_data.xyz, 1.0);
data.pos = pos_4d.xyz / pos_4d.w;
data.unresolved_radius = position_data.w;
data.color = color;
data.color = textureLoad(color_texture, coord, 0);
data.scale = textureLoad(scale_texture, coord, 0).xyz;
data.rotation_quat_xyzw = textureLoad(rotation_texture, coord, 0);
data.picking_instance_id = textureLoad(picking_instance_id_texture, coord, 0).rg;
return data;
}
Expand All @@ -97,24 +114,154 @@ fn vs_main(@builtin(vertex_index) vertex_idx: u32) -> VertexOut {
// Read point data (valid for the entire quad)
let point_data = read_data(quad_idx);

// Span quad
let camera_distance = distance(frame.camera_position, point_data.pos);
let world_scale_factor = average_scale_from_transform(batch.world_from_obj); // TODO(andreas): somewhat costly, should precompute this
let world_radius = unresolved_size_to_world(point_data.unresolved_radius, camera_distance,
frame.auto_size_points, world_scale_factor) +
world_size_from_point_size(draw_data.radius_boost_in_ui_points, camera_distance);
let quad = sphere_or_circle_quad_span(vertex_idx, point_data.pos, world_radius,
has_any_flag(batch.flags, FLAG_DRAW_AS_CIRCLES));

// Output, transform to projection space and done.
let splat = true; // TODO: use batch.flags

var out: VertexOut;
out.position = apply_depth_offset(frame.projection_from_world * vec4f(quad.pos_in_world, 1.0), batch.depth_offset);
out.color = point_data.color;
out.radius = quad.point_resolved_radius;
out.world_position = quad.pos_in_world;
out.point_center = point_data.pos;
out.picking_instance_id = point_data.picking_instance_id;

if splat {
// Gaussian splatting code based on several (sometimes similar) sources:
// * https://github.com/antimatter15/splat/blob/main/main.js
// * https://github.com/aras-p/UnityGaussianSplatting/blob/main/package/Shaders/GaussianSplatting.hlsl
// * https://github.com/BladeTransformerLLC/gauzilla/blob/cef36adf71835eb60d1c6e8b2a2b34af3790c828/src/gsplat.vert
// * https://github.com/cvlab-epfl/gaussian-splatting-web // TODO: read this one!
// * https://github.com/huggingface/gsplat.js/blob/4933ddfecf1a859da473013e63b9d883e298cd15/src/renderers/webgl/programs/RenderProgram.ts
//
// TODO: read https://towardsdatascience.com/a-comprehensive-overview-of-gaussian-splatting-e7d570081362
// How to create your own splats: https://docs.spline.design/e17b7c105ef0433f8c5d2b39d512614e
//
// with references to equations in https://www.cs.umd.edu/~zwicker/publications/EWASplatting-TVCG02.pdf
// and https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_low.pdf
//
// TODO(emilk):
// * View-dependent color based on spherical-harmonics
// * Tranparency
let splat_scale = 1.0; // TODO: let user control it

let pos2d = frame.projection_from_world * vec4f(out.point_center, 1.0);

let clip = 1.2 * pos2d.w;
if (pos2d.z < -clip || pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) {
// Discard
out.position = vec4(0.0, 0.0, 0.0, 0.0);
out.color = vec4f();
return out;
}

// Convert rotation to 3x3 rotation matrix:
let rot: vec4f = point_data.rotation_quat_xyzw;
let qx = rot.x;
let qy = rot.y;
let qz = rot.z;
let qw = rot.w;
let r = mat3x3f(
1.0 - 2.0 * (qy * qy + qz * qz),
2.0 * (qx * qy + qw * qz),
2.0 * (qx * qz - qw * qy),

2.0 * (qx * qy - qw * qz),
1.0 - 2.0 * (qx * qx + qz * qz),
2.0 * (qy * qz + qw * qx),

2.0 * (qx * qz + qw * qy),
2.0 * (qy * qz - qw * qx),
1.0 - 2.0 * (qx * qx + qy * qy),
);

// Scale matrix:
let s = mat3x3f(
point_data.scale.x, 0.0, 0.0,
0.0, point_data.scale.y, 0.0,
0.0, 0.0, point_data.scale.z,
);

// world-space covariance matrix (called "Vrk" in other sources).
let cov3d_in_world = r * s * transpose(s) * transpose(r);

let pos_in_cam: vec3f = frame.view_from_world * vec4f(point_data.pos, 1.0);

// Project to 2D screen space and clamp:
let limx: f32 = 1.3 * frame.tan_half_fov.x;
let limy: f32 = 1.3 * frame.tan_half_fov.y;
let pos_in_2d = vec3f(
clamp(pos_in_cam.x / pos_in_cam.z, -limx, limx) * pos_in_cam.z,
clamp(pos_in_cam.y / pos_in_cam.z, -limy, limy) * pos_in_cam.z,
pos_in_cam.z,
);

// Crate Jacobian for the Taylor approximation of the nonlinear view_from_camera transformation:
let z_sqr = pos_in_2d.z * pos_in_2d.z;
let t = pos_in_2d;
let l = length(t);
let aspect_ratio = frame.projection_from_view[1][1] / frame.projection_from_view[0][0];
// EWA Splatting eq.29, with some modifications.
// I'm not sure how correct this is. The transpose of it also seems to work okish.
let j: mat3x3f = mat3x3f(
1.0 / (aspect_ratio * t.z), 0.0, -t.x / (t.z * t.z),
0.0, 1.0 / t.z, -t.y / (t.z * t.z),
0.0, 0.0, 1.0,
);

let view3 = mat3x3f(frame.view_from_world.x, frame.view_from_world.y, frame.view_from_world.z);

// eq.5 in https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_low.pdf
let jw: mat3x3f = j * view3;

// covariance matrix in view space
let cov2d: mat3x3f = jw * cov3d_in_world * transpose(jw);

// Find eigen-values of the covariance matrix:
let mid: f32 = 0.5 * (cov2d[0][0] + cov2d[1][1]);
let radius: f32 = length(vec2(0.5 * (cov2d[0][0] - cov2d[1][1]), cov2d[0][1]));
let lambda1: f32 = mid + radius;
let lambda2: f32 = mid - radius;

if (lambda2 < 0.0) {
// Discard
out.position = vec4(0.0, 0.0, 0.0, 0.0);
out.color = vec4f();
return out;
}

// Find eigen-vectors of the covariance matrix (the major and minor axes of the ellipsis):
let diagonal_vector: vec2f = normalize(vec2f(cov2d[0][1], lambda1 - cov2d[0][0]));
let major_axis: vec2f = min(sqrt(2.0 * lambda1), 1024.0) * diagonal_vector;
let minor_axis: vec2f = min(sqrt(2.0 * lambda2), 1024.0) * vec2f(diagonal_vector.y, -diagonal_vector.x);

let local_idx = vertex_idx % 6u;
let top_bottom = f32(local_idx <= 1u || local_idx == 5u) * 2.0 - 1.0; // 1 for a top vertex, -1 for a bottom vertex.
let left_right = f32(vertex_idx % 2u) * 2.0 - 1.0; // 1 for a right vertex, -1 for a left vertex.
let vpos = 2.0 * vec2f(left_right, top_bottom);

let major: vec2f = (vpos.x * major_axis);
let minor: vec2f = (vpos.y * minor_axis);

// Use a correctish Z so we don't need depth-sorting for opaque splats:
let proj = frame.projection_from_world * vec4f(out.point_center, 1.0);
let z = proj.z / proj.w;

out.vpos = vpos;
out.position = vec4(pos2d.xy / pos2d.w + splat_scale * (major + minor), z, 1.0);
out.color *= clamp(pos2d.z / pos2d.w + 1.0, 0.0, 1.0);
out.radius = -666.0; // signal that this is a splat
} else {
let world_radius = unresolved_size_to_world(point_data.unresolved_radius, camera_distance,
frame.auto_size_points, world_scale_factor) +
world_size_from_point_size(draw_data.radius_boost_in_ui_points, camera_distance);
let quad = sphere_or_circle_quad_span(vertex_idx, point_data.pos, world_radius,
has_any_flag(batch.flags, FLAG_DRAW_AS_CIRCLES));

// Output, transform to projection space and done.
out.position = apply_depth_offset(frame.projection_from_world * vec4f(quad.pos_in_world, 1.0), batch.depth_offset);
out.radius = quad.point_resolved_radius;
out.world_position = quad.pos_in_world;

}
return out;
}

Expand Down Expand Up @@ -142,6 +289,44 @@ fn coverage(world_position: vec3f, radius: f32, point_center: vec3f) -> f32 {

@fragment
fn fs_main(in: VertexOut) -> @location(0) vec4f {
let splat = in.radius == -666.0;
if splat {
let A = -dot(in.vpos, in.vpos);
if (A < -4.0) {
discard; // Outside the ellipsis
}

var alpha = 1.0;

// The input color comes with premultiplied alpha.
// Despite not sorting the splats, this looks okish.
// (re_renderer has premul alpha-blending turned on,
// despite not sorting, or having any order-independent transparency).
var rgba = in.color.rgba;


if false {
// In some scenes the alpha seems off, and this seems to help.
var a = rgba.a;
rgba /= a; // undo premul
a = pow(a, 0.1); // change alpha
rgba *= a; // redo premul
}

// TODO(#1611): transparency in rerun
if false {
// Fade the gaussian at the edges.
// This makes the splats way too transparent atm,
// but I think that is an artifact of use not sorting the splats,
// but having the Z buffer on, so a transparent edge will occlude other splats.
// Maybe the splats are too small too?
// If you turn this on, increase splat_scale to 3-4 or so.
rgba *= exp(A);
}

return rgba;
}

let coverage = coverage(in.world_position, in.radius, in.point_center);
if coverage < 0.001 {
discard;
Expand Down
47 changes: 47 additions & 0 deletions crates/re_renderer/src/point_cloud_builder.rs
Expand Up @@ -19,6 +19,8 @@ pub struct PointCloudBuilder {
pub vertices: Vec<PositionRadius>,

pub(crate) color_buffer: CpuWriteGpuReadBuffer<Color32>,
pub(crate) scale_buffer: CpuWriteGpuReadBuffer<glam::Vec4>, // TODO: optional
pub(crate) rotation_buffer: CpuWriteGpuReadBuffer<glam::Quat>, // TODO: optional
pub(crate) picking_instance_ids_buffer: CpuWriteGpuReadBuffer<PickingLayerInstanceId>,

pub(crate) batches: Vec<PointCloudBatchInfo>,
Expand All @@ -40,6 +42,27 @@ impl PointCloudBuilder {
PointCloudDrawData::MAX_NUM_POINTS,
)
.expect("Failed to allocate color buffer"); // TODO(#3408): Should never happen but should propagate error anyways

let scale_buffer = ctx
.cpu_write_gpu_read_belt
.lock()
.allocate::<glam::Vec4>(
&ctx.device,
&ctx.gpu_resources.buffers,
PointCloudDrawData::MAX_NUM_POINTS,
)
.expect("Failed to allocate scale buffer"); // TODO(#3408): Should never happen but should propagate error anyways

let rotation_buffer = ctx
.cpu_write_gpu_read_belt
.lock()
.allocate::<glam::Quat>(
&ctx.device,
&ctx.gpu_resources.buffers,
PointCloudDrawData::MAX_NUM_POINTS,
)
.expect("Failed to allocate rotation buffer"); // TODO(#3408): Should never happen but should propagate error anyways

let picking_instance_ids_buffer = ctx
.cpu_write_gpu_read_belt
.lock()
Expand All @@ -53,6 +76,8 @@ impl PointCloudBuilder {
Self {
vertices: Vec::with_capacity(RESERVE_SIZE),
color_buffer,
scale_buffer,
rotation_buffer,
picking_instance_ids_buffer,
batches: Vec::with_capacity(16),
radius_boost_in_ui_points_for_outlines: 0.0,
Expand Down Expand Up @@ -279,6 +304,28 @@ impl<'a> PointCloudBatchBuilder<'a> {
self
}

pub fn push_scales3(&mut self, scales: &[glam::Vec3]) {
// TODO: handle only some point clouds having scales
re_tracing::profile_function!();
let scales4 = scales
.iter()
.copied()
.map(|s| glam::Vec4::new(s.x, s.y, s.z, 1.0));
self.0
.scale_buffer
.extend(scales4.into_iter())
.unwrap_debug_or_log_error();
}

pub fn push_rotations(&mut self, rotations: &[glam::Quat]) {
// TODO: handle only some point clouds having rotations
re_tracing::profile_function!();
self.0
.rotation_buffer
.extend_from_slice(rotations)
.unwrap_debug_or_log_error();
}

/// Pushes additional outline mask ids for a specific range of points.
/// The range is relative to this batch.
///
Expand Down

0 comments on commit ccd7450

Please sign in to comment.