Now using strict shape/dtype promotion rules. #349
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This means that:
JAX_NUMPY_DTYPE_PROMOTION=strict
andJAX_NUMPY_RANK_PROMOTION=raise
, and these are enabled in tests by default.diffeqsolve
now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to):a. The dtype of timelike values is the
jnp.result_type
oft0
,t1
,dt0
, andSaveAt(ts=...)
. If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype.b. The
jnp.result_type
of the time dtype, and each leaf ofy0
, is the dtype of that leaf.diffeqsolve
accepts user-specified functions (e.g. the vector field of anODETerm
), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enableJAX_NUMPY_DTYPE_PROMOTION=strict
and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself!