@@ -438,8 +438,11 @@ def _init_inputs(self) -> None:
438438
439439 self .input_ids_cpu = np .zeros (self .max_num_tokens , dtype = np .int32 )
440440 self .positions_cpu = np .zeros (self .max_num_tokens , dtype = np .int32 )
441- self .block_table_cpu = np .zeros (
442- (self .max_num_reqs , self .max_num_blocks_per_req ), dtype = np .int32 )
441+ self .block_tables_cpu = [
442+ np .zeros ((self .max_num_reqs , self .max_num_blocks_per_req ),
443+ dtype = np .int32 )
444+ ]
445+
443446 self .query_start_loc_cpu = np .zeros (self .max_num_reqs + self .dp_size ,
444447 dtype = np .int32 )
445448 self .seq_lens_cpu = np .zeros (self .max_num_reqs , dtype = np .int32 )
@@ -535,6 +538,7 @@ def get_kv_cache_spec(self):
535538
536539 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
537540 self .kv_cache_config = kv_cache_config
541+ self .use_hybrid_kvcache = len (kv_cache_config .kv_cache_groups ) > 1
538542 self .kv_caches = []
539543 self .kv_cache_manager .initialize_kv_cache (kv_cache_config )
540544 if has_kv_transfer_group ():
@@ -701,6 +705,7 @@ def _execute_model(
701705 # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
702706 (
703707 input_ids ,
708+ input_positions ,
704709 attn_metadata ,
705710 _ ,
706711 logits_indices ,
@@ -747,6 +752,7 @@ def _execute_model(
747752 self .kv_caches ,
748753 input_ids ,
749754 attn_metadata ,
755+ input_positions ,
750756 inputs_embeds ,
751757 tuple (self .layer_name_to_kvcache_index .items ()),
752758 lora_metadata ,
@@ -1303,16 +1309,6 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13031309 mrope_positions = self .mrope_positions_cpu [:, :
13041310 padded_total_num_scheduled_tokens ]
13051311
1306- block_tables = self .block_table_cpu [:self .max_num_reqs ]
1307- for dp_rank in range (dp_size ):
1308- req_offset = dp_rank * max_num_reqs_per_dp_rank
1309- _num_reqs = num_req_per_dp_rank [dp_rank ]
1310-
1311- block_tables [
1312- req_offset :req_offset + _num_reqs , :self .
1313- max_num_blocks_per_req ] = self .input_batch .block_table [
1314- 0 ].get_cpu_tensor ()[req_indices_dp [dp_rank ]]
1315-
13161312 query_start_loc = self .query_start_loc_cpu [:self .max_num_reqs +
13171313 dp_size ]
13181314 seq_lens = self .seq_lens_cpu [:self .max_num_reqs ]
@@ -1354,20 +1350,55 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13541350 if self .uses_mrope :
13551351 positions = mrope_positions
13561352
1357- # Convert block_tables to 1D on cpu.
1358- block_tables = block_tables .reshape (- 1 )
1359-
13601353 query_start_loc_cpu = query_start_loc
13611354 logits_indices_cpu = logits_indices
13621355 seq_lens_cpu = seq_lens
13631356
1364- (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1365- logits_indices , request_distribution ) = device_array (
1357+ (input_ids , positions , query_start_loc , seq_lens , logits_indices ,
1358+ request_distribution ) = device_array (
13661359 self .mesh ,
1367- (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1368- logits_indices , request_distribution ),
1360+ (input_ids , positions , query_start_loc , seq_lens , logits_indices ,
1361+ request_distribution ),
13691362 sharding = data_parallel_attn_sharding ,
13701363 )
1364+
1365+ attention_metadata_per_layer : Dict [str , AttentionMetadata ] = {}
1366+ uniform_attention_metadata : AttentionMetadata = None
1367+ for kv_cache_gid , kv_cache_group in enumerate (
1368+ self .kv_cache_config .kv_cache_groups ):
1369+ block_tables = self .block_tables_cpu [kv_cache_gid ][:self .
1370+ max_num_reqs ]
1371+ for dp_rank in range (dp_size ):
1372+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1373+ _num_reqs = num_req_per_dp_rank [dp_rank ]
1374+
1375+ block_tables [
1376+ req_offset :req_offset + _num_reqs , :self .
1377+ max_num_blocks_per_req ] = self .input_batch .block_table [
1378+ 0 ].get_cpu_tensor ()[req_indices_dp [dp_rank ]]
1379+ # Convert block_tables to 1D on cpu.
1380+ block_tables = block_tables .reshape (- 1 )
1381+ block_tables = device_array (self .mesh , (block_tables ))
1382+
1383+ attention_metadata_gid = AttentionMetadata (
1384+ input_positions = positions ,
1385+ block_tables = block_tables ,
1386+ seq_lens = seq_lens ,
1387+ query_start_loc = query_start_loc ,
1388+ request_distribution = request_distribution ,
1389+ )
1390+
1391+ # This is for making these cpu buffers hidden during tracing
1392+ attention_metadata_gid .query_start_loc_cpu = query_start_loc_cpu
1393+ attention_metadata_gid .seq_lens_cpu = seq_lens_cpu
1394+
1395+ if not self .use_hybrid_kvcache :
1396+ uniform_attention_metadata = attention_metadata_gid
1397+ else :
1398+ for layer_name in kv_cache_group .layer_names :
1399+ attention_metadata_per_layer [
1400+ layer_name ] = attention_metadata_gid
1401+
13711402 # Async scheduling: substitute placeholder tokens for DP
13721403 if self .scheduler_config .async_scheduling and self ._pre_async_results is not None :
13731404 # Collect all token indices that need substitution across all DP ranks
@@ -1396,20 +1427,13 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13961427 padded_total_num_scheduled_tokens ,
13971428 )
13981429
1399- attention_metadata = AttentionMetadata (
1400- input_positions = positions ,
1401- block_tables = block_tables ,
1402- seq_lens = seq_lens ,
1403- query_start_loc = query_start_loc ,
1404- request_distribution = request_distribution ,
1405- )
1406-
1407- # This is for making these cpu buffers hidden during tracing
1408- attention_metadata .query_start_loc_cpu = query_start_loc_cpu
1409- attention_metadata .seq_lens_cpu = seq_lens_cpu
1410-
1430+ if self .use_hybrid_kvcache :
1431+ attention_metadata = attention_metadata_per_layer
1432+ else :
1433+ attention_metadata = uniform_attention_metadata
14111434 return (
14121435 input_ids ,
1436+ positions ,
14131437 attention_metadata ,
14141438 sampling_metadata ,
14151439 logits_indices ,
@@ -1516,9 +1540,6 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15161540 positions = self .positions_cpu [:padded_total_num_scheduled_tokens ]
15171541 mrope_positions = self .mrope_positions_cpu [:, :
15181542 padded_total_num_scheduled_tokens ]
1519- block_tables = self .block_table_cpu [:self .max_num_reqs ]
1520- block_tables [:num_reqs , :self .max_num_blocks_per_req ] = (
1521- self .input_batch .block_table [0 ].get_cpu_tensor ()[:num_reqs ])
15221543
15231544 # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
15241545 query_start_loc = self .query_start_loc_cpu [:self .max_num_reqs + 1 ]
@@ -1547,16 +1568,44 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15471568 self .mesh , self .input_batch , padded_num_reqs )
15481569 if self .uses_mrope :
15491570 positions = mrope_positions
1550-
1551- # Convert block_tables to 1D on cpu.
1552- block_tables = block_tables .reshape (- 1 )
1553-
15541571 query_start_loc_cpu = query_start_loc
15551572 seq_lens_cpu = seq_lens
1556- (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1573+
1574+ (input_ids , positions , query_start_loc , seq_lens ,
15571575 logits_indices , request_distribution ) = device_array (
1558- self .mesh , (input_ids , positions , block_tables , query_start_loc ,
1559- seq_lens , logits_indices , request_distribution ))
1576+ self .mesh , (input_ids , positions , query_start_loc , seq_lens ,
1577+ logits_indices , request_distribution ))
1578+
1579+ attention_metadata_per_layer : Dict [str , AttentionMetadata ] = {}
1580+ uniform_attention_metadata : AttentionMetadata = None
1581+ for kv_cache_gid , kv_cache_group in enumerate (
1582+ self .kv_cache_config .kv_cache_groups ):
1583+ block_tables = self .block_tables_cpu [kv_cache_gid ][:self .
1584+ max_num_reqs ]
1585+ block_tables [:num_reqs ] = (
1586+ self .input_batch .block_table [kv_cache_gid ].get_cpu_tensor ()
1587+ [:num_reqs ])
1588+ # Convert block_tables to 1D on cpu.
1589+ block_tables = block_tables .reshape (- 1 )
1590+ block_tables = device_array (self .mesh , (block_tables ))
1591+
1592+ attention_metadata_gid = AttentionMetadata (
1593+ input_positions = positions ,
1594+ block_tables = block_tables ,
1595+ seq_lens = seq_lens ,
1596+ query_start_loc = query_start_loc ,
1597+ request_distribution = request_distribution )
1598+ # This is for making these cpu buffers hidden during tracing
1599+ attention_metadata_gid .query_start_loc_cpu = query_start_loc_cpu
1600+ attention_metadata_gid .seq_lens_cpu = seq_lens_cpu
1601+
1602+ if not self .use_hybrid_kvcache :
1603+ # all layers share the same attention metadata
1604+ uniform_attention_metadata = attention_metadata_gid
1605+ else :
1606+ for layer_name in kv_cache_group .layer_names :
1607+ attention_metadata_per_layer [
1608+ layer_name ] = attention_metadata_gid
15601609
15611610 if self .scheduler_config .async_scheduling and len (
15621611 token_in_tpu_cur_input_indices ) > 0 :
@@ -1569,19 +1618,13 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15691618 self .lora_utils .set_active_loras (
15701619 num_scheduled_tokens_per_req , total_num_scheduled_tokens ,
15711620 padded_total_num_scheduled_tokens )
1572-
1573- attention_metadata = AttentionMetadata (
1574- input_positions = positions ,
1575- block_tables = block_tables ,
1576- seq_lens = seq_lens ,
1577- query_start_loc = query_start_loc ,
1578- request_distribution = request_distribution )
1579-
1580- # This is for making these cpu buffers hidden during tracing
1581- attention_metadata .query_start_loc_cpu = query_start_loc_cpu
1582- attention_metadata .seq_lens_cpu = seq_lens_cpu
15831621 logits_indices_selector = None
1584- return (input_ids , attention_metadata , sampling_metadata ,
1622+
1623+ if self .use_hybrid_kvcache :
1624+ attention_metadata = attention_metadata_per_layer
1625+ else :
1626+ attention_metadata = uniform_attention_metadata
1627+ return (input_ids , positions , attention_metadata , sampling_metadata ,
15851628 logits_indices , spec_decode_metadata , logits_indices_selector ,
15861629 padded_num_reqs )
15871630
0 commit comments