Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NormalSimulator broadcasted output shapes #295

Merged
merged 6 commits into from
Feb 7, 2025
Merged

Conversation

han-ol
Copy link
Contributor

@han-ol han-ol commented Jan 27, 2025

The NormalSimulator used for testing seems to have unintended output shapes.

The mean and std had repeated values, corresponding to the number of observations.

The PR:

  1. undoes the repetition by slicing
  2. flattens the observations

The latter could be avoided if a summary network or adapter transform is chosen.

han-ol and others added 6 commits January 27, 2025 17:03
…low-org#291)

* Splines draft

* update keras requirement

* small improvements to error messages

* add rq spline function

* add spline transform

* update searchsorted utils for jax
also add padd util

* update tests

* add assert_allclose util for improved messages

* parametrize transform for flow tests

* update jacobian, jacobian trace, vjp, jvp, and corresponding usages and tests

* fix imports, remove old jacobian and jvp, fix application in free form flow

* improve logdet computation in free form flows

* Fix comparison for symbolic tensors under tf

* Add splines to twomoons notebook

* improve pad utility

* fix missing left edge in spline

* fix inside mask edge case

* explicitly set bias initializer

* add better expand utility

* small clean up, renaming

* fix indexing, fix inside check

* dump

* fix sign of log jacobian for inverse pass in rq spline

* fix parameter splitting for spline transform

* improve readability

* fix scale and shift trailing dimension

* fix inverse pass return value

* correctly choose bins once for each dimension, even for multi-dimensional inputs

* run formatter

* reduce searchsorted log spam

* log backend used at setup

* remove maximum message cache size

* Improve warning message for jax searchsorted

* Fix spline parameter binning for compiled contexts

* update inverse transform same as forward

* Update TwoMoons notebook with splines WIP [skip ci]

* fix spline inverse call for out of bounds values

* Add working splines

---------

Co-authored-by: stefanradev93 <[email protected]>
@vpratz
Copy link
Collaborator

vpratz commented Feb 7, 2025

Good spot, thanks for the fix! I refactored it a bit to use automatic broadcasting, please take a look and let me know if you are happy with those changes.

@han-ol
Copy link
Contributor Author

han-ol commented Feb 7, 2025

Thanks, the refactor looks good. Happy!

@vpratz vpratz merged commit accf8b4 into bayesflow-org:dev Feb 7, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants