Skip to content

Latest commit

 

History

History
479 lines (407 loc) · 12 KB

game-of-life.md

File metadata and controls

479 lines (407 loc) · 12 KB
title head
WebGPU: Game of Life
<style> body { margin: 0; overflow: hidden; } p:has(> canvas) { margin: 0; } canvas { height: 100vh; width: 100vw; } </style>

// @run type="module"
const devicePixelRatio = window.devicePixelRatio;

const GameOptions = {
  width: Math.floor(window.innerWidth / 2),
  height: Math.floor(window.innerHeight / 2),
  workgroupSize: 8,
  timestep: 1,
  background: true
};

const computeShaderCode = `
    @binding(0) @group(0) var<storage, read> size: vec2u;
@binding(1) @group(0) var<storage, read> current: array<u32>;
@binding(2) @group(0) var<storage, read_write> next: array<u32>;

override blockSize = 8;

fn getIndex(x: u32, y: u32) -> u32 {
  let h = size.y;
  let w = size.x;

  return (y % h) * w + (x % w);
}

fn getCell(x: u32, y: u32) -> u32 {
  return current[getIndex(x, y)];
}

fn countNeighbors(x: u32, y: u32) -> u32 {
  return getCell(x - 1, y - 1) + getCell(x, y - 1) + getCell(x + 1, y - 1) + 
         getCell(x - 1, y) +                         getCell(x + 1, y) + 
         getCell(x - 1, y + 1) + getCell(x, y + 1) + getCell(x + 1, y + 1);
}

@compute @workgroup_size(blockSize, blockSize)
fn main(@builtin(global_invocation_id) grid: vec3u) {
  let x = grid.x;
  let y = grid.y;
  let n = countNeighbors(x, y);
  next[getIndex(x, y)] = select(u32(n == 3u), u32(n == 2u || n == 3u), getCell(x, y) == 1u); 
}`;

const vertexShaderCode = `
      struct Out {
  @builtin(position) pos: vec4f,
  @location(0) cell: f32,
  @location(1) uv: vec2f,
}

@binding(0) @group(0) var<uniform> size: vec2u;

@vertex
fn main(@builtin(instance_index) i: u32, @location(0) cell: u32, @location(1) pos: vec2u) -> Out {
  let w = size.x;
  let h = size.y;
  let x = (f32(i % w + pos.x) / f32(w) - 0.5) * 2.0;
  let y = (f32((i - (i % w)) / w + pos.y) / f32(h) - 0.5) * -2.0;
  
  let uv = vec2f(f32(i % w) / f32(w), f32(i / w) / f32(h));

  return Out(vec4f(x, y, 0., 1.), f32(cell), uv);
}`;

const fragmentShaderCode = `
@fragment
fn main(@location(0) cell: f32, @location(1) uv: vec2f) -> @location(0) vec4f {
  if (cell > 0.5) {
    // Use UV coordinates for coloring
    return vec4f(uv.x, uv.y, 1.0 - (uv.x + uv.y) * 0.5, 1.0);
  } else {
    return vec4f(0.0, 0.0, 0.0, 1.0);
  }
}
`;

class LazyWebGPU {
  constructor(props) {
    Object.assign(this, props);
  }

  static async init(selector, lowPower = false) {
    const canvas = document.querySelector(selector);
    const adapter = await navigator.gpu.requestAdapter({
      powerPreference: lowPower ? "low-power" : "high-performance"
    });

    if (!adapter) {
      throw Error("No WebGPU support");
    }

    const device = await adapter.requestDevice();
    const context = canvas.getContext('webgpu');

    canvas.width = canvas.clientWidth * devicePixelRatio;
    canvas.height = canvas.clientHeight * devicePixelRatio;

    const presentationFormat = navigator.gpu.getPreferredCanvasFormat();

    context.configure({
      device: device,
      format: presentationFormat,
      alphaMode: 'premultiplied',
    });

    return new LazyWebGPU({
      canvas,
      adapter,
      context,
      device,
      presentationFormat,
      commandEncoder: device.createCommandEncoder(),
    });
  }

  bindGroupLayout(entries, offset = 0) {
    return this.device.createBindGroupLayout({
      entries: entries.map((entry, index) => ({
        binding: offset + index,
        visibility: entry.visibility,
        buffer: {
          type: entry.bufferType,
        },
      })),
    });
  }

  bindGroupLayoutOfVisibility(visibility, bufferTypes, offset = 0) {
    return this.bindGroupLayout(
      bufferTypes.map((bufferType) => ({
        visibility,
        bufferType,
      })),
      offset
    );
  }

  createBuffer(constructor, items, usage, mapped = true) {
    const sizeBuffer = this.device.createBuffer({
      size: items.length * constructor.BYTES_PER_ELEMENT,
      usage,
      mappedAtCreation: mapped,
    });
    if (mapped) {
      new constructor(sizeBuffer.getMappedRange()).set(items);
      sizeBuffer.unmap();
    }
    return sizeBuffer;
  }

  createGroup(bindGroupLayout, buffers, offset = 0) {
    return this.device.createBindGroup({
      layout: bindGroupLayout,
      entries: buffers.map((buffer, i) => (
        {binding: offset + i, resource: {buffer}}
      )),
    });
  }

  createRenderGroupOfVisibility(visibility, entries, offset = 0) {
    const layout = this.bindGroupLayout(
      entries.map((e) => ({...e, visibility})),
      offset
    );
    const group = this.device.createBindGroup({
      layout,
      entries: entries.map(({resource}, i) => (
        {binding: offset + i, resource}
      )),
    });
    return {layout, group};
  }

  basicRenderPass() {
    const view = this.context.getCurrentTexture().createView();
    return {
      colorAttachments: [
        {
          view,
          loadOp: 'clear',
          clearValue: {r: 0.5, g: 0.5, b: 0.5, a: 1.0},
          storeOp: 'store',
        },
      ],
    };
  }

  computePipeline(bindGroupLayoutCompute, computeShader, constants) {
    return this.device.createComputePipeline({
      layout: this.device.createPipelineLayout({
        bindGroupLayouts: [bindGroupLayoutCompute],
      }),
      compute: {
        module: computeShader,
        constants,
      },
    });
  }

  computePass(computePipeline, bindGroups, workgroup) {
    const passEncoderCompute = this.commandEncoder.beginComputePass();
    passEncoderCompute.setPipeline(computePipeline);
    bindGroups.forEach((group, i) => {
      passEncoderCompute.setBindGroup(i, group);
    });
    passEncoderCompute.dispatchWorkgroups(
      workgroup.x,
      workgroup.y,
      workgroup.z
    );
    passEncoderCompute.end();
  }

  renderPass({renderPipeline, pass}, fn) {
    const passEncoderRender = this.commandEncoder.beginRenderPass(
      pass ?? this.basicRenderPass()
    );
    passEncoderRender.setPipeline(renderPipeline);
    fn(passEncoderRender);
    passEncoderRender.end();
  }

  resetCommandEncoder() {
    this.commandEncoder = this.device.createCommandEncoder();
  }

  basicQueueSubmit() {
    this.device.queue.submit([this.commandEncoder.finish()]);
  }

  shader(code) {
    return this.device.createShaderModule({code});
  }

  renderPipeline(bindGroupLayouts, vertexState, fragmentState, topology = 'triangle-strip') {
    return this.device.createRenderPipeline({
      layout: this.device.createPipelineLayout({
        bindGroupLayouts,
      }),
      primitive: {
        topology,
      },
      vertex: vertexState,
      fragment: fragmentState,
    });
  }

  createVertexBuffers(props) {
    let shaderLocation = 0;
    const layouts = [];

    for (let i = 0; i < props.length; i++) {
      const {stepMode, attributes} = props[i];
      layouts.push(
        this.createVertexBuffer(stepMode, attributes, shaderLocation)
      );
      shaderLocation += attributes.length;
    }

    return layouts;
  }

  createVertexBuffer(stepMode, attributes, shaderLocationStart = 0) {
    const finalAttrs = [];
    let offset = 0;
    attributes.forEach(({type, stride}, i) => {
      finalAttrs.push({
        shaderLocation: shaderLocationStart + i,
        offset,
        format: this.convertToGPUVertexFormat(type, stride),
      });
      offset += stride * type.BYTES_PER_ELEMENT;
    });

    return {
      arrayStride: offset,
      stepMode,
      attributes: finalAttrs
    };
  }

  convertToGPUVertexFormat(type, multiplier) {
    const formats = {
      1: {Uint8Array: 'uint8x2', Uint16Array: 'uint16x2', Uint32Array: 'uint32', Float32Array: 'float32'},
      2: {Uint8Array: 'uint8x2', Uint16Array: 'uint16x2', Uint32Array: 'uint32x2', Float32Array: 'float32x2'},
      3: {Uint8Array: 'uint8x4', Uint16Array: 'uint16x4', Uint32Array: 'uint32x3', Float32Array: 'float32x3'},
      4: {Uint8Array: 'uint8x4', Uint16Array: 'uint16x4', Uint32Array: 'uint32x4', Float32Array: 'float32x4'}
    };
    return formats[multiplier][type.name];
  }
}

const webGPU = await LazyWebGPU.init('canvas', true);

const groupLayoutCompute = webGPU.bindGroupLayoutOfVisibility(
  GPUShaderStage.COMPUTE,
  ['read-only-storage', 'read-only-storage', 'storage']
);

