|
10 | 10 | from collections.abc import Sequence
|
11 | 11 | from concurrent.futures import ThreadPoolExecutor
|
12 | 12 | from dataclasses import asdict
|
| 13 | +from enum import Enum |
13 | 14 | from functools import partial
|
14 | 15 | from multiprocessing.reduction import ForkingPickler
|
15 | 16 | from queue import Queue
|
16 |
| -from typing import Any, Dict, List, Optional |
| 17 | +from typing import Any, Dict, List, Optional, Union |
17 | 18 |
|
18 | 19 | import numpy as np
|
19 | 20 | import torch
|
@@ -105,6 +106,39 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
|
105 | 106 | cfg.devices = cfg.devices or list(range(cfg.device_num))
|
106 | 107 |
|
107 | 108 |
|
| 109 | +# Borrowed from xgrammar's TokenizerInfo.VocabType |
| 110 | +class VocabType(Enum): |
| 111 | + """The type of the vocabulary. |
| 112 | +
|
| 113 | + Used in TokenizerInfo. XGrammar supports three types of |
| 114 | + vocabularies: RAW, BYTE_FALLBACK, BYTE_LEVEL. |
| 115 | + """ |
| 116 | + |
| 117 | + RAW = 0 |
| 118 | + """The vocabulary is in the raw format. |
| 119 | +
|
| 120 | + The tokens in the vocabulary are kept in their original form without any processing. This kind of tokenizer includes |
| 121 | + the tiktoken tokenizer, e.g. microsoft/Phi-3-small-8k-instruct, Qwen/Qwen-7B-Chat, etc. |
| 122 | + """ |
| 123 | + |
| 124 | + BYTE_FALLBACK = 1 |
| 125 | + """The vocabulary used in the byte fallback BPE tokenizer. |
| 126 | +
|
| 127 | + The tokens are encoded through the byte-fallback conversion. E.g. "\u001b" -> "<0x1B>", " apple" -> "▁apple". This |
| 128 | + kind of tokenizer includes meta-llama/Llama-2-7b-chat, microsoft/Phi-3.5-mini-instruct, etc. |
| 129 | + """ |
| 130 | + |
| 131 | + BYTE_LEVEL = 2 |
| 132 | + """The vocabulary used in the byte level BPE tokenizer. |
| 133 | +
|
| 134 | + The tokens are encoded through the byte-to-unicode conversion, as in |
| 135 | + https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 |
| 136 | +
|
| 137 | + This kind of tokenizer includes meta-llama/Meta-Llama-3-8B-Instruct, |
| 138 | + meta-llama/Meta-Llama-3.1-8B-Instruct, etc. |
| 139 | + """ |
| 140 | + |
| 141 | + |
108 | 142 | class TurboMind:
|
109 | 143 | """LMDeploy's inference engine.
|
110 | 144 |
|
@@ -163,18 +197,177 @@ def __init__(self,
|
163 | 197 | self.session_len = self.config.session_len
|
164 | 198 |
|
165 | 199 | if decode_grammar is not None:
|
166 |
| - tokenizer_info = _xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=decode_grammar_vocab_size) |
| 200 | + tokenizer_info = self._get_xgrammar_tokenizer_info(tokenizer, vocab_size=decode_grammar_vocab_size) |
167 | 201 | compiler = _xgr.GrammarCompiler(tokenizer_info, max_threads=decode_grammar_threads)
|
168 | 202 |
|
169 | 203 | if decode_grammar_type == 'json_schema':
|
170 | 204 | grammar = compiler.compile_json_schema(decode_grammar, **decode_grammar_extra)
|
171 | 205 | elif decode_grammar_type == 'regex':
|
172 |
| - grammar = compiler.from_regex(decode_grammar) |
| 206 | + grammar = compiler.compile_regex(decode_grammar) |
173 | 207 | else:
|
174 | 208 | assert False, f'Decode grammar type {decode_grammar_type} should be in ["json_schema", "regex"]'
|
175 | 209 |
|
176 | 210 | self.grammar = grammar
|
177 | 211 |
|
| 212 | + # Borrowed from xgrammar's TokenizerInfo.from_huggingface |
| 213 | + def _get_xgrammar_tokenizer_info( |
| 214 | + self, |
| 215 | + tokenizer: 'PreTrainedTokenizerBase', # noqa: F821 |
| 216 | + *, |
| 217 | + vocab_size: Optional[int] = None, |
| 218 | + stop_token_ids: Optional[Union[List[int], int]] = None, |
| 219 | + ) -> 'TokenizerInfo': # noqa: F821 |
| 220 | + """Construct the tokenizer info from the huggingface tokenizer. This |
| 221 | + constructor supports various tokenizer backends, including the |
| 222 | + huggingface fast tokenizer and tiktoken tokenizer. Necessary |
| 223 | + information is automatically detected from the tokenizer. |
| 224 | +
|
| 225 | + The vocab_size parameter is introduced to handle the misalignment between the model's |
| 226 | + vocab_size and the tokenizer's vocabulary size. User should pass the model's vocab_size |
| 227 | + (could be defined in the model config) here. See docs of vocab_size for more details. |
| 228 | +
|
| 229 | + The stop token ids is by default the eos_token_id of the tokenizer. If there are other |
| 230 | + stop tokens, you can specify them manually. |
| 231 | +
|
| 232 | + Parameters |
| 233 | + ---------- |
| 234 | + tokenizer : PreTrainedTokenizerBase |
| 235 | + The huggingface tokenizer. |
| 236 | +
|
| 237 | + vocab_size : Optional[int], default: None |
| 238 | + The vocabulary size **defined by the model** (**not the tokenizer**). This equals to the |
| 239 | + vocab dimension of the model's lm_head. This is the size of the token mask. |
| 240 | +
|
| 241 | + It can be: |
| 242 | +
|
| 243 | + 1. the same as the tokenizer's vocabulary size. This is the most common case. |
| 244 | + 2. larger than the tokenizer's vocabulary size. This happens when the model has padding |
| 245 | + to lm_head, possibly due to aligning lm_head to the power of 2. |
| 246 | + E.g. Phi-3 and Deepseek-V2. |
| 247 | + 3. smaller than the tokenizer's vocabulary size. This happens when the tokenizer has |
| 248 | + some added tokens that will not supported by the model. E.g. |
| 249 | + Llama-3.2 Vision and Molmo-72B-0924 has padded `<|image|>` tokens, but they will not |
| 250 | + be considered in lm_head or generated by the model. |
| 251 | +
|
| 252 | + model_vocab_size need to be provided for case 2 and 3. If not provided, it will be |
| 253 | + set to the tokenizer's vocabulary size. |
| 254 | +
|
| 255 | + stop_token_ids : Optional[List[int]], default: None |
| 256 | + The stop token ids. If not provided, the eos_token_id of the tokenizer will be used. |
| 257 | +
|
| 258 | + Returns |
| 259 | + ------- |
| 260 | + tokenizer_info : TokenizerInfo |
| 261 | + The tokenizer info. |
| 262 | + """ |
| 263 | + from transformers import PreTrainedTokenizerFast |
| 264 | + |
| 265 | + if isinstance(stop_token_ids, int): |
| 266 | + stop_token_ids = [stop_token_ids] |
| 267 | + if isinstance(stop_token_ids, list) and len(stop_token_ids) == 0: |
| 268 | + raise ValueError('stop_token_ids cannot be empty') |
| 269 | + |
| 270 | + try: |
| 271 | + vocab_dict = tokenizer.get_vocab() |
| 272 | + except AttributeError as e: |
| 273 | + msg = (f'Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer ' |
| 274 | + 'should have a get_vocab method.') |
| 275 | + raise ValueError(msg) from e |
| 276 | + |
| 277 | + # Some tokenizer don't have token id 0 or 1 or 2. So the max_id could be larger than the |
| 278 | + # number of tokens. |
| 279 | + max_id = max(vocab_dict.values()) |
| 280 | + tokenizer_vocab_size = max(len(vocab_dict), max_id + 1) |
| 281 | + |
| 282 | + vocab_size = vocab_size or tokenizer_vocab_size |
| 283 | + |
| 284 | + # maintain tokenizer's indexing |
| 285 | + encoded_vocab = [''] * vocab_size |
| 286 | + for token, idx in vocab_dict.items(): |
| 287 | + if idx < vocab_size: |
| 288 | + encoded_vocab[idx] = token |
| 289 | + |
| 290 | + if isinstance(tokenizer, PreTrainedTokenizerFast): |
| 291 | + # huggingface fast tokenizer |
| 292 | + # - the vocabulary is directly obtained from tokenizer.get_vocab() |
| 293 | + # (tokenizer.backend_tokenizer.to_str() may not contain the full vocab, special |
| 294 | + # tokens may be omitted) |
| 295 | + # - the vocab size is obtained from len(tokenizer.get_vocab()) or provided by user |
| 296 | + # - the vocab type and add_prefix_space are obtained from |
| 297 | + # tokenizer.backend_tokenizer.to_str() |
| 298 | + # - stop token id is provided by user, or auto detected. |
| 299 | + backend_str = tokenizer.backend_tokenizer.to_str() |
| 300 | + if stop_token_ids is None: |
| 301 | + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: |
| 302 | + stop_token_ids = [tokenizer.eos_token_id] |
| 303 | + else: |
| 304 | + logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, ' |
| 305 | + 'stop_token_ids is neither provided by user nor found from the tokenizer. ' |
| 306 | + 'It will be automatically detected.') |
| 307 | + metadata = json.loads(_xgr.TokenizerInfo._detect_metadata_from_hf(backend_str)) |
| 308 | + return _xgr.TokenizerInfo( |
| 309 | + encoded_vocab, |
| 310 | + vocab_type=metadata['vocab_type'], |
| 311 | + vocab_size=vocab_size, |
| 312 | + stop_token_ids=stop_token_ids, |
| 313 | + add_prefix_space=metadata['add_prefix_space'], |
| 314 | + ) |
| 315 | + |
| 316 | + elif _xgr.TokenizerInfo._is_tiktoken_tokenizer(tokenizer): |
| 317 | + # tiktoken tokenizer |
| 318 | + # e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously) |
| 319 | + if stop_token_ids is None: |
| 320 | + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: |
| 321 | + stop_token_ids = [tokenizer.eos_token_id] |
| 322 | + else: |
| 323 | + logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, ' |
| 324 | + 'stop_token_ids is neither provided by user nor found from the tokenizer. ' |
| 325 | + 'It will be automatically detected.') |
| 326 | + return _xgr.TokenizerInfo( |
| 327 | + encoded_vocab, |
| 328 | + VocabType.RAW, |
| 329 | + vocab_size=vocab_size, |
| 330 | + stop_token_ids=stop_token_ids, |
| 331 | + add_prefix_space=False, |
| 332 | + ) |
| 333 | + |
| 334 | + elif _xgr.TokenizerInfo._is_sentencepiece_tokenizer(tokenizer): |
| 335 | + # sentencepiece tokenizer |
| 336 | + # e.g. Chatglm3-6b |
| 337 | + if hasattr(tokenizer, 'sp_model'): |
| 338 | + sp_model = tokenizer.sp_model |
| 339 | + elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'): |
| 340 | + sp_model = tokenizer.tokenizer.sp_model |
| 341 | + |
| 342 | + if stop_token_ids is None: |
| 343 | + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: |
| 344 | + stop_token_ids = [tokenizer.eos_token_id] |
| 345 | + else: |
| 346 | + eos_id = sp_model.eos_id() |
| 347 | + if eos_id != -1: |
| 348 | + stop_token_ids = [eos_id] |
| 349 | + else: |
| 350 | + logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, ' |
| 351 | + 'stop_token_ids is neither provided by user nor found from the tokenizer. ' |
| 352 | + 'It will be automatically detected.') |
| 353 | + # detect vocab_type of tokenizer |
| 354 | + if '<0x0A>' in vocab_dict: |
| 355 | + vocab_type = VocabType.BYTE_FALLBACK |
| 356 | + else: |
| 357 | + vocab_type = VocabType.RAW |
| 358 | + |
| 359 | + return _xgr.TokenizerInfo( |
| 360 | + encoded_vocab, |
| 361 | + vocab_type=vocab_type, |
| 362 | + vocab_size=vocab_size, |
| 363 | + stop_token_ids=stop_token_ids, |
| 364 | + add_prefix_space=True, |
| 365 | + ) |
| 366 | + |
| 367 | + else: |
| 368 | + # TODO(yixin): unsupported tokenizer |
| 369 | + raise ValueError(f'Unsupported tokenizer type: {type(tokenizer)}') |
| 370 | + |
178 | 371 | def _check_unloaded_tm_params(self):
|
179 | 372 | tm_params = self._tm_model.tm_params
|
180 | 373 | if len(tm_params) > 0:
|
|
0 commit comments