Skip to content

Commit

Permalink
TLAS in progress; still needs inst id for hit record.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbikker committed Jan 11, 2025
1 parent a085e2c commit 6db12ae
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 23 deletions.
166 changes: 146 additions & 20 deletions tiny_bvh.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ class BVH : public BVHBase
void Build( const bvhvec4slice& vertices );
void Build( const bvhvec4* vertices, const uint32_t* indices, const uint32_t primCount );
void Build( const bvhvec4slice& vertices, const uint32_t* indices, const uint32_t primCount );
void Build( const BLASInstance* bvhs, const uint32_t instCount );
void BuildHQ( const bvhvec4* vertices, const uint32_t primCount );
void BuildHQ( const bvhvec4slice& vertices );
void BuildHQ( const bvhvec4* vertices, const uint32_t* indices, const uint32_t primCount );
Expand All @@ -595,17 +596,16 @@ class BVH : public BVHBase
void BuildNEON( const bvhvec4* vertices, const uint32_t primCount );
void BuildNEON( const bvhvec4slice& vertices );
#endif
void BuildTLAS( const bvhaabb* aabbs, const uint32_t aabbCount );
void BuildTLAS( const BLASInstance* bvhs, const uint32_t instCount );
void Refit( const uint32_t nodeIdx = 0 );
int32_t Intersect( Ray& ray ) const;
int32_t IntersectTLAS( Ray& ray ) const;
bool IsOccluded( const Ray& ray ) const;
void Intersect256Rays( Ray* first ) const;
void Intersect256RaysSSE( Ray* packet ) const; // requires BVH_USEAVX
private:
void PrepareBuild( const bvhvec4slice& vertices, const uint32_t* indices, const uint32_t primCount );
void Build();
bool IsOccludedTLAS( const Ray& ray ) const;
int32_t IntersectTLAS( Ray& ray ) const;
void PrepareAVXBuild( const bvhvec4slice& vertices, const uint32_t* indices, const uint32_t primCount );
void BuildAVX();
void PrepareHQBuild( const bvhvec4slice& vertices, const uint32_t* indices, const uint32_t prims );
Expand All @@ -622,10 +622,16 @@ class BVH : public BVHBase
void BuildDefault( const bvhvec4* vertices, const uint32_t* indices, const uint32_t primCount );
void BuildDefault( const bvhvec4slice& vertices, const uint32_t* indices, const uint32_t primCount );
public:
// BVH type identification
bool isTLAS() const { return instList != 0; }
bool isBLAS() const { return instList == 0; }
bool isIndexed() const { return vertIdx != 0; }
bool hasCustomGeom() const { return customIntersect != 0; }
// Basic BVH data
bvhvec4slice verts = {}; // pointer to input primitive array: 3x16 bytes per tri.
uint32_t* vertIdx = 0; // vertex indices, only used in case the BVH is built over indexed prims.
uint32_t* triIdx = 0; // primitive index array.
BLASInstance* instList = 0; // instance array, for top-level acceleration structure.
BVHNode* bvhNode = 0; // BVH node pool, Wald 32-byte format. Root is always in node 0.
uint32_t newNodePtr = 0; // used during build to keep track of next free node in pool.
Fragment* fragment = 0; // input primitive bounding boxes.
Expand Down Expand Up @@ -926,9 +932,11 @@ class BLASInstance
BLASInstance() = default;
BLASInstance( BVH* bvh ) : blas( bvh ) {}
float transform[16] = { 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 }; // identity
float invTransform[16] = { 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 }; // identity
BVH* blas = 0; // Bottom-level acceleration structure.
bvhvec3 TransformPoint( const bvhvec3& v ) const;
bvhvec3 TransformVector( const bvhvec3& v ) const;
bvhvec3 TransformPoint( const bvhvec3& v, const float* T ) const;
bvhvec3 TransformVector( const bvhvec3& v, const float* T ) const;
void InvertTransform();
};

// Experimental & 'under construction' structs
Expand Down Expand Up @@ -1322,9 +1330,9 @@ void BVH::Build( const bvhvec4slice& vertices, const uint32_t* indices, uint32_t
bvh_over_indices = true;
}

