Skip to content

Commit 8fd6d05

Browse files
committed
fix: fix some bug and add initial tests
1 parent f93c1b1 commit 8fd6d05

File tree

10 files changed

+441
-149
lines changed

10 files changed

+441
-149
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,9 @@ add_subdirectory(src)
292292
if (BUILD_PY_FFI)
293293
if (CALL_FROM_SETUP_PY)
294294
install(TARGETS _turbomind DESTINATION ${CMAKE_INSTALL_PREFIX})
295+
install(TARGETS _xgrammar DESTINATION ${CMAKE_INSTALL_PREFIX})
295296
else()
296297
install(TARGETS _turbomind DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib)
298+
install(TARGETS _xgrammar DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib)
297299
endif()
298300
endif ()

lmdeploy/turbomind/turbomind.py

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from collections.abc import Sequence
1111
from concurrent.futures import ThreadPoolExecutor
1212
from dataclasses import asdict
13+
from enum import Enum
1314
from functools import partial
1415
from multiprocessing.reduction import ForkingPickler
1516
from queue import Queue
16-
from typing import Any, Dict, List, Optional
17+
from typing import Any, Dict, List, Optional, Union
1718

1819
import numpy as np
1920
import torch
@@ -105,6 +106,39 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
105106
cfg.devices = cfg.devices or list(range(cfg.device_num))
106107

107108

109+
# Borrowed from xgrammar's TokenizerInfo.VocabType
110+
class VocabType(Enum):
111+
"""The type of the vocabulary.
112+
113+
Used in TokenizerInfo. XGrammar supports three types of
114+
vocabularies: RAW, BYTE_FALLBACK, BYTE_LEVEL.
115+
"""
116+
117+
RAW = 0
118+
"""The vocabulary is in the raw format.
119+
120+
The tokens in the vocabulary are kept in their original form without any processing. This kind of tokenizer includes
121+
the tiktoken tokenizer, e.g. microsoft/Phi-3-small-8k-instruct, Qwen/Qwen-7B-Chat, etc.
122+
"""
123+
124+
BYTE_FALLBACK = 1
125+
"""The vocabulary used in the byte fallback BPE tokenizer.
126+
127+
The tokens are encoded through the byte-fallback conversion. E.g. "\u001b" -> "<0x1B>", " apple" -> "▁apple". This
128+
kind of tokenizer includes meta-llama/Llama-2-7b-chat, microsoft/Phi-3.5-mini-instruct, etc.
129+
"""
130+
131+
BYTE_LEVEL = 2
132+
"""The vocabulary used in the byte level BPE tokenizer.
133+
134+
The tokens are encoded through the byte-to-unicode conversion, as in
135+
https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59
136+
137+
This kind of tokenizer includes meta-llama/Meta-Llama-3-8B-Instruct,
138+
meta-llama/Meta-Llama-3.1-8B-Instruct, etc.
139+
"""
140+
141+
108142
class TurboMind:
109143
"""LMDeploy's inference engine.
110144
@@ -163,18 +197,177 @@ def __init__(self,
163197
self.session_len = self.config.session_len
164198

165199
if decode_grammar is not None:
166-
tokenizer_info = _xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=decode_grammar_vocab_size)
200+
tokenizer_info = self._get_xgrammar_tokenizer_info(tokenizer, vocab_size=decode_grammar_vocab_size)
167201
compiler = _xgr.GrammarCompiler(tokenizer_info, max_threads=decode_grammar_threads)
168202

169203
if decode_grammar_type == 'json_schema':
170204
grammar = compiler.compile_json_schema(decode_grammar, **decode_grammar_extra)
171205
elif decode_grammar_type == 'regex':
172-
grammar = compiler.from_regex(decode_grammar)
206+
grammar = compiler.compile_regex(decode_grammar)
173207
else:
174208
assert False, f'Decode grammar type {decode_grammar_type} should be in ["json_schema", "regex"]'
175209

176210
self.grammar = grammar
177211

212+
# Borrowed from xgrammar's TokenizerInfo.from_huggingface
213+
def _get_xgrammar_tokenizer_info(
214+
self,
215+
tokenizer: 'PreTrainedTokenizerBase', # noqa: F821
216+
*,
217+
vocab_size: Optional[int] = None,
218+
stop_token_ids: Optional[Union[List[int], int]] = None,
219+
) -> 'TokenizerInfo': # noqa: F821
220+
"""Construct the tokenizer info from the huggingface tokenizer. This
221+
constructor supports various tokenizer backends, including the
222+
huggingface fast tokenizer and tiktoken tokenizer. Necessary
223+
information is automatically detected from the tokenizer.
224+
225+
The vocab_size parameter is introduced to handle the misalignment between the model's
226+
vocab_size and the tokenizer's vocabulary size. User should pass the model's vocab_size
227+
(could be defined in the model config) here. See docs of vocab_size for more details.
228+
229+
The stop token ids is by default the eos_token_id of the tokenizer. If there are other
230+
stop tokens, you can specify them manually.
231+
232+
Parameters
233+
----------
234+
tokenizer : PreTrainedTokenizerBase
235+
The huggingface tokenizer.
236+
237+
vocab_size : Optional[int], default: None
238+
The vocabulary size **defined by the model** (**not the tokenizer**). This equals to the
239+
vocab dimension of the model's lm_head. This is the size of the token mask.
240+
241+
It can be:
242+
243+
1. the same as the tokenizer's vocabulary size. This is the most common case.
244+
2. larger than the tokenizer's vocabulary size. This happens when the model has padding
245+
to lm_head, possibly due to aligning lm_head to the power of 2.
246+
E.g. Phi-3 and Deepseek-V2.
247+
3. smaller than the tokenizer's vocabulary size. This happens when the tokenizer has
248+
some added tokens that will not supported by the model. E.g.
249+
Llama-3.2 Vision and Molmo-72B-0924 has padded `<|image|>` tokens, but they will not
250+
be considered in lm_head or generated by the model.
251+
252+
model_vocab_size need to be provided for case 2 and 3. If not provided, it will be
253+
set to the tokenizer's vocabulary size.
254+
255+
stop_token_ids : Optional[List[int]], default: None
256+
The stop token ids. If not provided, the eos_token_id of the tokenizer will be used.
257+
258+
Returns
259+
-------
260+
tokenizer_info : TokenizerInfo
261+
The tokenizer info.
262+
"""
263+
from transformers import PreTrainedTokenizerFast
264+
265+
if isinstance(stop_token_ids, int):
266+
stop_token_ids = [stop_token_ids]
267+
if isinstance(stop_token_ids, list) and len(stop_token_ids) == 0:
268+
raise ValueError('stop_token_ids cannot be empty')
269+
270+
try:
271+
vocab_dict = tokenizer.get_vocab()
272+
except AttributeError as e:
273+
msg = (f'Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer '
274+
'should have a get_vocab method.')
275+
raise ValueError(msg) from e
276+
277+
# Some tokenizer don't have token id 0 or 1 or 2. So the max_id could be larger than the
278+
# number of tokens.
279+
max_id = max(vocab_dict.values())
280+
tokenizer_vocab_size = max(len(vocab_dict), max_id + 1)
281+
282+
vocab_size = vocab_size or tokenizer_vocab_size
283+
284+
# maintain tokenizer's indexing
285+
encoded_vocab = [''] * vocab_size
286+
for token, idx in vocab_dict.items():
287+
if idx < vocab_size:
288+
encoded_vocab[idx] = token
289+
290+
if isinstance(tokenizer, PreTrainedTokenizerFast):
291+
# huggingface fast tokenizer
292+
# - the vocabulary is directly obtained from tokenizer.get_vocab()
293+
# (tokenizer.backend_tokenizer.to_str() may not contain the full vocab, special
294+
# tokens may be omitted)
295+
# - the vocab size is obtained from len(tokenizer.get_vocab()) or provided by user
296+
# - the vocab type and add_prefix_space are obtained from
297+
# tokenizer.backend_tokenizer.to_str()
298+
# - stop token id is provided by user, or auto detected.
299+
backend_str = tokenizer.backend_tokenizer.to_str()
300+
if stop_token_ids is None:
301+
if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
302+
stop_token_ids = [tokenizer.eos_token_id]
303+
else:
304+
logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '
305+
'stop_token_ids is neither provided by user nor found from the tokenizer. '
306+
'It will be automatically detected.')
307+
metadata = json.loads(_xgr.TokenizerInfo._detect_metadata_from_hf(backend_str))
308+
return _xgr.TokenizerInfo(
309+
encoded_vocab,
310+
vocab_type=metadata['vocab_type'],
311+
vocab_size=vocab_size,
312+
stop_token_ids=stop_token_ids,
313+
add_prefix_space=metadata['add_prefix_space'],
314+
)
315+
316+
elif _xgr.TokenizerInfo._is_tiktoken_tokenizer(tokenizer):
317+
# tiktoken tokenizer
318+
# e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)
319+
if stop_token_ids is None:
320+
if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
321+
stop_token_ids = [tokenizer.eos_token_id]
322+
else:
323+
logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '
324+
'stop_token_ids is neither provided by user nor found from the tokenizer. '
325+
'It will be automatically detected.')
326+
return _xgr.TokenizerInfo(
327+
encoded_vocab,
328+
VocabType.RAW,
329+
vocab_size=vocab_size,
330+
stop_token_ids=stop_token_ids,
331+
add_prefix_space=False,
332+
)
333+
334+
elif _xgr.TokenizerInfo._is_sentencepiece_tokenizer(tokenizer):
335+
# sentencepiece tokenizer
336+
# e.g. Chatglm3-6b
337+
if hasattr(tokenizer, 'sp_model'):
338+
sp_model = tokenizer.sp_model
339+
elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'):
340+
sp_model = tokenizer.tokenizer.sp_model
341+
342+
if stop_token_ids is None:
343+
if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
344+
stop_token_ids = [tokenizer.eos_token_id]
345+
else:
346+
eos_id = sp_model.eos_id()
347+
if eos_id != -1:
348+
stop_token_ids = [eos_id]
349+
else:
350+
logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '
351+
'stop_token_ids is neither provided by user nor found from the tokenizer. '
352+
'It will be automatically detected.')
353+
# detect vocab_type of tokenizer
354+
if '<0x0A>' in vocab_dict:
355+
vocab_type = VocabType.BYTE_FALLBACK
356+
else:
357+
vocab_type = VocabType.RAW
358+
359+
return _xgr.TokenizerInfo(
360+
encoded_vocab,
361+
vocab_type=vocab_type,
362+
vocab_size=vocab_size,
363+
stop_token_ids=stop_token_ids,
364+
add_prefix_space=True,
365+
)
366+
367+
else:
368+
# TODO(yixin): unsupported tokenizer
369+
raise ValueError(f'Unsupported tokenizer type: {type(tokenizer)}')
370+
178371
def _check_unloaded_tm_params(self):
179372
tm_params = self._tm_model.tm_params
180373
if len(tm_params) > 0:

src/turbomind/engine/model_request.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output
139139
return OutputParam{outputs_, state, metrics};
140140
}
141141

142-
void ModelRequest::setGrammar(std::shared_ptr<xgrammar::CompiledGrammar> grammar)
142+
void ModelRequest::setGrammar(const xgrammar::CompiledGrammar& grammar)
143143
{
144-
grammar_ = grammar;
144+
grammar_ = std::make_shared<xgrammar::CompiledGrammar>(grammar);
145145
}
146146

147147
} // namespace turbomind

src/turbomind/engine/model_request.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ModelRequest {
4040
};
4141

4242
OutputParam Forward(InputParam param, std::function<void()> cb);
43-
void setGrammar(std::shared_ptr<xgrammar::CompiledGrammar> grammar);
43+
void setGrammar(const xgrammar::CompiledGrammar& grammar);
4444

4545
protected:
4646
Gateway* const gateway_;
@@ -55,8 +55,8 @@ class ModelRequest {
5555

5656
std::weak_ptr<Request> request_;
5757

58-
std::shared_ptr<TensorMap> inputs_;
59-
std::shared_ptr<TensorMap> outputs_;
58+
std::shared_ptr<TensorMap> inputs_;
59+
std::shared_ptr<TensorMap> outputs_;
6060
std::shared_ptr<xgrammar::CompiledGrammar> grammar_;
6161
};
6262

src/turbomind/python/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@ if(NOT pybind11_FOUND)
1313
endif()
1414

1515
pybind11_add_module(${PROJECT_NAME} bind.cpp)
16-
target_link_libraries(${PROJECT_NAME} PRIVATE LlamaTritonBackend xgrammar)
16+
target_link_libraries(${PROJECT_NAME} PRIVATE LlamaTritonBackend)
1717
target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_14)
1818

19+
pybind11_add_module(_xgrammar xgrammar_bind.cpp)
20+
target_link_libraries(_xgrammar PRIVATE core xgrammar)
21+
target_compile_features(_xgrammar PRIVATE cxx_std_14)
22+
1923
if (CALL_FROM_SETUP_PY)
2024
set(_INSTALL_CUDA_RPATH
2125
"\$ORIGIN"

0 commit comments

Comments
 (0)