Skip to content

Commit

Permalink
Remove repId and do validation in shader
Browse files Browse the repository at this point in the history
  • Loading branch information
jzm-intel committed Nov 1, 2024
1 parent f5f3456 commit f527cff
Showing 1 changed file with 55 additions and 62 deletions.
117 changes: 55 additions & 62 deletions src/webgpu/shader/execution/shader_io/fragment_builtins.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1578,23 +1578,6 @@ fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f {
const byteLength = bytesPerRow * blocksPerColumn;
const uintLength = byteLength / 4;

const buffer = t.makeBufferWithContents(
new Uint32Array([1]),
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
);

const bg = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer,
},
},
],
});

for (let i = 0; i < 2; i++) {
const framebuffer = t.createTextureTracked({
size: [width, height],
Expand All @@ -1617,7 +1600,6 @@ fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f {
],
});
pass.setPipeline(pipeline);
pass.setBindGroup(0, bg);
// Draw the uperr-left triangle (vertices 0-2) or the lower-right triangle (vertices 3-5)
pass.draw(3, 1, i * 3);
pass.end();
Expand Down Expand Up @@ -1660,15 +1642,11 @@ enable subgroups;
const width = ${t.params.size[0]};
const height = ${t.params.size[1]};
@group(0) @binding(0) var<storage, read_write> for_layout : u32;
@fragment
fn fsMain(
@builtin(position) pos : vec4f,
@builtin(subgroup_size) sg_size : u32,
) -> @location(0) vec4u {
_ = for_layout;
let ballot = countOneBits(subgroupBallot(true));
let ballotSize = ballot.x + ballot.y + ballot.z + ballot.w;
Expand Down Expand Up @@ -1700,6 +1678,10 @@ fn fsMain(
);
});

// A non-zero magic number indicating no expectation error, in order to prevent the false no-error
// result from zero-initialization.
const kSubgroupInvocationIdNoError = 17;

/**
* Checks subgroup_invocation_id value consistency
*
Expand All @@ -1710,8 +1692,9 @@ fn fsMain(
* @param data An array of vec4u that contains (per texel):
* * subgroup_invocation_id
* * subgroup size
* * ballot size
* * non-zero ID unique to each subgroup
* * ballot active invocation number
* * error flag, should be equal to kSubgroupInvocationIdNoError or shader found
* expection failed otherwise.
* @param format The texture format of data
* @param width The width of the framebuffer
* @param height The height of the framebuffer
Expand All @@ -1728,24 +1711,19 @@ function checkSubgroupInvocationIdConsistency(
const uintsPerRow = bytesPerRow / 4;
const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4;

const invocationIdBitmapOfSubgroups = new Map<number, bigint>();
const ballotSizeRecordOfSubgroups = new Map<
number,
{ ballotSize: number; row: number; col: number }
>();
for (let row = 0; row < height; row++) {
for (let col = 0; col < width; col++) {
const offset = uintsPerRow * row + col * uintsPerTexel;
const id = data[offset];
const sgSize = data[offset + 1];
const ballotSize = data[offset + 2];
const repId = data[offset + 3];
const error = data[offset + 3];

if (repId === 0) {
// repId of 0 indicates inactivate fragment, and all output should be zero.
if ((id !== 0) || (sgSize !== 0) || (ballotSize !== 0)) {
if (error === 0) {
// Inactive fragment get error `0` instead of noError. Check all output being zero.
if (id !== 0 || sgSize !== 0 || ballotSize !== 0) {
return new Error(
`Unexpected zero repId with non-zero outputs for (${row}, ${col}): got output [${id}, ${sgSize}, ${ballotSize}, ${repId}]`
`Unexpected zero error with non-zero outputs for (${row}, ${col}): got output [${id}, ${sgSize}, ${ballotSize}, ${error}]`
);
}
continue;
Expand All @@ -1763,28 +1741,14 @@ function checkSubgroupInvocationIdConsistency(
);
}

const ballotSizeRecord = ballotSizeRecordOfSubgroups.get(repId);
if (ballotSizeRecord === undefined) {
ballotSizeRecordOfSubgroups.set(repId, { ballotSize, row, col });
} else {
if (ballotSize !== ballotSizeRecord.ballotSize) {
return new Error(
`Inconsistent subgroup ballot size within same subgroup
- icoord: (${ballotSizeRecord.row}, ${ballotSizeRecord.col})
- got: ${ballotSizeRecord.ballotSize}
- icoord: (${row}, ${col})
- got: ${ballotSize}`
);
}
}

let invocationIdBitmap = invocationIdBitmapOfSubgroups.get(repId) ?? 0n;
const mask = 1n << BigInt(id);
if ((mask & invocationIdBitmap) !== 0n) {
return new Error(`Multiple invocations with id '${id}' in subgroup '${repId}'`);
if (error !== kSubgroupInvocationIdNoError) {
return new Error(
`Unexpected error value
- icoord: (${row}, ${col})
- expected: noError (${kSubgroupInvocationIdNoError})
- got: ${error}`
);
}
invocationIdBitmap |= mask;
invocationIdBitmapOfSubgroups.set(repId, invocationIdBitmap);
}
}

Expand All @@ -1809,22 +1773,51 @@ enable subgroups;
const width = ${t.params.size[0]};
const height = ${t.params.size[1]};
@group(0) @binding(0) var<storage, read_write> counter : atomic<u32>;
const maxSubgroupSize = 128u;
// A non-zero magic number indicating no expectation error, in order to prevent the
// false no-error result from zero-initialization.
const noError = ${kSubgroupInvocationIdNoError}u;
@fragment
fn fsMain(
@builtin(position) pos : vec4f,
@builtin(subgroup_invocation_id) id : u32,
@builtin(subgroup_size) sg_size : u32,
) -> @location(0) vec4u {
let ballot = countOneBits(subgroupBallot(true));
let ballotSize = ballot.x + ballot.y + ballot.z + ballot.w;
// Generate representative id for this subgroup.
var repId = atomicAdd(&counter, 1);
repId = subgroupBroadcast(repId, 0);
var error: u32 = noError;
// Validate that reported subgroup size is no larger than maxSubgroupSize
if (sg_size > maxSubgroupSize) {
error++;
}
// Validate that reported subgroup invocation id is smaller than subgroup size
if (id >= sg_size) {
error++;
}
// Validate that each subgroup id is assigned to at most one active invocation
// in the subgroup
var countAssignedId: u32 = 0u;
for (var i: u32 = 0; i < maxSubgroupSize; i++) {
let ballotIdEqualsI = countOneBits(subgroupBallot(id == i));
let countInvocationIdEqualsI = ballotIdEqualsI.x + ballotIdEqualsI.y + ballotIdEqualsI.z + ballotIdEqualsI.w;
// Validate an id assigned at most once
error += select(1u, 0u, countInvocationIdEqualsI <= 1);
// Validate id larger than subgroup size will not get balloted
error += select(1u, 0u, (id < sg_size) || (countInvocationIdEqualsI == 0));
// Sum up the assigned invocation number of each id
countAssignedId += countInvocationIdEqualsI;
}
// Validate that all active invocation get counted during the above loop
let ballotActive = countOneBits(subgroupBallot(true));
let activeInvocations = ballotActive.x + ballotActive.y + ballotActive.z + ballotActive.w;
if (activeInvocations != countAssignedId) {
error++;
}
return vec4u(id, sg_size, ballotSize, repId);
return vec4u(id, sg_size, activeInvocations, error);
}`;

await runSubgroupTest(
Expand Down

0 comments on commit f527cff

Please sign in to comment.