-
Notifications
You must be signed in to change notification settings - Fork 166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Is vLLMRollout.generate_sequences the right place to implement tool calling? #176
Comments
Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready. |
Hi @accupham , thanks for your questions!
Yes, adding dynamic function calling by hooking to the generate method is a good way. The |
I agree. Using a custom LogitsProcessor can help detect the function calls. This can already by implemented in the current vLLM by assigning your custom LogitsProcessor functions to the SamplingParams. I believe you can implement this by modifying the code of vllm_rollout.py . But it may be better if we can make the customized func to be passed through config file so that users won't need to modify the vllm_rollout.py file. Are you interested in contributing to this feature? Moreover, I'm not sure using the customized LogitsProcessor is general enough to cover "all" function calling scenarios? |
Yes I would be interested in contributing. Traditional function calling is usually done with the vLLM LLM.chat() calling semantics. But we could leave this up to the user by letting them implement a pluggable function, which produces the final output tensors to pass onto the rest of the pipeline. So we could take this from the vllm_rollout.py: with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False) And instead have this from typing import Protocol
from vllm.outputs import RequestOutput
class RolloutSampler(Protocol):
def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
...
# impl default RolloutSampler
class OneShotRolloutSampler:
def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
return llm.generate(
prompts=prompts, # pass in prompts instead of token_ids to make it user-friendly
sampling_params=sampling_params,
use_tqdm=False)
# impl RolloutSampler
class FnCallRolloutSampler:
@property
def tools() -> list[dict]:
...
def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
r1 = llm.chat(
messages= [{"role": "user", "content": "blah blah blah please do tool calling" }]
sampling_params=sampling_params,
use_tqdm=False,
tools: self.tools,
)
...
# execute tool calls from r1 and return results
r2 = llm.chat(...)
# get final RequestOutput response from vllm
r3 = llm.chat(...)
# (etc etc)
return r3 # list[RequestOutput] from vllm So we modify the init function signature of vLLMRollout as follows: class vLLMRollout(BaseRollout):
def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, rollout_sampler: RolloutSampler=OneShotRolloutSampler, **kwargs):
...
self.rollout_sampler = rollout_sampler
... Now users can pass in their own sampler implementation according to their needs by writing familar VLLM API code, or use default one-shot sampler like the old way. I think for function calling we will always need to detokenize token_ids first. Pass in prompts, and not token_ids into VLLM. It's not user friendly to deal with tokens because parsing function calls is string oriented and we can't expect there to be a dedicated token for function calling. Why don't we just let vLLM handle the tokenization instead? I thought about it some more and maybe a custom LogitsProcessor may bottleneck the entire batch if there is excess latency for a single function call. Better just to do function calling the traditional way instead of reinventing the wheel. What do you think of this proposal? |
@accupham The API design is really nice from my perspective. However, it seems that it relies on vLLM 0.7.0 for the chat API. We're working on integrating it in: #116 . Will let you know once merged! As the key challenge for generation in RL training task is throughput not latency, I have a few questions/concerns about this proposal:
|
After doing a bit of digging, perhaps this API design enabling multi-turn tool calling interaction is not feasible from a performance perspective. Here's why:
Alternative Proposal
Pros
Cons
PoC Tool calling Logits Processorfrom typing import Dict, List, Callable
class FunctionProcessor:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
function_map: Dict[str, Callable],
start_tag: str = " <function>",
end_tag: str = " </function>",
result_start: str = "<results>",
result_end: str = "</results>"
):
self.tokenizer = tokenizer
self.function_map = function_map
self.buffer = []
self.in_function = False
self.current_function = []
# Pre-tokenize markers
self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)
self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)
self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
self.max_marker_len = max(
len(self.start_marker),
len(self.end_marker),
len(self.result_start),
len(self.result_end)
)
self.result_tokens = []
def evaluate_expression(self, expr: str) -> str:
# Strip the function markers
expr = expr.replace("<function>", "").replace("</function>", "").strip()
# Parse function call
func_name = expr.split("(")[0]
args_str = expr.split("(")[1].rstrip(")")
# Get the function from our map
if func_name not in self.function_map:
return f"Error: Unknown function {func_name}"
func = self.function_map[func_name]
try:
# Parse args - this could be made more sophisticated
args = [float(arg.strip()) for arg in args_str.split(",")]
result = func(*args)
return str(result)
except Exception as e:
return f"Error: {str(e)}"
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
try:
self.buffer.extend(input_ids[-1:])
if self.result_tokens:
#scores.fill_(-float('inf'))
scores[self.result_tokens.pop()] = 100
return scores
self.buffer = self.buffer[-self.max_marker_len*2:]
#print(self.tokenizer.decode(self.buffer))
if not self.in_function and self.check_marker(self.start_marker):
self.in_function = True
self.current_function = []
return scores
if self.in_function:
self.current_function.extend(input_ids[-1:])
if self.check_marker(self.end_marker):
self.in_function = False
func_text = self.tokenizer.decode(self.current_function)
result = self.evaluate_expression(func_text)
self.result_tokens = list(reversed(
self.result_start +
self.tokenizer.encode(result) +
self.result_end
))
scores[self.result_tokens.pop()] = 100
return scores
except Exception as e:
print(f"Error in processor: {e}")
return scores
def check_marker(self, marker: List[int]) -> bool:
#print(marker, self.buffer)
marker_len = len(marker)
buffer_len = len(self.buffer)
if buffer_len < marker_len:
return False
# Only need to check the last possible positions where marker could fit
start_pos = max(0, buffer_len - marker_len * 2)
for i in range(start_pos, buffer_len - marker_len + 1):
if self.buffer[i:i + marker_len] == marker:
return True
return False PoC Usagedef add(x, y):
return x + y
def multiply(x, y):
return x * y
# Create function map
function_map = {
"add": add,
"multiply": multiply
}
my_tool_processor = FunctionProcessor(tokenizer, function_map)
prompts = [
"Hello world. please say <function> multiply(3, 302) </function> \nthis is a test",
"Hello world. please say <function> add(3, 302) </function> ",
]
r = llm.generate( # llm is a vllm LLM instance
prompts,
SamplingParams(
logits_processors=[my_tool_processor],
max_tokens=200,
))
#print(r)
for rr in r:
print(rr.outputs[0].text)
print("----")
|
Hi, I am trying to understand the code. I would like to try RL training on tool calling in an interactive environment.
As I understand it, the reward is calculated by some custom reward function for a particular dataset. In other words, the flow of data during PPO is like this:
But the inference step rollout here is a one-shot input/output function. If online tool calling was desired, we'd have to hook the llm.generate function here, right?
https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout.py#L181
Then we could inject in function calling. But i'm confused because the inference engine is not an ordinary VLLM LLM class, but a subclass which monkey patches the output to return tensors instead of the normal VLLM output format.
So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call
LLM._post_process_output
to convert token_id and logprobs from VLLM into torch tensors at the very end?Or is there an more obvious place to add in this feature?
The text was updated successfully, but these errors were encountered: