-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[V1] Logits processors extensibility #19912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @afeldman-nm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the V1 logits processing capabilities by introducing a new extensible framework. This refactors existing logits-related sampling parameters into a modular system of LogitsProcessor
classes, managed by a LogitsProcessorManager
. The changes streamline batch state updates, improve code organization, and lay the groundwork for future custom sampling logic, while also exposing new extensibility points in the API.
Highlights
- Logits Processor Extensibility: I've introduced a new abstract
LogitsProcessor
class and aLogitsProcessorManager
to enable custom logits processing logic. This allows for a more modular and extensible way to apply various sampling constraints. - Refactored Logits Handling: Existing functionalities such as
min_p
,logit_bias
, andmin_tokens
are now implemented as dedicatedLogitsProcessor
subclasses. This centralizes their logic and state management, moving them out of core components likeSamplingMetadata
andSampler
. - Improved Batch State Management: A new
BatchUpdate
mechanism has been introduced to efficiently track and apply changes (additions, removals, and moves) to the batch state. This ensures that individual logits processors can accurately update their internal states in response to dynamic batch changes. - API Extensibility: I've added
vllm_xargs
to the OpenAI API protocol for bothChatCompletionRequest
andCompletionRequest
. This provides a generic way to pass additional, custom parameters, paving the way for future extensions and plugins. - Comprehensive Testing: A new, dedicated test suite (
test_logits_processors.py
) has been added to thoroughly validate the new logits processor framework and its interactions with various batch update scenarios, ensuring robustness and correctness.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-designed refactoring to make logits processors in V1 extensible. The core changes involve:
- Introducing
LogitsProcessor
ABC and concrete implementations (MinPLogitsProcessor
,LogitBiasLogitsProcessor
,MinTokensLogitsProcessor
) invllm/v1/sample/logits_processor.py
. - Adding
LogitsProcessorManager
invllm/v1/worker/utils.py
to manage collections of these processors, distinguishing between those compatible with greedy sampling and those that are not. - Implementing
BatchUpdate
to track changes (adds, removes, moves) to the request batch within a single step. This state is then used to update eachLogitsProcessor
. - Refactoring
InputBatch
to delegate the management and application ofmin_p
,logit_bias
, andmin_tokens
to the newLogitsProcessorManager
. Therefresh()
method now ensures logits processor states are updated via_commit_logit_procs_state_changes
. - Updating
Sampler
to use thegreedy_list
andnongreedy_list
fromSamplingMetadata.logitsprocs
. - Adding comprehensive tests in
tests/v1/sample/test_logits_processors.py
that simulate dynamic batching and validate the behavior of each logits processor under various add/remove/move scenarios. - Adding
vllm_xargs
to OpenAI protocol requests for better extensibility of custom parameters.
The changes are well-structured, and the new testing framework for logits processors is robust. The state management through BatchUpdate
and its propagation to individual processors appears sound. This PR significantly improves the modularity and extensibility of logits processing in vLLM.
def _commit_logit_procs_state_changes(self) -> None: | ||
"""Apply batch add/remove/permute to logits procs' states""" | ||
self.batch_update.batch_size = self.num_reqs | ||
for logit_proc in self.logitsprocs.all_list: | ||
logit_proc.update_state(self.batch_update) | ||
# Clear state change representation to prepare for next step | ||
self.batch_update.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _commit_logit_procs_state_changes
method is a crucial addition. It ensures that all registered logits processors are updated with the accumulated batch changes (batch_update
) before the SamplingMetadata
is created for the next step. Resetting batch_update
afterwards prepares it for the next iteration. This is a clean way to manage state propagation.
first_empty_index = batch_update.peek_removed_if_can() | ||
assert first_empty_index is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert first_empty_index is not None
is good for type safety. However, given the logic flow, batch_update.removed
being non-empty (checked by the while
loop) should imply that peek_removed_if_can()
returns an integer, not None
. If it could return None
here, it might indicate a potential issue in BatchUpdate.peek_removed_if_can()
or the loop condition. This is more of a thought for robustness; the current logic might be perfectly fine if peek_removed_if_can
guarantees a non-None value when has_removed()
is true.
if self.min_p > 0.0 and self.temperature == 0.0: | ||
raise ValueError("min_p > 0.0 requires random sampling " | ||
"but temperature == 0.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _sort_removed(self) -> None: | ||
"""Sort removed request indices in | ||
descending order. | ||
|
||
Idempotent after first call, until | ||
reset. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _sort_removed
method sorts in descending order. This means self._removed[-1]
(used in peek_removed_if_can
) will be the smallest index, and self._removed.pop()
(used in pop_removed_if_can
) will also remove and return the smallest index. This behavior is crucial for correctly filling gaps when adding new requests. This seems correct and well-thought-out.
# Internal representation of per-step batch state changes. | ||
# Should reset each step. | ||
self.batch_update = BatchUpdate() | ||
|
||
# Define logits processors. Note that Min-P logitsproc is returned | ||
# both on its own as min_p_logitsproc (to support spec decoding | ||
# compatibility check) and also as part of logits_procs | ||
# TODO(andy): logits processor list should be extensible via engine | ||
# constructor argument; for now the list is fixed. | ||
self.logitsprocs = init_hard_coded_logitsprocs( | ||
pin_memory_available=pin_memory, | ||
max_num_reqs=max_num_reqs + 1, | ||
device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _get_next_add_index(self) -> int: | ||
if (req_index := self.batch_update.pop_removed_if_can()) is not None: | ||
# Fill the empty index. | ||
return req_index | ||
# Append to end | ||
return self.num_reqs | ||
|
||
def _register_add_request(self, request: "CachedRequestState") -> int: | ||
"""Track add-request operations""" | ||
req_index = self._get_next_add_index() | ||
assert req_index < self.max_num_reqs | ||
self.batch_update.added.append( | ||
(req_index, request.sampling_params, request.output_token_ids)) | ||
return req_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Find the smallest empty index. | ||
empty_index = empty_req_indices.pop() | ||
empty_index = self.batch_update.peek_removed_if_can() | ||
assert empty_index is not None | ||
if empty_index >= last_req_index: | ||
break | ||
|
||
# Swap the states. | ||
# Move active request down into empty request | ||
# index. | ||
self.batch_update.pop_removed_if_can() | ||
self.batch_update.moved.append((last_req_index, empty_index, | ||
MoveDirectionality.UNIDIRECTIONAL)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has_removed_requests = self.input_batch.has_step_removed_requests() | ||
batch_changed = has_removed_requests or len(req_ids_to_add) > 0 | ||
|
||
# Add the new or resumed requests to the persistent batch. | ||
# The smaller empty indices are filled first. | ||
removed_req_indices.sort(reverse=True) | ||
for req_id in req_ids_to_add: | ||
req_state = self.requests[req_id] | ||
if removed_req_indices: | ||
# Fill the empty index. | ||
req_index = removed_req_indices.pop() | ||
else: | ||
# Append to the end. | ||
req_index = None | ||
self.input_batch.add_request(req_state, req_index) | ||
self.input_batch.add_request(req_state) | ||
|
||
# Condense the batched states if there are empty indices. | ||
if removed_req_indices: | ||
self.input_batch.condense(removed_req_indices) | ||
if self.input_batch.has_step_removed_requests(): | ||
self.input_batch.condense() | ||
|
||
batch_reordered = self._may_reorder_batch(scheduler_output) | ||
|
||
if batch_changed or batch_reordered: | ||
self.input_batch.refresh_sampling_metadata() | ||
self.input_batch.refresh() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for updating InputBatch
has been simplified due to InputBatch
now managing its internal state changes (like removed requests) via self.batch_update
.
has_step_removed_requests()
is used to check if condensation is needed.add_request
no longer needsreq_index
.condense()
is called without arguments.refresh()
is called to update sampling metadata and logits processor states.
These changes align well with the refactoring inInputBatch
.
@dataclass | ||
class LogitsProcessorManager: | ||
"""Encapsulates initialized logitsproc objects. | ||
|
||
Each logits processor has a unique id. | ||
""" | ||
nongreedy: dict[str, LogitsProcessor] = field( | ||
default_factory=dict) # id -> nongreedy-sampling-only logitsproc | ||
greedy: dict[str, LogitsProcessor] = field( | ||
default_factory=dict) # id -> greedy-sampling compatible logitsproc | ||
|
||
def __post_init__(self): | ||
"""Guarantee unique ids""" | ||
if (self.nongreedy.keys() & self.greedy.keys()): | ||
raise ValueError("Greedy and non-greedy logits " | ||
"processors must not share ids") | ||
|
||
def get_logitsproc_by_id(self, id: str) -> Optional[LogitsProcessor]: | ||
"""Find logits processor by id, if it exists""" | ||
return self.all.get(id, None) | ||
|
||
@property | ||
def all(self) -> dict[str, LogitsProcessor]: | ||
"""All logits processors""" | ||
return self.greedy | self.nongreedy | ||
|
||
@property | ||
def nongreedy_list(self) -> list[LogitsProcessor]: | ||
return list(self.nongreedy.values()) | ||
|
||
@property | ||
def greedy_list(self) -> list[LogitsProcessor]: | ||
return list(self.greedy.values()) | ||
|
||
@property | ||
def all_list(self) -> list[LogitsProcessor]: | ||
"""List of all logits processors""" | ||
return self.nongreedy_list + self.greedy_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LogitsProcessorManager
class is a good abstraction for managing different types of logits processors (greedy vs. non-greedy). The properties all
, nongreedy_list
, greedy_list
, and all_list
provide convenient ways to access the processors. The __post_init__
check for unique IDs between greedy and non-greedy processors is a good safeguard.
def init_hard_coded_logitsprocs( | ||
pin_memory_available: bool, max_num_reqs: int, | ||
device: torch.device) -> LogitsProcessorManager: | ||
min_tokens_logitproc = MinTokensLogitsProcessor( | ||
pin_memory=pin_memory_available, device=device) | ||
logit_bias_logitproc = LogitBiasLogitsProcessor( | ||
pin_memory=pin_memory_available, device=device) | ||
min_p_logitproc = MinPLogitsProcessor( | ||
pin_memory=pin_memory_available, | ||
device=device, | ||
# +1 for temporary swap space | ||
max_num_reqs=max_num_reqs + 1) | ||
return LogitsProcessorManager( | ||
greedy={ | ||
STR_MIN_TOKENS_LOGITPROC_ID: min_tokens_logitproc, | ||
STR_LOGITS_BIAS_LOGITPROC_ID: logit_bias_logitproc | ||
}, | ||
nongreedy={STR_MIN_P_LOGITPROC_ID: min_p_logitproc}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_hard_coded_logitsprocs
correctly initializes the built-in logits processors and categorizes them into greedy and non-greedy sets within the LogitsProcessorManager
. The max_num_reqs + 1
for MinPLogitsProcessor
to accommodate a temporary swap space is a thoughtful detail for efficient state management during batch updates.
Purpose
Enable V1 logits processors support to be extended with custom logits processors.
Test Plan
(WIP)
Test Result
WIP
(Optional) Documentation Update
WIP
Fixes #17799
Fixes #12678