Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now using strict shape/dtype promotion rules. #349

Merged
merged 1 commit into from
Jan 8, 2024

Conversation

patrick-kidger
Copy link
Owner

This means that:

  1. Tests now pass using JAX_NUMPY_DTYPE_PROMOTION=strict and JAX_NUMPY_RANK_PROMOTION=raise, and these are enabled in tests by default.
  2. The values passed to diffeqsolve now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to):
    a. The dtype of timelike values is the jnp.result_type of t0, t1, dt0, and SaveAt(ts=...). If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype.
    b. The jnp.result_type of the time dtype, and each leaf of y0, is the dtype of that leaf.
  3. Of course, diffeqsolve accepts user-specified functions (e.g. the vector field of an ODETerm), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enable JAX_NUMPY_DTYPE_PROMOTION=strict and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself!

This means that:

1. Tests now pass using `JAX_NUMPY_DTYPE_PROMOTION=strict` and `JAX_NUMPY_RANK_PROMOTION=raise`, and these are enabled in tests by default.
2. The values passed to `diffeqsolve` now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to):
    a. The dtype of timelike values is the `jnp.result_type` of `t0`, `t1`, `dt0`, and `SaveAt(ts=...)`. If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype.
    b. The `jnp.result_type` of the time dtype, and each leaf of `y0`, is the dtype of that leaf.
3. Of course, `diffeqsolve` accepts user-specified functions (e.g. the vector field of an `ODETerm`), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enable `JAX_NUMPY_DTYPE_PROMOTION=strict` and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself!
@patrick-kidger patrick-kidger merged commit a408b4e into dev Jan 8, 2024
2 checks passed
@patrick-kidger patrick-kidger deleted the strict-promotion branch January 8, 2024 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant