From fd5361a61c5ec28d0bf8ce38df6ab09f41894651 Mon Sep 17 00:00:00 2001 From: Manoj Date: Fri, 28 Aug 2020 04:30:47 -0400 Subject: [PATCH] Adding scater --- lib/backends/cpu/op-resolve-rules.ts | 2 ++ lib/backends/cpu/ops/scatter.ts | 29 +++++++++++++++++++++ lib/ops/scatter.ts | 39 ++++++++++++++++++++++++++++ test/test-suite-whitelist.jsonc | 1 + 4 files changed, 71 insertions(+) create mode 100644 lib/backends/cpu/ops/scatter.ts create mode 100644 lib/ops/scatter.ts diff --git a/lib/backends/cpu/op-resolve-rules.ts b/lib/backends/cpu/op-resolve-rules.ts index 1db7c98c..0e8adb87 100644 --- a/lib/backends/cpu/op-resolve-rules.ts +++ b/lib/backends/cpu/op-resolve-rules.ts @@ -21,6 +21,7 @@ import {CpuPad} from './ops/pad'; import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool'; import * as cpuReduce from './ops/reduce'; import {CpuReshape} from './ops/reshape'; +import {CpuScatter} from './ops/scatter'; import {CpuSlice, CpuSliceV10} from './ops/slice'; import {CpuSoftmax} from './ops/softmax'; import {CpuSqueeze} from './ops/squeeze'; @@ -104,4 +105,5 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray = [ ['Unsqueeze', '', '1+', () => new CpuUnsqueeze()], ['Upsample', '', '7-8', () => new CpuUpsample()], ['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))], + ['Scatter', '', '7+', () => new CpuScatter()], ]; diff --git a/lib/backends/cpu/ops/scatter.ts b/lib/backends/cpu/ops/scatter.ts new file mode 100644 index 00000000..26b29099 --- /dev/null +++ b/lib/backends/cpu/ops/scatter.ts @@ -0,0 +1,29 @@ +import {Scatter} from '../../../ops/scatter'; +import {Tensor} from '../../../tensor'; +import {CpuInferenceHandler} from '../inference-handler'; + +export class CpuScatter extends Scatter { + run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[]|Promise { + const output = scatter(inputs[0], inputs[1], inputs[2]); + return [output]; + } +} + +export function scatter(data: Tensor, indices: Tensor, updates: Tensor): Tensor { + const datadims = data.dims; + const indicedims = indices.dims; + const updatedims = updates.dims; + const datanew = new Tensor(datadims, data.type); + const Y = datanew.data; + const X = updates.data; + let flatIndex = 0; + let updateFlatIndex = 0; + for (let i = 0; i < datadims[0]; ++i) { + for (let j = 0; j < indicedims[1]; ++j) { + flatIndex = i * datadims[1] + (indices.data[j] as number); + updateFlatIndex = i * updatedims[1] + j; + Y[flatIndex] = X[updateFlatIndex]; + } + } + return new Tensor(datadims, data.type, undefined, undefined, Y); +} diff --git a/lib/ops/scatter.ts b/lib/ops/scatter.ts new file mode 100644 index 00000000..76e86d27 --- /dev/null +++ b/lib/ops/scatter.ts @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Attribute} from '../attribute'; +import {InferenceHandler} from '../backend'; +import {Operator} from '../operators'; +import {Tensor} from '../tensor'; + +export abstract class Scatter implements Operator { + // Inputs are {data_tensor->float32/float64, indices_tensor-> int32, update_data_tensor->float32/float64} + abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + + initialize(attributes: Attribute): void {} + + checkInputs(inputs: Tensor[]): boolean { + if (!inputs || inputs.length !== 3) { + return false; + } + const tensorRank = inputs[0].dims.length; + if (tensorRank < 1) { + return false; + } + + return this.checkInputTypes(inputs); + } + + protected checkInputTypes(inputs: Tensor[]): boolean { + if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') { + return false; + } + if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') { + return false; + } + if (inputs[2].type !== 'float32' && inputs[2].type !== 'float64') { + return false; + } + return true; + } +} diff --git a/test/test-suite-whitelist.jsonc b/test/test-suite-whitelist.jsonc index 4a1e86be..3a7f8b65 100644 --- a/test/test-suite-whitelist.jsonc +++ b/test/test-suite-whitelist.jsonc @@ -82,6 +82,7 @@ "test_flatten_default_axis", "test_gather_0", "test_gather_1", + "test_scatter_0", "test_gemm_broadcast", "test_gemm_nobroadcast", "test_globalaveragepool_precomputed",