Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Is vLLMRollout.generate_sequences the right place to implement tool calling? #176

Open
accupham opened this issue Jan 31, 2025 · 6 comments
Labels
enhancement New feature or request question Further information is requested vllm related

Comments

@accupham
Copy link

Hi, I am trying to understand the code. I would like to try RL training on tool calling in an interactive environment.

As I understand it, the reward is calculated by some custom reward function for a particular dataset. In other words, the flow of data during PPO is like this:

graph TD
   DatasetExample --> InferenceRollout --> RewardFunction --> UpdateGradients
Loading

But the inference step rollout here is a one-shot input/output function. If online tool calling was desired, we'd have to hook the llm.generate function here, right?

https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout.py#L181

Then we could inject in function calling. But i'm confused because the inference engine is not an ordinary VLLM LLM class, but a subclass which monkey patches the output to return tensors instead of the normal VLLM output format.

So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call LLM._post_process_output to convert token_id and logprobs from VLLM into torch tensors at the very end?

Or is there an more obvious place to add in this feature?

@accupham
Copy link
Author

Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready.

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Jan 31, 2025

Hi @accupham , thanks for your questions!

So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call LLM._post_process_output to convert token_id and logprobs from VLLM into torch tensors at the very end?

Yes, adding dynamic function calling by hooking to the generate method is a good way. The _post_process_output is not necessarily a class function of LLM but can be moved to the vllm_rollout after all the results are ready and then converted them into tensors.

@PeterSH6
Copy link
Collaborator

Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready.

I agree. Using a custom LogitsProcessor can help detect the function calls. This can already by implemented in the current vLLM by assigning your custom LogitsProcessor functions to the SamplingParams. I believe you can implement this by modifying the code of vllm_rollout.py . But it may be better if we can make the customized func to be passed through config file so that users won't need to modify the vllm_rollout.py file. Are you interested in contributing to this feature?

Moreover, I'm not sure using the customized LogitsProcessor is general enough to cover "all" function calling scenarios?
Some function calls may need to detokenize the token_ids first and using customized LogitsProcessor may not achieve the optimal throughput.

@accupham
Copy link
Author

accupham commented Jan 31, 2025

Yes I would be interested in contributing. Traditional function calling is usually done with the vLLM LLM.chat() calling semantics.

But we could leave this up to the user by letting them implement a pluggable function, which produces the final output tensors to pass onto the rest of the pipeline.

So we could take this from the vllm_rollout.py:

        with self.update_sampling_params(**kwargs):
            output = self.inference_engine.generate(
                prompts=None,  # because we have already convert it to prompt token id
                sampling_params=self.sampling_params,
                prompt_token_ids=idx_list,
                use_tqdm=False)

And instead have this RolloutSampler as the pluggable module, which can be passed in when initializing vLLMRollout class:

from typing import Protocol
from vllm.outputs import RequestOutput

class RolloutSampler(Protocol):
    def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
       ...

# impl default RolloutSampler
class OneShotRolloutSampler:
    def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
        return llm.generate(
                prompts=prompts,  # pass in prompts instead of token_ids to make it user-friendly
                sampling_params=sampling_params,
                use_tqdm=False)

# impl RolloutSampler
class FnCallRolloutSampler:
    @property
    def tools() -> list[dict]:
         ...

    def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
        r1 = llm.chat(
            messages= [{"role": "user", "content": "blah blah blah please do tool calling" }]
            sampling_params=sampling_params,
            use_tqdm=False,
            tools: self.tools,
        )
       
       ...
       # execute tool calls from r1 and return results 
       r2 = llm.chat(...)
       # get final RequestOutput response from vllm
       r3 = llm.chat(...)
       # (etc etc)
       return r3 # list[RequestOutput] from vllm

So we modify the init function signature of vLLMRollout as follows:

class vLLMRollout(BaseRollout):
    def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, rollout_sampler:  RolloutSampler=OneShotRolloutSampler, **kwargs):
    ...
    self.rollout_sampler = rollout_sampler
    ...

Now users can pass in their own sampler implementation according to their needs by writing familar VLLM API code, or use default one-shot sampler like the old way.


I think for function calling we will always need to detokenize token_ids first. Pass in prompts, and not token_ids into VLLM. It's not user friendly to deal with tokens because parsing function calls is string oriented and we can't expect there to be a dedicated token for function calling. Why don't we just let vLLM handle the tokenization instead?

I thought about it some more and maybe a custom LogitsProcessor may bottleneck the entire batch if there is excess latency for a single function call. Better just to do function calling the traditional way instead of reinventing the wheel.

What do you think of this proposal?

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Feb 1, 2025

@accupham The API design is really nice from my perspective. However, it seems that it relies on vLLM 0.7.0 for the chat API. We're working on integrating it in: #116 . Will let you know once merged!

As the key challenge for generation in RL training task is throughput not latency, I have a few questions/concerns about this proposal:

  1. Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.

  2. With such a design, how to implement the overlapping of prefill/decode computation with function calling as the function calling may be a remote function and could be quite time-consuming.

I think for function calling we will always need to detokenize token_ids first. Why don't we just let vLLM handle the tokenization instead?

  1. From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?

  2. I would like to raise another point that this proposal doesn't mention and would also related to point 2. It is not quite relevant to the above API designs in your proposal but is very important in veRL.
    Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.

@PeterSH6 PeterSH6 added enhancement New feature or request question Further information is requested vllm related labels Feb 1, 2025
@accupham
Copy link
Author

accupham commented Feb 1, 2025

