1818
1919import json
2020import os
21+ from dataclasses import field
2122from enum import Enum
2223from typing import Any , Dict , List , Literal , Optional , Union
2324
2425import paddle
2526import paddle .distributed as dist
2627from paddleformers .transformers .configuration_utils import PretrainedConfig
28+ from typing_extensions import assert_never
2729
2830import fastdeploy
2931from fastdeploy import envs
3032from fastdeploy .model_executor .layers .quantization .quant_base import QuantConfigBase
3133from fastdeploy .multimodal .registry import MultimodalRegistry
3234from fastdeploy .platforms import current_platform
3335from fastdeploy .scheduler import SchedulerConfig
36+ from fastdeploy .transformer_utils .config import get_pooling_config
3437from fastdeploy .utils import ceil_div , check_unified_ckpt , get_host_ip , get_logger
3538
3639logger = get_logger ("config" , "config.log" )
3740
38- TaskOption = Literal ["generate" ]
41+ TaskOption = Literal ["auto" , "generate" , "embedding" , "embed" ]
42+
43+ RunnerType = Literal ["generate" , "pooling" ]
44+
45+ RunnerOption = Literal ["auto" , "generate" , "pooling" ]
46+
47+ ConvertOption = Literal ["auto" , "none" , "embed" ]
48+
49+ ConvertType = Literal ["none" , "embed" ]
50+
51+ _ResolvedTask = Literal ["generate" , "encode" , "embed" ]
52+
53+ _RUNNER_CONVERTS : dict [RunnerType , list [ConvertType ]] = {
54+ "generate" : [],
55+ "pooling" : ["embed" ],
56+ }
57+
58+ # Some model suffixes are based on auto classes from Transformers:
59+ # https://huggingface.co/docs/transformers/en/model_doc/auto
60+ # NOTE: Items higher on this list priority over lower ones
61+ _SUFFIX_TO_DEFAULTS : list [tuple [str , tuple [RunnerType , ConvertType ]]] = [
62+ ("ForCausalLM" , ("generate" , "none" )),
63+ ("ForConditionalGeneration" , ("generate" , "none" )),
64+ ("ChatModel" , ("generate" , "none" )),
65+ ("LMHeadModel" , ("generate" , "none" )),
66+ ("ForTextEncoding" , ("pooling" , "embed" )),
67+ ("EmbeddingModel" , ("pooling" , "embed" )),
68+ ("ForSequenceClassification" , ("pooling" , "classify" )),
69+ ("ForAudioClassification" , ("pooling" , "classify" )),
70+ ("ForImageClassification" , ("pooling" , "classify" )),
71+ ("ForVideoClassification" , ("pooling" , "classify" )),
72+ ("ClassificationModel" , ("pooling" , "classify" )),
73+ ("ForRewardModeling" , ("pooling" , "reward" )),
74+ ("RewardModel" , ("pooling" , "reward" )),
75+ # Let other `*Model`s take priority
76+ ("Model" , ("pooling" , "embed" )),
77+ ]
78+
79+
80+ def iter_architecture_defaults ():
81+ yield from _SUFFIX_TO_DEFAULTS
82+
83+
84+ def try_match_architecture_defaults (
85+ architecture : str ,
86+ * ,
87+ runner_type : Optional [RunnerType ] = None ,
88+ convert_type : Optional [ConvertType ] = None ,
89+ ):
90+ for suffix , (default_runner_type , default_convert_type ) in iter_architecture_defaults ():
91+ if (
92+ (runner_type is None or runner_type == default_runner_type )
93+ and (convert_type is None or convert_type == default_convert_type )
94+ and architecture .endswith (suffix )
95+ ):
96+ return suffix , (default_runner_type , default_convert_type )
97+ return None
3998
4099
41100class MoEPhase :
@@ -133,6 +192,12 @@ def __init__(
133192 self .eos_tokens_lens : int = 2
134193 self .lm_head_fp32 : bool = False
135194 self .model_format = "auto"
195+ self .runner = "auto"
196+ self .convert = "auto"
197+ self .pooler_config : Optional ["PoolerConfig" ] = field (init = False )
198+ self .override_pooler_config : Optional [Union [dict , "PoolerConfig" ]] = None
199+ self .revision = None
200+
136201 self .partial_rotary_factor : float = 1.0
137202 self .num_nextn_predict_layers = 0
138203 for key , value in args .items ():
@@ -161,6 +226,7 @@ def __init__(
161226 self .ori_vocab_size = args .get ("ori_vocab_size" , self .vocab_size )
162227
163228 architectures = self .architectures [0 ]
229+
164230 if MultimodalRegistry .contains_model (architectures ):
165231 self .enable_mm = True
166232 else :
@@ -171,6 +237,43 @@ def __init__(
171237 self .override_name_from_config ()
172238 self .read_from_env ()
173239 self .read_model_config ()
240+ self .runner_type = self ._get_runner_type (self .architectures , self .runner )
241+ self .convert_type = self ._get_convert_type (self .architectures , self .runner_type , self .convert )
242+
243+ registry = self .registry
244+ is_generative_model = registry .is_text_generation_model (self .architectures , self )
245+ is_pooling_model = registry .is_pooling_model (self .architectures , self )
246+ is_multimodal_model = registry .is_multimodal_model (self .architectures , self )
247+
248+ if self .runner_type == "generate" and not is_generative_model :
249+ if is_multimodal_model :
250+ pass
251+ else :
252+ generate_converts = _RUNNER_CONVERTS ["generate" ]
253+ if self .convert_type not in generate_converts :
254+ raise ValueError ("This model does not support '--runner generate." )
255+ if self .runner_type == "pooling" and not is_pooling_model :
256+ pooling_converts = _RUNNER_CONVERTS ["pooling" ]
257+ if self .convert_type not in pooling_converts :
258+ convert_option = "<" + "|" .join (pooling_converts ) + ">"
259+ raise ValueError (
260+ "This model does not support `--runner pooling`. "
261+ f"You can pass `--convert { convert_option } to adapt "
262+ "it into a pooling model."
263+ )
264+
265+ self .supported_tasks = self ._get_supported_tasks (self .architectures , self .runner_type , self .convert_type )
266+ model_info , arch = registry .inspect_model_cls (self .architectures , self )
267+ self ._model_info = model_info
268+ self ._architecture = arch
269+
270+ self .pooler_config = self ._init_pooler_config ()
271+
272+ @property
273+ def registry (self ):
274+ from fastdeploy .model_executor .models .model_base import ModelRegistry
275+
276+ return ModelRegistry ()
174277
175278 def override_name_from_config (self ):
176279 """
@@ -194,7 +297,6 @@ def override_name_from_config(self):
194297 def read_from_env (self ):
195298 """
196299 Read configuration information from environment variables and update the object's attributes.
197-
198300 If an attribute is not present or is an empty string in the environment variables, use the default value.
199301 """
200302 self .max_stop_seqs_num = int (envs .FD_MAX_STOP_SEQS_NUM )
@@ -235,6 +337,165 @@ def read_model_config(self):
235337 f"Config file path: { config_path } "
236338 )
237339
340+ def _get_default_runner_type (
341+ self ,
342+ architectures : list [str ],
343+ ) -> RunnerType :
344+ registry = self .registry
345+ if get_pooling_config (self .model , self .revision ):
346+ return "pooling"
347+ for arch in architectures :
348+ if arch in registry .get_supported_archs ():
349+ if registry .is_pooling_model (architectures , self ):
350+ return "pooling"
351+ if registry .is_text_generation_model (architectures , self ):
352+ return "generate"
353+ match = try_match_architecture_defaults (arch )
354+ if match :
355+ _ , (runner_type , _ ) = match
356+ return runner_type
357+ return "generate"
358+
359+ def _get_default_convert_type (
360+ self ,
361+ architectures : list [str ],
362+ runner_type : RunnerType ,
363+ ) -> ConvertType :
364+ registry = self .registry
365+
366+ for arch in architectures :
367+ if arch in registry .get_supported_archs ():
368+ if runner_type == "generate" and registry .is_text_generation_model (architectures , self ):
369+ return "none"
370+ if runner_type == "pooling" and registry .is_pooling_model (architectures , self ):
371+ return "none"
372+ match = try_match_architecture_defaults (arch , runner_type = runner_type )
373+ if match :
374+ _ , (_ , convert_type ) = match
375+ return convert_type
376+
377+ # This is to handle Sentence Transformers models that use *ForCausalLM
378+ # and also multi-modal pooling models which are not defined as
379+ # Sentence Transformers models
380+ if runner_type == "pooling" :
381+ return "embed"
382+
383+ return "none"
384+
385+ def _get_runner_type (
386+ self ,
387+ architectures : list [str ],
388+ runner : RunnerOption ,
389+ ) -> RunnerType :
390+ if runner != "auto" :
391+ return runner
392+
393+ runner_type = self ._get_default_runner_type (architectures )
394+ if runner_type != "generate" :
395+ logger .info (
396+ "Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message." ,
397+ runner_type ,
398+ )
399+
400+ return runner_type
401+
402+ def _get_convert_type (
403+ self ,
404+ architectures : list [str ],
405+ runner_type : RunnerType ,
406+ convert : ConvertOption ,
407+ ) -> ConvertType :
408+ if convert != "auto" :
409+ return convert
410+
411+ convert_type = self ._get_default_convert_type (architectures , runner_type )
412+
413+ if convert_type != "none" :
414+ logger .info (
415+ "Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message." ,
416+ convert_type ,
417+ )
418+
419+ return convert_type
420+
421+ def _get_supported_generation_tasks (
422+ self ,
423+ architectures : list [str ],
424+ convert_type : ConvertType ,
425+ ) -> list [_ResolvedTask ]:
426+ registry = self .registry
427+
428+ supported_tasks = list [_ResolvedTask ]()
429+ if registry .is_text_generation_model (architectures , self ) or convert_type in _RUNNER_CONVERTS ["generate" ]:
430+ supported_tasks .append ("generate" )
431+
432+ # TODO:Temporarily does not support transcription.
433+ return supported_tasks
434+
435+ def _get_default_pooling_task (
436+ self ,
437+ architectures : list [str ],
438+ ) -> Literal ["embed" ]:
439+ # Temporarily does not support classification and reward.
440+ for arch in architectures :
441+ match = try_match_architecture_defaults (arch , runner_type = "pooling" )
442+ if match :
443+ _ , (_ , convert_type ) = match
444+ assert convert_type != "none"
445+ return convert_type
446+
447+ return "embed"
448+
449+ def _get_supported_pooling_tasks (
450+ self ,
451+ architectures : list [str ],
452+ convert_type : ConvertType ,
453+ ) -> list [_ResolvedTask ]:
454+ registry = self .registry
455+
456+ supported_tasks = list [_ResolvedTask ]()
457+ if registry .is_pooling_model (architectures , self ) or convert_type in _RUNNER_CONVERTS ["pooling" ]:
458+ supported_tasks .append ("encode" )
459+
460+ extra_task = self ._get_default_pooling_task (architectures ) if convert_type == "none" else convert_type
461+ supported_tasks .append (extra_task )
462+
463+ return supported_tasks
464+
465+ def _get_supported_tasks (
466+ self ,
467+ architectures : list [str ],
468+ runner_type : RunnerType ,
469+ convert_type : ConvertType ,
470+ ) -> list [_ResolvedTask ]:
471+ if runner_type == "generate" :
472+ return self ._get_supported_generation_tasks (architectures , convert_type )
473+ if runner_type == "pooling" :
474+ return self ._get_supported_pooling_tasks (architectures , convert_type )
475+
476+ assert_never (runner_type )
477+
478+ def _init_pooler_config (self ) -> Optional ["PoolerConfig" ]:
479+ if self .runner_type == "pooling" :
480+ if isinstance (self .override_pooler_config , dict ):
481+ self .override_pooler_config = PoolerConfig (** self .override_pooler_config )
482+
483+ pooler_config = self .override_pooler_config or PoolerConfig ()
484+
485+ base_config = get_pooling_config (self .model , self .revision )
486+ if base_config is not None :
487+ for k , v in base_config .items ():
488+ if getattr (pooler_config , k ) is None :
489+ setattr (pooler_config , k , v )
490+
491+ default_pooling_type = self ._model_info .default_pooling_type
492+ if pooler_config .pooling_type is None :
493+ pooler_config .pooling_type = default_pooling_type
494+
495+ return pooler_config
496+
497+ return None
498+
238499 def _get_download_model (self , model_name , model_type = "default" ):
239500 # TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
240501 pass
@@ -846,6 +1107,41 @@ def __init__(
8461107 setattr (self , key , value )
8471108
8481109
1110+ class PoolerConfig :
1111+ """Controls the behavior of output pooling in pooling models."""
1112+
1113+ pooling_type : Optional [str ] = None
1114+ """
1115+ The pooling method of the pooling model.
1116+ """
1117+ # for embeddings models
1118+ normalize : Optional [bool ] = None
1119+ """
1120+ Whether to normalize the embeddings outputs. Defaults to True.
1121+ """
1122+ dimensions : Optional [int ] = None
1123+ """
1124+ Reduce the dimensions of embeddings if model
1125+ support matryoshka representation. Defaults to None.
1126+ """
1127+ enable_chunked_processing : Optional [bool ] = None
1128+ """
1129+ Whether to enable chunked processing for long inputs that exceed the model's
1130+ maximum position embeddings. When enabled, long inputs will be split into
1131+ chunks, processed separately, and then aggregated using weighted averaging.
1132+ This allows embedding models to handle arbitrarily long text without CUDA
1133+ errors. Defaults to False.
1134+ """
1135+ max_embed_len : Optional [int ] = None
1136+ """
1137+ Maximum input length allowed for embedding generation. When set, allows
1138+ inputs longer than max_embed_len to be accepted for embedding models.
1139+ When an input exceeds max_embed_len, it will be handled according to
1140+ the original max_model_len validation logic.
1141+ Defaults to None (i.e. set to max_model_len).
1142+ """
1143+
1144+
8491145class LoRAConfig :
8501146 """LoRA Config"""
8511147
0 commit comments