Skip to content

Commit

Permalink
Small optimization to skip restitution in SIMD if no lanes have resti…
Browse files Browse the repository at this point in the history
…tution

Also fixed bug where contacts with no restitution were affected by
the restitution solver.
This affected the determinism test result.
  • Loading branch information
erincatto committed Feb 26, 2025
1 parent 577a10d commit 82ee95d
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 30 deletions.
1 change: 0 additions & 1 deletion samples/sample_bodies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,6 @@ class Weeble : public Sample

b2Capsule capsule = { { 0.0f, -1.0f }, { 0.0f, 1.0f }, 1.0f };
b2ShapeDef shapeDef = b2DefaultShapeDef();
shapeDef.density = 1.0f;
b2CreateCapsuleShape( m_weebleId, &shapeDef, &capsule );

float mass = b2Body_GetMass( m_weebleId );
Expand Down
1 change: 1 addition & 0 deletions src/constraint_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ typedef struct b2GraphColor
// This bitset is indexed by bodyId so this is over-sized to encompass static bodies
// however I never traverse these bits or use the bit count for anything
// This bitset is unused on the overflow color.
// todo consider having a uint_16 per body that tracks the graph color membership
b2BitSet bodySet;

// cache friendly arrays
Expand Down
116 changes: 90 additions & 26 deletions src/contact_solver.c
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ static inline b2FloatW b2MaxW( b2FloatW a, b2FloatW b )
static inline b2FloatW b2SymClampW( b2FloatW a, b2FloatW b )
{
b2FloatW nb = _mm256_sub_ps( _mm256_setzero_ps(), b );
return _mm256_max_ps(nb, _mm256_min_ps( a, b ));
return _mm256_max_ps( nb, _mm256_min_ps( a, b ) );
}

static inline b2FloatW b2OrW( b2FloatW a, b2FloatW b )
Expand All @@ -621,6 +621,19 @@ static inline b2FloatW b2EqualsW( b2FloatW a, b2FloatW b )
return _mm256_cmp_ps( a, b, _CMP_EQ_OQ );
}

static inline bool b2AllZeroW( b2FloatW a )
{
// Compare each element with zero
b2FloatW zero = _mm256_setzero_ps();
b2FloatW cmp = _mm256_cmp_ps( a, zero, _CMP_EQ_OQ );

// Create a mask from the comparison results
int mask = _mm256_movemask_ps( cmp );

// If all elements are zero, the mask will be 0xFF (11111111 in binary)
return mask == 0xFF;
}

// component-wise returns mask ? b : a
static inline b2FloatW b2BlendW( b2FloatW a, b2FloatW b, b2FloatW mask )
{
Expand Down Expand Up @@ -702,6 +715,25 @@ static inline b2FloatW b2EqualsW( b2FloatW a, b2FloatW b )
return vreinterpretq_f32_u32( vceqq_f32( a, b ) );
}

static inline bool b2AllZeroW( b2FloatW a )
{
// Create a zero vector for comparison
b2FloatW zero = vdupq_n_f32( 0.0f );

// Compare the input vector with zero
uint32x4_t cmp_result = vceqq_f32( value, zero );

// Check if all comparison results are non-zero using vminvq
#ifdef __ARM_FEATURE_SVE
// ARM v8.2+ has horizontal minimum instruction
return vminvq_u32( cmp_result ) != 0;
#else
// For older ARM architectures, we need to manually check all lanes
return vgetq_lane_u32( cmp_result, 0 ) != 0 && vgetq_lane_u32( cmp_result, 1 ) != 0 && vgetq_lane_u32( cmp_result, 2 ) != 0 &&
vgetq_lane_u32( cmp_result, 3 ) != 0;
#endif
}

// component-wise returns mask ? b : a
static inline b2FloatW b2BlendW( b2FloatW a, b2FloatW b, b2FloatW mask )
{
Expand Down Expand Up @@ -822,6 +854,19 @@ static inline b2FloatW b2EqualsW( b2FloatW a, b2FloatW b )
return _mm_cmpeq_ps( a, b );
}

static inline bool b2AllZeroW( b2FloatW a )
{
// Compare each element with zero
__m128 zero = _mm_setzero_ps();
__m128 cmp = _mm_cmpeq_ps( a, zero );

// Create a mask from the comparison results
int mask = _mm_movemask_ps( cmp );

// If all elements are zero, the mask will be 0xF (1111 in binary)
return mask == 0xF;
}

// component-wise returns mask ? b : a
static inline b2FloatW b2BlendW( b2FloatW a, b2FloatW b, b2FloatW mask )
{
Expand Down Expand Up @@ -852,37 +897,37 @@ static inline b2FloatW b2UnpackHiW( b2FloatW a, b2FloatW b )

static inline b2FloatW b2ZeroW()
{
return ( b2FloatW ){ 0.0f, 0.0f, 0.0f, 0.0f };
return (b2FloatW){ 0.0f, 0.0f, 0.0f, 0.0f };
}

static inline b2FloatW b2SplatW( float scalar )
{
return ( b2FloatW ){ scalar, scalar, scalar, scalar };
return (b2FloatW){ scalar, scalar, scalar, scalar };
}

static inline b2FloatW b2AddW( b2FloatW a, b2FloatW b )
{
return ( b2FloatW ){ a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w };
return (b2FloatW){ a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w };
}

static inline b2FloatW b2SubW( b2FloatW a, b2FloatW b )
{
return ( b2FloatW ){ a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w };
return (b2FloatW){ a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w };
}

static inline b2FloatW b2MulW( b2FloatW a, b2FloatW b )
{
return ( b2FloatW ){ a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w };
return (b2FloatW){ a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w };
}

static inline b2FloatW b2MulAddW( b2FloatW a, b2FloatW b, b2FloatW c )
{
return ( b2FloatW ){ a.x + b.x * c.x, a.y + b.y * c.y, a.z + b.z * c.z, a.w + b.w * c.w };
return (b2FloatW){ a.x + b.x * c.x, a.y + b.y * c.y, a.z + b.z * c.z, a.w + b.w * c.w };
}

static inline b2FloatW b2MulSubW( b2FloatW a, b2FloatW b, b2FloatW c )
{
return ( b2FloatW ){ a.x - b.x * c.x, a.y - b.y * c.y, a.z - b.z * c.z, a.w - b.w * c.w };
return (b2FloatW){ a.x - b.x * c.x, a.y - b.y * c.y, a.z - b.z * c.z, a.w - b.w * c.w };
}

static inline b2FloatW b2MinW( b2FloatW a, b2FloatW b )
Expand All @@ -909,10 +954,10 @@ static inline b2FloatW b2MaxW( b2FloatW a, b2FloatW b )
static inline b2FloatW b2SymClampW( b2FloatW a, b2FloatW b )
{
b2FloatW r;
r.x = b2ClampFloat(a.x, -b.x, b.x);
r.y = b2ClampFloat(a.y, -b.y, b.y);
r.z = b2ClampFloat(a.z, -b.z, b.z);
r.w = b2ClampFloat(a.w, -b.w, b.w);
r.x = b2ClampFloat( a.x, -b.x, b.x );
r.y = b2ClampFloat( a.y, -b.y, b.y );
r.z = b2ClampFloat( a.z, -b.z, b.z );
r.w = b2ClampFloat( a.w, -b.w, b.w );
return r;
}

Expand Down Expand Up @@ -946,6 +991,11 @@ static inline b2FloatW b2EqualsW( b2FloatW a, b2FloatW b )
return r;
}

static inline bool b2AllZeroW( b2FloatW a )
{
return a.x == 0.0f && a.y == 0.0f && a.z == 0.0f && a.w == 0.0f;
}

// component-wise returns mask ? b : a
static inline b2FloatW b2BlendW( b2FloatW a, b2FloatW b, b2FloatW mask )
{
Expand All @@ -971,7 +1021,7 @@ static inline b2FloatW b2CrossW( b2Vec2W a, b2Vec2W b )

static inline b2Vec2W b2RotateVectorW( b2RotW q, b2Vec2W v )
{
return ( b2Vec2W ){ b2SubW( b2MulW( q.C, v.X ), b2MulW( q.S, v.Y ) ), b2AddW( b2MulW( q.S, v.X ), b2MulW( q.C, v.Y ) ) };
return (b2Vec2W){ b2SubW( b2MulW( q.C, v.X ), b2MulW( q.S, v.Y ) ), b2AddW( b2MulW( q.S, v.X ), b2MulW( q.C, v.Y ) ) };
}

// Soft contact constraints with sub-stepping support
Expand Down Expand Up @@ -1331,21 +1381,23 @@ static b2BodyStateW b2GatherBodies( const b2BodyState* B2_RESTRICT states, int*
b2BodyState s4 = indices[3] == B2_NULL_INDEX ? identity : states[indices[3]];

b2BodyStateW simdBody;
simdBody.v.X = ( b2FloatW ){ s1.linearVelocity.x, s2.linearVelocity.x, s3.linearVelocity.x, s4.linearVelocity.x };
simdBody.v.Y = ( b2FloatW ){ s1.linearVelocity.y, s2.linearVelocity.y, s3.linearVelocity.y, s4.linearVelocity.y };
simdBody.w = ( b2FloatW ){ s1.angularVelocity, s2.angularVelocity, s3.angularVelocity, s4.angularVelocity };
simdBody.flags = ( b2FloatW ){ (float)s1.flags, (float)s2.flags, (float)s3.flags, (float)s4.flags };
simdBody.dp.X = ( b2FloatW ){ s1.deltaPosition.x, s2.deltaPosition.x, s3.deltaPosition.x, s4.deltaPosition.x };
simdBody.dp.Y = ( b2FloatW ){ s1.deltaPosition.y, s2.deltaPosition.y, s3.deltaPosition.y, s4.deltaPosition.y };
simdBody.dq.C = ( b2FloatW ){ s1.deltaRotation.c, s2.deltaRotation.c, s3.deltaRotation.c, s4.deltaRotation.c };
simdBody.dq.S = ( b2FloatW ){ s1.deltaRotation.s, s2.deltaRotation.s, s3.deltaRotation.s, s4.deltaRotation.s };
simdBody.v.X = (b2FloatW){ s1.linearVelocity.x, s2.linearVelocity.x, s3.linearVelocity.x, s4.linearVelocity.x };
simdBody.v.Y = (b2FloatW){ s1.linearVelocity.y, s2.linearVelocity.y, s3.linearVelocity.y, s4.linearVelocity.y };
simdBody.w = (b2FloatW){ s1.angularVelocity, s2.angularVelocity, s3.angularVelocity, s4.angularVelocity };
simdBody.flags = (b2FloatW){ (float)s1.flags, (float)s2.flags, (float)s3.flags, (float)s4.flags };
simdBody.dp.X = (b2FloatW){ s1.deltaPosition.x, s2.deltaPosition.x, s3.deltaPosition.x, s4.deltaPosition.x };
simdBody.dp.Y = (b2FloatW){ s1.deltaPosition.y, s2.deltaPosition.y, s3.deltaPosition.y, s4.deltaPosition.y };
simdBody.dq.C = (b2FloatW){ s1.deltaRotation.c, s2.deltaRotation.c, s3.deltaRotation.c, s4.deltaRotation.c };
simdBody.dq.S = (b2FloatW){ s1.deltaRotation.s, s2.deltaRotation.s, s3.deltaRotation.s, s4.deltaRotation.s };

return simdBody;
}

// This writes only the velocities back to the solver bodies
static void b2ScatterBodies( b2BodyState* B2_RESTRICT states, int* B2_RESTRICT indices, const b2BodyStateW* B2_RESTRICT simdBody )
{
// todo somehow skip writing to kinematic bodies

if ( indices[0] != B2_NULL_INDEX )
{
b2BodyState* state = states + indices[0];
Expand Down Expand Up @@ -1662,8 +1714,8 @@ void b2WarmStartContactsTask( int startIndex, int endIndex, b2StepContext* conte
bB.v.Y = b2MulAddW( bB.v.Y, c->invMassB, P.Y );
}

bA.w = b2MulSubW(bA.w, c->invIA, c->rollingImpulse);
bB.w = b2MulAddW(bB.w, c->invIB, c->rollingImpulse);
bA.w = b2MulSubW( bA.w, c->invIA, c->rollingImpulse );
bB.w = b2MulAddW( bB.w, c->invIB, c->rollingImpulse );

b2ScatterBodies( states, c->indexA, &bA );
b2ScatterBodies( states, c->indexB, &bB );
Expand Down Expand Up @@ -1895,11 +1947,11 @@ void b2SolveContactsTask( int startIndex, int endIndex, b2StepContext* context,

// Rolling resistance
{
b2FloatW deltaLambda = b2MulW( c->rollingMass, b2SubW( bA.w, bB.w ));
b2FloatW deltaLambda = b2MulW( c->rollingMass, b2SubW( bA.w, bB.w ) );
b2FloatW lambda = c->rollingImpulse;
b2FloatW maxLambda = b2MulW( c->rollingResistance, totalNormalImpulse );
c->rollingImpulse = b2SymClampW( b2AddW(lambda, deltaLambda), maxLambda );
deltaLambda = b2SubW(c->rollingImpulse, lambda);
c->rollingImpulse = b2SymClampW( b2AddW( lambda, deltaLambda ), maxLambda );
deltaLambda = b2SubW( c->rollingImpulse, lambda );

bA.w = b2MulSubW( bA.w, c->invIA, deltaLambda );
bB.w = b2MulAddW( bB.w, c->invIB, deltaLambda );
Expand All @@ -1925,6 +1977,16 @@ void b2ApplyRestitutionTask( int startIndex, int endIndex, b2StepContext* contex
{
b2ContactConstraintSIMD* c = constraints + i;

if ( b2AllZeroW( c->restitution ) )
{
// No lanes have restitution. Common case.
continue;
}

// Create a mask based on restitution so that lanes with no restitution are not affected
// by the calculations below.
b2FloatW restitutionMask = b2GreaterThanW( c->restitution, zero );

b2BodyStateW bA = b2GatherBodies( states, c->indexA );
b2BodyStateW bB = b2GatherBodies( states, c->indexB );

Expand All @@ -1935,6 +1997,7 @@ void b2ApplyRestitutionTask( int startIndex, int endIndex, b2StepContext* contex
b2FloatW mask2 = b2EqualsW( c->totalNormalImpulse1, zero );
b2FloatW mask = b2OrW( mask1, mask2 );
b2FloatW mass = b2BlendW( c->normalMass1, zero, mask );
mass = b2BlendW( zero, mass, restitutionMask );

// fixed anchors for Jacobians
b2Vec2W rA = c->anchorA1;
Expand Down Expand Up @@ -1973,6 +2036,7 @@ void b2ApplyRestitutionTask( int startIndex, int endIndex, b2StepContext* contex
b2FloatW mask2 = b2EqualsW( c->totalNormalImpulse2, zero );
b2FloatW mask = b2OrW( mask1, mask2 );
b2FloatW mass = b2BlendW( c->normalMass2, zero, mask );
mass = b2BlendW( zero, mass, restitutionMask );

// fixed anchors for Jacobians
b2Vec2W rA = c->anchorA2;
Expand Down
2 changes: 1 addition & 1 deletion src/geometry.c
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ b2CastOutput b2RayCastSegment( const b2RayCastInput* input, const b2Segment* sha
}

output.fraction = t;
output.point = b2MulAdd( p1, t, d );
output.point = p;
output.normal = normal;
output.hit = true;

Expand Down
2 changes: 2 additions & 0 deletions src/island.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ void b2LinkContact( b2World* world, b2Contact* contact )
{
b2AddContactToIsland( world, islandIdB, contact );
}

// todo why not merge the islands right here?
}

// This is called when a contact no longer has contact points or when a contact is destroyed.
Expand Down
1 change: 1 addition & 0 deletions src/island.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ typedef struct b2Island
int jointCount;

// Union find
// todo this could go away if islands are merged immediately with b2LinkJoint and b2LinkContact
int parentIsland;

// Keeps track of how many contacts have been removed from this island.
Expand Down
4 changes: 2 additions & 2 deletions test/test_determinism.c
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ static int CrossPlatformTest(void)
}

ENSURE( stepCount < maxSteps );
ENSURE( sleepStep == 383 );
ENSURE( hash == 0xfeb0cd4e );
ENSURE( sleepStep == 304 );
ENSURE( hash == 0xd6ddbc8d );

free( bodies );

Expand Down

0 comments on commit 82ee95d

Please sign in to comment.