|
32 | 32 | lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
|
33 | 33 | sys.path.append(osp.join(lmdeploy_dir, 'lib'))
|
34 | 34 | import _turbomind as _tm # noqa: E402
|
| 35 | +import _xgrammar as _xgr # noqa: E402 |
35 | 36 |
|
36 | 37 | logger = get_logger('lmdeploy')
|
37 | 38 |
|
@@ -125,6 +126,11 @@ def __init__(self,
|
125 | 126 | model_name: str = None,
|
126 | 127 | chat_template_name: str = None,
|
127 | 128 | engine_config: TurbomindEngineConfig = None,
|
| 129 | + decode_grammar: Optional[str] = None, |
| 130 | + decode_grammar_type: str = 'json_schema', |
| 131 | + decode_grammar_threads: int = 4, |
| 132 | + decode_grammar_vocab_size: Optional[int] = None, |
| 133 | + decode_grammar_extra: Dict[str, Any] = {}, |
128 | 134 | **kwargs):
|
129 | 135 | self.model_name = model_name
|
130 | 136 | self.chat_template_name = chat_template_name
|
@@ -156,12 +162,25 @@ def __init__(self,
|
156 | 162 |
|
157 | 163 | self.session_len = self.config.session_len
|
158 | 164 |
|
| 165 | + if decode_grammar is not None: |
| 166 | + tokenizer_info = _xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=decode_grammar_vocab_size) |
| 167 | + compiler = _xgr.GrammarCompiler(tokenizer_info, max_threads=decode_grammar_threads) |
| 168 | + |
| 169 | + if decode_grammar_type == 'json_schema': |
| 170 | + grammar = compiler.compile_json_schema(decode_grammar, **decode_grammar_extra) |
| 171 | + elif decode_grammar_type == 'regex': |
| 172 | + grammar = compiler.from_regex(decode_grammar) |
| 173 | + else: |
| 174 | + assert False, f'Decode grammar type {decode_grammar_type} should be in ["json_schema", "regex"]' |
| 175 | + |
| 176 | + self.grammar = grammar |
| 177 | + |
159 | 178 | def _check_unloaded_tm_params(self):
|
160 | 179 | tm_params = self._tm_model.tm_params
|
161 | 180 | if len(tm_params) > 0:
|
162 | 181 | uninitialized = list(tm_params.keys())
|
163 | 182 | logger.warning('the model may not be loaded successfully '
|
164 |
| - f'with {len(tm_params)} uninitialized params:\n{uninitialized}') |
| 183 | + f'with {len(tm_params)} uninitialized params:\n{uninitialized}') # noqa: E231 |
165 | 184 |
|
166 | 185 | def _load_weights(self):
|
167 | 186 | """Load weights."""
|
@@ -252,7 +271,7 @@ def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: Tu
|
252 | 271 | # pack `self.config` and `self.engine_config` into a dict
|
253 | 272 | self.config_dict = self.config.to_dict()
|
254 | 273 | self.config_dict.update(dict(engine_config=asdict(self.engine_config)))
|
255 |
| - logger.info(f'turbomind model config:\n\n' |
| 274 | + logger.info(f'turbomind model config:\n\n' # noqa: E231 |
256 | 275 | f'{json.dumps(self.config_dict, indent=2)}')
|
257 | 276 |
|
258 | 277 | def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):
|
@@ -550,6 +569,9 @@ def model_inst(self):
|
550 | 569 |
|
551 | 570 | def _create_model_instance(self, device_id):
|
552 | 571 | model_inst = self.tm_model.model_comm.create_model_instance(device_id)
|
| 572 | + if hasattr(self.tm_model, 'grammar'): |
| 573 | + model_inst.set_grammar(self.tm_model.grammar) |
| 574 | + |
553 | 575 | return model_inst
|
554 | 576 |
|
555 | 577 | def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor], gen_config: GenerationConfig,
|
|
0 commit comments