Skip to content

Commit f88d7a9

Browse files
authored
Fix lora test by removing LoRA extra vocab (#1156)
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent 37afd29 commit f88d7a9

File tree

4 files changed

+1
-19
lines changed

4 files changed

+1
-19
lines changed

tests/lora/test_layers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def populate_loras(
9191
index_to_id: list[Optional[int]],
9292
lora_layer: BaseLayerWithLoRA,
9393
baselayer_weights: torch.Tensor,
94-
generate_embeddings_tensor: int = 0,
9594
repeats: int = 1,
9695
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
9796
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -103,8 +102,6 @@ def populate_loras(
103102
lora_layer: the LoRAlayer to populate.
104103
baselayer_weights: the PyTorch tensor containing the layer's
105104
weights.
106-
generate_embeddings_tensor: whether to generate an
107-
embeddings tensor for each LoRA.
108105
repeats: must only be set for column parallel packed
109106
layers. Indicates the number of loras to compose
110107
together to create a single lora layer.
@@ -131,7 +128,6 @@ def populate_loras(
131128
baselayer_weights.device).init_random_lora(
132129
module_name=f"fake_{i}",
133130
weight=baselayer_weights,
134-
generate_embeddings_tensor=generate_embeddings_tensor,
135131
)
136132
sublora.lora_b = sublora.lora_b[(sublora_len *
137133
i):(sublora_len * (i + 1)), :]
@@ -147,7 +143,6 @@ def populate_loras(
147143
slot_idx,
148144
lora_a=lora.lora_a,
149145
lora_b=lora.lora_b,
150-
embeddings_tensor=lora.embeddings_tensor,
151146
)
152147

153148
lora_dict[lora_id] = lora
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
546541
index_to_id,
547542
lora_config.max_loras,
548543
vocab_size=512,
549-
extra_vocab_size=lora_config.lora_extra_vocab_size,
550544
)
551545
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
552546
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'

tests/lora/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def init_random_lora(
2424
module_name: str,
2525
weight: torch.Tensor,
2626
rank: int = 8,
27-
generate_embeddings_tensor: int = 0,
2827
):
2928
lora = LoRALayerWeights(
3029
module_name,
@@ -37,13 +36,6 @@ def init_random_lora(
3736
dtype=weight.dtype,
3837
device=self._device),
3938
)
40-
if generate_embeddings_tensor:
41-
lora.embeddings_tensor = torch.rand(
42-
5,
43-
generate_embeddings_tensor,
44-
dtype=weight.dtype,
45-
device=self._device,
46-
)
4739
self.set_module_lora(module_name, lora)
4840

4941
return lora

tpu_inference/lora/torch_punica_tpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def _update_base_metadata(
239239
lora_index_to_id: list[Optional[int]],
240240
max_loras: int,
241241
vocab_size: int,
242-
extra_vocab_size: int,
243242
):
244243
# Pad the prompt mapping to avoid running into recompiles on the TPU
245244
# TODO: Should this happen inside mapping internally? If so how can we
@@ -258,7 +257,7 @@ def _update_base_metadata(
258257
lora_index_to_id,
259258
max_loras,
260259
vocab_size,
261-
extra_vocab_size,
260+
0, # extra_vocab_size
262261
"cpu",
263262
)
264263
with torchax.default_env():

tpu_inference/runner/tpu_runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,6 @@ def _init_inputs(self) -> None:
472472

473473
# tensors for structured decoding
474474
self.vocab_size = self.model_config.get_vocab_size()
475-
if self.lora_config is not None:
476-
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
477-
self.vocab_size += self.lora_config.lora_extra_vocab_size
478475
self.grammar_bitmask_cpu = np.zeros(
479476
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
480477
dtype=np.int32,

0 commit comments

Comments
 (0)