Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose callback during sampling #166

Open
wd60622 opened this issue Jan 18, 2025 · 4 comments
Open

Expose callback during sampling #166

wd60622 opened this issue Jan 18, 2025 · 4 comments

Comments

@wd60622
Copy link

wd60622 commented Jan 18, 2025

I want to be able to run a python callback on each sample. The signature would be different than this:

https://github.com/wd60622/nutpie/blob/5b881638d658fb8236d4b627209d7f034ee02050/python/nutpie/sample.py#L361-L363

@aseyboldt
Copy link
Member

There isn't currently any mechanism to call a function for each draw. The progress callback is called once every 50ms or so.
I think that might be better for the mlflow usecase as well? That way we don't slow things down a lot if the sampler is fast?

Either option wouldn't be too difficult to add though.
What info would you want to have for each update call?

@wd60622
Copy link
Author

wd60622 commented Jan 20, 2025

That would do the trick! Yeah, no need to slow it down.

So I'm hearing that the progressbar html is updated every 50ms or so. Is there much overhead in that?

What info would you want to have for each update call?

Can you point me to what would be possible to get? Are there some objects already that could be used?

@aseyboldt
Copy link
Member

If we want to push progress info to mlflow, I think once every second or so would probably be enough? We don't want to spam the poor thing.
I can easily expose anything here for each chain https://github.com/pymc-devs/nuts-rs/blob/main/src/sampler.rs#L386

@wd60622
Copy link
Author

wd60622 commented Jan 20, 2025

Thanks for sharing that link.

Totally. Would that be up to the user though? Something that could be defined in the callback:

def create_callback(update_frequency): 
    # Psuedo code
    step = None
    def callback(state):
        if step is None: 
            step = time.time()

        if time.time() - step > update_frequency:
            step = time.time()

            # do something with state

    return callback

@wd60622 wd60622 changed the title Expose callback for each sample Expose callback during sampling Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants