Conversation
sunildkumar
left a comment
There was a problem hiding this comment.
nice! left some questions.
trl/trainer/qwen_grpo_trainer.py
Outdated
| # Check if the stop string is in the completions | ||
| # We need to convert the tensor to a string. | ||
| if self.tool_defn.completion_has_tool_call(prompt_completion_str): | ||
| tool_response_str = self.tool_defn.call_tool(prompt_completion_str) |
There was a problem hiding this comment.
why doesn't the dataclass have call_tool?
| # We need to convert the tensor to a string. | ||
| if self.tool_defn.completion_has_tool_call(prompt_completion_str): | ||
| tool_response_str = self.tool_defn.call_tool(prompt_completion_str) | ||
| tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) |
There was a problem hiding this comment.
I'm assuming this doesn't add an extra BOS token?
| return inputs | ||
|
|
||
| def _generate_completion( | ||
| self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] |
There was a problem hiding this comment.
I think prompt_inputs might be a BatchFeature: https://huggingface.co/docs/transformers/en/main_classes/feature_extractor#transformers.BatchFeature
There was a problem hiding this comment.
Also, thanks for making this a function, clearly the right move.
| return prompt_completion_ids | ||
|
|
||
| def _generate_single_completion_with_tools( | ||
| self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], max_steps: int = 10 |
There was a problem hiding this comment.
same nit here - BatchFeature
| (Note that 46*44 is 2024). | ||
| """ | ||
| conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) | ||
| # Loop until tool isn't called, of we max out |
There was a problem hiding this comment.
| # Loop until tool isn't called, of we max out | |
| # Loop until tool isn't called, or we max out |
| - input_ids: [1, 710] ints. Some stuff at the beginning and the end, the middle full of 151655 | ||
| - attention_mask: [1, 710] ints. All 1 | ||
| - pixel_values: 2024x1176 floats. The image. | ||
| - image_grid_thw: a 1x3 tensor with values: [1, 46, 44]. |
There was a problem hiding this comment.
nit: maybe add a short comment about what max _steps is.
My understanding: The generation will stop once a tool is called, then this code processes the tool call. max_steps is the maximum number of tools we're willing to process for a single completion?
No description provided.