@@ -64,16 +64,6 @@ def linear_requires_sync(linear_type: LinearType):
6464 return linear_type in REQUIRES_SYNC
6565
6666
67- def _update_history_with_new_amax (new_amax , amax_history ):
68- """
69- Updates `amax_history` (the last N cur_amax values) inplace with the value
70- of `new_amax`.
71- """
72- new_amax_history = torch .roll (amax_history , 1 )
73- new_amax_history [0 ] = new_amax
74- amax_history .copy_ (new_amax_history )
75-
76-
7767def _update_history_stack (
7868 new_amax : torch .Tensor , amax_history_stack : torch .Tensor
7969) -> torch .Tensor :
@@ -85,10 +75,12 @@ def _update_history_stack(
8575 new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1)
8676 amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length)
8777 """
88- assert amax_history_stack .dim () == 2 , "amax_history_stack must be 2D"
78+ assert (
79+ amax_history_stack .dim () == 2
80+ ), f"Expected amat_history_stack to be 2D, got { amax_history_stack .shape ()} "
8981 assert new_amax .size (0 ) == amax_history_stack .size (
9082 0
91- ), " new_amax must have the same size as the second dimension of amax_history_stack"
83+ ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got { new_amax . size ( 0 ) } and { amax_history_stack . size ( 0 ) } "
9284 new_amax_history_stack = torch .roll (amax_history_stack , 1 , dims = 1 )
9385 new_amax_history_stack [:, 0 ] = new_amax .squeeze (- 1 )
9486 amax_history_stack .copy_ (new_amax_history_stack )
@@ -155,9 +147,7 @@ def get_float8_layers(model: torch.nn.Module):
155147 """
156148
157149 # Get all fp8 layers and tensors
158- fp8_layers = [
159- child for _ , child in model .named_modules () if isinstance (child , Float8Linear )
160- ]
150+ fp8_layers = [child for child in model .modules () if isinstance (child , Float8Linear )]
161151
162152 return fp8_layers
163153
@@ -176,7 +166,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
176166 TODO(future): design the UX for this (context manager, etc)
177167
178168 PERFORMANCE NOTE:
179- When you can it is much more efficient to call te get_float8_layers once a
169+ When you can, it is much more efficient to call get_float8_layers once at
180170 the beginning of the training loop and pass the result to this function.
181171 Because of how this interacts with torch.compile
182172
@@ -249,13 +239,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
249239 reduced_fp8_amax_dL_dY_tensor ,
250240 ) = torch .split (all_reduced_amax_tensor , len (fp8_amax_x_tensor_list ))
251241
252- # TODO foreach is not supported with AsyncCollectiveTensor
253242 for idx , child in enumerate (fp8_layers ):
254243 child .fp8_amax_x .copy_ (reduced_fp8_amax_tensor [idx ])
255244 child .fp8_amax_w .copy_ (reduced_fp8_amax_w_tensor [idx ])
256245 child .fp8_amax_dL_dY .copy_ (reduced_fp8_amax_dL_dY_tensor [idx ])
257246
258- # We create two stacked tensors , one for the amax history and one for the current scales
247+ # We create two stacked tensor groups , one for the amax history and one for the current scales
259248 fp8_amax_x_tensors = torch .vstack (fp8_amax_x_tensor_list )
260249 fp8_amax_w_tensors = torch .vstack (fp8_amax_w_tensor_list )
261250 fp8_amax_dL_dY_tensors = torch .vstack (fp8_amax_dL_dY_tensor_list )
@@ -264,11 +253,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
264253 fp8_w_amax_history_stack = torch .vstack (fp8_w_amax_history_stack )
265254 fp8_dL_dY_amax_history_stack = torch .vstack (fp8_dL_dY_amax_history_stack )
266255
256+ # Update the history stacks with the new amax values
267257 _update_history_stack (fp8_amax_x_tensors , fp8_x_amax_history_stack )
268258 _update_history_stack (fp8_amax_w_tensors , fp8_w_amax_history_stack )
269259 _update_history_stack (fp8_amax_dL_dY_tensors , fp8_dL_dY_amax_history_stack )
270260
271- # We are not reading the
261+ # Calculate the new scales from the updated history stacks
272262 new_x_scales = amax_history_to_scale_stack (
273263 fp8_x_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
274264 )
0 commit comments