Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Nov 3, 2023
1 parent 7dd928a commit 3778326
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/examples/operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
"\n",
"- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n",
"- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n",
"- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\t\\imes m}$.\n",
"- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n",
"- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n",
"- etc!\n",
"\n",
Expand Down
15 changes: 7 additions & 8 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,16 +1937,13 @@ def _(operator):

@conj.register(JacobianLinearOperator)
def _(operator):
fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args))
jvpfn = lambda vec: jax.jvp(fn, (operator.x,), (vec,))[1]
jvpfnc = lambda x: jvpfn(x.conj()).conj()
return FunctionLinearOperator(jvpfnc, operator.in_structure(), operator.tags)
return conj(linearise(operator))


@conj.register(FunctionLinearOperator)
def _(operator):
return FunctionLinearOperator(
lambda vec: operator.mv(vec.conj()).conj(),
lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))),
operator.in_structure(),
operator.tags,
)
Expand Down Expand Up @@ -1978,9 +1975,11 @@ def _(operator):

@conj.register(TangentLinearOperator)
def _(operator):
c = lambda operator: conj(operator)
primal_out, tangent_out = eqx.filter_jvp(c, (operator.primal,), (operator.tangent,))
return TangentLinearOperator(primal_out, tangent_out)
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)


@conj.register(AddLinearOperator)
Expand Down

0 comments on commit 3778326

Please sign in to comment.