Skip to content

Commit f93c1b1

Browse files
committed
feat: enable grammar init in turbomind
1 parent 5b846a6 commit f93c1b1

File tree

6 files changed

+67
-33
lines changed

6 files changed

+67
-33
lines changed

lmdeploy/turbomind/turbomind.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
3333
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
3434
import _turbomind as _tm # noqa: E402
35+
import _xgrammar as _xgr # noqa: E402
3536

3637
logger = get_logger('lmdeploy')
3738

@@ -125,6 +126,11 @@ def __init__(self,
125126
model_name: str = None,
126127
chat_template_name: str = None,
127128
engine_config: TurbomindEngineConfig = None,
129+
decode_grammar: Optional[str] = None,
130+
decode_grammar_type: str = 'json_schema',
131+
decode_grammar_threads: int = 4,
132+
decode_grammar_vocab_size: Optional[int] = None,
133+
decode_grammar_extra: Dict[str, Any] = {},
128134
**kwargs):
129135
self.model_name = model_name
130136
self.chat_template_name = chat_template_name
@@ -156,12 +162,25 @@ def __init__(self,
156162

157163
self.session_len = self.config.session_len
158164

165+
if decode_grammar is not None:
166+
tokenizer_info = _xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=decode_grammar_vocab_size)
167+
compiler = _xgr.GrammarCompiler(tokenizer_info, max_threads=decode_grammar_threads)
168+
169+
if decode_grammar_type == 'json_schema':
170+
grammar = compiler.compile_json_schema(decode_grammar, **decode_grammar_extra)
171+
elif decode_grammar_type == 'regex':
172+
grammar = compiler.from_regex(decode_grammar)
173+
else:
174+
assert False, f'Decode grammar type {decode_grammar_type} should be in ["json_schema", "regex"]'
175+
176+
self.grammar = grammar
177+
159178
def _check_unloaded_tm_params(self):
160179
tm_params = self._tm_model.tm_params
161180
if len(tm_params) > 0:
162181
uninitialized = list(tm_params.keys())
163182
logger.warning('the model may not be loaded successfully '
164-
f'with {len(tm_params)} uninitialized params:\n{uninitialized}')
183+
f'with {len(tm_params)} uninitialized params:\n{uninitialized}') # noqa: E231
165184

166185
def _load_weights(self):
167186
"""Load weights."""
@@ -252,7 +271,7 @@ def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: Tu
252271
# pack `self.config` and `self.engine_config` into a dict
253272
self.config_dict = self.config.to_dict()
254273
self.config_dict.update(dict(engine_config=asdict(self.engine_config)))
255-
logger.info(f'turbomind model config:\n\n'
274+
logger.info(f'turbomind model config:\n\n' # noqa: E231
256275
f'{json.dumps(self.config_dict, indent=2)}')
257276

258277
def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):
@@ -550,6 +569,9 @@ def model_inst(self):
550569

551570
def _create_model_instance(self, device_id):
552571
model_inst = self.tm_model.model_comm.create_model_instance(device_id)
572+
if hasattr(self.tm_model, 'grammar'):
573+
model_inst.set_grammar(self.tm_model.grammar)
574+
553575
return model_inst
554576

555577
def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor], gen_config: GenerationConfig,

src/turbomind/engine/model_request.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output
127127
r->output_ids = outputs_->at("output_ids");
128128
r->sequence_length = outputs_->at("sequence_length");
129129

130-
if (compiled_grammar_) {
131-
r->matcher = std::make_shared<xgrammar::GrammarMatcher>(*compiled_grammar_);
130+
if (grammar_) {
131+
r->matcher = std::make_shared<xgrammar::GrammarMatcher>(*grammar_);
132132
}
133133

134134
// Keep a weak reference for canceling the request
@@ -139,4 +139,9 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output
139139
return OutputParam{outputs_, state, metrics};
140140
}
141141

142+
void ModelRequest::setGrammar(std::shared_ptr<xgrammar::CompiledGrammar> grammar)
143+
{
144+
grammar_ = grammar;
145+
}
146+
142147
} // namespace turbomind

src/turbomind/engine/model_request.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ModelRequest {
4040
};
4141

4242
OutputParam Forward(InputParam param, std::function<void()> cb);
43+
void setGrammar(std::shared_ptr<xgrammar::CompiledGrammar> grammar);
4344

4445
protected:
4546
Gateway* const gateway_;
@@ -56,7 +57,7 @@ class ModelRequest {
5657

5758
std::shared_ptr<TensorMap> inputs_;
5859
std::shared_ptr<TensorMap> outputs_;
59-
std::shared_ptr<xgrammar::CompiledGrammar> compiled_grammar_;
60+
std::shared_ptr<xgrammar::CompiledGrammar> grammar_;
6061
};
6162

6263
} // namespace turbomind

src/turbomind/python/bind.cpp

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,12 @@ PYBIND11_MODULE(_turbomind, m)
588588
py::call_guard<py::gil_scoped_release>(),
589589
"device_id"_a,
590590
"tags"_a)
591+
.def(
592+
"set_grammar",
593+
&LlamaTritonModel::setGrammar,
594+
py::call_guard<py::gil_scoped_release>(),
595+
"grammar"_a
596+
)
591597
.def("__str__", &LlamaTritonModel::toString)
592598
.def("__repr__", &LlamaTritonModel::toString)
593599
.def("get_tensor_para_size", &LlamaTritonModel::getTensorParaSize)
@@ -697,31 +703,18 @@ PYBIND11_MODULE(_xgrammar, m)
697703
return TokenizerInfo::DeserializeJSON(str, CommonEncodedVocabType(encoded_vocab));
698704
});
699705

700-
py::class_<Grammar> pyGrammar(m, "Grammar");
701-
pyGrammar
702-
.def("to_string", &Grammar::ToString)
703-
704-
.def_static("from_ebnf", &Grammar::FromEBNF)
705-
706-
.def_static("from_json_schema",
707-
&Grammar::FromJSONSchema,
708-
py::arg("schema"),
709-
py::arg("any_whitespace"),
710-
py::arg("indent") = py::none(),
711-
py::arg("separators") = py::none(),
712-
py::arg("strict_mode"),
713-
py::arg("print_converted_ebnf"),
714-
py::call_guard<py::gil_scoped_release>())
715-
716-
.def_static("from_regex", &Grammar::FromRegex, py::call_guard<py::gil_scoped_release>())
717-
718-
.def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar)
719-
720-
.def_static("union", &Grammar::Union, py::call_guard<py::gil_scoped_release>())
721-
722-
.def_static("concat", &Grammar::Concat, py::call_guard<py::gil_scoped_release>())
723-
724-
.def("serialize_json", &Grammar::SerializeJSON)
725-
726-
.def_static("deserialize_json", &Grammar::DeserializeJSON);
706+
py::class_<GrammarCompiler> pyGrammarCompiler(m, "GrammarCompiler");
707+
pyGrammarCompiler.def(py::init<const TokenizerInfo&, int, bool, int64_t>())
708+
.def("compile_json_schema",
709+
&GrammarCompiler::CompileJSONSchema,
710+
py::call_guard<py::gil_scoped_release>(),
711+
py::arg("schema"),
712+
py::arg("any_whitespace") = false,
713+
py::arg("indent") = py::none(),
714+
py::arg("separators") = py::none(),
715+
py::arg("strict_mode") = true)
716+
.def("compile_regex",
717+
&GrammarCompiler::CompileRegex,
718+
py::call_guard<py::gil_scoped_release>(),
719+
py::arg("schema"));
727720
}

src/turbomind/triton_backend/llama/LlamaTritonModel.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,12 @@ std::unique_ptr<ModelRequest> LlamaTritonModel::createModelInstance(int device_i
454454
{
455455
FT_CHECK(engines_[device_id] != nullptr);
456456

457-
return std::make_unique<ModelRequest>(
457+
auto model_inst = std::make_unique<ModelRequest>(
458458
gateway_.get(), dtype_, engine_param_.session_len, model_param_.vocab_size, model_param_.hidden_units);
459+
if (grammar_) {
460+
model_inst->setGrammar(grammar_);
461+
}
462+
return model_inst;
459463
}
460464

461465
void LlamaTritonModel::createSharedWeights(int device_id, int rank)
@@ -666,4 +670,8 @@ int LlamaTritonModel::getPipelineParaSize()
666670
return 1;
667671
}
668672

673+
void LlamaTritonModel::setGrammar(const xgrammar::CompiledGrammar& grammar) {
674+
grammar_ = std::make_shared<xgrammar::CompiledGrammar>(grammar);
675+
}
676+
669677
} // namespace turbomind

src/turbomind/triton_backend/llama/LlamaTritonModel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include <string>
2525
#include <unordered_map>
2626

27+
#include <xgrammar/xgrammar.h>
28+
2729
#include "src/turbomind/comm/device_comm.h"
2830

2931
#include "src/turbomind/engine/gateway.h"
@@ -59,6 +61,8 @@ class LlamaTritonModel {
5961

6062
void wakeup(int device_id, const std::vector<std::string>& tags);
6163

64+
void setGrammar(const xgrammar::CompiledGrammar& grammar);
65+
6266
std::string toString();
6367

6468
int getTensorParaSize();
@@ -96,6 +100,7 @@ class LlamaTritonModel {
96100

97101
std::string model_name_;
98102
std::string model_dir_;
103+
std::shared_ptr<xgrammar::CompiledGrammar> grammar_;
99104
};
100105

101106
} // namespace turbomind

0 commit comments

Comments
 (0)