Skip to content

Commit

Permalink
Merge pull request #637 from joddiy/dev
Browse files Browse the repository at this point in the history
Merge master into dev
  • Loading branch information
nudles authored Mar 25, 2020
2 parents 00b4743 + a8a4483 commit 3e0c036
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 15 deletions.
21 changes: 21 additions & 0 deletions .asf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

github:
description: a distributed deep learning platform
labels:
- deep-learning
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
specific language governing permissions and limitations
under the License.
-->

![Logo](doc/_static/singa.png)

# Apache SINGA

[![Build Status](https://travis-ci.org/apache/singa.png)](https://travis-ci.org/apache/singa)
Expand All @@ -41,6 +44,8 @@ Distributed deep learning system
![LGTM C++ Grade](https://img.shields.io/lgtm/grade/cpp/github/apache/incubator-singa)
![LGTM Python Grade](https://img.shields.io/lgtm/grade/python/github/apache/incubator-singa)

[![Stargazers over time](https://starchart.cc/apache/singa.svg)](https://starchart.cc/apache/singa)

## Mailing Lists

* [Development Mailing List](mailto:[email protected]) ([Archive](http://mail-archives.apache.org/mod_mbox/singa-dev/))
Expand Down
19 changes: 19 additions & 0 deletions SECURITY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with < this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->

# Security Policy

## Reporting a Vulnerability
Expand Down
8 changes: 5 additions & 3 deletions doc/en/community/team-list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ PMC
+--------------------+--------------------------------+-----------------------------------------------+
| Kian-Lee Tan | [email protected] | National University of Singapore |
+--------------------+--------------------------------+-----------------------------------------------+
| Meihui Zhang | [email protected] | Beijing Institute of Technology |
+--------------------+--------------------------------+-----------------------------------------------+
| Moaz Reyad | [email protected] | Université Grenoble Alpes |
+--------------------+--------------------------------+-----------------------------------------------+
| Sheng Wang | [email protected] | DAMO Academy, Alibaba Group |
+--------------------+--------------------------------+-----------------------------------------------+
| Ted Dunning | [email protected] | Apache Software Foundation |
+--------------------+--------------------------------+-----------------------------------------------+
| Thejas Nair | [email protected] | Apache Software Foundation |
Expand All @@ -64,7 +68,7 @@ Committers
+====================+================================+===============================================+
| Chonho Lee | [email protected] | Osaka University |
+--------------------+--------------------------------+-----------------------------------------------+
| Sheng Wang | wangsh@apache.org | DAMO Academy, Alibaba Group |
| Chris Yeung | chrishkchris@apache.org | National University of Singapore |
+--------------------+--------------------------------+-----------------------------------------------+
| Wanqi Xue | [email protected] | National University of Singapore |
+--------------------+--------------------------------+-----------------------------------------------+
Expand All @@ -89,8 +93,6 @@ Contributors
+--------------------+--------------------------------+-----------------------------------------------+
| Wenfeng Wu | [email protected] | Freelancer, China |
+--------------------+--------------------------------+-----------------------------------------------+
| Meihui Zhang | [email protected] | Beijing Institute of Technology |
+--------------------+--------------------------------+-----------------------------------------------+
| Chang Yao | [email protected] | Hangzhou MZH Technologies |
+--------------------+--------------------------------+-----------------------------------------------+

77 changes: 74 additions & 3 deletions python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def backward(self, dy):
return singa.__mul__(dy, self.mask)


def clip(x, min, max):
def clip(x, min=None, max=None):
return Clip(min, max)(x)[0]


Expand Down Expand Up @@ -3187,6 +3187,78 @@ def gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
return Gemm(alpha, beta, transA, transB)(A, B, C)[0]


class GlobalAveragePool(Operation):

def __init__(self, data_format='channels_first'):
"""
init a GlobalAveragePool operator
Args:data_format:
A string, we support two formats: channels_last and channels_first, default is channels_first.
channels_first means the format of input is (N x C x H x W)
channels_last means the format of input is (N x H x W x C)
"""
super(GlobalAveragePool, self).__init__()
self.data_format = data_format

def forward(self, x):
"""
forward propogation of GlobalAveragePool
Args:x:
the input tensor
Returns:
tensor, the output
"""
if training:
self.mask = singa.Tensor(x.shape(), x.device())

shape = list(x.shape())

# (N x C x H x W) for channels_first
if self.data_format == 'channels_first':
axes = tuple(i for i in range(2, len(shape)))
self.shape_divisor = 1 / np.prod(shape[2:])
else: # (N x H x W x C) for channels_last
axes = tuple(i for i in range(1, len(shape) - 1))
self.shape_divisor = 1 / np.prod(shape[1:-1])

# output shape
# (N x C x 1 x 1) for channels_first
# (N x 1 x 1 x C) for channels_last
for i in axes:
shape[i] = 1

x = tensor.from_raw_tensor(x)
x = tensor.sum(x, axis=axes)
x = tensor.reshape(x, shape)
return singa.MultFloat(x.data, self.shape_divisor)

def backward(self, dy):
"""
backward propogation of GlobalAveragePool
Args:dy:
the gradient tensor from upper operations
Returns:
tensor, the gradient over input
"""
self.mask.SetFloatValue(self.shape_divisor)
return singa.__mul__(self.mask, dy)


def globalaveragepool(x, data_format='channels_first'):
"""
GlobalAveragePool operator
Args:x
the input tensor
Args:data_format:
A string, we support two formats: channels_last and channels_first, default is channels_first.
channels_first means the format of input is (N x C x H x W)
channels_last means the format of input is (N x H x W x C)
Returns:
tensor, the output
"""
return GlobalAveragePool(data_format)(x)[0]


class ConstantOfShape(Operation):

def __init__(self, value=0):
Expand Down Expand Up @@ -3911,7 +3983,6 @@ def forward(self, x):
x.AsType(self.to)
return x


def backward(self, dy):
"""
backward of Cast
Expand All @@ -3934,4 +4005,4 @@ def cast(x, to):
Returns:
the output CTensor.
"""
return Cast(to)(x)[0]
return Cast(to)(x)[0]
10 changes: 7 additions & 3 deletions python/singa/sonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class SingaFrontend(object):
'Not': 'Not',
'Negative': 'Neg',
'Reciprocal': 'Reciprocal',
'GlobalAveragePool' : 'GlobalAveragePool'
}

# this dict indicates the operators that need extra handle
Expand Down Expand Up @@ -822,6 +823,7 @@ class SingaBackend(Backend):
'Not': '_not',
'Neg': 'negative',
'Reciprocal': 'reciprocal',
'GlobalAveragePool' : 'globalaveragepool'
}

# this dict indicates the operators that need extra handle
Expand Down Expand Up @@ -1001,10 +1003,12 @@ def _create_conv(cls, onnx_node, inputs, opset_version):
"""
kernel = tuple(onnx_node.attrs["kernel_shape"])
# todo: we only support the padding with tuple
padding = tuple(
onnx_node.attrs["pads"][0:2]) if "pads" in onnx_node.attrs else (0,
0)
stride = tuple(onnx_node.getattr('strides', (1, 1)))
padding = tuple(onnx_node.attrs["pads"][0:2]) if "pads" in onnx_node.attrs else (0, 0)
if "auto_pad" in onnx_node.attrs:
auto_pad = force_unicode(onnx_node.attrs['auto_pad'])
out_shape = get_output_shape(auto_pad, inputs[0].shape[2:], kernel, stride)
padding = get_pad_shape(auto_pad, inputs[0].shape[2:], kernel, stride, out_shape)
dilation = onnx_node.getattr('dilations', 1)
group = onnx_node.getattr('group', 1)

Expand Down
18 changes: 18 additions & 0 deletions test/python/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,24 @@ def test_transfer_learning(self):
sgd.update(p, gp)
sgd.step()

def test_globalaveragepool(self):
X = np.array([[[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]]]).astype(np.float32)

x = tensor.from_numpy(X)
x.to_device(gpu_dev)
y = autograd.globalaveragepool(x)

# frontend
model = sonnx.to_onnx([x], [y])
# backend
sg_ir = sonnx.prepare(model, device=gpu_dev)
y_t = sg_ir.run([x])

np.testing.assert_array_almost_equal(tensor.to_numpy(y), tensor.to_numpy(y_t[0]), decimal=5)

if __name__ == '__main__':
unittest.main()
28 changes: 22 additions & 6 deletions test/python/test_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,15 @@ def expect(node, inputs, outputs, name, opset_version=_default_opset_version):
input_labels = [x for x in onnx_node.inputs if x != ""]
# prepare input tensors
for key, val in zip(input_labels, inputs):
# very important! must be float
if not isinstance(val, np.ndarray) or len(val.shape) == 0:
val = np.array([val])
x = tensor.from_numpy(val.astype(np.float32))
x.to_device(gpu_dev)
input_tensors[key] = x
if node.op_type=="Clip" and key in ("min", "max"):
input_tensors[key] = val.item()
else:
# very important! must be float
if not isinstance(val, np.ndarray) or len(val.shape) == 0:
val = np.array([val])
x = tensor.from_numpy(val.astype(np.float32))
x.to_device(gpu_dev)
input_tensors[key] = x
outputs_dict = sonnx.run_node(onnx_node, input_tensors, opset_version)
for out1, out2 in zip(outputs, outputs_dict.values()):
np.testing.assert_array_almost_equal(out1,
Expand Down Expand Up @@ -2277,6 +2280,19 @@ def pool(
y[shape] = f(window_vals[np.where(~np.isnan(window_vals))])
return y.astype(np.float32)

def test_globalaveragepool(self):
node = onnx.helper.make_node(
'GlobalAveragePool',
inputs=['x'],
outputs=['y'],
)
x = np.array([[[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]]]).astype(np.float32)
y = np.array([[[[5]]]]).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name='test_globalaveragepool_precomputed')

if __name__ == '__main__':
unittest.main()
53 changes: 53 additions & 0 deletions test/python/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3592,6 +3592,59 @@ def test_gemm_cpu(self):
def test_gemm_gpu(self):
self.gemm_test(gpu_dev)

def globalaveragepool_channel_first(self, dev):
X = np.array([[[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]]]).astype(np.float32)
XT = np.array([[[[5]]]]).astype(np.float32)
DY = np.ones((1, 1, 1, 1), dtype=np.float32)

x = tensor.from_numpy(X)
x.to_device(dev)
dy = tensor.from_numpy(DY)
dy.to_device(dev)

result = autograd.globalaveragepool(x)
dx = result.creator.backward(dy.data)

DX = np.ones(X.shape, dtype=np.float32)
DX = np.multiply(DX, DY) / np.prod(X.shape[2:])

np.testing.assert_array_almost_equal(tensor.to_numpy(result), XT, decimal=5)
np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)), DX, decimal=5)

def globalaveragepool_channel_last(self, dev):
X = np.array([[
[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]],
]]).astype(np.float32)
XT = np.array([[[[5]]]]).astype(np.float32)
DY = np.ones((1, 1, 1, 1), dtype=np.float32)

x = tensor.from_numpy(X)
x.to_device(dev)
dy = tensor.from_numpy(DY)
dy.to_device(dev)

result = autograd.globalaveragepool(x, 'channel_last')
dx = result.creator.backward(dy.data)

DX = np.ones(X.shape, dtype=np.float32)
DX = np.multiply(DX, DY) / np.prod(X.shape[1:-1])

np.testing.assert_array_almost_equal(tensor.to_numpy(result), XT, decimal=5)
np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)), DX, decimal=5)

def test_globalaveragepool_cpu(self):
self.globalaveragepool_channel_first(cpu_dev)
self.globalaveragepool_channel_last(cpu_dev)

def test_globalaveragepool_gpu(self):
self.globalaveragepool_channel_first(gpu_dev)
self.globalaveragepool_channel_last(gpu_dev)
def constantOfShape_test(self, dev):
# float_ones
X = np.array([4, 3, 2]).astype(np.int64)
Expand Down

0 comments on commit 3e0c036

Please sign in to comment.