After doing a bit of digging, perhaps this API design enabling multi-turn tool calling interaction is not feasible from a performance perspective. Here's why:

Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.

  • The vLLM chat API does support batch processing. All this method does is simply apply a chat template, then call the generate API internally, which as expected, will do prefill and other vLLM optimimization if enabled.
  • However, the function call part has issues. The tool call execution instructions will arrive in batches, but we would have to execute those serially, wait for all calls to resolve, then pass it back in as part of that batch group. Even then, we may have multiple back-and-forth calls for some items to complete, but not others, leading to the batch being partially empty toward the end.
    • In other words, we have head-of-line congestion issues with traditional multi-turn tool calling, causing throughput issues.

From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?

  • I've never benchmarked tokenization so I can't comment on that, but i'll have to take your word on it. If it really is time consuming, then using the chat API will be horrible for performance. We go from not having to tokenize/detokenize in the original implementation, to having to do tokenize/detokenize 4 * batch_size * n_turns more times. Chat API needs to tokenize many times to apply chat template, then decode/encode to parse function call. Not optimal.

Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.

  • That sounds too complicated to implement in a user-friendly way-- multi-turn tool calling might be too hard to do if throughput is desired.

Alternative Proposal

  • Allow the user to pass in a custom LogitProcessor into vLLM.

Pros

  • This maintains all batching optimizations afforded by all subcomponents of the system. Throughput is maintained.
  • Each LogitProcessor is cloned on vLLM's side, so if function calling is slow, it only affects that generation instance, and not others, making scheduling easy.
  • LogitProcessor can be easily be used in production by passing it into vLLM
  • The style of inline function calling is very fluid and lends to the reasoning style of R1-like models.

Cons

  • Very untraditional approach (I've never seen this be done before)
  • Token boundry issues may occur unless you train with special tokens
  • Implementation is not as user-friendly because it deals with raw tokens and modifying logits manually.
    • See proof of concept... It's hacky AF.

PoC Tool calling Logits Processor

from typing import Dict, List, Callable

class FunctionProcessor:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        function_map: Dict[str, Callable],
        start_tag: str = " <function>",
        end_tag: str = " </function>",
        result_start: str = "<results>",
        result_end: str = "</results>"
    ):
        self.tokenizer = tokenizer
        self.function_map = function_map
        self.buffer = []
        self.in_function = False
        self.current_function = []
        
        # Pre-tokenize markers 
        self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)
        self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)
        self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
        self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
        self.max_marker_len = max(
                len(self.start_marker), 
                len(self.end_marker),
                len(self.result_start),
                len(self.result_end)
            )
        
        self.result_tokens = []
        
    def evaluate_expression(self, expr: str) -> str:
        # Strip the function markers
        expr = expr.replace("<function>", "").replace("</function>", "").strip()
        
        # Parse function call
        func_name = expr.split("(")[0]
        args_str = expr.split("(")[1].rstrip(")")
        
        # Get the function from our map
        if func_name not in self.function_map:
            return f"Error: Unknown function {func_name}"
            
        func = self.function_map[func_name]
        
        try:
            # Parse args - this could be made more sophisticated
            args = [float(arg.strip()) for arg in args_str.split(",")]
            result = func(*args)
            return str(result)
        except Exception as e:
            return f"Error: {str(e)}"
        
    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        try:
            self.buffer.extend(input_ids[-1:])

            if self.result_tokens:
                #scores.fill_(-float('inf'))
                scores[self.result_tokens.pop()] = 100
                return scores
            
            self.buffer = self.buffer[-self.max_marker_len*2:]
            #print(self.tokenizer.decode(self.buffer))
            if not self.in_function and self.check_marker(self.start_marker):
                self.in_function = True
                self.current_function = []
                return scores
                
            if self.in_function:
                self.current_function.extend(input_ids[-1:])
                
                if self.check_marker(self.end_marker):
                    self.in_function = False
                    func_text = self.tokenizer.decode(self.current_function)
                    result = self.evaluate_expression(func_text)
                    
                    self.result_tokens = list(reversed(
                        self.result_start +
                        self.tokenizer.encode(result) +
                        self.result_end
                    ))
                    scores[self.result_tokens.pop()] = 100
                    

                        
            return scores
            
        except Exception as e:
            print(f"Error in processor: {e}")
            return scores
            
    def check_marker(self, marker: List[int]) -> bool:
        #print(marker, self.buffer)
        marker_len = len(marker)
        buffer_len = len(self.buffer)
        
        if buffer_len < marker_len:
            return False
            
        # Only need to check the last possible positions where marker could fit
        start_pos = max(0, buffer_len - marker_len * 2)
        
        for i in range(start_pos, buffer_len - marker_len + 1):
            if self.buffer[i:i + marker_len] == marker:
                return True
                
        return False

PoC Usage

def add(x, y):
    return x + y

def multiply(x, y):
    return x * y

# Create function map
function_map = {
    "add": add,
    "multiply": multiply
}

my_tool_processor = FunctionProcessor(tokenizer, function_map)

prompts = [
    "Hello world. please say <function> multiply(3, 302) </function> \nthis is a test",
    "Hello world. please say <function> add(3, 302) </function> ",
]
r = llm.generate( # llm is a vllm LLM instance
    prompts,
    SamplingParams(
        logits_processors=[my_tool_processor],
        max_tokens=200,
    ))

#print(r)
for rr in r:
    print(rr.outputs[0].text)
    print("----")
<function> multiply(3, 302) </function><result>906.0</result>. Can I assist with anything else?
---
<function> add(3, 302) </function><result>305.0</result>. Is there anything else I can help with?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested vllm related
Projects
None yet
Development

No branches or pull requests

2 participants