From 3f89a6e2bd28bf022af49f01c87c081624bac6ba Mon Sep 17 00:00:00 2001 From: Trevin Lee Date: Mon, 25 Aug 2025 08:37:25 +0200 Subject: [PATCH] SonicTriton: RetryActionDiffServer using TritonService::fallback; extend tests; add Catch2 unit test; update tritonTest_cfg.py for retry policy logging; ported to CMSSW_15_1_0_pre3 base --- HeterogeneousCore/SonicCore/BuildFile.xml | 1 + .../SonicCore/interface/RetryActionBase.h | 35 +++ .../SonicCore/interface/SonicClientBase.h | 14 +- .../SonicCore/plugins/BuildFile.xml | 6 + .../plugins/RetrySameServerAction.cc | 30 +++ .../SonicCore/src/RetryActionBase.cc | 15 ++ .../SonicCore/src/SonicClientBase.cc | 72 ++++-- .../SonicCore/test/BuildFile.xml | 4 +- .../SonicCore/test/DummyClient.h | 2 +- .../SonicCore/test/sonicTestAna_cfg.py | 1 + .../SonicCore/test/sonicTest_cfg.py | 57 ++++- HeterogeneousCore/SonicTriton/BuildFile.xml | 3 +- HeterogeneousCore/SonicTriton/README.md | 2 +- .../interface/RetryActionDiffServer.h | 32 +++ .../SonicTriton/interface/TritonClient.h | 209 +++++++++--------- .../SonicTriton/interface/TritonData.h | 6 +- .../SonicTriton/interface/TritonService.h | 3 +- .../SonicTriton/interface/triton_utils.h | 5 +- .../SonicTriton/src/RetryActionDiffServer.cc | 36 +++ .../SonicTriton/src/TritonClient.cc | 79 +++++-- .../SonicTriton/src/TritonData.cc | 2 +- .../SonicTriton/src/TritonService.cc | 10 +- .../SonicTriton/src/triton_utils.cc | 3 +- .../SonicTriton/test/BuildFile.xml | 13 +- .../test/retry_action_diff_log_test.sh | 22 ++ .../test/test_RetryActionDiffServer.cc | 73 ++++++ .../SonicTriton/test/tritonTest_cfg.py | 109 ++++++++- 27 files changed, 666 insertions(+), 178 deletions(-) create mode 100644 HeterogeneousCore/SonicCore/interface/RetryActionBase.h create mode 100644 HeterogeneousCore/SonicCore/plugins/BuildFile.xml create mode 100644 HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc create mode 100644 HeterogeneousCore/SonicCore/src/RetryActionBase.cc create mode 100644 HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h create mode 100644 HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc create mode 100755 HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh create mode 100644 HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc diff --git a/HeterogeneousCore/SonicCore/BuildFile.xml b/HeterogeneousCore/SonicCore/BuildFile.xml index b0d5e2a08b98f..5208c91638f37 100644 --- a/HeterogeneousCore/SonicCore/BuildFile.xml +++ b/HeterogeneousCore/SonicCore/BuildFile.xml @@ -2,6 +2,7 @@ + diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h new file mode 100644 index 0000000000000..e3fc0bbb8af9a --- /dev/null +++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h @@ -0,0 +1,35 @@ +#ifndef HeterogeneousCore_SonicCore_RetryActionBase +#define HeterogeneousCore_SonicCore_RetryActionBase + +#include "FWCore/PluginManager/interface/PluginFactory.h" +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" +#include +#include + +// Base class for retry actions +class RetryActionBase { +public: + RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client); + virtual ~RetryActionBase() = default; + + bool shouldRetry() const { return shouldRetry_; } // Getter for shouldRetry_ + + virtual void retry() = 0; // Pure virtual function for execution logic + virtual void start() = 0; // Pure virtual function for execution logic for initialization + +protected: + void eval(); // interface for calling evaluate in client + +protected: + SonicClientBase* client_; + bool shouldRetry_; // Flag to track if further retries should happen +}; + +// Define the factory for creating retry actions +using RetryActionFactory = + edmplugin::PluginFactory; + +#endif + +#define DEFINE_RETRY_ACTION(type) DEFINE_EDM_PLUGIN(RetryActionFactory, type, #type); diff --git a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h index 47caaae8b2052..45a089701ed12 100644 --- a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h +++ b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h @@ -9,12 +9,15 @@ #include "HeterogeneousCore/SonicCore/interface/SonicDispatcherPseudoAsync.h" #include +#include #include #include #include enum class SonicMode { Sync = 1, Async = 2, PseudoAsync = 3 }; +class RetryActionBase; + class SonicClientBase { public: //constructor @@ -54,14 +57,23 @@ class SonicClientBase { SonicMode mode_; bool verbose_; std::unique_ptr dispatcher_; - unsigned allowedTries_, tries_; + unsigned totalTries_; std::optional holder_; + // Use a unique_ptr with a custom deleter to avoid incomplete type issues + struct RetryDeleter { + void operator()(RetryActionBase* ptr) const; + }; + + using RetryActionPtr = std::unique_ptr; + std::vector retryActions_; + //for logging/debugging std::string debugName_, clientName_, fullDebugName_; friend class SonicDispatcher; friend class SonicDispatcherPseudoAsync; + friend class RetryActionBase; }; #endif diff --git a/HeterogeneousCore/SonicCore/plugins/BuildFile.xml b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml new file mode 100644 index 0000000000000..0ecf2187a0f82 --- /dev/null +++ b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc new file mode 100644 index 0000000000000..9877013b93d5b --- /dev/null +++ b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc @@ -0,0 +1,30 @@ +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" + +class RetrySameServerAction : public RetryActionBase { +public: + RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client) + : RetryActionBase(pset, client), allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {} + + void start() override { tries_ = 0; }; + +protected: + void retry() override; + +private: + unsigned allowedTries_, tries_; +}; + +void RetrySameServerAction::retry() { + ++tries_; + //if max retries has not been exceeded, call evaluate again + if (tries_ < allowedTries_) { + eval(); + return; + } else { + shouldRetry_ = false; // Flip flag when max retries are reached + edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; + } +} + +DEFINE_RETRY_ACTION(RetrySameServerAction) diff --git a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc new file mode 100644 index 0000000000000..41b9a6186da2b --- /dev/null +++ b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc @@ -0,0 +1,15 @@ +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +// Constructor implementation +RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client) + : client_(client), shouldRetry_(true) {} + +void RetryActionBase::eval() { + if (client_) { + client_->evaluate(); + } else { + edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate."; + } +} + +EDM_REGISTER_PLUGINFACTORY(RetryActionFactory, "RetryActionFactory"); diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc index 745c51f17aaf3..9949d9d1f2ea2 100644 --- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc +++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc @@ -1,18 +1,33 @@ #include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" #include "FWCore/Utilities/interface/Exception.h" #include "FWCore/ParameterSet/interface/allowedValues.h" +// Custom deleter implementation +void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { delete ptr; } + SonicClientBase::SonicClientBase(const edm::ParameterSet& params, const std::string& debugName, const std::string& clientName) - : allowedTries_(params.getUntrackedParameter("allowedTries", 0)), - debugName_(debugName), - clientName_(clientName), - fullDebugName_(debugName_) { + : debugName_(debugName), clientName_(clientName), fullDebugName_(debugName_) { if (!clientName_.empty()) fullDebugName_ += ":" + clientName_; + const auto& retryPSetList = params.getParameter>("Retry"); std::string modeName(params.getParameter("mode")); + + for (const auto& retryPSet : retryPSetList) { + const std::string& actionType = retryPSet.getParameter("retryType"); + + auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this); + if (retryAction) { + //Convert to RetryActionPtr Type from raw pointer of retryAction + retryActions_.emplace_back(RetryActionPtr(retryAction.release())); + } else { + throw cms::Exception("Configuration") << "Unknown Retry type " << actionType << " for SonicClient: " << modeName; + } + } + if (modeName == "Sync") setMode(SonicMode::Sync); else if (modeName == "Async") @@ -40,24 +55,30 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) { holder_ = std::move(holder); } -void SonicClientBase::start() { tries_ = 0; } +void SonicClientBase::start() { + totalTries_ = 0; + // initialize all actions + for (const auto& action : retryActions_) { + action->start(); + } +} void SonicClientBase::finish(bool success, std::exception_ptr eptr) { //retries are only allowed if no exception was raised if (!success and !eptr) { - ++tries_; - //if max retries has not been exceeded, call evaluate again - if (tries_ < allowedTries_) { - evaluate(); - //avoid calling doneWaiting() twice - return; - } - //prepare an exception if exceeded - else { - edm::Exception ex(edm::errors::ExternalFailure); - ex << "SonicCallFailed: call failed after max " << tries_ << " tries"; - eptr = make_exception_ptr(ex); + ++totalTries_; + for (const auto& action : retryActions_) { + if (action->shouldRetry()) { + action->retry(); // Call retry only if shouldRetry_ is true + return; + } } + //prepare an exception if no more retries left + edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ + << " tries."; + edm::Exception ex(edm::errors::ExternalFailure); + ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; + eptr = make_exception_ptr(ex); } if (holder_) { holder_->doneWaiting(eptr); @@ -74,7 +95,20 @@ void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc //restrict allowed values desc.ifValue(edm::ParameterDescription("mode", "PseudoAsync", true), edm::allowedValues("Sync", "Async", "PseudoAsync")); - if (allowRetry) - desc.addUntracked("allowedTries", 0); + if (allowRetry) { + // Defines the structure of each entry in the VPSet + edm::ParameterSetDescription retryDesc; + retryDesc.add("retryType", "RetrySameServerAction"); + retryDesc.addUntracked("allowedTries", 0); + + // Define a default retry action + edm::ParameterSet defaultRetry; + defaultRetry.addParameter("retryType", "RetrySameServerAction"); + defaultRetry.addUntrackedParameter("allowedTries", 0); + + // Add the VPSet with the default retry action + desc.addVPSet("Retry", retryDesc, {defaultRetry}); + } + desc.add("sonicClientBase", desc); desc.addUntracked("verbose", false); } diff --git a/HeterogeneousCore/SonicCore/test/BuildFile.xml b/HeterogeneousCore/SonicCore/test/BuildFile.xml index 04b2bcb20df2f..11e8f860b5818 100644 --- a/HeterogeneousCore/SonicCore/test/BuildFile.xml +++ b/HeterogeneousCore/SonicCore/test/BuildFile.xml @@ -1,6 +1,6 @@ - - + + diff --git a/HeterogeneousCore/SonicCore/test/DummyClient.h b/HeterogeneousCore/SonicCore/test/DummyClient.h index ccef888ad9f7d..6504843926c0a 100644 --- a/HeterogeneousCore/SonicCore/test/DummyClient.h +++ b/HeterogeneousCore/SonicCore/test/DummyClient.h @@ -36,7 +36,7 @@ class DummyClient : public SonicClient { this->output_ = this->input_ * factor_; //simulate a failure - if (this->tries_ < fails_) + if (this->totalTries_ < fails_) this->finish(false); else this->finish(true); diff --git a/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py index 11c23c6cdfcc9..b8b66db34abd9 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py @@ -1,4 +1,5 @@ import FWCore.ParameterSet.Config as cms +from FWCore.ParameterSet.VarParsing import VarParsing process = cms.Process("Test") diff --git a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py index 614297d86e3bb..bf7b44cb01519 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py @@ -1,11 +1,13 @@ import FWCore.ParameterSet.Config as cms -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from FWCore.ParameterSet.VarParsing import VarParsing -_allowedModuleTypes = ["Producer","Filter"] -parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) -parser.add_argument("--moduleType", type=str, required=True, choices=_allowedModuleTypes, help="Type of module to test") -options = parser.parse_args() +options = VarParsing() +options.register("moduleType","", VarParsing.multiplicity.singleton, VarParsing.varType.string) +options.parseArguments() +_allowedModuleTypes = ["Producer","Filter"] +if options.moduleType not in ["Producer","Filter"]: + raise ValueError("Unknown module type: {} (allowed: {})".format(options.moduleType,_allowedModuleTypes)) _moduleName = "SonicDummy"+options.moduleType _moduleClass = getattr(cms,"ED"+options.moduleType) @@ -17,15 +19,19 @@ process.options.numberOfThreads = 2 process.options.numberOfStreams = 0 - process.dummySync = _moduleClass(_moduleName, input = cms.int32(1), Client = cms.PSet( mode = cms.string("Sync"), factor = cms.int32(-1), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) ), ) @@ -35,8 +41,14 @@ mode = cms.string("PseudoAsync"), factor = cms.int32(2), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) + ), ) @@ -46,32 +58,53 @@ mode = cms.string("Async"), factor = cms.int32(5), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) ), ) process.dummySyncRetry = process.dummySync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2) + ) + ) + ) ) process.dummyPseudoAsyncRetry = process.dummyPseudoAsync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2) + ) + ) ) ) process.dummyAsyncRetry = process.dummyAsync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + allowedTries = cms.untracked.uint32(2), + retryType = cms.string('RetrySameServerAction') + ) + ) ) ) diff --git a/HeterogeneousCore/SonicTriton/BuildFile.xml b/HeterogeneousCore/SonicTriton/BuildFile.xml index b93d51e711e87..4af38d69d89e9 100644 --- a/HeterogeneousCore/SonicTriton/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/BuildFile.xml @@ -10,6 +10,7 @@ + - + diff --git a/HeterogeneousCore/SonicTriton/README.md b/HeterogeneousCore/SonicTriton/README.md index 88058ed88289b..488566c937caf 100644 --- a/HeterogeneousCore/SonicTriton/README.md +++ b/HeterogeneousCore/SonicTriton/README.md @@ -71,7 +71,7 @@ There are specific local input and output containers that should be used in prod Here, `T` is a primitive type, and the two aliases listed below are passed to `TritonInputData::toServer()` and returned by `TritonOutputData::fromServer()`, respectively: * `TritonInputContainer = std::shared_ptr> = std::shared_ptr>>` -* `TritonOutput = std::vector>` +* `TritonOutput = std::vector>` The `TritonInputContainer` object should be created using the helper function described below. It expects one vector per batch entry (i.e. the size of the outer vector is the batch size (rectangular case) or number of entries (ragged case)). diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h new file mode 100644 index 0000000000000..af7720b90cb0b --- /dev/null +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -0,0 +1,32 @@ +#ifndef HeterogeneousCore_SonicTriton_RetryActionDiffServer_h +#define HeterogeneousCore_SonicTriton_RetryActionDiffServer_h + +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +/** + * @class RetryActionDiffServer + * @brief A concrete implementation of RetryActionBase that attempts to retry an inference + * request on a different Triton server. + * + * This class provides a fallback mechanism. If an initial inference request fails + * (e.g., due to server unavailability or a model-specific error), this action will be + * triggered. It queries the central TritonService to select an alternative server (e.g., + * the fallback server when available) and instructs the TritonClient to reconnect to + * that server for the retry attempt. This action is designed for one-time use per + * inference call; after the retry attempt, it disables itself until the next `start()` + * call. + */ + +class RetryActionDiffServer : public RetryActionBase { +public: + RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client); + ~RetryActionDiffServer() override = default; + + void retry() override; + void start() override; + +private: +}; + +#endif + diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index df8f9b559427c..9e21b646508e9 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -1,102 +1,107 @@ -#ifndef HeterogeneousCore_SonicTriton_TritonClient -#define HeterogeneousCore_SonicTriton_TritonClient - -#include "FWCore/ParameterSet/interface/ParameterSet.h" -#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" -#include "FWCore/ServiceRegistry/interface/ServiceToken.h" -#include "HeterogeneousCore/SonicCore/interface/SonicClient.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonData.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" - -#include -#include -#include -#include -#include - -#include "grpc_client.h" -#include "grpc_service.pb.h" - -enum class TritonBatchMode { Rectangular = 1, Ragged = 2 }; - -class TritonClient : public SonicClient { -public: - struct ServerSideStats { - uint64_t inference_count_; - uint64_t execution_count_; - uint64_t success_count_; - uint64_t cumm_time_ns_; - uint64_t queue_time_ns_; - uint64_t compute_input_time_ns_; - uint64_t compute_infer_time_ns_; - uint64_t compute_output_time_ns_; - }; - - //constructor - TritonClient(const edm::ParameterSet& params, const std::string& debugName); - - //destructor - ~TritonClient() override; - - //accessors - unsigned batchSize() const; - TritonBatchMode batchMode() const { return batchMode_; } - bool verbose() const { return verbose_; } - bool useSharedMemory() const { return useSharedMemory_; } - void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; } - bool setBatchSize(unsigned bsize); - void setBatchMode(TritonBatchMode batchMode); - void resetBatchMode(); - void reset() override; - TritonServerType serverType() const { return serverType_; } - bool isLocal() const { return isLocal_; } - - //for fillDescriptions - static void fillPSetDescription(edm::ParameterSetDescription& iDesc); - -protected: - //helpers - bool noOuterDim() const { return noOuterDim_; } - unsigned outerDim() const { return outerDim_; } - unsigned nEntries() const; - void getResults(const std::vector>& results); - void evaluate() override; - template - bool handle_exception(F&& call); - - void reportServerSideStats(const ServerSideStats& stats) const; - ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, - const inference::ModelStatistics& end_status) const; - - inference::ModelStatistics getServerSideStatus() const; - - //members - unsigned maxOuterDim_; - unsigned outerDim_; - bool noOuterDim_; - unsigned nEntries_; - TritonBatchMode batchMode_; - bool manualBatchMode_; - bool verbose_; - bool useSharedMemory_; - TritonServerType serverType_; - bool isLocal_; - grpc_compression_algorithm compressionAlgo_; - triton::client::Headers headers_; - - std::unique_ptr client_; - //stores timeout, model name and version - std::vector options_; - edm::ServiceToken token_; - -private: - friend TritonInputData; - friend TritonOutputData; - - //private accessors only used by data - auto client() { return client_.get(); } - void addEntry(unsigned entry); - void resizeEntries(unsigned entry); -}; - -#endif +#ifndef HeterogeneousCore_SonicTriton_TritonClient +#define HeterogeneousCore_SonicTriton_TritonClient + +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" +#include "FWCore/ServiceRegistry/interface/ServiceToken.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonData.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" + +#include +#include +#include +#include +#include + +#include "grpc_client.h" +#include "grpc_service.pb.h" + +enum class TritonBatchMode { Rectangular = 1, Ragged = 2 }; + +class TritonClient : public SonicClient { +public: + struct ServerSideStats { + uint64_t inference_count_; + uint64_t execution_count_; + uint64_t success_count_; + uint64_t cumm_time_ns_; + uint64_t queue_time_ns_; + uint64_t compute_input_time_ns_; + uint64_t compute_infer_time_ns_; + uint64_t compute_output_time_ns_; + }; + + //constructor + TritonClient(const edm::ParameterSet& params, const std::string& debugName); + + //destructor + ~TritonClient() override; + + //accessors + unsigned batchSize() const; + TritonBatchMode batchMode() const { return batchMode_; } + bool verbose() const { return verbose_; } + bool useSharedMemory() const { return useSharedMemory_; } + void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; } + bool setBatchSize(unsigned bsize); + void setBatchMode(TritonBatchMode batchMode); + void resetBatchMode(); + void reset() override; + TritonServerType serverType() const { return serverType_; } + bool isLocal() const { return isLocal_; } + virtual void connectToServer(const std::string& url); + virtual void updateServer(std::string serverName); + + //for fillDescriptions + static void fillPSetDescription(edm::ParameterSetDescription& iDesc); + +protected: + // Protected default constructor for unit testing (no framework services) + TritonClient(); + + //helpers + bool noOuterDim() const { return noOuterDim_; } + unsigned outerDim() const { return outerDim_; } + unsigned nEntries() const; + void getResults(const std::vector>& results); + virtual void evaluate() override; + template + bool handle_exception(F&& call); + + void reportServerSideStats(const ServerSideStats& stats) const; + ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, + const inference::ModelStatistics& end_status) const; + + inference::ModelStatistics getServerSideStatus() const; + + //members + unsigned maxOuterDim_; + unsigned outerDim_; + bool noOuterDim_; + unsigned nEntries_; + TritonBatchMode batchMode_; + bool manualBatchMode_; + bool verbose_; + bool useSharedMemory_; + TritonServerType serverType_; + bool isLocal_; + grpc_compression_algorithm compressionAlgo_; + triton::client::Headers headers_; + + std::unique_ptr client_; + //stores timeout, model name and version + std::vector options_; + edm::ServiceToken token_; + +private: + friend TritonInputData; + friend TritonOutputData; + + //private accessors only used by data + auto client() { return client_.get(); } + void addEntry(unsigned entry); + void resizeEntries(unsigned entry); +}; + +#endif diff --git a/HeterogeneousCore/SonicTriton/interface/TritonData.h b/HeterogeneousCore/SonicTriton/interface/TritonData.h index 026783be30547..a6703811b6257 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonData.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonData.h @@ -2,10 +2,10 @@ #define HeterogeneousCore_SonicTriton_TritonData #include "FWCore/Utilities/interface/Exception.h" +#include "FWCore/Utilities/interface/Span.h" #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h" #include -#include #include #include #include @@ -34,7 +34,7 @@ class TritonGpuShmResource; template using TritonInput = std::vector>; template -using TritonOutput = std::vector>; +using TritonOutput = std::vector>; //other useful typdefs template @@ -49,7 +49,7 @@ class TritonData { using Result = triton::client::InferResult; using TensorMetadata = inference::ModelMetadataResponse_TensorMetadata; using ShapeType = std::vector; - using ShapeView = std::span; + using ShapeView = edm::Span; //constructor TritonData(const std::string& name, const TensorMetadata& model_info, TritonClient* client, const std::string& pid); diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index 8f36f73566e06..470c6ad76b436 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -18,6 +18,7 @@ namespace edm { class ActivityRegistry; class ConfigurationDescriptions; + class PathsAndConsumesOfModulesBase; class ProcessContext; class ModuleDescription; namespace service { @@ -122,7 +123,7 @@ class TritonService { void preModuleConstruction(edm::ModuleDescription const&); void postModuleConstruction(edm::ModuleDescription const&); void preModuleDestruction(edm::ModuleDescription const&); - void preBeginJob(edm::ProcessContext const&); + void preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&); void postEndJob(); //helper diff --git a/HeterogeneousCore/SonicTriton/interface/triton_utils.h b/HeterogeneousCore/SonicTriton/interface/triton_utils.h index c414621717301..01ac1ef0f0b0f 100644 --- a/HeterogeneousCore/SonicTriton/interface/triton_utils.h +++ b/HeterogeneousCore/SonicTriton/interface/triton_utils.h @@ -2,11 +2,11 @@ #define HeterogeneousCore_SonicTriton_triton_utils #include "FWCore/Utilities/interface/Exception.h" +#include "FWCore/Utilities/interface/Span.h" #include "HeterogeneousCore/SonicTriton/interface/TritonException.h" #include #include -#include #include #include @@ -82,7 +82,8 @@ inline bool triton_utils::checkType(inference::DataType dtype) { throw TritonException("TritonFailure", NOTIFY) << (MSG) << (err.Message().empty() ? "" : ": " + err.Message()); \ } -extern template std::string triton_utils::printColl(const std::span& coll, const std::string& delim); +extern template std::string triton_utils::printColl(const edm::Span::const_iterator>& coll, + const std::string& delim); extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc new file mode 100644 index 0000000000000..c3f931ae567ba --- /dev/null +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -0,0 +1,36 @@ +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" +#include "FWCore/MessageLogger/interface/MessageLogger.h" +#include "FWCore/ServiceRegistry/interface/Service.h" + +RetryActionDiffServer::RetryActionDiffServer( + const edm::ParameterSet& conf, + SonicClientBase* client +): RetryActionBase(conf, client) {} + +void RetryActionDiffServer::start() { + this->shouldRetry_ = true; +} + +void RetryActionDiffServer::retry() { + if (!this->shouldRetry_) { + this->shouldRetry_ = false; + edm::LogInfo("RetryActionDiffServer") << "Retry not armed; skipping."; + return; + } + + try { + auto* tritonClient = static_cast(client_); + edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to fallback server"; + tritonClient->updateServer(TritonService::Server::fallbackName); + eval(); + } catch (const std::exception& e) { + edm::LogError("RetryActionDiffServer") + << "Failed to retry with alternative server: " + << e.what(); + } + this->shouldRetry_ = false; +} + +DEFINE_RETRY_ACTION(RetryActionDiffServer); \ No newline at end of file diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index ddcdff83448d0..404c84455e87d 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -28,6 +28,19 @@ namespace tc = triton::client; namespace { + // Minimal ParameterSet to satisfy SonicClientBase requirements during unit tests + edm::ParameterSet makeMinimalSonicParamsForTest() { + edm::ParameterSet params; + params.addParameter("mode", "PseudoAsync"); + + edm::ParameterSet defaultRetry; + defaultRetry.addParameter("retryType", "RetrySameServerAction"); + defaultRetry.addUntrackedParameter("allowedTries", 0u); + std::vector retryVec{defaultRetry}; + params.addParameter>("Retry", retryVec); + + return params; + } grpc_compression_algorithm getCompressionAlgo(const std::string& name) { if (name.empty() or name.compare("none") == 0) return grpc_compression_algorithm::GRPC_COMPRESS_NONE; @@ -61,7 +74,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d useSharedMemory_(params.getUntrackedParameter("useSharedMemory")), compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter("compression"))) { options_.emplace_back(params.getParameter("modelName")); - //get appropriate server for this model + edm::Service ts; // We save the token to be able to notify the service in case of an exception in the evaluate method. @@ -70,22 +83,9 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d // create the context. token_ = edm::ServiceRegistry::instance().presentToken(); - const auto& server = - ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter("preferredServer")); - serverType_ = server.type; - edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; - //enforce sync mode for fallback CPU server to avoid contention - //todo: could enforce async mode otherwise (unless mode was specified by user?) - if (serverType_ == TritonServerType::LocalCPU) - setMode(SonicMode::Sync); - isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; - - //connect to the server - TRITON_THROW_IF_ERROR( - tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), - "TritonClient(): unable to create inference context", - isLocal_); - + //Connect to server + updateServer(params.getUntrackedParameter("preferredServer")); + //set options options_[0].model_version_ = params.getParameter("modelVersion"); options_[0].client_timeout_ = params.getUntrackedParameter("timeout"); @@ -369,7 +369,7 @@ void TritonClient::getResults(const std::vector //default case for sync and pseudo async void TritonClient::evaluate() { //undo previous signal from TritonException - if (tries_ > 0) { + if (totalTries_ > 0) { // If we are retrying then the evaluate method is called outside the frameworks TBB thread pool. // So we need to setup the service token for the current thread to access the service registry. edm::ServiceRegistry::Operate op(token_); @@ -574,6 +574,26 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const { return inference::ModelStatistics{}; } +void TritonClient::updateServer(std::string serverName){ + //get appropriate server for this model + edm::Service ts; + + const auto& server = ts->serverInfo(options_[0].model_name_, serverName); + serverType_ = server.type; + edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; + //enforce sync mode for fallback CPU server to avoid contention + //todo: could enforce async mode otherwise (unless mode was specified by user?) + if (serverType_ == TritonServerType::LocalCPU) + setMode(SonicMode::Sync); + isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; + + //connect to the server + TRITON_THROW_IF_ERROR( + tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), + "TritonClient(): unable to create inference context", + isLocal_); +} + //for fillDescriptions void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { edm::ParameterSetDescription descClient; @@ -591,3 +611,26 @@ void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { descClient.addUntracked>("outputs", {}); iDesc.add("Client", descClient); } + +void TritonClient::connectToServer(const std::string& url) { + // Update client state for a generic remote server + serverType_ = TritonServerType::Remote; + isLocal_ = false; + + edm::LogInfo("TritonDiscovery") << debugName_ << " connecting to server: " << url; + + // Use default SSL options + triton::client::SslOptions sslOptions; + bool useSsl = false; // Assuming no SSL for direct URL connection + + // Connect to the server + TRITON_THROW_IF_ERROR( + triton::client::InferenceServerGrpcClient::Create(&client_, url, false, useSsl, sslOptions), + "TritonClient::connectToServer(): unable to create inference context", + false // isLocal is false + ); +} + +//constructor for testing +TritonClient::TritonClient() : SonicClient(makeMinimalSonicParamsForTest(), "TritonClient_test", "TritonClient") {} + diff --git a/HeterogeneousCore/SonicTriton/src/TritonData.cc b/HeterogeneousCore/SonicTriton/src/TritonData.cc index 0a462ce43c6cd..d8fc506d6e99a 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonData.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonData.cc @@ -264,7 +264,7 @@ TritonOutput
TritonOutputData::fromServer() const { for (unsigned i0 = 0; i0 < outerDim; ++i0) { auto offset = i0 * entry.sizeShape_; - dataOut.emplace_back(r1 + offset, entry.sizeShape_); + dataOut.emplace_back(r1 + offset, r1 + offset + entry.sizeShape_); } } diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index d0d82bbaa9efc..ca5aa9c7c65e7 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -161,13 +161,11 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr msg += modelName + ", "; } } else { - const std::string& baseMsg = "unable to get repository index"; - const std::string& extraMsg = err.Message().empty() ? "" : ": " + err.Message(); if (verbose_) - msg += baseMsg + extraMsg; + msg += "unable to get repository index"; else - edm::LogWarning("TritonFailure") << "TritonService(): " << baseMsg << " for " << serverName << " (" - << server.url << ")" << extraMsg; + edm::LogWarning("TritonFailure") << "TritonService(): unable to get repository index for " + serverName + " (" + + server.url + ")"; } if (verbose_) msg += "\n"; @@ -245,7 +243,7 @@ TritonService::Server TritonService::serverInfo(const std::string& model, const return server; } -void TritonService::preBeginJob(edm::ProcessContext const&) { +void TritonService::preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&) { //only need fallback if there are unserved models if (!fallbackOpts_.enable or unservedModels_.empty()) return; diff --git a/HeterogeneousCore/SonicTriton/src/triton_utils.cc b/HeterogeneousCore/SonicTriton/src/triton_utils.cc index 121b605bec63a..322dddf133381 100644 --- a/HeterogeneousCore/SonicTriton/src/triton_utils.cc +++ b/HeterogeneousCore/SonicTriton/src/triton_utils.cc @@ -19,7 +19,8 @@ namespace triton_utils { void convertToWarning(const cms::Exception& e) { edm::LogWarning(e.category()) << e.explainSelf(); } } // namespace triton_utils -template std::string triton_utils::printColl(const std::span& coll, const std::string& delim); +template std::string triton_utils::printColl(const edm::Span::const_iterator>& coll, + const std::string& delim); template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index e4ff7a0bb56f3..f6ab9108035f2 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,11 +1,22 @@ + + - + + + + + + + + + + diff --git a/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh b/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh new file mode 100755 index 0000000000000..00e1cc1dce90a --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +LOCALTOP=$1 + +tmpFile=$(mktemp -p ${LOCALTOP} RetryActionDiffLogXXXXXXXX.log) +cmsRun ${LOCALTOP}/src/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py \ + --modules TritonGraphProducer --models gat_test \ + --maxEvents 2 --unittest --device cpu --retryAction diff --verbose \ + > "$tmpFile" 2>&1 +status=$? + +if ! grep -q "Retry type: RetryActionDiffServer" "$tmpFile"; then + echo "Expected retry type log line not found" >&2 + cat "$tmpFile" + rm -f "$tmpFile" + exit 1 +fi + +rm -f "$tmpFile" +exit $status + + diff --git a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc new file mode 100644 index 0000000000000..48305dd083216 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc @@ -0,0 +1,73 @@ +#include "catch.hpp" + +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +#include "FWCore/ParameterSet/interface/ParameterSet.h" + +#include + +// Test double for TritonClient to observe updateServer calls without framework/services +class TestTritonClient : public TritonClient { +public: + TestTritonClient() : TritonClient() {} + + void connectToServer(const std::string& url) override { lastConnectedUrl = url; } + + void updateServer(std::string serverName) override { + lastUpdatedServerName = std::move(serverName); + } + + const std::string& lastUrl() const { return lastConnectedUrl; } + const std::string& lastServerName() const { return lastUpdatedServerName; } + +protected: + void evaluate() override {} + +private: + std::string lastConnectedUrl; + std::string lastUpdatedServerName; +}; + +TEST_CASE("RetryActionDiffServer switches to fallback via updateServer", "[RetryActionDiffServer]") { + edm::ParameterSet empty; + TestTritonClient client; + + RetryActionDiffServer action(empty, static_cast(&client)); + + // start should arm the action + action.start(); + REQUIRE(action.shouldRetry()); + + // retry should call updateServer with fallback name then disarm + action.retry(); + REQUIRE(client.lastServerName() == TritonService::Server::fallbackName); + + // second retry without re-arming should be a no-op: lastServerName unchanged + std::string afterFirst = client.lastServerName(); + action.retry(); + REQUIRE(client.lastServerName() == afterFirst); +} + +// A client that throws during updateServer to exercise error handling path +class ThrowingTritonClient : public TritonClient { +public: + ThrowingTritonClient() : TritonClient() {} + void updateServer(std::string) override { throw std::runtime_error("updateServer failure"); } +protected: + void evaluate() override {} +}; + +TEST_CASE("RetryActionDiffServer catches exceptions from updateServer", "[RetryActionDiffServer]") { + edm::ParameterSet empty; + ThrowingTritonClient client; + RetryActionDiffServer action(empty, static_cast(&client)); + action.start(); + + // Should not throw despite client throwing internally; action disarms afterward + REQUIRE_NOTHROW(action.retry()); +} + + diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 33d6a9c60aad4..b5915ea1a05ce 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -1,6 +1,6 @@ import FWCore.ParameterSet.Config as cms import os, sys, json -from HeterogeneousCore.SonicTriton.customize import getDefaultClientPSet, getParser, getOptions, applyOptions, applyClientOptions +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter # module/model correspondence models = { @@ -13,16 +13,47 @@ # other choices allowed_modes = ["Async","PseudoAsync","Sync"] +allowed_compression = ["none","deflate","gzip"] +allowed_devices = ["auto","cpu","gpu"] +allowed_containers = ["apptainer","docker","podman","podman-hpc"] -parser = getParser() +parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) +parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)") +parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)") +parser.add_argument("--address", default="", type=str, help="server address") +parser.add_argument("--port", default=8001, type=int, help="server port") +parser.add_argument("--timeout", default=30, type=int, help="timeout for requests") +parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout") +parser.add_argument("--params", default="", type=str, help="json file containing server address/port") +parser.add_argument("--threads", default=1, type=int, help="number of threads") +parser.add_argument("--streams", default=0, type=int, help="number of streams") parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)") parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)") parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client") +parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output") +parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients") +parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server") +parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService") +parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService") parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") +parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server") parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes") parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa") +parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory") +parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression") +parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication") +parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server") +parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server") +parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request") +parser.add_argument("--retryAction", default="same", type=str, choices=["same","diff"], help="retry policy: same server or different server") +options = parser.parse_args() -options = getOptions(parser, verbose=True) +if len(options.params)>0: + with open(options.params,'r') as pfile: + pdict = json.load(pfile) + options.address = pdict["address"] + options.port = int(pdict["port"]) + print("server = "+options.address+":"+str(options.port)) # check models and modules if len(options.modules)!=len(options.models): @@ -38,8 +69,30 @@ process = cms.Process('tritonTest',enableSonicTriton) process.load("HeterogeneousCore.SonicTriton.TritonService_cff") + +process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) ) + process.source = cms.Source("EmptySource") +process.TritonService.verbose = options.verbose or options.verboseService or options.verboseDiscovery +process.TritonService.fallback.verbose = options.verbose or options.verboseServer +process.TritonService.fallback.container = options.container +process.TritonService.fallback.device = options.device +if len(options.fallbackName)>0: + process.TritonService.fallback.instanceBaseName = options.fallbackName +if len(options.address)>0: + process.TritonService.servers.append( + cms.PSet( + name = cms.untracked.string(options.serverName), + address = cms.untracked.string(options.address), + port = cms.untracked.uint32(options.port), + useSsl = cms.untracked.bool(options.ssl), + rootCertificates = cms.untracked.string(""), + privateKey = cms.untracked.string(""), + certificateChain = cms.untracked.string(""), + ) + ) + # Let it run process.p = cms.Path() @@ -49,19 +102,45 @@ "Analyzer": cms.EDAnalyzer, } -defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options) +keepMsgs = [] +if options.verbose or options.verboseDiscovery: + keepMsgs.append('TritonDiscovery') +if options.verbose or options.verboseClient: + keepMsgs.append('TritonClient') +if options.verbose or options.verboseService: + keepMsgs.append('TritonService') +if options.verbose: + # ensure RetryActionDiffServer messages are not suppressed if emitted + keepMsgs.append('RetryActionDiffServer') for im,module in enumerate(options.modules): model = options.models[im] Module = [obj for name,obj in modules.items() if name in module][0] setattr(process, module, Module(module, - Client = defaultClient.clone( + Client = cms.PSet( mode = cms.string(options.mode), preferredServer = cms.untracked.string(""), + timeout = cms.untracked.uint32(options.timeout), + timeoutUnit = cms.untracked.string(options.timeoutUnit), modelName = cms.string(model), modelVersion = cms.string(""), modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)), + verbose = cms.untracked.bool(options.verbose or options.verboseClient), + useSharedMemory = cms.untracked.bool(not options.noShm), + compression = cms.untracked.string(options.compression), + Retry = ( + cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(options.tries) + ) + ) if options.retryAction == 'same' else cms.VPSet( + cms.PSet( + retryType = cms.string('RetryActionDiffServer') + ) + ) + ) ) ) ) @@ -84,6 +163,10 @@ processModule.edgeMax = cms.uint32(15000) processModule.brief = cms.bool(options.brief) process.p += processModule + if options.verbose: + print("Retry type:", ('RetrySameServerAction' if options.retryAction == 'same' else 'RetryActionDiffServer')) + if options.verbose or options.verboseClient: + keepMsgs.extend([module,module+':TritonClient']) if options.testother: # clone modules to test both gRPC and shared memory _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM" @@ -94,5 +177,19 @@ ) processModule2 = getattr(process, _module2) process.p += processModule2 + if options.verbose or options.verboseClient: + keepMsgs.extend([_module2,_module2+':TritonClient']) + +process.load('FWCore/MessageService/MessageLogger_cfi') +process.MessageLogger.cerr.FwkReport.reportEvery = 500 +for msg in keepMsgs: + setattr(process.MessageLogger.cerr,msg, + cms.untracked.PSet( + limit = cms.untracked.int32(10000000), + ) + ) + +if options.threads>0: + process.options.numberOfThreads = options.threads + process.options.numberOfStreams = options.streams -process = applyOptions(process, options)