Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/editor/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
import * as THREE from "three";
import { OrbitControls } from "three/addons/controls/OrbitControls.js";
import { GUI } from "lil-gui";
import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, isPcSogs, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark";
import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark";
import { getAssetFileURL } from "../js/get-asset-url.js";

const scene = new THREE.Scene();
Expand Down
42 changes: 8 additions & 34 deletions examples/splat-painter/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
RgbaArray,
readRgbaArray,
SparkControls,
SpzWriter,
writeSpz,
unpackSplat,
PackedSplats,
} from "@sparkjsdev/spark";
Expand Down Expand Up @@ -405,7 +405,7 @@
alert("No splat mesh loaded for export.");
return;
}

try {
console.log("Starting SPZ export with painted changes...");

Expand All @@ -416,7 +416,7 @@
});
currentSplatMesh.updateGenerator();
}

const rgba = new RgbaArray();
rgba.render({
renderer,
Expand Down Expand Up @@ -470,7 +470,7 @@
unpacked.color.g = rgbaBytes[rgbaOffset + 1] / 255;
unpacked.color.b = rgbaBytes[rgbaOffset + 2] / 255;
unpacked.opacity = opacity;

// Push to new PackedSplats
newPackedSplats.pushSplat(
unpacked.center,
Expand All @@ -482,39 +482,13 @@

processedCount++;
}

console.log(`Processed ${processedCount} splats`);

// Now export the PackedSplats to SPZ
console.log("Creating SPZ writer...");
const maxSh = ioOptions.maxSh;
const spzWriter = new SpzWriter({
numSplats: nonZeroCount,
shDegree: maxSh,
fractionalBits: ioOptions.fractionalBits,
flagAntiAlias: true,
});

console.log("Writing splats to SPZ...");
// Iterate through the new packed array
for (let i = 0; i < nonZeroCount; i++) {
const unpacked = unpackSplat(
newPackedSplats.packedArray,
i,
newPackedSplats.splatEncoding
);

spzWriter.setCenter(i, unpacked.center.x, unpacked.center.y, unpacked.center.z);
spzWriter.setScale(i, unpacked.scales.x, unpacked.scales.y, unpacked.scales.z);
spzWriter.setQuat(i, unpacked.quaternion.x, unpacked.quaternion.y, unpacked.quaternion.z, unpacked.quaternion.w);
spzWriter.setAlpha(i, unpacked.opacity);
spzWriter.setRgb(i, unpacked.color.r, unpacked.color.g, unpacked.color.b);
}
const spzBytes = await spzWriter.finalize();
if (spzWriter.clippedCount > 0) {
console.log(`Clipped ${spzWriter.clippedCount} splats. Consider decreasing fractional-bits from ${ioOptions.fractionalBits} to reduce clipping.`);
}

const { fileBytes: spzBytes } = writeSpz(newPackedSplats, ioOptions.maxSh, ioOptions.fractionalBits);

console.log("Creating download...");
const blob = new Blob([spzBytes], { type: "application/octet-stream" });
const url = URL.createObjectURL(blob);
Expand All @@ -531,7 +505,7 @@
}
},
};

const ioFolder = gui.addFolder("I/O");
ioFolder.add(ioOptions, "loadFile").name("Load Splats (SPZ/PLY)");
ioFolder.add(ioOptions, "saveToSpz").name("Save Splats (SPZ)");
Expand Down
1 change: 1 addition & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/spark-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ web-sys = { workspace = true, features = ["Window", "Performance"] }
spark-lib = { path = "../spark-lib" }
serde-wasm-bindgen.workspace = true
serde_json.workspace = true
serde.workspace = true
itertools.workspace = true
console_error_panic_hook = "0.1"
32 changes: 32 additions & 0 deletions rust/spark-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
use std::cell::RefCell;
use js_sys::{Array, Float32Array, Object, Reflect, Uint8Array, Uint16Array, Uint32Array};
use spark_lib::decoder::{ChunkReceiver, MultiDecoder, SplatEncoding, SplatFileType, SplatGetter};
use spark_lib::spz::SpzEncoder;
use spark_lib::gsplat::{GsplatSH1,GsplatSH2,GsplatSH3};
use spark_lib::gsplat::GsplatArray as GsplatArrayInner;
use spark_lib::csplat::CsplatArray as CsplatArrayInner;
use spark_lib::tsplat::TsplatArray;
Expand All @@ -16,6 +18,9 @@ use raycast::{raycast_packed_ellipsoids, raycast_ext_ellipsoids};
mod sort;
use sort::{sort_internal, SortBuffers, sort32_internal, Sort32Buffers};

mod transform;
use transform::{transform_gsplatarray, TransformOptions};

mod decoder;
mod packed_splats;
mod ext_splats;
Expand Down Expand Up @@ -179,6 +184,7 @@ impl GsplatArray {
}
}


#[wasm_bindgen]
impl GsplatArray {
pub fn len(&self) -> usize {
Expand Down Expand Up @@ -257,6 +263,32 @@ impl GsplatArray {
pub fn inject_rgba8(&mut self, rgba: Uint8Array) {
self.inner.inject_rgba8(&rgba.to_vec());
}

pub fn transform(&mut self, transform: JsValue) -> Result<(), JsValue> {
let transform_options: TransformOptions = serde_wasm_bindgen::from_value(transform)?;
transform_gsplatarray(&mut self.inner, transform_options);
Ok(())
}

pub fn concat(&mut self, other: &mut GsplatArray) -> Result<(), JsValue> {
for i in 0..other.inner.len() {
let sh1 = if other.maxShDegree >= 1 { other.inner.sh1[i].clone() } else { GsplatSH1::default() };
let sh2 = if other.maxShDegree >= 2 { other.inner.sh2[i].clone() } else { GsplatSH2::default() };
let sh3 = if other.maxShDegree >= 3 { other.inner.sh3[i].clone() } else { GsplatSH3::default() };
self.inner.push_splat(other.inner.get(i).clone(), Some(sh1), Some(sh2), Some(sh3));
}
Ok(())
}

pub fn encode_to_spz(mut self, max_sh: u32, fractional_bits: u8) -> Result<Uint8Array, JsValue> {
self.inner.clamp_sh_degree(max_sh as usize);
self.maxShDegree = self.inner.max_sh_degree;
let encoded = match SpzEncoder::new(self.inner).with_max_sh(max_sh as usize).with_fractional_bits(fractional_bits).encode() {
Err(err) => { return Err(JsValue::from(err.to_string())); },
Ok(encoded) => encoded
};
Ok(Uint8Array::from(encoded.as_slice()))
}
}

#[wasm_bindgen]
Expand Down
73 changes: 73 additions & 0 deletions rust/spark-rs/src/transform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use glam::{Vec3A, Quat};

use spark_lib::gsplat::GsplatArray;
use spark_lib::tsplat::TsplatArray;
use spark_lib::tsplat::Tsplat;
use serde::{Deserialize, Serialize};

use spark_lib::decoder::SplatReceiver;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformOptions {
pub translation: [f32; 3],
pub rotation: [f32; 4],
pub scale: f32,
pub clip: Option<[f32; 6]>,
#[serde(rename = "opacityThreshold")]
pub opacity_threshold: f32,
}

pub fn transform_gsplatarray(gsplats: &mut GsplatArray, transform_options: TransformOptions) {
let translation = Vec3A::from_array(transform_options.translation);
let quaternion = Quat::from_array(transform_options.rotation);
let scale = Vec3A::splat(transform_options.scale);

let clip = transform_options.clip.map(|clip| (Vec3A::from_slice(&clip[..3]), Vec3A::from_slice(&clip[3..])));

let mut out_index = 0;
for splat_index in 0..gsplats.splats.len() {
let in_splat = gsplats.get(splat_index);

let mut center = in_splat.center();
// Transform center
center = quaternion * (center * scale) + translation;

// Check clip box
let clipped = match clip {
Some((min, max)) => (center.cmplt(min)).any() || (center.cmpgt(max)).any(),
None => false
};
if clipped {
continue;
}

// Check opacity threshold
let opacity = in_splat.opacity();
if opacity < transform_options.opacity_threshold {
continue;
}

let mut scales = in_splat.scales();
let mut quat = in_splat.quaternion();
let rgb = in_splat.rgb();

gsplats.set_center(out_index, 1, &center.to_array());

scales *= scale;
gsplats.set_scale(out_index, 1, &scales.to_array());

quat *= quaternion;
gsplats.set_quat(out_index, 1, &quat.to_array());

gsplats.set_rgb(out_index, 1, &rgb.to_array());
gsplats.set_opacity(out_index, 1, &[opacity]);

gsplats.set_sh1(out_index, 1, gsplats.get_sh1(splat_index).as_slice());
gsplats.set_sh2(out_index, 1, gsplats.get_sh2(splat_index).as_slice());
gsplats.set_sh3(out_index, 1, gsplats.get_sh3(splat_index).as_slice());

out_index += 1;
}

gsplats.truncate(out_index);
}
60 changes: 0 additions & 60 deletions src/SplatLoader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { PackedSplats, type PackedSplatsOptions } from "./PackedSplats";
import { SplatMesh } from "./SplatMesh";
import { workerPool } from "./SplatWorker";
import { type SplatEncoding, SplatFileType } from "./defines";
import { PlyReader } from "./ply";
import { decompressPartialGzip, getTextureSize } from "./utils";

// SplatLoader implements the THREE.Loader interface and supports loading a variety
Expand Down Expand Up @@ -340,65 +339,6 @@ export class SplatLoader extends Loader {
}
}

async function fetchWithProgress(
request: Request,
onProgress?: (event: ProgressEvent) => void,
) {
const response = await fetch(request);
if (!response.ok) {
throw new Error(
`${response.status} "${response.statusText}" fetching URL: ${request.url}`,
);
}
if (!response.body) {
throw new Error(`Response body is null for URL: ${request.url}`);
}

const reader = response.body.getReader();
let loaded = 0;
const chunks: Uint8Array[] = [];
try {
const contentLength = Number.parseInt(
response.headers.get("Content-Length") || "0",
);
const total = Number.isNaN(contentLength) ? 0 : contentLength;

while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
chunks.push(value);
loaded += value.length;

if (onProgress) {
onProgress(
new ProgressEvent("progress", {
lengthComputable: total !== 0,
loaded,
total,
}),
);
}
}
} catch (err) {
try {
const reason = err instanceof Error ? err.message : "Unknown error";
await reader.cancel(reason);
} catch {}
throw err;
}

// Combine chunks into a single buffer
const bytes = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
bytes.set(chunk, offset);
offset += chunk.length;
}
return bytes.buffer;
}

export function getSplatFileType(
fileBytes: Uint8Array,
): SplatFileType | undefined {
Expand Down
Loading
Loading