const squareVerticesItems = [0, 0, 0, 1, 1, 0, 1, 1];
const squareBuffer = webGPU.createBuffer(Uint32Array, squareVerticesItems, GPUBufferUsage.VERTEX);

const vertexState = {
  module: webGPU.shader(vertexShaderCode),
  buffers: webGPU.createVertexBuffers([
    {
      stepMode: "instance",
      attributes: [{type: Uint32Array, stride: 1}]
    },
    {
      stepMode: "vertex",
      attributes: [{type: Uint32Array, stride: 2}]
    }
  ])
};

const fragmentState = {
  module: webGPU.shader(fragmentShaderCode),
  targets: [
    {
      format: webGPU.presentationFormat,
    },
  ],
};

const computeShader = webGPU.shader(computeShaderCode);

const state = {
  loopTimes: 0, render: () => {
  }
};

function resetGameData() {
  const computePipeline = webGPU.computePipeline(groupLayoutCompute, computeShader, {
    blockSize: GameOptions.workgroupSize,
  });

  const sizeBuffer = webGPU.createBuffer(Uint32Array, [
    GameOptions.width,
    GameOptions.height,
  ], GPUBufferUsage.STORAGE | GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST | GPUBufferUsage.VERTEX);

  const length = GameOptions.width * GameOptions.height;

  const cells = new Array(length).fill(0).map(() => Math.random() < 0.25 ? 1 : 0);
  const buff0 = webGPU.createBuffer(Uint32Array, cells, GPUBufferUsage.STORAGE | GPUBufferUsage.VERTEX, true);
  const buff1 = webGPU.createBuffer(Uint32Array, cells, GPUBufferUsage.STORAGE | GPUBufferUsage.VERTEX);

  const bindGroup0 = webGPU.createGroup(groupLayoutCompute, [sizeBuffer, buff0, buff1]);
  const bindGroup1 = webGPU.createGroup(groupLayoutCompute, [sizeBuffer, buff1, buff0]);

  const {
    group: uniformBindGroup,
    layout: vertexRenderLayout
  } = webGPU.createRenderGroupOfVisibility(GPUShaderStage.VERTEX, [
    {bufferType: "uniform", resource: {buffer: sizeBuffer}},
  ]);

  const renderPipeline = webGPU.renderPipeline(
    [vertexRenderLayout], vertexState, fragmentState
  );

  const workgroup = {
    x: GameOptions.width / GameOptions.workgroupSize,
    y: GameOptions.height / GameOptions.workgroupSize
  };

  state.loopTimes = 0;
  state.render = () => {
    webGPU.resetCommandEncoder();

    webGPU.computePass(
      computePipeline,
      [state.loopTimes ? bindGroup1 : bindGroup0],
      workgroup
    );
    webGPU.renderPass({renderPipeline}, (encoder) => {
      encoder.setVertexBuffer(0, state.loopTimes ? buff1 : buff0);
      encoder.setVertexBuffer(1, squareBuffer);
      encoder.setBindGroup(0, uniformBindGroup);
      encoder.draw(4, length);
    });

    webGPU.basicQueueSubmit();
  };
}

function wrappedLoop(state, options, onLoop) {
  let wholeTime = 0;
  let animationFrameId = null;
  let isTabActive = true;
  let timestep = 6;

  function loop() {
    if (!isTabActive && !options.background) {
      if (animationFrameId) {
        cancelAnimationFrame(animationFrameId);
        animationFrameId = null;
      }
      return;
    }

    if (timestep) {
      if (timestep != options.timestep) {
        timestep = options.timestep;
      }

      wholeTime++;
      if (wholeTime >= timestep) {
        state.render();
        wholeTime -= timestep;
        if (onLoop) {
          onLoop();
        }
      }
    }
    animationFrameId = requestAnimationFrame(loop);
  }

  function activateTab() {
    isTabActive = true;
    timestep = options.timestep;
  }

  // Start the game loop
  animationFrameId = requestAnimationFrame(loop);

  // Event listeners for tab visibility
  document.addEventListener('visibilitychange', () => {
    isTabActive = !document.hidden;
    if (isTabActive && !animationFrameId) {
      activateTab();
      animationFrameId = requestAnimationFrame(loop);
    }
  });

  window.addEventListener('blur', () => {
    isTabActive = false;
  });

  window.addEventListener('focus', () => {
    activateTab();
    if (!animationFrameId) {
      animationFrameId = requestAnimationFrame(loop);
    }
  });
}

function main() {
  resetGameData();
  wrappedLoop(state, GameOptions, () => {
    state.loopTimes = 1 - state.loopTimes;
  });
}

main();