Skip to content

Commit

Permalink
perf(functions): Optimize remaps in join() operations. (#1242)
Browse files Browse the repository at this point in the history
- Optimizes join() to use remapAttribute / remapIndices 
- Reduce allocations
- Add benchmarks
  • Loading branch information
donmccurdy authored Jan 27, 2024
1 parent 4adbe34 commit e9e618c
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 57 deletions.
3 changes: 2 additions & 1 deletion benchmarks/tasks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Task } from '../constants.js';
import { tasks as createTasks } from './clone.bench.js';
import { tasks as cloneTasks } from './create.bench.js';
import { tasks as disposeTasks } from './dispose.bench.js';
import { tasks as joinTasks } from './join.bench.js';
import { tasks as weldTasks } from './weld.bench.js';

export const tasks: Task[] = [...createTasks, ...cloneTasks, ...disposeTasks, ...weldTasks];
export const tasks: Task[] = [...createTasks, ...cloneTasks, ...disposeTasks, ...joinTasks, ...weldTasks];
37 changes: 37 additions & 0 deletions benchmarks/tasks/join.bench.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { Document } from '@gltf-transform/core';
import { join } from '@gltf-transform/functions';
import { Task } from '../constants';
import { LOGGER, createTorusKnotPrimitive } from '../utils';

let _document: Document;

export const tasks: Task[] = [
[
'join::sm',
async () => {
await _document.transform(join());
},
{ beforeEach: () => void (_document = createDocument(10, 64, 64)) }, // ~4000 vertices / prim
],
[
'join::md',
async () => {
await _document.transform(join());
},
{ beforeEach: () => void (_document = createDocument(4, 512, 512)) }, // ~250,000 vertices / prim
],
];

function createDocument(primCount: number, radialSegments: number, tubularSegments: number): Document {
const document = new Document().setLogger(LOGGER);

const scene = document.createScene();
for (let i = 0; i < primCount; i++) {
const prim = createTorusKnotPrimitive(document, { radialSegments, tubularSegments });
const mesh = document.createMesh().addPrimitive(prim);
const node = document.createNode().setMesh(mesh);
scene.addChild(node);
}

return document;
}
4 changes: 2 additions & 2 deletions benchmarks/tasks/weld.bench.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Document } from '@gltf-transform/core';
import { weld } from '@gltf-transform/functions';
import { Task } from '../constants';
import { createTorusKnotPrimitive } from '../utils';
import { LOGGER, createTorusKnotPrimitive } from '../utils';

let _document: Document;

Expand All @@ -23,7 +23,7 @@ export const tasks: Task[] = [
];

function createTorusKnotDocument(radialSegments: number, tubularSegments: number): Document {
const document = new Document();
const document = new Document().setLogger(LOGGER);
const prim = createTorusKnotPrimitive(document, { radialSegments, tubularSegments });
const mesh = document.createMesh().addPrimitive(prim);
const node = document.createNode().setMesh(mesh);
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/utils.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { Document, Mesh, Node, Primitive, Scene, vec3 } from '@gltf-transform/core';
import { Document, Logger, Mesh, Node, Primitive, Scene, vec3 } from '@gltf-transform/core';
import { vec3 as glvec3 } from 'gl-matrix';

export const LOGGER = new Logger(Logger.Verbosity.SILENT);

/******************************************************************************
* PROPERTY CONSTRUCTORS
*/

export function createLargeDocument(rootNodeCount: number): Document {
const document = new Document();
const document = new Document().setLogger(LOGGER);
createSubtree(document, document.createScene('Scene'), rootNodeCount);
return document;
}
Expand Down
81 changes: 44 additions & 37 deletions packages/functions/src/join-primitives.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Document, Primitive, ComponentTypeToTypedArray } from '@gltf-transform/core';
import { createIndices, createPrimGroupKey, shallowCloneAccessor } from './utils.js';
import { createIndices, createPrimGroupKey, remapAttribute, remapIndices, shallowCloneAccessor } from './utils.js';

interface JoinPrimitiveOptions {
skipValidation?: boolean;
Expand All @@ -9,6 +9,8 @@ const JOIN_PRIMITIVE_DEFAULTS: Required<JoinPrimitiveOptions> = {
skipValidation: false,
};

const EMPTY_U32 = 2 ** 32 - 1;

/**
* Given a list of compatible Mesh {@link Primitive Primitives}, returns new Primitive
* containing their vertex data. Compatibility requires that all Primitives share the
Expand Down Expand Up @@ -44,29 +46,32 @@ export function joinPrimitives(prims: Primitive[], options: JoinPrimitiveOptions
);
}

const remapList = [] as Uint32Array[]; // remap[srcIndex] → dstIndex, by prim
const countList = [] as number[]; // vertex count, by prim
const indicesList = [] as (Uint32Array | Uint16Array)[]; // indices, by prim
const primRemaps = [] as Uint32Array[]; // remap[srcIndex] → dstIndex, by prim
const primVertexCounts = new Uint32Array(prims.length); // vertex count, by prim

let dstVertexCount = 0;
let dstIndicesCount = 0;

// (2) Build remap lists.
for (const srcPrim of prims) {
const indices = _getOrCreateIndices(srcPrim);
const remap = [];
let count = 0;
for (let i = 0; i < indices.length; i++) {
const index = indices[i];
if (remap[index] === undefined) {
for (let primIndex = 0; primIndex < prims.length; primIndex++) {
const srcPrim = prims[primIndex];
const srcIndices = srcPrim.getIndices();
const srcVertexCount = srcPrim.getAttribute('POSITION')!.getCount();
const srcIndicesArray = srcIndices ? srcIndices.getArray() : null;
const srcIndicesCount = srcIndices ? srcIndices.getCount() : srcVertexCount;

const remap = new Uint32Array(getIndicesMax(srcPrim) + 1).fill(EMPTY_U32);

for (let i = 0; i < srcIndicesCount; i++) {
const index = srcIndicesArray ? srcIndicesArray[i] : i;
if (remap[index] === EMPTY_U32) {
remap[index] = dstVertexCount++;
count++;
primVertexCounts[primIndex]++;
}
dstIndicesCount++;
}
remapList.push(new Uint32Array(remap));
countList.push(count);
indicesList.push(indices);

primRemaps.push(new Uint32Array(remap));
dstIndicesCount += srcIndicesCount;
}

// (3) Allocate joined attributes.
Expand All @@ -88,40 +93,42 @@ export function joinPrimitives(prims: Primitive[], options: JoinPrimitiveOptions
dstPrim.setIndices(dstIndices);

// (5) Remap attributes into joined Primitive.
let dstNextIndex = 0;
for (let primIndex = 0; primIndex < remapList.length; primIndex++) {
let dstIndicesOffset = 0;
for (let primIndex = 0; primIndex < primRemaps.length; primIndex++) {
const srcPrim = prims[primIndex];
const remap = remapList[primIndex];
const indicesArray = indicesList[primIndex];
const srcVertexCount = srcPrim.getAttribute('POSITION')!.getCount();
const srcIndices = srcPrim.getIndices();
const srcIndicesCount = srcIndices ? srcIndices.getCount() : -1;

const primStartIndex = dstNextIndex;
let primNextIndex = primStartIndex;
const remap = primRemaps[primIndex];

if (srcIndices && dstIndices) {
remapIndices(srcIndices, remap, dstIndicesOffset, srcIndicesCount, dstIndices);
}

for (const semantic of dstPrim.listSemantics()) {
const srcAttribute = srcPrim.getAttribute(semantic)!;
const dstAttribute = dstPrim.getAttribute(semantic)!;
const el = [] as number[];

primNextIndex = primStartIndex;
for (let i = 0; i < indicesArray.length; i++) {
const index = indicesArray[i];
srcAttribute.getElement(index, el);
dstAttribute.setElement(remap[index], el);
if (dstIndices) {
dstIndices.setScalar(primNextIndex++, remap[index]);
}
}
remapAttribute(srcAttribute, remap, srcVertexCount, dstAttribute);
}

dstNextIndex = primNextIndex;
dstIndicesOffset += srcIndicesCount;
}

return dstPrim;
}

function _getOrCreateIndices(prim: Primitive): Uint16Array | Uint32Array {
function getIndicesMax(prim: Primitive): number {
const indices = prim.getIndices();
if (indices) return indices.getArray() as Uint32Array | Uint16Array;
const position = prim.getAttribute('POSITION')!;
return createIndices(position.getCount());
if (!indices) return position.getCount() - 1;

const indicesArray = indices.getArray()!;
const indicesCount = indices.getCount();

let indicesMax = -1;
for (let i = 0; i < indicesCount; i++) {
indicesMax = Math.max(indicesMax, indicesArray[i]);
}
return indicesMax;
}
41 changes: 35 additions & 6 deletions packages/functions/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,19 @@ export function remapPrimitive(prim: Primitive, remap: TypedArray, dstVertexCoun
}

/** @hidden */
export function remapAttribute(attribute: Accessor, remap: TypedArray, dstCount: number): Accessor {
const elementSize = attribute.getElementSize();
const srcCount = attribute.getCount();
const srcArray = attribute.getArray()!;
const dstArray = srcArray.slice(0, dstCount * elementSize);
export function remapAttribute(
srcAttribute: Accessor,
remap: TypedArray,
dstCount: number,
dstAttribute = srcAttribute,
): Accessor {
const elementSize = srcAttribute.getElementSize();
const srcCount = srcAttribute.getCount();
const srcArray = srcAttribute.getArray()!;
// prettier-ignore
const dstArray = dstAttribute === srcAttribute
? srcArray.slice(0, dstCount * elementSize)
: dstAttribute.getArray()!;
const done = new Uint8Array(dstCount);

for (let srcIndex = 0; srcIndex < srcCount; srcIndex++) {
Expand All @@ -253,7 +261,28 @@ export function remapAttribute(attribute: Accessor, remap: TypedArray, dstCount:
done[dstIndex] = 1;
}

return attribute.setArray(dstArray);
return dstAttribute.setArray(dstArray);
}

/** @hidden */
export function remapIndices(
srcIndices: Accessor,
remap: TypedArray,
dstOffset: number,
dstCount: number,
dstIndices = srcIndices,
): Accessor {
const srcCount = srcIndices.getCount();
const srcArray = srcIndices.getArray()!;
const dstArray = dstIndices === srcIndices ? srcArray.slice(0, dstCount) : dstIndices.getArray()!;

for (let i = 0; i < srcCount; i++) {
const srcIndex = srcArray[i];
const dstIndex = remap[srcIndex];
dstArray[dstOffset + i] = dstIndex;
}

return dstIndices.setArray(dstArray);
}

/** @hidden */
Expand Down
16 changes: 8 additions & 8 deletions packages/functions/src/weld.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import {
const NAME = 'weld';

/** Flags 'empty' values in a Uint32Array index. */
const EMPTY = 2 ** 32 - 1;
const EMPTY_U32 = 2 ** 32 - 1;

const Tolerance = {
DEFAULT: 0,
Expand Down Expand Up @@ -199,21 +199,21 @@ function _weldPrimitiveStrict(document: Document, prim: Primitive): void {

const hash = new HashTable(prim);
const tableSize = ceilPowerOfTwo(srcVertexCount + srcVertexCount / 4);
const table = new Uint32Array(tableSize).fill(EMPTY);
const writeMap = new Uint32Array(srcVertexCount).fill(EMPTY); // oldIndex → newIndex
const table = new Uint32Array(tableSize).fill(EMPTY_U32);
const writeMap = new Uint32Array(srcVertexCount).fill(EMPTY_U32); // oldIndex → newIndex

// (1) Compare and identify indices to weld.

let dstVertexCount = 0;

for (let i = 0; i < srcIndicesCount; i++) {
const srcIndex = srcIndicesArray ? srcIndicesArray[i] : i;
if (writeMap[srcIndex] !== EMPTY) continue;
if (writeMap[srcIndex] !== EMPTY_U32) continue;

const hashIndex = hashLookup(table, tableSize, hash, srcIndex, EMPTY);
const hashIndex = hashLookup(table, tableSize, hash, srcIndex, EMPTY_U32);
const dstIndex = table[hashIndex];

if (dstIndex === EMPTY) {
if (dstIndex === EMPTY_U32) {
table[hashIndex] = srcIndex;
writeMap[srcIndex] = dstVertexCount++;
} else {
Expand Down Expand Up @@ -263,7 +263,7 @@ function _weldPrimitive(document: Document, prim: Primitive, options: Required<W

const srcMaxIndex = uniqueIndices[uniqueIndices.length - 1];
const weldMap = createIndices(srcMaxIndex + 1); // oldIndex → oldCommonIndex
const writeMap = new Uint32Array(uniqueIndices.length).fill(EMPTY); // oldIndex → newIndex
const writeMap = new Uint32Array(uniqueIndices.length).fill(EMPTY_U32); // oldIndex → newIndex

const srcVertexCount = srcPosition.getCount();
let dstVertexCount = 0;
Expand Down Expand Up @@ -498,7 +498,7 @@ export function murmurHash2(h: number, key: Uint32Array): number {
return h;
}

function hashLookup(table: Uint32Array, buckets: number, hash: HashTable, key: number, empty = EMPTY): number {
function hashLookup(table: Uint32Array, buckets: number, hash: HashTable, key: number, empty = EMPTY_U32): number {
const hashmod = buckets - 1;
const hashval = hash.hash(key);
let bucket = hashval & hashmod;
Expand Down
2 changes: 1 addition & 1 deletion packages/functions/test/join-primitives.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ test('indexed', async (t) => {
0, 0, 0, 0,
], 'position data');

t.is(primAB.getIndices().getCount(), 6, 'indices data');
t.deepEqual(Array.from(primAB.getIndices().getArray()), [0, 1, 2, 3, 4, 5], 'indices data');
});

function createPrimA(document: Document): [Primitive, Accessor, Accessor] {
Expand Down

0 comments on commit e9e618c

Please sign in to comment.