@@ -99,6 +99,7 @@ def __init__(
99
99
adapters_dir_path : str ,
100
100
hbm_memory_budget : int ,
101
101
cpu_memory_budget : int ,
102
+ total_slots : int ,
102
103
):
103
104
"""Initializes the AdapterTensorStore."""
104
105
self .engine = engine # Possibly MaxEngine object
@@ -119,8 +120,27 @@ def __init__(
119
120
self .running_requests : int = (
120
121
0 # Number of async tasks which are in "loading" state
121
122
)
123
+ self .decoding_adapters_cache : Dict [str , Any ] = {}
124
+
125
+ # TODO: Make dtype configurable for the scale factor array
126
+ self .adapters_scale_factor = jnp .empty (1 , dtype = jnp .bfloat16 )
127
+
128
+ self .total_slots = total_slots
122
129
self .lock = asyncio .Lock () # Use an asyncio Lock for thread safety
123
130
131
+ def _get_adapter_scale_factor (self , adapter_id : str ):
132
+ """
133
+ Internal: Get the LoRA scale_factor using the adapter_id.
134
+ """
135
+ adapter_config = self .adapter_registry [adapter_id ].config
136
+ lora_scale_factor = float (1 )
137
+
138
+ if "r" in adapter_config and "lora_alpha" in adapter_config :
139
+ lora_rank = int (adapter_config ["r" ])
140
+ lora_scale_factor = float (adapter_config ["lora_alpha" ]) / lora_rank
141
+
142
+ return lora_scale_factor
143
+
124
144
# --- Unsafe Internal methods which assumes that lock is held ---
125
145
def _unsafe_transfer_to_hbm (self , adapter_id : str ):
126
146
"""
@@ -207,6 +227,90 @@ def _unsafe_unload_adapter(self, adapter_id: str):
207
227
metadata .size_hbm = 0
208
228
metadata .size_cpu = 0
209
229
230
+ def _initialize_decoding_adapters_cache (self , adapter_weights ):
231
+ """
232
+ Create a new PyTree with zero tensors at the paths corresponding to
233
+ non-None leaves in the input PyTree. The zero tensors have an added
234
+ dimension of size `self.totol_slots`.
235
+ Args:
236
+ adatper_weights: The input PyTree, whose structure will be mirrored.
237
+ Returns:
238
+ A new PyTree with zero Tensors or None values, mirroring the structure
239
+ of the input PyTree.
240
+ """
241
+
242
+ def create_zero_leaf (leaf ):
243
+ if leaf is not None :
244
+ original_shape = leaf .shape
245
+ if not original_shape : # handle scalar case
246
+ zero_tensor_shape = (self .total_slots ,)
247
+ else :
248
+ zero_tensor_shape = (
249
+ self .total_slots ,
250
+ ) + original_shape # Prepend a new dimension
251
+
252
+ return jnp .zeros (zero_tensor_shape , dtype = leaf .dtype )
253
+ else :
254
+ return None # Maintain None structure for None leaves
255
+
256
+ self .adapters_scale_factor = jnp .ones (self .total_slots , dtype = jnp .bfloat16 )
257
+ return jax .tree_util .tree_map (create_zero_leaf , adapter_weights )
258
+
259
+ def insert_adapter_in_cache (self , adapter_id : str , slot_id : int ):
260
+ """
261
+ Insert the specific adapter tensors into a slot in the
262
+ serving_adapters_cache.
263
+ Args:
264
+ adapter_id: The id of the adapter, whose tensors will be inserted
265
+ slot_id: The id of slot, which represents the index in the
266
+ serving_adapter_cache where the adapter tensors will be inserted.
267
+ """
268
+
269
+ def insert_leaf (dest_leaf , source_leaf ):
270
+ if dest_leaf is not None and source_leaf is not None :
271
+ return dest_leaf .at [slot_id ].set (
272
+ source_leaf
273
+ ) # Insert at the specific index
274
+ elif dest_leaf is not None :
275
+ return dest_leaf # If source_leaf is None, keep the zero_leaf as is
276
+ elif (
277
+ source_leaf is not None
278
+ ): # In this case the adapters have different target modules
279
+ original_shape = source_leaf .shape
280
+ if not original_shape : # Handle scalar case
281
+ zero_tensor_shape = (self .total_slots ,)
282
+ else :
283
+ zero_tensor_shape = (self .total_slots ,) + original_shape
284
+ new_dest_leaf = jnp .zeros (zero_tensor_shape , dtype = source_leaf .dtype )
285
+ return new_dest_leaf .at [slot_id ].set (source_leaf )
286
+ else :
287
+ return None # If both are None, return None
288
+
289
+ if adapter_id == "" :
290
+ logging .info (
291
+ "Empty adapter id. No LoRA tensors added to adapter_tensorstore cache"
292
+ )
293
+ return
294
+
295
+ asyncio .run (self .load_adapter (adapter_id , None , True ))
296
+
297
+ adapter_weights = self .loaded_adapters_hbm [adapter_id ]
298
+
299
+ if not self .decoding_adapters_cache :
300
+ self .decoding_adapters_cache = self ._initialize_decoding_adapters_cache (
301
+ adapter_weights
302
+ )
303
+
304
+ adapter_scale_factor = jnp .bfloat16 (
305
+ self ._get_adapter_scale_factor (adapter_id )
306
+ )
307
+ self .adapters_scale_factor = self .adapters_scale_factor .at [slot_id ].set (
308
+ adapter_scale_factor
309
+ )
310
+ self .decoding_adapters_cache = jax .tree_util .tree_map (
311
+ insert_leaf , self .decoding_adapters_cache , adapter_weights
312
+ )
313
+
210
314
# --- Public Methods (Acquire lock, then call unsafe methods) ---
211
315
212
316
async def register_adapter (
0 commit comments