Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix optimize_for backend_opts to be empty dictionary instead of None (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
samskalicky authored Feb 3, 2021
1 parent f9d90c9 commit c723ae2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ def optimize_for(self, x, *args, backend=None, clear=False,
self._first_forward = True
# clear the backend
self._backend = None
self._backend_opts = None
self._backend_opts = {}

def _clear_cached_op(self):
self._cached_graph = ()
Expand Down
10 changes: 8 additions & 2 deletions tests/python/unittest/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,17 @@ def test_subgraph():
sym_filename, params_filename = sym_block3.export('optimized')
assert sym_filename == 'optimized-symbol.json'
assert params_filename is None

# Test with additional input to subgraph op
sym_block3.optimize_for(a_data, b_data, backend="addInputPass")
out5 = sym_block3(a_data, b_data)

# Reload exported block
sym_block4 = nn.SymbolBlock.imports(sym_filename, ['a','b'], params_filename)

out5 = sym_block4(a_data, b_data)
out6 = sym_block4(a_data, b_data)
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)
assert_almost_equal(out[0].asnumpy(), out6[0].asnumpy(), rtol=1e-3, atol=1e-3)

@pytest.mark.skipif(check_platform(['x86_64']), reason="not all machine types supported")
@pytest.mark.skipif(is_cd_run(), reason="continuous delivery run - ignoring test")
Expand Down

0 comments on commit c723ae2

Please sign in to comment.