@@ -262,6 +262,16 @@ int ModelSession::GetServingSessionId() {
262262}
263263
264264Status ModelSession::Predict (Request& req, Response& resp) {
265+ return InternalPredict (req, resp, GetServingSessionId ());
266+ }
267+
268+ Status ModelSession::Predict (Request& req, Response& resp,
269+ int sess_id) {
270+ return InternalPredict (req, resp, sess_id);
271+ }
272+
273+ Status ModelSession::InternalPredict (Request& req, Response& resp,
274+ int sess_id) {
265275 if (is_local_) {
266276 return Status (error::Code::INTERNAL,
267277 " Local sparse storage, please use LocalPredict." );
@@ -278,17 +288,31 @@ Status ModelSession::Predict(Request& req, Response& resp) {
278288 // TODO: which session selected to run on, add some policy here
279289 status = session_group_->Run (run_options, req.inputs ,
280290 req.output_tensor_names , {}, &resp.outputs ,
281- &run_metadata, GetServingSessionId () );
291+ &run_metadata, sess_id );
282292 Tracer::GetTracer ()->GenTimeline (run_metadata);
283293 } else {
284294 status = session_group_->Run (req.inputs , req.output_tensor_names ,
285- {}, &resp.outputs , GetServingSessionId () );
295+ {}, &resp.outputs , sess_id );
286296 }
287297 --counter_;
288298 return status;
289299}
290300
291- Status ModelSession::LocalPredict (Request& req, Response& resp) {
301+ Status ModelSession::LocalPredict (Request& req,
302+ Response& resp) {
303+ return InternalLocalPredict (req, resp,
304+ GetServingSessionId ());
305+ }
306+
307+ Status ModelSession::LocalPredict (Request& req,
308+ Response& resp,
309+ int sess_id) {
310+ return InternalLocalPredict (req, resp, sess_id);
311+ }
312+
313+ Status ModelSession::InternalLocalPredict (Request& req,
314+ Response& resp,
315+ int sess_id) {
292316 if (!is_local_) {
293317 return Status (error::Code::INTERNAL,
294318 " Remote sparse storage, please use Predict." );
@@ -302,16 +326,31 @@ Status ModelSession::LocalPredict(Request& req, Response& resp) {
302326 // TODO: which session selected to run on, add some policy here
303327 status = session_group_->Run (run_options, req.inputs ,
304328 req.output_tensor_names , {}, &resp.outputs ,
305- &run_metadata, GetServingSessionId () );
329+ &run_metadata, sess_id );
306330 Tracer::GetTracer ()->GenTimeline (run_metadata);
307331 } else {
308332 status = session_group_->Run (req.inputs , req.output_tensor_names ,
309- {}, &resp.outputs , GetServingSessionId () );
333+ {}, &resp.outputs , sess_id );
310334 }
311335 --counter_;
312336 return status;
313337}
314338
339+ Status ModelSession::Warmup (Request& req, Response& resp, bool local) {
340+ int N = session_group_->GetSessionNum ();
341+ for (int i = 0 ; i < N; ++i) {
342+ Status s;
343+ if (local) {
344+ s = LocalPredict (req, resp, i);
345+ } else {
346+ s = Predict (req, resp, i);
347+ }
348+ if (!s.ok ()) return s;
349+ }
350+
351+ return Status::OK ();
352+ }
353+
315354Status ModelSessionMgr::Predict (Request& req, Response& resp) {
316355 return serving_session_->Predict (req, resp);
317356}
@@ -320,6 +359,10 @@ Status ModelSessionMgr::LocalPredict(Request& req, Response& resp) {
320359 return serving_session_->LocalPredict (req, resp);
321360}
322361
362+ Status ModelSessionMgr::Warmup (Request& req, Response& resp, bool local) {
363+ return serving_session_->Warmup (req, resp, local);
364+ }
365+
323366Status ModelSessionMgr::CreateModelSession (
324367 const Version& version, const char * ckpt_name,
325368 IFeatureStoreMgr* sparse_storage, bool is_incr_ckpt,
0 commit comments