Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dramatically speed up sampling compilation time #4573

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import modules
import sow_lib
import transformer as transformer_lib
from flax.nnx import graph
from flax.nnx import statelib
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -128,17 +130,28 @@ def __init__(
vocab: vocabulary of the given model.
cache_size: size of the cache for the transformer.
"""
self.transformer = transformer
self.vocab = vocab
self.cache_size = cache_size
graphdef, state = nnx.split(transformer)
self._transformer_graphdef: graph.NodeDef = graphdef
self._transformer_state: statelib.State = state
# we separate out state and graph def so that the state can be passed as an
# argument to _sample_fn, resulting in it not being treated as a static
# arg. This greatly reduces the size of the HLO and reduces compile time
self._compiled_sample_fn = jax.jit(self._sample_fn)

@property
def transformer(self) -> transformer_lib.Transformer:
return nnx.merge(self._transformer_graphdef, self._transformer_state)

@property
def dtype(self) -> jnp.dtype:
params_state = nnx.state(self.transformer, nnx.Param)
return jax.tree_util.tree_leaves(nnx.to_flat_state(params_state))[0].dtype

def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
def _sample_step(
self, params: statelib.State, sampler_state: _SamplingState
) -> _SamplingState:
"""Performs a single sampling step."""
batch_size = sampler_state.token_buffer.shape[0]
decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32)
Expand All @@ -152,7 +165,8 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
)
last_token = last_token.reshape((batch_size, 1))

logits, cache = self.transformer(
transformer = nnx.merge(self._transformer_graphdef, params)
logits, cache = transformer(
last_token,
step_positions,
sampler_state.cache,
Expand Down Expand Up @@ -287,12 +301,13 @@ def mask_tokens_after_eos_ids(self, token_buffer):

def _sample_fn(
self,
params: statelib.State,
initial_sampling_state: _SamplingState,
) -> _SamplingState:
"""Internal sampling function (to be jitted)."""

def sample_with_params(sampler_state: _SamplingState):
return self._sample_step(sampler_state)
return self._sample_step(params, sampler_state)

def cond_fn(sampler_state: _SamplingState):
return (
Expand Down Expand Up @@ -346,7 +361,9 @@ def __call__(
forbidden_token_ids=forbidden_token_ids,
)

sampling_state = self._compiled_sample_fn(initial_sampling_state)
sampling_state = self._compiled_sample_fn(
self._transformer_state, initial_sampling_state
)

masked_token_buffer = self.mask_tokens_after_eos_ids(
sampling_state.token_buffer
Expand Down
Loading