void BVH::BuildTLAS( const BLASInstance* bvhs, const uint32_t instCount )
void BVH::Build( const BLASInstance* bvhs, const uint32_t instCount )
{
FATAL_ERROR_IF( instCount == 0, "BVH::BuildTLAS( .. ), instCount == 0." );
FATAL_ERROR_IF( instCount == 0, "BVH::Build( BLASInstance*, instCount ), instCount == 0." );
triCount = idxCount = instCount;
const uint32_t spaceNeeded = instCount * 2; // upper limit
if (allocatedNodes < spaceNeeded)
Expand All @@ -1338,19 +1346,22 @@ void BVH::BuildTLAS( const BLASInstance* bvhs, const uint32_t instCount )
triIdx = (uint32_t*)AlignedAlloc( instCount * sizeof( uint32_t ) );
fragment = (Fragment*)AlignedAlloc( instCount * sizeof( Fragment ) );
}
instList = (BLASInstance*)bvhs;
// copy relevant data from instance array
BVHNode& root = bvhNode[0];
root.leftFirst = 0, root.triCount = instCount, root.aabbMin = bvhvec3( BVH_FAR ), root.aabbMax = bvhvec3( -BVH_FAR );
for (uint32_t i = 0; i < instCount; i++)
{
FATAL_ERROR_IF( instList[i].blas->bvhNode == 0, "BVH::Build( BLASInstance*, .. ), BLAS not built." );
// transform the eight corners of the root node aabb using the instance
// transform and calculate the worldspace aabb over those.
instList[i].InvertTransform(); // TODO: done unconditionally; for a big TLAS this may be wasteful. Detect changes automatically?
bvhvec3 minBounds = bvhvec3( BVH_FAR ), maxBounds = bvhvec3( -BVH_FAR );
bvhvec3 bmin = bvhs[i].blas->bvhNode[0].aabbMin, bmax = bvhs[i].blas->bvhNode[0].aabbMax;
bvhvec3 bmin = instList[i].blas->bvhNode[0].aabbMin, bmax = instList[i].blas->bvhNode[0].aabbMax;
for (int32_t j = 0; j < 8; j++)
{
const bvhvec3 p( j & 1 ? bmax.x : bmin.x, j & 2 ? bmax.y : bmin.y, j & 4 ? bmax.z : bmin.z );
const bvhvec3 t = bvhs[i].TransformPoint( p );
const bvhvec3 t = instList[i].TransformPoint( p, instList[i].transform );
minBounds = tinybvh_min( minBounds, t ), maxBounds = tinybvh_max( maxBounds, t );
}
fragment[i].bmin = minBounds, fragment[i].primIdx = i;
Expand All @@ -1360,7 +1371,7 @@ void BVH::BuildTLAS( const BLASInstance* bvhs, const uint32_t instCount )
}
// start build
newNodePtr = 2;
Build();
Build(); // or BuildAVX, for large TLAS.
}

void BVH::PrepareBuild( const bvhvec4slice& vertices, const uint32_t* indices, const uint32_t prims )
Expand Down Expand Up @@ -1909,6 +1920,7 @@ void BVH::Refit( const uint32_t nodeIdx )

int32_t BVH::Intersect( Ray& ray ) const
{
if (isTLAS()) return IntersectTLAS( ray );
BVHNode* node = &bvhNode[0], * stack[64];
uint32_t stackPtr = 0, cost = 0;
while (1)
Expand Down Expand Up @@ -1942,8 +1954,54 @@ int32_t BVH::Intersect( Ray& ray ) const
return cost;
}

int32_t BVH::IntersectTLAS( Ray& ray ) const
{
BVHNode* node = &bvhNode[0], * stack[64];
uint32_t stackPtr = 0, cost = 0;
while (1)
{
cost += C_TRAV;
if (node->isLeaf())
{
Ray tmp;
for (uint32_t i = 0; i < node->triCount; i++)
{
// BLAS traversal
BLASInstance& inst = instList[triIdx[node->leftFirst + i]];
BVH* blas = inst.blas;
// 1. Transform ray with the inverse of the instance transform
tmp.O = inst.TransformPoint( ray.O, inst.invTransform );
tmp.D = inst.TransformVector( ray.D, inst.invTransform );
tmp.rD = tinybvh_safercp( tmp.D );
tmp.hit = ray.hit;
// 2. Traverse BLAS with the transformed ray
cost += blas->Intersect( tmp );
// 3. Restore ray
ray.hit = tmp.hit;
}
if (stackPtr == 0) break; else node = stack[--stackPtr];
continue;
}
BVHNode* child1 = &bvhNode[node->leftFirst];
BVHNode* child2 = &bvhNode[node->leftFirst + 1];
float dist1 = child1->Intersect( ray ), dist2 = child2->Intersect( ray );
if (dist1 > dist2) { tinybvh_swap( dist1, dist2 ); tinybvh_swap( child1, child2 ); }
if (dist1 == BVH_FAR /* missed both child nodes */)
{
if (stackPtr == 0) break; else node = stack[--stackPtr];
}
else /* hit at least one node */
{
node = child1; /* continue with the nearest */
if (dist2 != BVH_FAR) stack[stackPtr++] = child2; /* push far child */
}
}
return cost;
}

bool BVH::IsOccluded( const Ray& ray ) const
{
if (isTLAS()) return IsOccludedTLAS( ray );
BVHNode* node = &bvhNode[0], * stack[64];
uint32_t stackPtr = 0;
while (1)
Expand Down Expand Up @@ -1985,6 +2043,48 @@ bool BVH::IsOccluded( const Ray& ray ) const
return false;
}

bool BVH::IsOccludedTLAS( const Ray& ray ) const
{
BVHNode* node = &bvhNode[0], * stack[64];
uint32_t stackPtr = 0;
Ray tmp;
tmp.hit = ray.hit;
while (1)
{
if (node->isLeaf())
{
for (uint32_t i = 0; i < node->triCount; i++)
{
// BLAS traversal
BLASInstance& inst = instList[triIdx[node->leftFirst + i]];
BVH* blas = inst.blas;
// 1. Transform ray with the inverse of the instance transform
tmp.O = inst.TransformPoint( ray.O, inst.invTransform );
tmp.D = inst.TransformVector( ray.D, inst.invTransform );
tmp.rD = tinybvh_safercp( tmp.D );
// 2. Traverse BLAS with the transformed ray
if (blas->IsOccluded( tmp )) return true;
}
if (stackPtr == 0) break; else node = stack[--stackPtr];
continue;
}
BVHNode* child1 = &bvhNode[node->leftFirst];
BVHNode* child2 = &bvhNode[node->leftFirst + 1];
float dist1 = child1->Intersect( ray ), dist2 = child2->Intersect( ray );
if (dist1 > dist2) { tinybvh_swap( dist1, dist2 ); tinybvh_swap( child1, child2 ); }
if (dist1 == BVH_FAR /* missed both child nodes */)
{
if (stackPtr == 0) break; else node = stack[--stackPtr];
}
else /* hit at least one node */
{
node = child1; /* continue with the nearest */
if (dist2 != BVH_FAR) stack[stackPtr++] = child2; /* push far child */
}
}
return false;
}

// Intersect a WALD_32BYTE BVH with a ray packet.
// The 256 rays travel together to better utilize the caches and to amortize the cost
// of memory transfers over the rays in the bundle.
Expand Down Expand Up @@ -5687,22 +5787,48 @@ double BVH_Double::BVHNode::Intersect( const RayEx& ray ) const
// ============================================================================

// TransformPoint
bvhvec3 BLASInstance::TransformPoint( const bvhvec3& v ) const
bvhvec3 BLASInstance::TransformPoint( const bvhvec3& v, const float* T ) const
{
const bvhvec3 res(
transform[0] * v.x + transform[1] * v.y + transform[2] * v.z + transform[3],
transform[4] * v.x + transform[5] * v.y + transform[6] * v.z + transform[7],
transform[8] * v.x + transform[9] * v.y + transform[10] * v.z + transform[11] );
const float w = transform[12] * v.x + transform[13] * v.y + transform[14] * v.z + transform[15];
T[0] * v.x + T[1] * v.y + T[2] * v.z + T[3],
T[4] * v.x + T[5] * v.y + T[6] * v.z + T[7],
T[8] * v.x + T[9] * v.y + T[10] * v.z + T[11] );
const float w = T[12] * v.x + T[13] * v.y + T[14] * v.z + T[15];
if (w == 1) return res; else return res * (1.f / w);
}

// TransformVector - skips translation. Assumes orthonormal transform, for now.
bvhvec3 BLASInstance::TransformVector( const bvhvec3& v ) const
{
return bvhvec3( transform[0] * v.x + transform[1] * v.y + transform[2] * v.z,
transform[4] * v.x + transform[5] * v.y + transform[6] * v.z,
transform[8] * v.x + transform[9] * v.y + transform[10] * v.z );
bvhvec3 BLASInstance::TransformVector( const bvhvec3& v, const float* T ) const
{
return bvhvec3( T[0] * v.x + T[1] * v.y + T[2] * v.z, T[4] * v.x +
T[5] * v.y + T[6] * v.z, T[8] * v.x + T[9] * v.y + T[10] * v.z );
}

// InvertTransform - calculate the inverse of the matrix stored in 'transform'
void BLASInstance::InvertTransform()
{
// math from MESA, via http://stackoverflow.com/questions/1148309/inverting-a-4x4-matrix
const float* T = this->transform;
invTransform[0] = T[5] * T[10] * T[15] - T[5] * T[11] * T[14] - T[9] * T[6] * T[15] + T[9] * T[7] * T[14] + T[13] * T[6] * T[11] - T[13] * T[7] * T[10];
invTransform[1] = -T[1] * T[10] * T[15] + T[1] * T[11] * T[14] + T[9] * T[2] * T[15] - T[9] * T[3] * T[14] - T[13] * T[2] * T[11] + T[13] * T[3] * T[10];
invTransform[2] = T[1] * T[6] * T[15] - T[1] * T[7] * T[14] - T[5] * T[2] * T[15] + T[5] * T[3] * T[14] + T[13] * T[2] * T[7] - T[13] * T[3] * T[6];
invTransform[3] = -T[1] * T[6] * T[11] + T[1] * T[7] * T[10] + T[5] * T[2] * T[11] - T[5] * T[3] * T[10] - T[9] * T[2] * T[7] + T[9] * T[3] * T[6];
invTransform[4] = -T[4] * T[10] * T[15] + T[4] * T[11] * T[14] + T[8] * T[6] * T[15] - T[8] * T[7] * T[14] - T[12] * T[6] * T[11] + T[12] * T[7] * T[10];
invTransform[5] = T[0] * T[10] * T[15] - T[0] * T[11] * T[14] - T[8] * T[2] * T[15] + T[8] * T[3] * T[14] + T[12] * T[2] * T[11] - T[12] * T[3] * T[10];
invTransform[6] = -T[0] * T[6] * T[15] + T[0] * T[7] * T[14] + T[4] * T[2] * T[15] - T[4] * T[3] * T[14] - T[12] * T[2] * T[7] + T[12] * T[3] * T[6];
invTransform[7] = T[0] * T[6] * T[11] - T[0] * T[7] * T[10] - T[4] * T[2] * T[11] + T[4] * T[3] * T[10] + T[8] * T[2] * T[7] - T[8] * T[3] * T[6];
invTransform[8] = T[4] * T[9] * T[15] - T[4] * T[11] * T[13] - T[8] * T[5] * T[15] + T[8] * T[7] * T[13] + T[12] * T[5] * T[11] - T[12] * T[7] * T[9];
invTransform[9] = -T[0] * T[9] * T[15] + T[0] * T[11] * T[13] + T[8] * T[1] * T[15] - T[8] * T[3] * T[13] - T[12] * T[1] * T[11] + T[12] * T[3] * T[9];
invTransform[10] = T[0] * T[5] * T[15] - T[0] * T[7] * T[13] - T[4] * T[1] * T[15] + T[4] * T[3] * T[13] + T[12] * T[1] * T[7] - T[12] * T[3] * T[5];
invTransform[11] = -T[0] * T[5] * T[11] + T[0] * T[7] * T[9] + T[4] * T[1] * T[11] - T[4] * T[3] * T[9] - T[8] * T[1] * T[7] + T[8] * T[3] * T[5];
invTransform[12] = -T[4] * T[9] * T[14] + T[4] * T[10] * T[13] + T[8] * T[5] * T[14] - T[8] * T[6] * T[13] - T[12] * T[5] * T[10] + T[12] * T[6] * T[9];
invTransform[13] = T[0] * T[9] * T[14] - T[0] * T[10] * T[13] - T[8] * T[1] * T[14] + T[8] * T[2] * T[13] + T[12] * T[1] * T[10] - T[12] * T[2] * T[9];
invTransform[14] = -T[0] * T[5] * T[14] + T[0] * T[6] * T[13] + T[4] * T[1] * T[14] - T[4] * T[2] * T[13] - T[12] * T[1] * T[6] + T[12] * T[2] * T[5];
invTransform[15] = T[0] * T[5] * T[10] - T[0] * T[6] * T[9] - T[4] * T[1] * T[10] + T[4] * T[2] * T[9] + T[8] * T[1] * T[6] - T[8] * T[2] * T[5];
const float det = T[0] * invTransform[0] + T[1] * invTransform[4] + T[2] * invTransform[8] + T[3] * invTransform[12];
if (det == 0) return; // actually, invert failed. That's bad.
const float invdet = 1.0f / det;
for (int i = 0; i < 16; i++) invTransform[i] *= invdet;
}

// SA
Expand Down
5 changes: 2 additions & 3 deletions tiny_bvh_anim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ void Init()
inst[1].transform[3 /* i.e., x translation */] = 4;
inst[2] = BLASInstance( &blas );
inst[2].transform[3 /* i.e., x translation */] = -4;
tlas.BuildTLAS( inst, 3 );
int w = 0;
tlas.Build( inst, 3 );
}

bool UpdateCamera( float delta_time_s, fenster& f )
Expand Down Expand Up @@ -79,7 +78,7 @@ void Tick( float delta_time_s, fenster& f, uint32_t* buf )
float u = (float)(tx * 4 + x) / SCRWIDTH, v = (float)(ty * 4 + y) / SCRHEIGHT;
bvhvec3 D = normalize( p1 + u * (p2 - p1) + v * (p3 - p1) - eye );
Ray ray( eye, D, 1e30f );
bvh.Intersect( ray );
tlas.Intersect( ray );
if (ray.hit.t < 10000)
{
int pixel_x = tx * 4 + x, pixel_y = ty * 4 + y, primIdx = ray.hit.prim;
Expand Down

0 comments on commit 6db12ae

Please sign in to comment.