-
Notifications
You must be signed in to change notification settings - Fork 107
feat: Ensemble async callback execution (rework) #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions | ||
|
@@ -45,17 +45,45 @@ class EnsembleContext; | |
|
||
using IterationCount = size_t; | ||
|
||
// Check if the model is configured to preserve the order of responses. | ||
// This is critical for async execution of ResponseComplete callbacks. | ||
inline bool | ||
preserve_responses_order(const inference::ModelConfig& config) | ||
{ | ||
uint64_t total_instance_groups = 0; | ||
for (const auto& group : config.instance_group()) { | ||
total_instance_groups += group.count(); | ||
} | ||
|
||
// Case 1: Sequence batching is enabled | ||
// Case 2: Dynamic batching is disabled and there is only one instance group | ||
// Case 3: Dynamic batching is enabled and preserve_ordering is true | ||
// Case 4: Model transaction policy is decoupled (breaks RequestTracker | ||
// lifecycle) | ||
// Note: Although decoupled models do not preserve the order of | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. decoupled models "should preserve" the order of response There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I found this from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is referring to responses among different requests within the same batch, the responses of the same request is still preserved. i.e. batch can contain req1, req2, and respond in req2res1, req1res1, req1res2, req2res2 |
||
// responses, if the final response callback is not executed in the last step, | ||
// the RequestTracker object will be freed prematurely and led to segmentation | ||
// fault. | ||
return config.has_sequence_batching() || | ||
(!config.has_dynamic_batching() && total_instance_groups <= 1) || | ||
(config.has_dynamic_batching() && | ||
config.dynamic_batching().preserve_ordering()) || | ||
config.model_transaction_policy().decoupled(); | ||
} | ||
|
||
// Request tracker is passed as 'userp' in RequestRelease function and used | ||
// to manage the lifecycle of the ensemble request | ||
class RequestTracker { | ||
public: | ||
explicit RequestTracker( | ||
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns, | ||
MetricModelReporter* metric_reporter, | ||
InferenceStatsAggregator* stats_aggregator) | ||
InferenceStatsAggregator* stats_aggregator, | ||
triton::common::ThreadPool* callback_pool) | ||
: inflight_request_counter_(1), request_(std::move(request)), | ||
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter), | ||
stats_aggregator_(stats_aggregator), status_(Status::Success) | ||
stats_aggregator_(stats_aggregator), status_(Status::Success), | ||
callback_pool_(callback_pool) | ||
{ | ||
} | ||
|
||
|
@@ -70,6 +98,8 @@ class RequestTracker { | |
return context_stats_aggregator_; | ||
} | ||
|
||
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; } | ||
|
||
void IncrementCounter() | ||
{ | ||
std::lock_guard<std::mutex> lk(mtx_); | ||
|
@@ -120,6 +150,7 @@ class RequestTracker { | |
InferenceStatsAggregator* stats_aggregator_; | ||
InferenceStatsAggregator context_stats_aggregator_; | ||
Status status_; | ||
triton::common::ThreadPool* const callback_pool_; | ||
}; | ||
|
||
// Step is used as 'userp' and keeps ensemble context alive | ||
|
@@ -129,9 +160,9 @@ class RequestTracker { | |
struct Step { | ||
Step( | ||
size_t step_idx, const InferenceRequest::SequenceId& correlation_id, | ||
uint32_t flags) | ||
uint32_t flags, bool preserve_responses_order) | ||
: correlation_id_(correlation_id), flags_(flags), response_flags_(0), | ||
step_idx_(step_idx) | ||
preserve_responses_order_(preserve_responses_order), step_idx_(step_idx) | ||
{ | ||
} | ||
|
||
|
@@ -154,7 +185,7 @@ struct Step { | |
// returning from the callback. | ||
uint32_t response_flags_; | ||
TRITONSERVER_InferenceResponse* response_; | ||
|
||
const bool preserve_responses_order_; | ||
|
||
size_t step_idx_; | ||
}; | ||
|
@@ -237,7 +268,7 @@ class EnsembleContext { | |
MetricModelReporter* metric_reporter, | ||
InferenceStatsAggregator* stats_aggregator, InferenceServer* is, | ||
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request, | ||
cudaStream_t stream); | ||
cudaStream_t stream, triton::common::ThreadPool* callback_pool); | ||
|
||
// Perform transition on 'context' state given the information of | ||
// 'completed_step' | ||
|
@@ -326,6 +357,8 @@ class EnsembleContext { | |
void CacheEnsembleTopLevelRequest( | ||
std::unique_ptr<InferenceResponse>& response); | ||
|
||
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; } | ||
|
||
InferenceServer* is_; | ||
|
||
EnsembleInfo* info_; | ||
|
@@ -375,20 +408,26 @@ class EnsembleContext { | |
TRITONSERVER_ResponseAllocator, | ||
decltype(&TRITONSERVER_ResponseAllocatorDelete)> | ||
allocator_; | ||
|
||
// The thread pool used to execute ensemble callbacks and reduce e2e latency. | ||
// The thread pool is managed by InferenceServer. | ||
triton::common::ThreadPool* const callback_pool_; | ||
}; | ||
|
||
EnsembleContext::EnsembleContext( | ||
MetricModelReporter* metric_reporter, | ||
InferenceStatsAggregator* stats_aggregator, InferenceServer* is, | ||
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request, | ||
cudaStream_t stream) | ||
cudaStream_t stream, triton::common::ThreadPool* callback_pool) | ||
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0), | ||
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete) | ||
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete), | ||
callback_pool_(callback_pool) | ||
{ | ||
uint64_t compute_start_ns = 0; | ||
INFER_STATS_SET_TIMESTAMP(compute_start_ns); | ||
request_tracker_ = new RequestTracker( | ||
std::move(request), compute_start_ns, metric_reporter, stats_aggregator); | ||
std::move(request), compute_start_ns, metric_reporter, stats_aggregator, | ||
callback_pool); | ||
|
||
auto& lrequest = request_tracker_->Request(); | ||
|
||
|
@@ -603,29 +642,57 @@ void | |
EnsembleContext::RequestComplete( | ||
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) | ||
{ | ||
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { | ||
LOG_TRITONSERVER_ERROR( | ||
TRITONSERVER_InferenceRequestDelete(request), | ||
"deleting ensemble inference request"); | ||
auto request_tracker = reinterpret_cast<RequestTracker*>(userp); | ||
if (request_tracker->DecrementCounter()) { | ||
delete request_tracker; | ||
auto request_tracker = reinterpret_cast<RequestTracker*>(userp); | ||
auto pool = request_tracker->CallbackPool(); | ||
auto fn = [request, flags, request_tracker]() { | ||
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { | ||
LOG_TRITONSERVER_ERROR( | ||
TRITONSERVER_InferenceRequestDelete(request), | ||
"deleting ensemble inference request"); | ||
if (request_tracker->DecrementCounter()) { | ||
delete request_tracker; | ||
} | ||
} | ||
}; | ||
|
||
// Attempt to enqueue the callback. If all workers are busy and queue is at | ||
// capacity, execute the callback immediately in current thread. | ||
if (pool->TaskQueueSize() < pool->Size()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. But consider a case where N busy workers are almost finishing. Then as long as TaskQueueSize <= N, the pending tasks can execute almost immediately. The maximum of N is 8. In fact, I did compare |
||
pool->Enqueue(fn); | ||
} else { | ||
fn(); | ||
} | ||
} | ||
|
||
void | ||
EnsembleContext::ResponseComplete( | ||
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) | ||
{ | ||
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp)); | ||
step_ptr->response_flags_ = flags; | ||
step_ptr->response_ = response; | ||
|
||
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr); | ||
// Expecting more responses | ||
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) { | ||
step_ptr.release(); | ||
auto step_raw_ptr = reinterpret_cast<Step*>(userp); | ||
auto pool = step_raw_ptr->ctx_->CallbackPool(); | ||
auto fn = [response, flags, step_raw_ptr]() { | ||
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr); | ||
step_ptr->response_flags_ = flags; | ||
step_ptr->response_ = response; | ||
|
||
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr); | ||
// Expecting more responses | ||
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) { | ||
step_ptr.release(); | ||
} | ||
}; | ||
|
||
// Attempt to enqueue the callback. If all workers are busy and queue is at | ||
// capacity, execute the callback immediately in current thread. | ||
// Note: The async callback optimization does not guarantee the order of | ||
// responses and expolit cases where responses can be out-of-order. For models | ||
// required to preserve the order of responses, the response callbacks must be | ||
// executed in the same thread synchronously. | ||
if (!step_raw_ptr->preserve_responses_order_ && | ||
pool->TaskQueueSize() < pool->Size()) { | ||
pool->Enqueue(fn); | ||
} else { | ||
fn(); | ||
} | ||
} | ||
|
||
|
@@ -971,8 +1038,8 @@ EnsembleContext::InitStep( | |
for (const auto& pair : istep.output_to_tensor_) { | ||
irequest->AddOriginalRequestedOutput(pair.first); | ||
} | ||
|
||
step->reset(new Step(step_idx, correlation_id, flags)); | ||
const bool preserve_order = preserve_responses_order(model->Config()); | ||
step->reset(new Step(step_idx, correlation_id, flags, preserve_order)); | ||
|
||
irequest->SetId(request_id_); | ||
irequest->SetCorrelationId(correlation_id); | ||
|
@@ -1448,7 +1515,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request) | |
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::EXECUTING)); | ||
std::shared_ptr<EnsembleContext> context(new EnsembleContext( | ||
metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request, | ||
stream_)); | ||
stream_, callback_pool_)); | ||
EnsembleContext::Proceed(context); | ||
return Status::Success; | ||
} | ||
|
@@ -1537,6 +1604,7 @@ EnsembleScheduler::EnsembleScheduler( | |
info_->tensor_to_prev_step_.emplace(pair.second, step_idx); | ||
} | ||
} | ||
callback_pool_ = is_->EnsembleCallbackPool(); | ||
} | ||
|
||
EnsembleScheduler::~EnsembleScheduler() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why the order needs to be preserved in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In gRPC streaming case, the client would expect the responses order match the requests order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think the scheduler needs to care whether the request is received via gRPC streaming or not. That is outer layer requirement. Scheduler only care whether itself needs to preserve ordering at model instance levels (i.e. whether preserve ordering is set / sequence batching is used)