-
-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix jax backend sampling with variable names that are not valid identifiers #135
Conversation
@@ -127,18 +127,34 @@ def test_det(backend, gradient_backend): | |||
assert trace.posterior.b.shape[-1] == 2 | |||
|
|||
|
|||
@parameterize_backends | |||
def test_non_identifier_names(backend, gradient_backend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are non identifier names? What was failing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we have a model like
with pm.Model() as model:
a = pm.Data("a/b", shape=2)
or a nested model, then the variables in the generated logp function have different names (they can't be called a/b
after all...).
But I was using kwargs with the real variable names as keys to pass the values to the logp function.
So I just switched to positional args, since we know the order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe call the test invalid_kwarg_names?
How do we do it in PyMC when we have the point func with names like this? Seems like it should also fail because we unpack a dict, or that somehow works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared variables are typically hidden in the pytensor function, and that uses positional args.
And the strange pymc point functions also assume things are in the right order (which has bitten me more than once...)
I called it non-identifier because this happens more or less if name.isidentifier()
is false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the point func just called the underlying function with **state. It's true it doesn't have to interact with shareds, but the name thing should also be an issue there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we convert it to positional arguments? Inside PyTensor? Because here it's just unpacking the dict: https://github.com/pymc-devs/pymc/blob/main/pymc%2Fpytensorf.py#L617-L624
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway I thought the problem was you cant unpack non identifiers in Python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, no limitation anywhere :-)
It is perfectly legal in python (and I think also in jax) to have a function
def foo(**kwargs):
pass
And then pass it something like {"a/b": 1}
. As long as foo takes the kwargs as **kwargs
that's allowed.
The problem is that the jax function we generate in dispatch generates something like
def jax_dispatch(a_b):
pass
and we can't call that function with foo(**{"a/b": 1})
, because the name of the variable and the key in the dictionary don't match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm right I didn't expect those internal variables to ever be called with kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or for the names to even respect those from the original graph
No description provided.