Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Implement Split operator #23198

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/split.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/tensor/split.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

namespace {

// Helper function to calculate the output index based on the input index and the sizes of the splits.
void CalculateOutputIndex(std::ostream& os, size_t output_count) {
os << "fn calculate_output_index(index: u32) -> u32 {\n"
<< " for (var i: u32 = 0u; i < " << output_count << "u; i += 1u ) {\n"
<< " if (index < " << GetElementAt("uniforms.sizes_in_split_axis", "i", output_count) << ") {\n"
<< " return i;\n"
<< " }\n"
<< " }\n"
<< " return " << output_count << "u;\n"
<< "}\n";
}

// Helper function to write the buffer data for each output.
void WriteBufferData(std::ostream& os, const ShaderVariableHelper& input,
gsl::span<const ShaderVariableHelper*> outputs) {
os << "fn write_buffer_data(output_number: u32, global_idx: u32, indices: output_0_indices_t) {\n";
for (size_t i = 0; i < outputs.size(); ++i) {
const auto buffer_write = outputs[i]->SetByIndices("indices", input.GetByOffset("global_idx"));
if (outputs.size() == 1) {
os << buffer_write;
} else if (i == 0) {
os << " if (output_number == 0u) {\n"
<< " " << buffer_write << "\n";
} else if (i == outputs.size() - 1) {
os << " } else {\n"
<< " " << buffer_write << "\n";
} else {
os << " } else if (output_number == " << i << "u) {\n"
<< " " << buffer_write << "\n";
}
}
os << " }\n"
<< "}\n";
}

} // namespace

Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

size_t output_count = Outputs().size();
std::vector<const ShaderVariableHelper*> outputs;
outputs.reserve(output_count);
for (size_t i = 0; i < output_count; ++i) {
outputs.push_back(
&shader.AddOutput("output_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias));
}

// Add implementation of fn calculate_output_index.
CalculateOutputIndex(shader.AdditionalImplementation(), output_count);
// Add implementation of fn write_buffer_data.
WriteBufferData(shader.AdditionalImplementation(), input, outputs);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
<< " var indices = " << input.OffsetToIndices("global_idx") << ";\n"
<< " var index = indices[" << axis_ << "];\n"
<< " let output_number = calculate_output_index(index);\n"
<< " if (output_number != 0u) {\n"
<< " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n"
<< " indices[" << axis_ << "] = index;\n"
<< " }\n"
<< " write_buffer_data(output_number, global_idx, indices);\n";

return Status::OK();
}

Status Split::ComputeInternal(ComputeContext& context) const {
const Tensor* input = context.Input<Tensor>(0);
auto& input_shape = input->Shape();
auto num_outputs = context.OutputCount();

int64_t axis = axis_;
std::vector<int64_t> split_sizes;

split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
// Compute split_sizes from the 'split' input tensor.
if (split_sizes_.size() == 0 && context.InputCount() > 1) {
const Tensor* split_tensor = context.Input<Tensor>(1);
// Check if split_tensor is valid.
if (split_tensor != nullptr) {
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
// Get split_sizes from the input tensor.
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
const auto* data = split_tensor->Data<int64_t>();
split_sizes.assign(data, data + nDims);
}
}

// The variables below are not actually used in the current implementation.
int before_dims = 0;
int after_dims_including_split_axis = 0;
int after_dims_excluding_split = 0;
// This handles the case where the axis is negative. It also splits outputs evenly according to num_ouputs if
// split_sizes is empty.
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis,
after_dims_excluding_split, split_sizes));

SplitProgram program{gsl::narrow_cast<uint32_t>(axis)};
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank});

auto output_dimensions = input_shape.AsShapeVector();
for (int i = 0; i < num_outputs; ++i) {
// Update the size of dimension for axis we're splitting on.
auto split_size = narrow<int>(split_sizes[i]);
output_dimensions[narrow<size_t>(axis)] = split_size;

Tensor* output = context.Output(i, TensorShape{output_dimensions});
program.AddOutput({output, ProgramTensorMetadataDependency::Rank});
}

uint32_t input_size = gsl::narrow<uint32_t>(input_shape.Size());
// Early return if the input tensor is empty.
if (input_size == 0) {
return Status::OK();
}

uint32_t previous_sum = 0;
std::vector<uint32_t> sizes_in_split_axis;

Check warning on line 130 in onnxruntime/core/providers/webgpu/tensor/split.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/split.cc:130: Add #include <vector> for vector<> [build/include_what_you_use] [4]
// sizes_in_split_axis are the cumulative sizes of the splits in the split axis.
for (auto split_size : split_sizes) {
previous_sum += gsl::narrow<uint32_t>(split_size);
sizes_in_split_axis.push_back(previous_sum);
}

program
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.CacheHint(std::to_string(axis))
.AddUniformVariables(
{input_size, gsl::span<const uint32_t>(sizes_in_split_axis.data(), sizes_in_split_axis.size())});
return context.RunProgram(program);
}

#define WEBGPU_SPLIT_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

#define WEBGPU_SPLIT_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 1, 1, Split_1, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 2, 10, Split_2_10, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 11, 12, Split_11_12, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 13, 17, Split_13_17, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_KERNEL(Split, 18, Split_18, WebGpuSupportedNumberTypes());

} // namespace webgpu
} // namespace onnxruntime
61 changes: 61 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/common.h"
#include "core/providers/cpu/tensor/split.h"

namespace onnxruntime {
namespace webgpu {

class SplitProgram final : public Program<SplitProgram> {
public:
SplitProgram(const uint32_t axis) : Program{"Split"}, axis_{axis} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
{"sizes_in_split_axis", ProgramUniformVariableDataType::Uint32});

private:
uint32_t axis_;
};

class Split : public WebGpuKernel, public SplitBase {
public:
Split(const OpKernelInfo& info, uint32_t opset) : WebGpuKernel(info), SplitBase(info, opset) {}

protected:
Status ComputeInternal(ComputeContext& context) const override;
};

class Split_1 final : public Split {
public:
Split_1(const OpKernelInfo& info) : Split(info, 1) {}
};

class Split_2_10 final : public Split {
public:
Split_2_10(const OpKernelInfo& info) : Split(info, 2) {}
};

class Split_11_12 final : public Split {
public:
Split_11_12(const OpKernelInfo& info) : Split(info, 11) {}
};

class Split_13_17 final : public Split {
public:
Split_13_17(const OpKernelInfo& info) : Split(info, 13) {}
};

class Split_18 final : public Split {
public:
Split_18(const OpKernelInfo& info) : Split(info, 18) {}
};

} // namespace webgpu
} // namespace onnxruntime
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand)>,

Expand Down
Loading