Skip to content

Commit

Permalink
make dot tjp drop references as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Aug 27, 2017
1 parent 2096c2c commit 23b09bd
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions autograd/numpy/numpy_tjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,25 @@
# ----- Trickier grads -----

def tjp_dot_arg0(ans, vs, out_vs, A, B):
if anp.ndim(B) == 0 or anp.ndim(B) == 1 or anp.ndim(A) == 0:
contract_dims = max(0, anp.ndim(B) - (anp.ndim(A) != 0))
return lambda G: anp.tensordot(G, B, contract_dims)
A_ndim, B_ndim = vs.ndim, anp.ndim(B)
if B_ndim == 0 or B_ndim == 1 or A_ndim == 0:
contract_num = max(0, B_ndim - (A_ndim != 0))
return lambda G: anp.tensordot(G, B, contract_num)
else:
return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), anp.ndim(B) - 1)
return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), B_ndim - 1)
deftjp(anp.dot, tjp_dot_arg0)

def tjp_dot_arg1(ans, vs, out_vs, A, B):
needs_transpose = anp.ndim(B) > 1 and anp.ndim(A) != 0
A_ndim, B_ndim = anp.ndim(A), vs.ndim
needs_transpose = B_ndim > 1 and A_ndim != 0
swap = (lambda x: anp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x)
if anp.ndim(A) == 0 or anp.ndim(A) == 1 or anp.ndim(B) == 0:
contract_dims = max(0, anp.ndim(A) - (anp.ndim(B) != 0))
if A_ndim == 0 or A_ndim == 1 or B_ndim == 0:
contract_dims = max(0, A_ndim - (B_ndim != 0))
return lambda G: swap(anp.tensordot(G, A, contract_dims))
else:
return lambda G: swap(anp.tensordot(
G, A, [range(-anp.ndim(A) - anp.ndim(B) + 2, -anp.ndim(B) + 1),
range(anp.ndim(A) - 1)]))
G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1),
range(A_ndim - 1)]))
deftjp(anp.dot, tjp_dot_arg1, argnum=1)

def tjp_transpose(ans, in_vs, out_vs, x, axes=None):
Expand Down

0 comments on commit 23b09bd

Please sign in to comment.