Skip to content

Commit 25c9633

Browse files
committed
type annot
Signed-off-by: Linkun Chen <[email protected]>
1 parent 71baf85 commit 25c9633

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from vllm.v1.request import Request
3636

3737
Transfer = tuple[int, float] # (xfer_handle, start_time)
38+
EngineId = str
39+
ReqId = str
3840
GET_META_MSG = b"get_meta_msg"
3941

4042
logger = init_logger(__name__)
@@ -74,7 +76,7 @@ class ReqMeta:
7476
class NixlConnectorMetadata(KVConnectorMetadata):
7577

7678
def __init__(self):
77-
self.requests: dict[str, ReqMeta] = {}
79+
self.requests: dict[ReqId, ReqMeta] = {}
7880

7981
def add_new_req(
8082
self,
@@ -95,16 +97,17 @@ class NixlConnector(KVConnectorBase_V1):
9597

9698
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
9799
assert vllm_config.kv_transfer_config is not None
98-
self.engine_id = vllm_config.kv_transfer_config.engine_id
100+
assert vllm_config.kv_transfer_config.engine_id is not None
101+
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
99102

100103
if role == KVConnectorRole.SCHEDULER:
101104
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
102-
NixlConnectorScheduler(vllm_config, str(self.engine_id))
105+
NixlConnectorScheduler(vllm_config, self.engine_id)
103106
self.connector_worker: Optional[NixlConnectorWorker] = None
104107
elif role == KVConnectorRole.WORKER:
105108
self.connector_scheduler = None
106109
self.connector_worker = NixlConnectorWorker(
107-
vllm_config, str(self.engine_id))
110+
vllm_config, self.engine_id)
108111

109112
############################################################
110113
# Scheduler Side Methods
@@ -178,7 +181,7 @@ class NixlConnectorScheduler:
178181
def __init__(self, vllm_config: VllmConfig, engine_id: str):
179182
self.vllm_config = vllm_config
180183
self.block_size = vllm_config.cache_config.block_size
181-
self.engine_id = engine_id
184+
self.engine_id: EngineId = engine_id
182185
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
183186
self.side_channel_port = (
184187
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
@@ -189,7 +192,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
189192
# Requests that need to start recv.
190193
# New requests are added by update_state_after_alloc in
191194
# the scheduler. Used to make metadata passed to Worker.
192-
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
195+
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
193196

194197
def get_num_new_matched_tokens(
195198
self, request: "Request",
@@ -331,19 +334,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
331334
# Agent.
332335
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
333336
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
334-
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
337+
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
335338

336339
# NIXL handshake port.
337340
# NOTE(rob): Within a DP group, each DP rank gets its own
338341
# base port (which is sent in the KVTransferParams).
339342
# Each TP rank listens/queries on the base_port + tp_rank.
340-
self.side_channel_port = (
343+
self.side_channel_port: int = (
341344
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
342345
vllm_config.parallel_config.data_parallel_rank_local *
343346
vllm_config.parallel_config.tensor_parallel_size)
344347

345348
# Metadata.
346-
self.engine_id = engine_id
349+
self.engine_id: EngineId = engine_id
347350
self.tp_rank = get_tensor_model_parallel_rank()
348351
self.world_size = get_tensor_model_parallel_world_size()
349352
self.tp_group = get_tp_group()
@@ -353,7 +356,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
353356

354357
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
355358
# rank will still only pull from a single remote TP worker.
356-
self.kv_caches_base_addr: dict[str, list[int]] = {}
359+
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
357360

358361
# Number of NIXL regions. Currently one region per cache
359362
# (so 1 per layer for MLA, otherwise 2 per layer)
@@ -363,23 +366,23 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
363366
# nixl_prepped_dlist_handle.
364367
self.src_xfer_side_handle: int = 0
365368
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
366-
self.dst_xfer_side_handles: dict[str, int] = {}
369+
self.dst_xfer_side_handles: dict[EngineId, int] = {}
367370

368371
# Map of engine_id -> num_blocks. All ranks in the same deployment will
369372
# have the same number of blocks.
370-
self.dst_num_blocks: dict[str, int] = {}
373+
self.dst_num_blocks: dict[EngineId, int] = {}
371374
self._registered_descs: list[Any] = []
372375

373376
# In progress transfers.
374377
# [req_id -> list[handle]]
375-
self._recving_transfers = defaultdict[str, list[Transfer]](list)
378+
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
376379

377380
# Complete transfer tracker. Used by the rank 0 to track finished
378381
# transactions on ranks 1 to N-1.
379382
# [req_id -> count]
380-
self._done_recving_count: defaultdict[str,
383+
self._done_recving_count: defaultdict[ReqId,
381384
int] = defaultdict(lambda: 0)
382-
self._done_sending_count: defaultdict[str,
385+
self._done_sending_count: defaultdict[ReqId,
383386
int] = defaultdict(lambda: 0)
384387

385388
# Background thread for establishing new connections.
@@ -407,10 +410,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
407410
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
408411
logger.debug("Detected attention backend %s", self.backend_name)
409412

410-
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
413+
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
411414
# With heterogeneous TP, P must wait for all assigned D TP workers to
412415
# finish reading before safely freeing the blocks.
413-
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
416+
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
414417

415418
@staticmethod
416419
def _nixl_handshake_listener(metadata: NixlAgentMetadata,

0 commit comments

Comments
 (0)