Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overriding simulators in .sample() #312

Open
Kucharssim opened this issue Feb 13, 2025 · 8 comments
Open

Allow overriding simulators in .sample() #312

Kucharssim opened this issue Feb 13, 2025 · 8 comments
Assignees
Labels
feature New feature or request

Comments

@Kucharssim
Copy link
Collaborator

Kucharssim commented Feb 13, 2025

Say you have a setup like this:

def context(batch_size):
    n = np.random.randint(10,100)
    return dict(n=n)

def prior():
    mu = np.random.normal(0,1)
    return dict(mu=mu)

def likelihood(mu, n):
    y = np.random.normal(mu,1,n)
    return dict(y=y)

simulator = bf.make_simulator([prior, likelihood], meta_fn=context)

Sometimes (e.g., during validation), it would be handy to generate a batch of data with something not according to the simulator specification, e.g., generate data with fixed context:

simulator.sample(1000, fixed=dict(n=50))

Of course one could make a separate simulator to do this, but that is at times a bit cumbersome and leads to code duplication. Looking at the code it should be relatively straightforward to implement this, I am curious to hear if there are any downsides, @stefanradev93, @LarsKue, @paul-buerkner?

@Kucharssim Kucharssim added the feature New feature or request label Feb 13, 2025
@LarsKue
Copy link
Contributor

LarsKue commented Feb 13, 2025

Agreed that this would be a valuable feature. What do you think about the naming of the argument @paul-buerkner ? I think we could even allow passing it directly, e.g.

simulator.sample(1000, n=50)

On the other hand, your usage @Kucharssim could be dangerous if the value of n is used before it is returned:

def context():
    n = 10
    c = np.zeros(n)
    return dict(n=n, c=c)

Since we cannot overwrite n inside the function. It might also complicate batching.

@Kucharssim
Copy link
Collaborator Author

could be dangerous if the value of n is used before it is returned

yes that is a good point. Though I think we could just let the user be responsible for this (and be very clear that we can intercept the values only in between each simulator fn, but not inside)

@Kucharssim
Copy link
Collaborator Author

Well actually this feature is already available because kwargs are passed into the simulators during sampling. One just has to explicitly handle the case when the input is manually added

def context(batch_size, n=None):
    if n is None:
        n = np.random.randint(10,100)
    return dict(n=n)

def prior():
    mu = np.random.normal(0,1)
    return dict(mu=mu)

def likelihood(mu, n):
    y = np.random.normal(mu,1,n)
    return dict(y=y)

simulator = bf.make_simulator([prior, likelihood], meta_fn=context)

data=simulator.sample(100, n=31415)
print(data["n"]) # 31415
print(data["y"].shape) # (100, 31415)

Now I am not sure how that plays with batched input, but I think for the majority of use cases this is sufficient already.

@paul-buerkner
Copy link
Contributor

Cool! Which potential issue regarding batching do you see?

@LarsKue
Copy link
Contributor

LarsKue commented Feb 15, 2025

@paul-buerkner When users pass a value like this in the already-batched context of the sample method, they might expect their input to be batched similar to their function output. E.g., what should happen if the user passes mu=0? Currently, this would overwrite a batched value with an integer, which is incorrect and not expected from the user side.

We could broadcast any outside input against the simulator output, but this has to be clearly communicated somehow.

@paul-buerkner
Copy link
Contributor

I don't think we should auto-broadcast for now because then, the example from above with n being fixed wouldn't work. So I understand, right now, we need (a) that users handle the is None cases in the simulators and (b) that users handle proper broadcasting. That's fine I think.

To make this feature more convenient (no is None and autobroadcasting), we would need to build a variable graph such that the simulator would know which variable is outputed and inputed by which other simulators and how they were broadcasted or meta etc. This may be something for the future of simulators (tagging @daniel-habermann just FYI), but I think not something to tackle right now.

What do you think?

@LarsKue
Copy link
Contributor

LarsKue commented Feb 17, 2025

@paul-buerkner I think you misunderstand, there is no significant issue with broadcasting here, as we already know the target shape. Since n is simulator output, we can broadcast the outside input against it as soon as we see it in the output, I.e.

if "n" in output:
    output["n"] = broadcast(input["n"], output["n"].shape)

Of course, this pseudo-code only works for Tensor output, so we would skip this for other output types, but this highlights the idea.

@paul-buerkner
Copy link
Contributor

Okay, so you say that this would already work then with minimal changes? If so, perhaps @Kucharssim would like to work on that?

@Kucharssim Kucharssim self-assigned this Feb 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants