From d09fa8cd10eedc837debef45b72b499b9e483713 Mon Sep 17 00:00:00 2001 From: Donovan Hutchence Date: Mon, 20 May 2024 10:55:01 +0100 Subject: [PATCH] Render compressed GS data (#6371) --- src/framework/parsers/gsplat-resource.js | 37 +- src/framework/parsers/ply.js | 10 +- .../gsplat/gsplat-compressed-material.js | 361 +++++++++++ src/scene/gsplat/gsplat-compressed.js | 155 +++++ src/scene/gsplat/gsplat-data.js | 582 +++++++++++------- src/scene/gsplat/gsplat-instance.js | 4 +- src/scene/gsplat/gsplat-material.js | 8 +- src/scene/gsplat/gsplat.js | 137 ++--- src/scene/gsplat/shader-generator-gsplat.js | 55 +- 9 files changed, 982 insertions(+), 367 deletions(-) create mode 100644 src/scene/gsplat/gsplat-compressed-material.js create mode 100644 src/scene/gsplat/gsplat-compressed.js diff --git a/src/framework/parsers/gsplat-resource.js b/src/framework/parsers/gsplat-resource.js index 9933f59b7c5..c91b249cdde 100644 --- a/src/framework/parsers/gsplat-resource.js +++ b/src/framework/parsers/gsplat-resource.js @@ -1,7 +1,7 @@ -import { BoundingBox } from '../../core/shape/bounding-box.js'; import { Entity } from '../entity.js'; import { GSplatInstance } from '../../scene/gsplat/gsplat-instance.js'; import { GSplat } from '../../scene/gsplat/gsplat.js'; +import { GSplatCompressed } from '../../scene/gsplat/gsplat-compressed.js'; /** * The resource for the gsplat asset type. @@ -22,7 +22,7 @@ class GSplatResource { splatData; /** - * @type {GSplat | null} + * @type {GSplat | GSplatCompressed | null} * @ignore */ splat = null; @@ -34,7 +34,7 @@ class GSplatResource { */ constructor(device, splatData) { this.device = device; - this.splatData = splatData.isCompressed ? splatData.decompress() : splatData; + this.splatData = splatData; } destroy() { @@ -46,37 +46,8 @@ class GSplatResource { createSplat() { if (!this.splat) { - - const splatData = this.splatData; - - const aabb = new BoundingBox(); - this.splatData.calcAabb(aabb); - - const splat = new GSplat(this.device, splatData.numSplats, aabb); - this.splat = splat; - - // texture data - splat.updateColorData(splatData.getProp('f_dc_0'), splatData.getProp('f_dc_1'), splatData.getProp('f_dc_2'), splatData.getProp('opacity')); - splat.updateTransformData( - splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'), - splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'), - splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2') - ); - - // centers - constant buffer that is sent to the worker - const x = splatData.getProp('x'); - const y = splatData.getProp('y'); - const z = splatData.getProp('z'); - - const centers = new Float32Array(this.splatData.numSplats * 3); - for (let i = 0; i < this.splatData.numSplats; ++i) { - centers[i * 3 + 0] = x[i]; - centers[i * 3 + 1] = y[i]; - centers[i * 3 + 2] = z[i]; - } - splat.centers = centers; + this.splat = this.splatData.isCompressed ? new GSplatCompressed(this.device, this.splatData) : new GSplat(this.device, this.splatData); } - return this.splat; } diff --git a/src/framework/parsers/ply.js b/src/framework/parsers/ply.js index 8441fd962bd..18822f21293 100644 --- a/src/framework/parsers/ply.js +++ b/src/framework/parsers/ply.js @@ -294,11 +294,19 @@ class PlyParser { } else { readPly(response.body.getReader(), asset.data.elementFilter ?? defaultElementFilter) .then((response) => { + // construct the GSplatData object const gsplatData = new GSplatData(response, { performZScale: asset.data.performZScale, reorder: asset.data.reorder }); - callback(null, new GSplatResource(this.device, gsplatData)); + + // construct the resource + const resource = new GSplatResource( + this.device, + gsplatData.isCompressed && asset.data.decompress ? gsplatData.decompress() : gsplatData + ); + + callback(null, resource); }) .catch((err) => { callback(err, null); diff --git a/src/scene/gsplat/gsplat-compressed-material.js b/src/scene/gsplat/gsplat-compressed-material.js new file mode 100644 index 00000000000..0b781136226 --- /dev/null +++ b/src/scene/gsplat/gsplat-compressed-material.js @@ -0,0 +1,361 @@ +import { CULLFACE_NONE, SEMANTIC_ATTR13, SEMANTIC_POSITION } from "../../platform/graphics/constants.js"; +import { ShaderProcessorOptions } from "../../platform/graphics/shader-processor-options.js"; +import { BLEND_NONE, BLEND_NORMAL, DITHER_NONE, GAMMA_NONE, GAMMA_SRGBHDR, SHADER_FORWARDHDR, TONEMAP_LINEAR } from "../constants.js"; +import { Material } from "../materials/material.js"; +import { getProgramLibrary } from "../shader-lib/get-program-library.js"; + +import { hashCode } from "../../core/hash.js"; +import { ShaderUtils } from "../../platform/graphics/shader-utils.js"; +import { shaderChunks } from "../shader-lib/chunks/chunks.js"; +import { ShaderGenerator } from "../shader-lib/programs/shader-generator.js"; +import { ShaderPass } from "../shader-pass.js"; + +const splatCoreVS = /* glsl */ ` + +uniform mat4 matrix_model; +uniform mat4 matrix_view; +uniform mat4 matrix_projection; + +attribute vec3 vertex_position; +attribute uint vertex_id_attrib; + +varying vec2 texCoord; +varying vec4 color; + +#ifndef DITHER_NONE + varying float id; +#endif + +uniform vec2 viewport; +uniform vec4 bufferWidths; +uniform highp usampler2D splatOrder; +uniform highp usampler2D packedTexture; +uniform highp sampler2D chunkTexture; + +uint splatId; +ivec2 splatUV; +ivec2 chunkUV; + +uvec4 packedData; +vec4 chunkDataA; +vec4 chunkDataB; +vec4 chunkDataC; + +void calcUV() { + int packedWidth = int(bufferWidths.x); + int chunkWidth = int(bufferWidths.y); + + // sample order texture + uint orderId = vertex_id_attrib + uint(vertex_position.z); + ivec2 orderUV = ivec2( + int(orderId) % packedWidth, + int(orderId) / packedWidth + ); + + // calculate splatUV + splatId = texelFetch(splatOrder, orderUV, 0).r; + splatUV = ivec2( + int(splatId) % packedWidth, + int(splatId) / packedWidth + ); + + // calculate chunkUV + int chunkId = int(splatId / 256u); + chunkUV = ivec2( + (chunkId % chunkWidth) * 3, + chunkId / chunkWidth + ); +} + +vec3 unpack111011(uint bits) { + return vec3( + float(bits >> 21u) / 2047.0, + float((bits >> 11u) & 0x3ffu) / 1023.0, + float(bits & 0x7ffu) / 2047.0 + ); +} + +vec4 unpack8888(uint bits) { + return vec4( + float(bits >> 24u) / 255.0, + float((bits >> 16u) & 0xffu) / 255.0, + float((bits >> 8u) & 0xffu) / 255.0, + float(bits & 0xffu) / 255.0 + ); +} + +float norm = 1.0 / (sqrt(2.0) * 0.5); + +vec4 unpackRotation(uint bits) { + float a = (float((bits >> 20u) & 0x3ffu) / 1023.0 - 0.5) * norm; + float b = (float((bits >> 10u) & 0x3ffu) / 1023.0 - 0.5) * norm; + float c = (float(bits & 0x3ffu) / 1023.0 - 0.5) * norm; + float m = sqrt(1.0 - (a * a + b * b + c * c)); + + uint mode = bits >> 30u; + if (mode == 0u) return vec4(m, a, b, c); + if (mode == 1u) return vec4(a, m, b, c); + if (mode == 2u) return vec4(a, b, m, c); + return vec4(a, b, c, m); +} + +vec3 getPosition() { + return mix(chunkDataA.xyz, vec3(chunkDataA.w, chunkDataB.xy), unpack111011(packedData.x)); +} + +vec4 getRotation() { + return unpackRotation(packedData.y); +} + +vec3 getScale() { + return exp(mix(vec3(chunkDataB.zw, chunkDataC.x), chunkDataC.yzw, unpack111011(packedData.z))); +} + +vec4 getColor() { + return unpack8888(packedData.w); +} + +mat3 quatToMat3(vec4 R) { + float x = R.x; + float y = R.y; + float z = R.z; + float w = R.w; + return mat3( + 1.0 - 2.0 * (z * z + w * w), + 2.0 * (y * z + x * w), + 2.0 * (y * w - x * z), + 2.0 * (y * z - x * w), + 1.0 - 2.0 * (y * y + w * w), + 2.0 * (z * w + x * y), + 2.0 * (y * w + x * z), + 2.0 * (z * w - x * y), + 1.0 - 2.0 * (y * y + z * z) + ); +} + +// Given a rotation matrix and scale vector, compute 3d covariance A and B +void calcCov3d(mat3 rot, vec3 scale, out vec3 covA, out vec3 covB) { + // M = S * R + mat3 M = transpose(mat3( + scale.x * rot[0], + scale.y * rot[1], + scale.z * rot[2] + )); + covA = vec3(dot(M[0], M[0]), dot(M[0], M[1]), dot(M[0], M[2])); + covB = vec3(dot(M[1], M[1]), dot(M[1], M[2]), dot(M[2], M[2])); +} + +// given the splat center (view space) and covariance A and B vectors, calculate +// the v1 and v2 vectors for this view. +vec4 calcV1V2(vec3 centerView, vec3 covA, vec3 covB, float focal, mat3 W) { + + mat3 Vrk = mat3( + covA.x, covA.y, covA.z, + covA.y, covB.x, covB.y, + covA.z, covB.y, covB.z + ); + + float J1 = focal / centerView.z; + vec2 J2 = -J1 / centerView.z * centerView.xy; + mat3 J = mat3( + J1, 0.0, J2.x, + 0.0, J1, J2.y, + 0.0, 0.0, 0.0 + ); + + mat3 T = W * J; + mat3 cov = transpose(T) * Vrk * T; + + float diagonal1 = cov[0][0] + 0.3; + float offDiagonal = cov[0][1]; + float diagonal2 = cov[1][1] + 0.3; + + float mid = 0.5 * (diagonal1 + diagonal2); + float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal)); + float lambda1 = mid + radius; + float lambda2 = max(mid - radius, 0.1); + vec2 diagonalVector = normalize(vec2(offDiagonal, lambda1 - diagonal1)); + + vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; + vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); + + return vec4(v1, v2); +} + +vec4 evalSplat() { + // calculate source UVs + calcUV(); + + // read raw data + packedData = texelFetch(packedTexture, splatUV, 0); + chunkDataA = texelFetch(chunkTexture, chunkUV, 0); + chunkDataB = texelFetch(chunkTexture, ivec2(chunkUV.x + 1, chunkUV.y), 0); + chunkDataC = texelFetch(chunkTexture, ivec2(chunkUV.x + 2, chunkUV.y), 0); + + mat4 modelView = matrix_view * matrix_model; + vec4 centerView = modelView * vec4(getPosition(), 1.0); + vec4 centerClip = matrix_projection * centerView; + + // cull behind camera + if (centerClip.z < -centerClip.w) { + return vec4(0.0, 0.0, 2.0, 1.0); + } + + // calculate the 3d covariance vectors from rotation and scale + vec3 covA, covB; + calcCov3d(quatToMat3(getRotation()), getScale(), covA, covB); + + vec4 v1v2 = calcV1V2(centerView.xyz, covA, covB, viewport.x * matrix_projection[0][0], transpose(mat3(modelView))); + + // early out tiny splats + // TODO: figure out length units and expose as uniform parameter + // TODO: perhaps make this a shader compile-time option + if (dot(v1v2.xy, v1v2.xy) < 4.0 && dot(v1v2.zw, v1v2.zw) < 4.0) { + return vec4(0.0, 0.0, 2.0, 1.0); + } + + texCoord = vertex_position.xy; + color = getColor(); + + #ifndef DITHER_NONE + id = float(splatId); + #endif + + return centerClip + vec4((texCoord.x * v1v2.xy + texCoord.y * v1v2.zw) / viewport * centerClip.w, 0, 0); +} +`; + +const splatCoreFS = /* glsl */ ` + +varying vec2 texCoord; +varying vec4 color; + +#ifndef DITHER_NONE + varying float id; +#endif + +#ifdef PICK_PASS + uniform vec4 uColor; +#endif + +vec4 evalSplat() { + + float A = -dot(texCoord, texCoord); + if (A < -4.0) discard; + float B = exp(A) * color.a; + + #ifdef PICK_PASS + if (B < 0.3) discard; + return(uColor); + #endif + + #ifndef DITHER_NONE + opacityDither(B, id * 0.013); + #endif + + #ifdef TONEMAP_ENABLED + color.rgb = gammaCorrectOutput(toneMap(decodeGamma(color.rgb))); + #endif + + return vec4(color.rgb, B); +} +`; + +class GSplatCompressedShaderGenerator { + generateKey(options) { + const vsHash = hashCode(options.vertex); + const fsHash = hashCode(options.fragment); + return `splat-${options.pass}-${options.gamma}-${options.toneMapping}-${vsHash}-${fsHash}-${options.dither}}`; + } + + createShaderDefinition(device, options) { + + const shaderPassInfo = ShaderPass.get(device).getByIndex(options.pass); + const shaderPassDefines = shaderPassInfo.shaderDefines; + + const defines = + shaderPassDefines + + `#define DITHER_${options.dither.toUpperCase()}\n` + + `#define TONEMAP_${options.toneMapping === TONEMAP_LINEAR ? 'DISABLED' : 'ENABLED'}\n`; + + const vs = defines + splatCoreVS + options.vertex; + const fs = defines + shaderChunks.decodePS + + (options.dither === DITHER_NONE ? '' : shaderChunks.bayerPS + shaderChunks.opacityDitherPS) + + ShaderGenerator.tonemapCode(options.toneMapping) + + ShaderGenerator.gammaCode(options.gamma) + + splatCoreFS + options.fragment; + + return ShaderUtils.createDefinition(device, { + name: 'SplatShader', + attributes: { + vertex_position: SEMANTIC_POSITION, + vertex_id_attrib: SEMANTIC_ATTR13 + }, + vertexCode: vs, + fragmentCode: fs + }); + } +} + +const gsplatCompressed = new GSplatCompressedShaderGenerator(); + +const splatMainVS = ` + void main(void) + { + gl_Position = evalSplat(); + } +`; + +const splatMainFS = ` + void main(void) + { + gl_FragColor = evalSplat(); + } +`; + +/** + * @typedef {object} SplatMaterialOptions - The options. + * @property {string} [vertex] - Custom vertex shader, see SPLAT MANY example. + * @property {string} [fragment] - Custom fragment shader, see SPLAT MANY example. + * @property {string} [dither] - Opacity dithering enum. + */ + +/** + * @param {SplatMaterialOptions} [options] - The options. + * @returns {Material} The GS material. + */ +const createGSplatCompressedMaterial = (options = {}) => { + + const ditherEnum = options.dither ?? DITHER_NONE; + const dither = ditherEnum !== DITHER_NONE; + + const material = new Material(); + material.name = 'compressedSplatMaterial'; + material.cull = CULLFACE_NONE; + material.blendType = dither ? BLEND_NONE : BLEND_NORMAL; + material.depthWrite = dither; + + material.getShaderVariant = function (device, scene, defs, unused, pass, sortedLights, viewUniformFormat, viewBindGroupFormat) { + + const programOptions = { + pass: pass, + gamma: (pass === SHADER_FORWARDHDR ? (scene.gammaCorrection ? GAMMA_SRGBHDR : GAMMA_NONE) : scene.gammaCorrection), + toneMapping: (pass === SHADER_FORWARDHDR ? TONEMAP_LINEAR : scene.toneMapping), + vertex: options.vertex ?? splatMainVS, + fragment: options.fragment ?? splatMainFS, + dither: ditherEnum + }; + + const processingOptions = new ShaderProcessorOptions(viewUniformFormat, viewBindGroupFormat); + + const library = getProgramLibrary(device); + library.register('splat-compressed', gsplatCompressed); + return library.getProgram('splat-compressed', programOptions, processingOptions); + }; + + material.update(); + + return material; +}; + +export { createGSplatCompressedMaterial }; diff --git a/src/scene/gsplat/gsplat-compressed.js b/src/scene/gsplat/gsplat-compressed.js new file mode 100644 index 00000000000..000b41fcf01 --- /dev/null +++ b/src/scene/gsplat/gsplat-compressed.js @@ -0,0 +1,155 @@ +import { Vec2 } from '../../core/math/vec2.js'; +import { + ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_RGBA32F, PIXELFORMAT_RGBA32U +} from '../../platform/graphics/constants.js'; +import { Texture } from '../../platform/graphics/texture.js'; +import { BoundingBox } from '../../core/shape/bounding-box.js'; +import { createGSplatCompressedMaterial } from './gsplat-compressed-material.js'; + +/** @ignore */ +class GSplatCompressed { + device; + + numSplats; + + /** @type {import('../../core/shape/bounding-box.js').BoundingBox} */ + aabb; + + /** @type {Float32Array} */ + centers; + + /** @type {Texture} */ + packedTexture; + + /** @type {Texture} */ + chunkTexture; + + /** + * @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device. + * @param {import('./gsplat-data.js').GSplatData} gsplatData - The splat data. + */ + constructor(device, gsplatData) { + const numSplats = gsplatData.numSplats; + const numChunks = Math.ceil(numSplats / 256); + + this.device = device; + this.numSplats = numSplats; + + // initialize aabb + this.aabb = new BoundingBox(); + gsplatData.calcAabb(this.aabb); + + // initialize centers + this.centers = new Float32Array(gsplatData.numSplats * 3); + gsplatData.getCenters(this.centers); + + // initialize packed data + this.packedTexture = this.createTexture('packedData', PIXELFORMAT_RGBA32U, this.evalTextureSize(numSplats)); + + const position = gsplatData.getProp('packed_position'); + const rotation = gsplatData.getProp('packed_rotation'); + const scale = gsplatData.getProp('packed_scale'); + const color = gsplatData.getProp('packed_color'); + + const packedData = this.packedTexture.lock(); + for (let i = 0; i < numSplats; ++i) { + packedData[i * 4 + 0] = position[i]; + packedData[i * 4 + 1] = rotation[i]; + packedData[i * 4 + 2] = scale[i]; + packedData[i * 4 + 3] = color[i]; + } + this.packedTexture.unlock(); + + // initialize chunk data + const chunkSize = this.evalTextureSize(numChunks); + chunkSize.x *= 3; + + this.chunkTexture = this.createTexture('chunkData', PIXELFORMAT_RGBA32F, chunkSize); + + const minX = gsplatData.getProp('min_x', 'chunk'); + const minY = gsplatData.getProp('min_y', 'chunk'); + const minZ = gsplatData.getProp('min_z', 'chunk'); + const maxX = gsplatData.getProp('max_x', 'chunk'); + const maxY = gsplatData.getProp('max_y', 'chunk'); + const maxZ = gsplatData.getProp('max_z', 'chunk'); + const minScaleX = gsplatData.getProp('min_scale_x', 'chunk'); + const minScaleY = gsplatData.getProp('min_scale_y', 'chunk'); + const minScaleZ = gsplatData.getProp('min_scale_z', 'chunk'); + const maxScaleX = gsplatData.getProp('max_scale_x', 'chunk'); + const maxScaleY = gsplatData.getProp('max_scale_y', 'chunk'); + const maxScaleZ = gsplatData.getProp('max_scale_z', 'chunk'); + + const chunkData = this.chunkTexture.lock(); + for (let i = 0; i < numChunks; ++i) { + chunkData[i * 12 + 0] = minX[i]; + chunkData[i * 12 + 1] = minY[i]; + chunkData[i * 12 + 2] = minZ[i]; + chunkData[i * 12 + 3] = maxX[i]; + chunkData[i * 12 + 4] = maxY[i]; + chunkData[i * 12 + 5] = maxZ[i]; + chunkData[i * 12 + 6] = minScaleX[i]; + chunkData[i * 12 + 7] = minScaleY[i]; + chunkData[i * 12 + 8] = minScaleZ[i]; + chunkData[i * 12 + 9] = maxScaleX[i]; + chunkData[i * 12 + 10] = maxScaleY[i]; + chunkData[i * 12 + 11] = maxScaleZ[i]; + } + this.chunkTexture.unlock(); + } + + destroy() { + this.packedTexture?.destroy(); + this.chunkTexture?.destroy(); + } + + /** + * @returns {import('../materials/material.js').Material} material - The material to set up for + * the splat rendering. + */ + createMaterial(options) { + const result = createGSplatCompressedMaterial(options); + result.setParameter('packedTexture', this.packedTexture); + result.setParameter('chunkTexture', this.chunkTexture); + result.setParameter('bufferWidths', new Float32Array([this.packedTexture.width, this.chunkTexture.width / 3, 0, 0])); + return result; + } + + /** + * Evaluates the texture size needed to store a given number of elements. + * The function calculates a width and height that is close to a square + * that can contain 'count' elements. + * + * @param {number} count - The number of elements to store in the texture. + * @returns {Vec2} The width and height of the texture. + */ + evalTextureSize(count) { + const width = Math.ceil(Math.sqrt(count)); + const height = Math.ceil(count / width); + return new Vec2(width, height); + } + + /** + * Creates a new texture with the specified parameters. + * + * @param {string} name - The name of the texture to be created. + * @param {number} format - The pixel format of the texture. + * @param {Vec2} size - The width and height of the texture. + * @returns {Texture} The created texture instance. + */ + createTexture(name, format, size) { + return new Texture(this.device, { + name: name, + width: size.x, + height: size.y, + format: format, + cubemap: false, + mipmaps: false, + minFilter: FILTER_NEAREST, + magFilter: FILTER_NEAREST, + addressU: ADDRESS_CLAMP_TO_EDGE, + addressV: ADDRESS_CLAMP_TO_EDGE + }); + } +} + +export { GSplatCompressed }; diff --git a/src/scene/gsplat/gsplat-data.js b/src/scene/gsplat/gsplat-data.js index a01e72f1a3c..6fd6d412bf0 100644 --- a/src/scene/gsplat/gsplat-data.js +++ b/src/scene/gsplat/gsplat-data.js @@ -12,70 +12,173 @@ const quat2 = new Quat(); const aabb = new BoundingBox(); const aabb2 = new BoundingBox(); -const debugPoints = [new Vec3(), new Vec3(), new Vec3(), new Vec3(), new Vec3(), new Vec3(), new Vec3(), new Vec3()]; -const debugLines = [ - debugPoints[0], debugPoints[1], debugPoints[1], debugPoints[3], debugPoints[3], debugPoints[2], debugPoints[2], debugPoints[0], - debugPoints[4], debugPoints[5], debugPoints[5], debugPoints[7], debugPoints[7], debugPoints[6], debugPoints[6], debugPoints[4], - debugPoints[0], debugPoints[4], debugPoints[1], debugPoints[5], debugPoints[2], debugPoints[6], debugPoints[3], debugPoints[7] -]; const debugColor = new Color(1, 1, 0, 0.4); +const SH_C0 = 0.28209479177387814; -/** - * Defines the shape of a SplatTRS. - * @typedef {object} SplatTRS - Represents a splat object with position, rotation, and scale. - * @property {number} x - The x-coordinate of the position. - * @property {number} y - The y-coordinate of the position. - * @property {number} z - The z-coordinate of the position. - * @property {number} rx - The x-component of the quaternion rotation. - * @property {number} ry - The y-component of the quaternion rotation. - * @property {number} rz - The z-component of the quaternion rotation. - * @property {number} rw - The w-component of the quaternion rotation. - * @property {number} sx - The scale factor in the x-direction. - * @property {number} sy - The scale factor in the y-direction. - * @property {number} sz - The scale factor in the z-direction. - */ +// iterator for accessing compressed splat data +class SplatCompressedIterator { + constructor(gsplatData, p, r, s, c) { + const unpackUnorm = (value, bits) => { + const t = (1 << bits) - 1; + return (value & t) / t; + }; + + const unpack111011 = (result, value) => { + result.x = unpackUnorm(value >>> 21, 11); + result.y = unpackUnorm(value >>> 11, 10); + result.z = unpackUnorm(value, 11); + }; + + const unpack8888 = (result, value) => { + result.x = unpackUnorm(value >>> 24, 8); + result.y = unpackUnorm(value >>> 16, 8); + result.z = unpackUnorm(value >>> 8, 8); + result.w = unpackUnorm(value, 8); + }; + + // unpack quaternion with 2,10,10,10 format (largest element, 3x10bit element) + const unpackRot = (result, value) => { + const norm = 1.0 / (Math.sqrt(2) * 0.5); + const a = (unpackUnorm(value >>> 20, 10) - 0.5) * norm; + const b = (unpackUnorm(value >>> 10, 10) - 0.5) * norm; + const c = (unpackUnorm(value, 10) - 0.5) * norm; + const m = Math.sqrt(1.0 - (a * a + b * b + c * c)); + + switch (value >>> 30) { + case 0: result.set(m, a, b, c); break; + case 1: result.set(a, m, b, c); break; + case 2: result.set(a, b, m, c); break; + case 3: result.set(a, b, c, m); break; + } + }; + + const lerp = (a, b, t) => a * (1 - t) + b * t; + + const min_x = gsplatData.getProp('min_x', 'chunk'); + const min_y = gsplatData.getProp('min_y', 'chunk'); + const min_z = gsplatData.getProp('min_z', 'chunk'); + const max_x = gsplatData.getProp('max_x', 'chunk'); + const max_y = gsplatData.getProp('max_y', 'chunk'); + const max_z = gsplatData.getProp('max_z', 'chunk'); + const min_scale_x = gsplatData.getProp('min_scale_x', 'chunk'); + const min_scale_y = gsplatData.getProp('min_scale_y', 'chunk'); + const min_scale_z = gsplatData.getProp('min_scale_z', 'chunk'); + const max_scale_x = gsplatData.getProp('max_scale_x', 'chunk'); + const max_scale_y = gsplatData.getProp('max_scale_y', 'chunk'); + const max_scale_z = gsplatData.getProp('max_scale_z', 'chunk'); + + const position = gsplatData.getProp('packed_position'); + const rotation = gsplatData.getProp('packed_rotation'); + const scale = gsplatData.getProp('packed_scale'); + const color = gsplatData.getProp('packed_color'); + + this.read = (i) => { + const ci = Math.floor(i / 256); + + if (p) { + unpack111011(p, position[i]); + p.x = lerp(min_x[ci], max_x[ci], p.x); + p.y = lerp(min_y[ci], max_y[ci], p.y); + p.z = lerp(min_z[ci], max_z[ci], p.z); + } + + if (r) { + unpackRot(r, rotation[i]); + } + + if (s) { + unpack111011(s, scale[i]); + s.x = lerp(min_scale_x[ci], max_scale_x[ci], s.x); + s.y = lerp(min_scale_y[ci], max_scale_y[ci], s.y); + s.z = lerp(min_scale_z[ci], max_scale_z[ci], s.z); + } + + if (c) { + unpack8888(c, color[i]); + } + }; + } +} + +// iterator for accessing uncompressed splat data +class SplatIterator { + constructor(gsplatData, p, r, s, c) { + const x = gsplatData.getProp('x'); + const y = gsplatData.getProp('y'); + const z = gsplatData.getProp('z'); + + const rx = gsplatData.getProp('rot_1'); + const ry = gsplatData.getProp('rot_2'); + const rz = gsplatData.getProp('rot_3'); + const rw = gsplatData.getProp('rot_0'); + + const sx = gsplatData.getProp('scale_0'); + const sy = gsplatData.getProp('scale_1'); + const sz = gsplatData.getProp('scale_2'); + + const cr = gsplatData.getProp('f_dc_0'); + const cg = gsplatData.getProp('f_dc_1'); + const cb = gsplatData.getProp('f_dc_2'); + const ca = gsplatData.getProp('opacity'); + + /** + * Calculates the sigmoid of a given value. + * + * @param {number} v - The value for which to compute the sigmoid function. + * @returns {number} The result of the sigmoid function. + */ + const sigmoid = (v) => { + if (v > 0) { + return 1 / (1 + Math.exp(-v)); + } + + const t = Math.exp(v); + return t / (1 + t); + }; + + this.read = (i) => { + if (p) { + p.x = x[i]; + p.y = y[i]; + p.z = z[i]; + } + + if (r) { + r.set(rx[i], ry[i], rz[i], rw[i]); + } + + if (s) { + s.set(Math.exp(sx[i]), Math.exp(sy[i]), Math.exp(sz[i])); + } + + if (c) { + c.set( + 0.5 + cr[i] * SH_C0, + 0.5 + cg[i] * SH_C0, + 0.5 + cb[i] * SH_C0, + sigmoid(ca[i]) + ); + } + }; + } +} /** + * Calculate a splat orientation matrix from its position and rotation. * @param {Mat4} result - Mat4 instance holding calculated rotation matrix. - * @param {SplatTRS} data - The splat TRS object. + * @param {Vec3} p - The splat position + * @param {Quat} r - The splat rotation */ -const calcSplatMat = (result, data) => { - const px = data.x; - const py = data.y; - const pz = data.z; - const d = Math.sqrt(data.rx * data.rx + data.ry * data.ry + data.rz * data.rz + data.rw * data.rw); - const x = data.rx / d; - const y = data.ry / d; - const z = data.rz / d; - const w = data.rw / d; - - // build rotation matrix - result.data.set([ - 1.0 - 2.0 * (z * z + w * w), - 2.0 * (y * z + x * w), - 2.0 * (y * w - x * z), - 0, - - 2.0 * (y * z - x * w), - 1.0 - 2.0 * (y * y + w * w), - 2.0 * (z * w + x * y), - 0, - - 2.0 * (y * w + x * z), - 2.0 * (z * w - x * y), - 1.0 - 2.0 * (y * y + z * z), - 0, - - px, py, pz, 1 - ]); +const calcSplatMat = (result, p, r) => { + quat.set(r.x, r.y, r.z, r.w).normalize(); + result.setTRS(p, quat, Vec3.ONE); }; class GSplatData { // /** @type {import('./ply-reader').PlyElement[]} */ elements; - // /** @type {import('./ply-reader').PlyElement} */ - vertexElement; + numSplats; // /** // * @param {import('./ply-reader').PlyElement[]} elements - The elements. @@ -86,7 +189,8 @@ class GSplatData { // */ constructor(elements, options = {}) { this.elements = elements; - this.vertexElement = elements.find(element => element.name === 'vertex'); + + this.numSplats = this.getElement('vertex').count; if (!this.isCompressed) { if (options.performZScale ?? true) { @@ -102,18 +206,16 @@ class GSplatData { } } - get numSplats() { - return this.vertexElement.count; - } - /** * @param {BoundingBox} result - Bounding box instance holding calculated result. - * @param {SplatTRS} data - The splat TRS object. + * @param {Vec3} p - The splat position + * @param {Quat} r - The splat rotation + * @param {Vec3} s - The splat scale */ - static calcSplatAabb(result, data) { - calcSplatMat(mat4, data); + static calcSplatAabb(result, p, r, s) { + calcSplatMat(mat4, p, r); aabb.center.set(0, 0, 0); - aabb.halfExtents.set(data.sx * 2, data.sy * 2, data.sz * 2); + aabb.halfExtents.set(s.x * 2, s.y * 2, s.z * 2); result.setFromTransformedAabb(aabb, mat4); } @@ -121,8 +223,13 @@ class GSplatData { * Transform splat data by the given matrix. * * @param {Mat4} mat - The matrix. + * @returns {boolean} True if the transformation was successful, false if the data is compressed. */ transform(mat) { + if (this.isCompressed) { + return false; + } + const x = this.getProp('x'); const y = this.getProp('y'); const z = this.getProp('z'); @@ -151,16 +258,23 @@ class GSplatData { // TODO: transform SH } + + return true; } // access a named property - getProp(name) { - return this.vertexElement.properties.find(property => property.name === name && property.storage)?.storage; + getProp(name, elementName = 'vertex') { + return this.getElement(elementName)?.properties.find(p => p.name === name)?.storage; + } + + // access the named element + getElement(name) { + return this.elements.find(e => e.name === name); } // add a new property addProp(name, storage) { - this.vertexElement.properties.push({ + this.getElement('vertex').properties.push({ type: 'float', name, storage, @@ -168,24 +282,113 @@ class GSplatData { }); } - // calculate scene aabb taking into account splat size + /** + * Create an iterator for accessing splat data + * + * @param {Vec3|null} [p] - the vector to receive splat position + * @param {Quat|null} [r] - the quaternion to receive splat rotation + * @param {Vec3|null} [s] - the vector to receive splat scale + * @param {Vec4|null} [c] - the vector to receive splat color + * @returns {SplatIterator | SplatCompressedIterator} - The iterator + */ + createIter(p, r, s, c) { + return this.isCompressed ? new SplatCompressedIterator(this, p, r, s, c) : new SplatIterator(this, p, r, s, c); + } + + // calculafte a pessimistic aabb, which is faster than calculating an exact aabb calcAabb(result, pred) { - const x = this.getProp('x'); - const y = this.getProp('y'); - const z = this.getProp('z'); + let mx, my, mz, Mx, My, Mz; + let first = true; - const rx = this.getProp('rot_0'); - const ry = this.getProp('rot_1'); - const rz = this.getProp('rot_2'); - const rw = this.getProp('rot_3'); + if (this.isCompressed && !pred && this.numSplats) { + // fast bounds calc using chunk data + const numChunks = Math.ceil(this.numSplats / 256); + + const min_x = this.getProp('min_x', 'chunk'); + const min_y = this.getProp('min_y', 'chunk'); + const min_z = this.getProp('min_z', 'chunk'); + const max_x = this.getProp('max_x', 'chunk'); + const max_y = this.getProp('max_y', 'chunk'); + const max_z = this.getProp('max_z', 'chunk'); + const max_scale_x = this.getProp('max_scale_x', 'chunk'); + const max_scale_y = this.getProp('max_scale_y', 'chunk'); + const max_scale_z = this.getProp('max_scale_z', 'chunk'); + + let s = Math.exp(Math.max(max_scale_x[0], max_scale_y[0], max_scale_z[0])); + mx = min_x[0] - s; + my = min_y[0] - s; + mz = min_z[0] - s; + Mx = max_x[0] + s; + My = max_y[0] + s; + Mz = max_z[0] + s; + + for (let i = 1; i < numChunks; ++i) { + s = Math.exp(Math.max(max_scale_x[i], max_scale_y[i], max_scale_z[i])); + mx = Math.min(mx, min_x[i] - s); + my = Math.min(my, min_y[i] - s); + mz = Math.min(mz, min_z[i] - s); + Mx = Math.max(Mx, max_x[i] + s); + My = Math.max(My, max_y[i] + s); + Mz = Math.max(Mz, max_z[i] + s); + } - const sx = this.getProp('scale_0'); - const sy = this.getProp('scale_1'); - const sz = this.getProp('scale_2'); + first = false; + } else { + const p = new Vec3(); + const s = new Vec3(); + + const iter = this.createIter(p, null, s); + + for (let i = 0; i < this.numSplats; ++i) { + if (pred && !pred(i)) { + continue; + } + + iter.read(i); + + const scaleVal = 2.0 * Math.max(s.x, s.y, s.z); + + if (first) { + first = false; + mx = p.x - scaleVal; + my = p.y - scaleVal; + mz = p.z - scaleVal; + Mx = p.x + scaleVal; + My = p.y + scaleVal; + Mz = p.z + scaleVal; + } else { + mx = Math.min(mx, p.x - scaleVal); + my = Math.min(my, p.y - scaleVal); + mz = Math.min(mz, p.z - scaleVal); + Mx = Math.max(Mx, p.x + scaleVal); + My = Math.max(My, p.y + scaleVal); + Mz = Math.max(Mz, p.z + scaleVal); + } + } + } - const splat = { - x: 0, y: 0, z: 0, rx: 0, ry: 0, rz: 0, rw: 0, sx: 0, sy: 0, sz: 0 - }; + if (!first) { + result.center.set((mx + Mx) * 0.5, (my + My) * 0.5, (mz + Mz) * 0.5); + result.halfExtents.set((Mx - mx) * 0.5, (My - my) * 0.5, (Mz - mz) * 0.5); + } + + return !first; + } + + /** + * Calculate exact scene aabb taking into account splat size + * + * @param {BoundingBox} result - Where to store the resulting bounding box. + * @param {(i) => boolean} [pred] - Optional predicate function to filter splats. + * @returns {boolean} - Whether the calculation was successful. + */ + calcAabbExact(result, pred) { + + const p = new Vec3(); + const r = new Quat(); + const s = new Vec3(); + + const iter = this.createIter(p, r, s); let first = true; @@ -194,22 +397,13 @@ class GSplatData { continue; } - splat.x = x[i]; - splat.y = y[i]; - splat.z = z[i]; - splat.rx = rx[i]; - splat.ry = ry[i]; - splat.rz = rz[i]; - splat.rw = rw[i]; - splat.sx = Math.exp(sx[i]); - splat.sy = Math.exp(sy[i]); - splat.sz = Math.exp(sz[i]); + iter.read(i); if (first) { first = false; - GSplatData.calcSplatAabb(result, splat); + GSplatData.calcSplatAabb(result, p, r, s); } else { - GSplatData.calcSplatAabb(aabb2, splat); + GSplatData.calcSplatAabb(aabb2, p, r, s); result.add(aabb2); } } @@ -217,18 +411,65 @@ class GSplatData { return !first; } + /** + * @param {Float32Array} result - Array containing the centers. + */ + getCenters(result) { + if (this.isCompressed) { + // optimised centers extraction for centers + const position = this.getProp('packed_position'); + const min_x = this.getProp('min_x', 'chunk'); + const min_y = this.getProp('min_y', 'chunk'); + const min_z = this.getProp('min_z', 'chunk'); + const max_x = this.getProp('max_x', 'chunk'); + const max_y = this.getProp('max_y', 'chunk'); + const max_z = this.getProp('max_z', 'chunk'); + + const numChunks = Math.ceil(this.numSplats / 256); + + let mx, my, mz, Mx, My, Mz; + + for (let c = 0; c < numChunks; ++c) { + mx = min_x[c]; + my = min_y[c]; + mz = min_z[c]; + Mx = max_x[c]; + My = max_y[c]; + Mz = max_z[c]; + + const end = Math.min(this.numSplats, (c + 1) * 256); + for (let i = c * 256; i < end; ++i) { + const p = position[i]; + const px = (p >>> 21) / 2047; + const py = ((p >>> 11) & 0x3ff) / 1023; + const pz = (p & 0x7ff) / 2047; + result[i * 3 + 0] = (1 - px) * mx + px * Mx; + result[i * 3 + 1] = (1 - py) * my + py * My; + result[i * 3 + 2] = (1 - pz) * mz + pz * Mz; + } + } + } else { + const p = new Vec3(); + const iter = this.createIter(p); + + for (let i = 0; i < this.numSplats; ++i) { + iter.read(i); + + result[i * 3 + 0] = p.x; + result[i * 3 + 1] = p.y; + result[i * 3 + 2] = p.z; + } + } + } + /** * @param {Vec3} result - The result. * @param {Function} pred - Predicate given index for skipping. */ calcFocalPoint(result, pred) { - const x = this.getProp('x'); - const y = this.getProp('y'); - const z = this.getProp('z'); - - const sx = this.getProp('scale_0'); - const sy = this.getProp('scale_1'); - const sz = this.getProp('scale_2'); + const p = new Vec3(); + const s = new Vec3(); + const iter = this.createIter(p, null, s, null); result.x = 0; result.y = 0; @@ -239,10 +480,13 @@ class GSplatData { if (pred && !pred(i)) { continue; } - const weight = 1.0 / (1.0 + Math.exp(Math.max(sx[i], sy[i], sz[i]))); - result.x += x[i] * weight; - result.y += y[i] * weight; - result.z += z[i] * weight; + + iter.read(i); + + const weight = 1.0 / (1.0 + Math.max(s.x, s.y, s.z)); + result.x += p.x * weight; + result.y += p.y * weight; + result.z += p.z * weight; sum += weight; } result.mulScalar(1 / sum); @@ -253,48 +497,26 @@ class GSplatData { * @param {Mat4} worldMat - The world matrix. */ renderWireframeBounds(scene, worldMat) { - const x = this.getProp('x'); - const y = this.getProp('y'); - const z = this.getProp('z'); - - const rx = this.getProp('rot_0'); - const ry = this.getProp('rot_1'); - const rz = this.getProp('rot_2'); - const rw = this.getProp('rot_3'); + const p = new Vec3(); + const r = new Quat(); + const s = new Vec3(); - const sx = this.getProp('scale_0'); - const sy = this.getProp('scale_1'); - const sz = this.getProp('scale_2'); + const min = new Vec3(); + const max = new Vec3(); - const splat = { - x: 0, y: 0, z: 0, rx: 0, ry: 0, rz: 0, rw: 0, sx: 0, sy: 0, sz: 0 - }; + const iter = this.createIter(p, r, s); for (let i = 0; i < this.numSplats; ++i) { - splat.x = x[i]; - splat.y = y[i]; - splat.z = z[i]; - splat.rx = rx[i]; - splat.ry = ry[i]; - splat.rz = rz[i]; - splat.rw = rw[i]; - splat.sx = Math.exp(sx[i]); - splat.sy = Math.exp(sy[i]); - splat.sz = Math.exp(sz[i]); - - calcSplatMat(mat4, splat); + iter.read(i); + + calcSplatMat(mat4, p, r); mat4.mul2(worldMat, mat4); - for (let j = 0; j < 8; ++j) { - vec3.set( - splat.sx * 2 * ((j & 1) ? 1 : -1), - splat.sy * 2 * ((j & 2) ? 1 : -1), - splat.sz * 2 * ((j & 4) ? 1 : -1) - ); - mat4.transformPoint(vec3, debugPoints[j]); - } + min.set(s.x * -2.0, s.y * -2.0, s.z * -2.0); + max.set(s.x * 2.0, s.y * 2.0, s.z * 2.0); - scene.drawLineArrays(debugLines, debugColor); + // @ts-ignore + scene.immediate.drawWireAlignedBox(min, max, debugColor, true, scene.defaultDrawLayer, mat4); } } @@ -304,100 +526,38 @@ class GSplatData { ['packed_position', 'packed_rotation', 'packed_scale', 'packed_color'].every(name => this.getProp(name)); } + // decompress data into uncompressed splat format and return a new GSplatData instance decompress() { const members = ['x', 'y', 'z', 'f_dc_0', 'f_dc_1', 'f_dc_2', 'opacity', 'rot_0', 'rot_1', 'rot_2', 'rot_3', 'scale_0', 'scale_1', 'scale_2']; - const chunks = this.elements.find(e => e.name === 'chunk'); - const vertices = this.vertexElement; // allocate uncompressed data const data = {}; members.forEach((name) => { - data[name] = new Float32Array(vertices.count); + data[name] = new Float32Array(this.numSplats); }); - const getChunkProp = (name) => { - return chunks.properties.find(p => p.name === name && p.storage)?.storage; - }; - - const min_x = getChunkProp('min_x'); - const min_y = getChunkProp('min_y'); - const min_z = getChunkProp('min_z'); - const max_x = getChunkProp('max_x'); - const max_y = getChunkProp('max_y'); - const max_z = getChunkProp('max_z'); - const min_scale_x = getChunkProp('min_scale_x'); - const min_scale_y = getChunkProp('min_scale_y'); - const min_scale_z = getChunkProp('min_scale_z'); - const max_scale_x = getChunkProp('max_scale_x'); - const max_scale_y = getChunkProp('max_scale_y'); - const max_scale_z = getChunkProp('max_scale_z'); - - const position = this.getProp('packed_position'); - const rotation = this.getProp('packed_rotation'); - const scale = this.getProp('packed_scale'); - const color = this.getProp('packed_color'); - - const unpackUnorm = (value, bits) => { - const t = (1 << bits) - 1; - return (value & t) / t; - }; - - const unpack111011 = (result, value) => { - result.x = unpackUnorm(value >>> 21, 11); - result.y = unpackUnorm(value >>> 11, 10); - result.z = unpackUnorm(value, 11); - }; - - const unpack8888 = (result, value) => { - result.x = unpackUnorm(value >>> 24, 8); - result.y = unpackUnorm(value >>> 16, 8); - result.z = unpackUnorm(value >>> 8, 8); - result.w = unpackUnorm(value, 8); - }; - - // unpack quaternion with 2,10,10,10 format (largest element, 3x10bit element) - const unpackRot = (result, value) => { - const norm = 1.0 / (Math.sqrt(2) * 0.5); - const a = (unpackUnorm(value >>> 20, 10) - 0.5) * norm; - const b = (unpackUnorm(value >>> 10, 10) - 0.5) * norm; - const c = (unpackUnorm(value, 10) - 0.5) * norm; - const m = Math.sqrt(1.0 - (a * a + b * b + c * c)); - - switch (value >>> 30) { - case 0: result.set(m, a, b, c); break; - case 1: result.set(a, m, b, c); break; - case 2: result.set(a, b, m, c); break; - case 3: result.set(a, b, c, m); break; - } - }; - - const lerp = (a, b, t) => a * (1 - t) + b * t; - const p = new Vec3(); const r = new Quat(); const s = new Vec3(); const c = new Vec4(); - for (let i = 0; i < vertices.count; ++i) { - const ci = Math.floor(i / 256); + const iter = this.createIter(p, r, s, c); - unpack111011(p, position[i]); - unpackRot(r, rotation[i]); - unpack111011(s, scale[i]); - unpack8888(c, color[i]); + for (let i = 0; i < this.numSplats; ++i) { + iter.read(i); - data.x[i] = lerp(min_x[ci], max_x[ci], p.x); - data.y[i] = lerp(min_y[ci], max_y[ci], p.y); - data.z[i] = lerp(min_z[ci], max_z[ci], p.z); + data.x[i] = p.x; + data.y[i] = p.y; + data.z[i] = p.z; data.rot_0[i] = r.x; data.rot_1[i] = r.y; data.rot_2[i] = r.z; data.rot_3[i] = r.w; - data.scale_0[i] = lerp(min_scale_x[ci], max_scale_x[ci], s.x); - data.scale_1[i] = lerp(min_scale_y[ci], max_scale_y[ci], s.y); - data.scale_2[i] = lerp(min_scale_z[ci], max_scale_z[ci], s.z); + data.scale_0[i] = s.x; + data.scale_1[i] = s.y; + data.scale_2[i] = s.z; const SH_C0 = 0.28209479177387814; data.f_dc_0[i] = (c.x - 0.5) / SH_C0; @@ -408,7 +568,7 @@ class GSplatData { return new GSplatData([{ name: 'vertex', - count: vertices.count, + count: this.numSplats, properties: members.map((name) => { return { name: name, @@ -479,8 +639,8 @@ class GSplatData { return indices; } + // reorder the splat data to aid in better gpu memory access at render time reorderData() { - // calculate splat morton order const order = this.calcMortonOrder(); const reorder = (data) => { diff --git a/src/scene/gsplat/gsplat-instance.js b/src/scene/gsplat/gsplat-instance.js index b6da74aa0e0..87f64c5513f 100644 --- a/src/scene/gsplat/gsplat-instance.js +++ b/src/scene/gsplat/gsplat-instance.js @@ -4,7 +4,6 @@ import { BUFFER_STATIC, PIXELFORMAT_R32U, SEMANTIC_ATTR13, TYPE_UINT32 } from '. import { DITHER_NONE } from '../constants.js'; import { MeshInstance } from '../mesh-instance.js'; import { Mesh } from '../mesh.js'; -import { createGSplatMaterial } from './gsplat-material.js'; import { GSplatSorter } from './gsplat-sorter.js'; import { VertexFormat } from '../../platform/graphics/vertex-format.js'; import { VertexBuffer } from '../../platform/graphics/vertex-buffer.js'; @@ -149,9 +148,8 @@ class GSplatInstance { } createMaterial(options) { - this.material = createGSplatMaterial(options); + this.material = this.splat.createMaterial(options); this.material.setParameter('splatOrder', this.orderTexture); - this.splat.setupMaterial(this.material); if (this.meshInstance) { this.meshInstance.material = this.material; } diff --git a/src/scene/gsplat/gsplat-material.js b/src/scene/gsplat/gsplat-material.js index 23ae17e9237..9a12632cb6a 100644 --- a/src/scene/gsplat/gsplat-material.js +++ b/src/scene/gsplat/gsplat-material.js @@ -1,4 +1,4 @@ -import { CULLFACE_BACK, CULLFACE_NONE } from "../../platform/graphics/constants.js"; +import { CULLFACE_NONE } from "../../platform/graphics/constants.js"; import { ShaderProcessorOptions } from "../../platform/graphics/shader-processor-options.js"; import { BLEND_NONE, BLEND_NORMAL, DITHER_NONE, GAMMA_NONE, GAMMA_SRGBHDR, SHADER_FORWARDHDR, TONEMAP_LINEAR } from "../constants.js"; import { Material } from "../materials/material.js"; @@ -24,7 +24,6 @@ const splatMainFS = ` /** * @typedef {object} SplatMaterialOptions - The options. - * @property {boolean} [debugRender] - Adds #define DEBUG_RENDER for shader. * @property {string} [vertex] - Custom vertex shader, see SPLAT MANY example. * @property {string} [fragment] - Custom fragment shader, see SPLAT MANY example. * @property {string} [dither] - Opacity dithering enum. @@ -36,14 +35,12 @@ const splatMainFS = ` */ const createGSplatMaterial = (options = {}) => { - const { debugRender } = options; - const ditherEnum = options.dither ?? DITHER_NONE; const dither = ditherEnum !== DITHER_NONE; const material = new Material(); material.name = 'splatMaterial'; - material.cull = debugRender ? CULLFACE_BACK : CULLFACE_NONE; + material.cull = CULLFACE_NONE; material.blendType = dither ? BLEND_NONE : BLEND_NORMAL; material.depthWrite = dither; @@ -55,7 +52,6 @@ const createGSplatMaterial = (options = {}) => { toneMapping: (pass === SHADER_FORWARDHDR ? TONEMAP_LINEAR : scene.toneMapping), vertex: options.vertex ?? splatMainVS, fragment: options.fragment ?? splatMainFS, - debugRender: debugRender, dither: ditherEnum }; diff --git a/src/scene/gsplat/gsplat.js b/src/scene/gsplat/gsplat.js index 0d2c83033bd..80ea7dc3836 100644 --- a/src/scene/gsplat/gsplat.js +++ b/src/scene/gsplat/gsplat.js @@ -2,13 +2,16 @@ import { FloatPacking } from '../../core/math/float-packing.js'; import { math } from '../../core/math/math.js'; import { Quat } from '../../core/math/quat.js'; import { Vec2 } from '../../core/math/vec2.js'; +import { Vec3 } from '../../core/math/vec3.js'; +import { Vec4 } from '../../core/math/vec4.js'; import { Mat3 } from '../../core/math/mat3.js'; import { ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F, PIXELFORMAT_RGBA8 } from '../../platform/graphics/constants.js'; import { Texture } from '../../platform/graphics/texture.js'; -import { Vec3 } from '../../core/math/vec3.js'; +import { BoundingBox } from '../../core/shape/bounding-box.js'; +import { createGSplatMaterial } from './gsplat-material.js'; const _tmpVecA = new Vec3(); const _tmpVecB = new Vec3(); @@ -16,7 +19,6 @@ const _tmpVecC = new Vec3(); const _m0 = new Vec3(); const _m1 = new Vec3(); const _m2 = new Vec3(); -const _s = new Vec3(); /** @ignore */ class GSplat { @@ -24,6 +26,12 @@ class GSplat { numSplats; + /** @type {Float32Array} */ + centers; + + /** @type {import('../../core/shape/bounding-box.js').BoundingBox} */ + aabb; + /** @type {Texture} */ colorTexture; @@ -36,27 +44,31 @@ class GSplat { /** @type {Texture} */ transformCTexture; - /** @type {Float32Array} */ - centers; - - /** @type {import('../../core/shape/bounding-box.js').BoundingBox} */ - aabb; - /** * @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device. - * @param {number} numSplats - Number of splats. - * @param {import('../../core/shape/bounding-box.js').BoundingBox} aabb - The bounding box. + * @param {import('./gsplat-data.js').GSplatData} gsplatData - The splat data. */ - constructor(device, numSplats, aabb) { + constructor(device, gsplatData) { + const numSplats = gsplatData.numSplats; + this.device = device; this.numSplats = numSplats; - this.aabb = aabb; + + this.centers = new Float32Array(gsplatData.numSplats * 3); + gsplatData.getCenters(this.centers); + + this.aabb = new BoundingBox(); + gsplatData.calcAabb(this.aabb); const size = this.evalTextureSize(numSplats); this.colorTexture = this.createTexture('splatColor', PIXELFORMAT_RGBA8, size); this.transformATexture = this.createTexture('transformA', PIXELFORMAT_RGBA32F, size); this.transformBTexture = this.createTexture('transformB', PIXELFORMAT_RGBA16F, size); this.transformCTexture = this.createTexture('transformC', PIXELFORMAT_R16F, size); + + // write texture data + this.updateColorData(gsplatData); + this.updateTransformData(gsplatData); } destroy() { @@ -67,20 +79,19 @@ class GSplat { } /** - * @param {import('../materials/material.js').Material} material - The material to set up for + * @returns {import('../materials/material.js').Material} material - The material to set up for * the splat rendering. */ - setupMaterial(material) { - - if (this.colorTexture) { - material.setParameter('splatColor', this.colorTexture); - material.setParameter('transformA', this.transformATexture); - material.setParameter('transformB', this.transformBTexture); - material.setParameter('transformC', this.transformCTexture); - - const { width, height } = this.colorTexture; - material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height])); - } + createMaterial(options) { + const result = createGSplatMaterial(options); + result.setParameter('splatColor', this.colorTexture); + result.setParameter('transformA', this.transformATexture); + result.setParameter('transformB', this.transformBTexture); + result.setParameter('transformC', this.transformCTexture); + + const { width, height } = this.colorTexture; + result.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height])); + return result; } /** @@ -125,62 +136,33 @@ class GSplat { * Assumes that the texture is using an RGBA format where RGB are color components influenced * by SH spherical harmonics and A is opacity after a sigmoid transformation. * - * @param {Float32Array} c0 - The first color component SH coefficients. - * @param {Float32Array} c1 - The second color component SH coefficients. - * @param {Float32Array} c2 - The third color component SH coefficients. - * @param {Float32Array} opacity - The opacity values to be transformed using a sigmoid function. + * @param {import('./gsplat-data.js').GSplatData} gsplatData - The source data */ - updateColorData(c0, c1, c2, opacity) { - const SH_C0 = 0.28209479177387814; + updateColorData(gsplatData) { const texture = this.colorTexture; if (!texture) return; const data = texture.lock(); - /** - * Calculates the sigmoid of a given value. - * - * @param {number} v - The value for which to compute the sigmoid function. - * @returns {number} The result of the sigmoid function. - */ - const sigmoid = (v) => { - if (v > 0) { - return 1 / (1 + Math.exp(-v)); - } - - const t = Math.exp(v); - return t / (1 + t); - }; + const c = new Vec4(); + const iter = gsplatData.createIter(null, null, null, c); for (let i = 0; i < this.numSplats; ++i) { + iter.read(i); - // colors - if (c0 && c1 && c2) { - data[i * 4 + 0] = math.clamp((0.5 + SH_C0 * c0[i]) * 255, 0, 255); - data[i * 4 + 1] = math.clamp((0.5 + SH_C0 * c1[i]) * 255, 0, 255); - data[i * 4 + 2] = math.clamp((0.5 + SH_C0 * c2[i]) * 255, 0, 255); - } - - // opacity - data[i * 4 + 3] = opacity ? math.clamp(sigmoid(opacity[i]) * 255, 0, 255) : 255; + data[i * 4 + 0] = math.clamp(c.x * 255, 0, 255); + data[i * 4 + 1] = math.clamp(c.y * 255, 0, 255); + data[i * 4 + 2] = math.clamp(c.z * 255, 0, 255); + data[i * 4 + 3] = math.clamp(c.w * 255, 0, 255); } texture.unlock(); } /** - * @param {Float32Array} x - The array containing the 'x' component of the center points. - * @param {Float32Array} y - The array containing the 'y' component of the center points. - * @param {Float32Array} z - The array containing the 'z' component of the center points. - * @param {Float32Array} rot0 - The array containing the 'w' component of quaternion rotations. - * @param {Float32Array} rot1 - The array containing the 'x' component of quaternion rotations. - * @param {Float32Array} rot2 - The array containing the 'y' component of quaternion rotations. - * @param {Float32Array} rot3 - The array containing the 'z' component of quaternion rotations. - * @param {Float32Array} scale0 - The first scale component associated with the x-dimension. - * @param {Float32Array} scale1 - The second scale component associated with the y-dimension. - * @param {Float32Array} scale2 - The third scale component associated with the z-dimension. + * @param {import('./gsplat-data.js').GSplatData} gsplatData - The source data */ - updateTransformData(x, y, z, rot0, rot1, rot2, rot3, scale0, scale1, scale2) { + updateTransformData(gsplatData) { const float2Half = FloatPacking.float2Half; @@ -191,29 +173,26 @@ class GSplat { const dataB = this.transformBTexture.lock(); const dataC = this.transformCTexture.lock(); - const quat = new Quat(); + const p = new Vec3(); + const r = new Quat(); + const s = new Vec3(); + const iter = gsplatData.createIter(p, r, s); + const mat = new Mat3(); const cA = new Vec3(); const cB = new Vec3(); for (let i = 0; i < this.numSplats; i++) { + iter.read(i); - // rotation - quat.set(rot1[i], rot2[i], rot3[i], rot0[i]).normalize(); - mat.setFromQuat(quat); - - // scale - _s.set( - Math.exp(scale0[i]), - Math.exp(scale1[i]), - Math.exp(scale2[i]) - ); + r.normalize(); + mat.setFromQuat(r); - this.computeCov3d(mat, _s, cA, cB); + this.computeCov3d(mat, s, cA, cB); - dataA[i * 4 + 0] = x[i]; - dataA[i * 4 + 1] = y[i]; - dataA[i * 4 + 2] = z[i]; + dataA[i * 4 + 0] = p.x; + dataA[i * 4 + 1] = p.y; + dataA[i * 4 + 2] = p.z; dataA[i * 4 + 3] = cB.x; dataB[i * 4 + 0] = float2Half(cA.x); diff --git a/src/scene/gsplat/shader-generator-gsplat.js b/src/scene/gsplat/shader-generator-gsplat.js index 33c83a91645..a4121f66a83 100644 --- a/src/scene/gsplat/shader-generator-gsplat.js +++ b/src/scene/gsplat/shader-generator-gsplat.js @@ -1,12 +1,12 @@ import { hashCode } from "../../core/hash.js"; import { SEMANTIC_ATTR13, SEMANTIC_POSITION } from "../../platform/graphics/constants.js"; import { ShaderUtils } from "../../platform/graphics/shader-utils.js"; -import { DITHER_NONE } from "../constants.js"; +import { DITHER_NONE, TONEMAP_LINEAR } from "../constants.js"; import { shaderChunks } from "../shader-lib/chunks/chunks.js"; import { ShaderGenerator } from "../shader-lib/programs/shader-generator.js"; import { ShaderPass } from "../shader-pass.js"; -const splatCoreVS = ` +const splatCoreVS = /* glsl */ ` uniform mat4 matrix_model; uniform mat4 matrix_view; uniform mat4 matrix_projection; @@ -146,45 +146,32 @@ const splatCoreFS = /* glsl_ */ ` #endif vec4 evalSplat() { + float A = -dot(texCoord, texCoord); + if (A < -4.0) discard; + float B = exp(A) * color.a; - #ifdef DEBUG_RENDER - - if (color.a < 0.2) discard; - return color; - - #else - - float A = -dot(texCoord, texCoord); - if (A < -4.0) discard; - float B = exp(A) * color.a; - - #ifdef PICK_PASS - if (B < 0.3) discard; - return(uColor); - #endif - - #ifndef DITHER_NONE - opacityDither(B, id * 0.013); - #endif - - // the color here is in gamma space, so bring it to linear - vec3 diffuse = decodeGamma(color.rgb); - - // apply tone-mapping and gamma correction as needed - diffuse = toneMap(diffuse); - diffuse = gammaCorrectOutput(diffuse); + #ifdef PICK_PASS + if (B < 0.3) discard; + return(uColor); + #endif - return vec4(diffuse, B); + #ifndef DITHER_NONE + opacityDither(B, id * 0.013); + #endif + #ifdef TONEMAP_ENABLED + color.rgb = gammaCorrectOutput(toneMap(decodeGamma(color.rgb))); #endif + + return vec4(color.rgb, B); } `; -class GShaderGeneratorSplat { +class GSplatShaderGenerator { generateKey(options) { const vsHash = hashCode(options.vertex); const fsHash = hashCode(options.fragment); - return `splat-${options.pass}-${options.gamma}-${options.toneMapping}-${vsHash}-${fsHash}-${options.debugRender}-${options.dither}}`; + return `splat-${options.pass}-${options.gamma}-${options.toneMapping}-${vsHash}-${fsHash}-${options.dither}}`; } createShaderDefinition(device, options) { @@ -194,8 +181,8 @@ class GShaderGeneratorSplat { const defines = shaderPassDefines + - (options.debugRender ? '#define DEBUG_RENDER\n' : '') + - `#define DITHER_${options.dither.toUpperCase()}\n`; + `#define DITHER_${options.dither.toUpperCase()}\n` + + `#define TONEMAP_${options.toneMapping === TONEMAP_LINEAR ? 'DISABLED' : 'ENABLED'}\n`; const vs = defines + splatCoreVS + options.vertex; const fs = defines + shaderChunks.decodePS + @@ -216,6 +203,6 @@ class GShaderGeneratorSplat { } } -const gsplat = new GShaderGeneratorSplat(); +const gsplat = new GSplatShaderGenerator(); export { gsplat };