From 32fbe21adb8fee54c52ab4068820b22b48fd6cd4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 12 Jan 2024 18:09:45 -0800 Subject: [PATCH 01/12] existing eager tests pass --- float8_experimental/float8_dynamic_linear.py | 11 +- float8_experimental/float8_ops.py | 136 ++++++++++++++++++- 2 files changed, 143 insertions(+), 4 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 58e352da..dd6067ff 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -8,6 +8,7 @@ """ import torch +from float8_experimental.float8_ops import float8_linear from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated @@ -48,12 +49,16 @@ class Float8DynamicLinear(torch.nn.Linear): def forward(self, x): x_fp8 = self.cast_to_float8(x) - w_fp8 = self.cast_to_float8(self.weight) - - y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) + # w_fp8 = self.cast_to_float8(self.weight) + # y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) + weight_scale = tensor_to_scale(self.weight, torch.float8_e4m3fn) + y = float8_linear.apply( + x_fp8, self.weight, weight_scale, None, self.emulate, False + ) # Cast gradY to float8_e5m2 during backward y = self.cast_to_float8e5m2_bw(y) + y = y + self.bias if self.bias is not None else y return y diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 392358f2..29b181e6 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict +from typing import Any, Dict, Optional import torch @@ -75,11 +75,33 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): return a_data, a_scale, b_data, b_scale +def float8_mm_helper(a: Float8Tensor, b: Float8Tensor) -> torch.Tensor: + """This is a helper function for float8_mm + Args: + a: The first matrix multiplication term. + b: The second matrix multiplication term. + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + output_dtype = a._orig_dtype + if a._emulate: + assert a._emulate == b._emulate + return torch.ops.aten.mm_float8_emulated( + a._data, a._scale, b._data, b._scale, output_dtype + )[0] + tensor_out, amax = addmm_float8_unwrapped( + a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None + ) + return tensor_out + + @implements([aten.mm.default]) def float8_mm(aten_op, args, kwargs=None): assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) a = args[0] b = args[1] + return float8_mm_helper(a, b) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype if a._emulate: @@ -140,3 +162,115 @@ def autocast_to_copy(aten_op, args, kwargs=None): return Float8Tensor( args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate ) + + +class float8_linear(torch.autograd.Function): + """Custom autograd function for computing torch.nn.Linear on Float8Tensor. + + This is needed for a couple reasons, we want to have fine grained control over the + recomputation of casted values for backward. + """ + + @staticmethod + def forward( + ctx, + x_fp8: torch.Tensor, + original_weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_amax_buffer: Optional[torch.Tensor], + emulate: bool, + recompute_float8_weight: bool, + ): + ctx.save_for_backward(x_fp8) + w_fp8 = Float8Tensor.to_float8( + original_weight, + weight_scale, + torch.float8_e4m3fn, + weight_amax_buffer, + emulate=emulate, + ) + if recompute_float8_weight: + # This should be set to True when using traditional fsdp to avoid saving + # saving the unsharded weight for + ctx.save_for_backward( + x_fp8, original_weight, weight_scale, weight_amax_buffer + ) + else: + # Does this interact properly with activation checkpointing? + ctx.save_for_backward(x_fp8, w_fp8) + + ctx.recompute_float8_weight = recompute_float8_weight + ctx.emulate = emulate + orig_shape = x_fp8._data.shape + x_fp8_reshaped = Float8Tensor( + x_fp8._data.reshape(-1, orig_shape[-1]), + x_fp8._scale, + x_fp8._orig_dtype, + emulate=emulate, + ) + + w_fp8_t = Float8Tensor( + w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, emulate=emulate + ) + + res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t) + + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, go_fp8: torch.Tensor): + if ctx.recompute_float8_weight: + x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors + w_fp8 = Float8Tensor.to_float8( + original_weight, + weight_scale, + torch.float8_e4m3fn, + weight_amax_buffer, + emulate=emulate, + ) + else: + x_fp8, w_fp8 = ctx.saved_tensors + + emulate = ctx.emulate + + go_fp8_orig_shape = go_fp8._data.shape + go_fp8_reshaped = Float8Tensor( + go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), + go_fp8._scale, + go_fp8._orig_dtype, + emulate=emulate, + ) + + w_fp8_t_c_t = Float8Tensor( + w_fp8._data.t().contiguous().t(), + w_fp8._scale, + w_fp8._orig_dtype, + emulate=emulate, + ) + + # calculate dL/dX + dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) + dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) + + x_fp8_orig_shape = x_fp8._data.shape + x_fp8_reshaped_t_c = Float8Tensor( + x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), + x_fp8._scale, + x_fp8._orig_dtype, + emulate=emulate, + ) + + go_fp8_reshaped_t_c_t = Float8Tensor( + go_fp8_reshaped._data.t().contiguous().t(), + go_fp8_reshaped._scale, + go_fp8_reshaped._orig_dtype, + emulate=emulate, + ) + + # calculate dL/dW + dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t) + dL_dW = dL_dW.t() + + empty_grads = None, None, None, None, None, None, None, None, None + return dL_dX, dL_dW, *empty_grads From 67dca363dce9a562f06cc4ee202d6134f25e0af4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 12 Jan 2024 18:50:16 -0800 Subject: [PATCH 02/12] gotta figure out a way to call shape that is compile friendly --- float8_experimental/float8_ops.py | 55 ++++++------------------------- 1 file changed, 10 insertions(+), 45 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 29b181e6..71acf613 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -201,21 +201,13 @@ def forward( ctx.recompute_float8_weight = recompute_float8_weight ctx.emulate = emulate - orig_shape = x_fp8._data.shape - x_fp8_reshaped = Float8Tensor( - x_fp8._data.reshape(-1, orig_shape[-1]), - x_fp8._scale, - x_fp8._orig_dtype, - emulate=emulate, - ) + x_fp8_reshaped = x_fp8.reshape(-1, x_fp8.size(-1)) - w_fp8_t = Float8Tensor( - w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, emulate=emulate - ) + w_fp8_t = w_fp8.t() res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t) - res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + res_bits = res_bits.reshape(*x_fp8.shape[:-1], res_bits.size(-1)) return res_bits @staticmethod @@ -227,48 +219,21 @@ def backward(ctx, go_fp8: torch.Tensor): weight_scale, torch.float8_e4m3fn, weight_amax_buffer, - emulate=emulate, + emulate=ctx.emulate, ) else: x_fp8, w_fp8 = ctx.saved_tensors - emulate = ctx.emulate - - go_fp8_orig_shape = go_fp8._data.shape - go_fp8_reshaped = Float8Tensor( - go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), - go_fp8._scale, - go_fp8._orig_dtype, - emulate=emulate, - ) - - w_fp8_t_c_t = Float8Tensor( - w_fp8._data.t().contiguous().t(), - w_fp8._scale, - w_fp8._orig_dtype, - emulate=emulate, - ) - # calculate dL/dX + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) + w_fp8_t_c_t = w_fp8.t().contiguous().t() dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) - dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) - - x_fp8_orig_shape = x_fp8._data.shape - x_fp8_reshaped_t_c = Float8Tensor( - x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), - x_fp8._scale, - x_fp8._orig_dtype, - emulate=emulate, - ) - - go_fp8_reshaped_t_c_t = Float8Tensor( - go_fp8_reshaped._data.t().contiguous().t(), - go_fp8_reshaped._scale, - go_fp8_reshaped._orig_dtype, - emulate=emulate, - ) + dL_dX = dL_dX.reshape(*go_fp8.shape[:-1], dL_dX.size(-1)) # calculate dL/dW + x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8.size(-1)).t().contiguous() + go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() + dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t) dL_dW = dL_dW.t() From 27a6a7f0a7e04c00e8cd29eb3ae1637b7fe357be Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 17 Jan 2024 10:36:21 -0800 Subject: [PATCH 03/12] fix multiple calls to save_for_backward --- float8_experimental/float8_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 71acf613..467b33ec 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -181,7 +181,6 @@ def forward( emulate: bool, recompute_float8_weight: bool, ): - ctx.save_for_backward(x_fp8) w_fp8 = Float8Tensor.to_float8( original_weight, weight_scale, @@ -191,7 +190,7 @@ def forward( ) if recompute_float8_weight: # This should be set to True when using traditional fsdp to avoid saving - # saving the unsharded weight for + # saving the unsharded weight for backwards ctx.save_for_backward( x_fp8, original_weight, weight_scale, weight_amax_buffer ) From 682c2e8961625fcb384d94d26dca346f3fecdb62 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 17 Jan 2024 11:48:13 -0800 Subject: [PATCH 04/12] add to delayed linear --- float8_experimental/float8_dynamic_linear.py | 16 +++++++-- float8_experimental/float8_linear.py | 35 ++++++++++++++++++-- float8_experimental/float8_ops.py | 26 ++++++++++++++- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index dd6067ff..eb23bcc1 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -53,8 +53,13 @@ def forward(self, x): # y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) weight_scale = tensor_to_scale(self.weight, torch.float8_e4m3fn) - y = float8_linear.apply( - x_fp8, self.weight, weight_scale, None, self.emulate, False + y = float8_linear( + x_fp8, + self.weight, + weight_scale, + None, + self.emulate, + self.recompute_weight_cast, ) # Cast gradY to float8_e5m2 during backward y = self.cast_to_float8e5m2_bw(y) @@ -72,17 +77,22 @@ def cast_to_float8e5m2_bw(self, gradY): return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) @classmethod - def from_float(cls, mod, emulate: bool = False): + def from_float( + cls, mod, emulate: bool = False, recompute_weight_cast: bool = False + ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + recompute_weight_cast (bool): whether to recompute the weight cast on every + backwards pass """ with torch.device("meta"): new_mod = cls(mod.in_features, mod.out_features, bias=False) new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + new_mod.recompute_weight_cast = recompute_weight_cast return new_mod diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 59285852..5171086e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -19,6 +19,7 @@ import float8_experimental.config as config import torch +from float8_experimental.float8_ops import float8_linear from float8_experimental.float8_tensor import Float8Tensor @@ -172,6 +173,13 @@ def __init__(self, *args, **kwargs): # and torch.compile, this option can disable them self.enable_pre_and_post_forward = config.enable_pre_and_post_forward + # This flag is used to modify what gets saved for backwards. Its default value + # is False, this saves the casted weight for backwards. Note that this typically increases memory usage + # Because both the weight parameter and the casted weight are saved on device. If set to true + # this will only save the weight parameter and during the backwards pass it will re-cast this weight to fp8. + # For traditional FSDP this should be set to True in order to not save the un-sharded weight for backwards. + self.recompute_weight_cast = False + def register_always_float32_buffer( self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True ) -> None: @@ -214,6 +222,20 @@ def cast_x_to_float8( ) return x_fp8 + def _maybe_init_amaxes_scales_weight( + self, w: torch.Tensor, is_amax_initialized: bool + ): + scale_fn_name = self.recipe.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + w, + self.fp8_amax_w, + self.fp8_amax_history_w, + self.fp8_scale_w, + scale_fn_name, + torch.float8_e4m3fn, + is_amax_initialized, + ) + def cast_w_to_float8( self, w: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: @@ -284,9 +306,18 @@ def forward(self, x): self.float8_pre_forward(x) x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + # w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + self._maybe_init_amaxes_scales_weight(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + y = float8_linear( + x_fp8, + self.weight, + self.fp8_scale_w, + self.fp8_amax_w, + self.emulate, + self.recompute_weight_cast, + ) + # y = torch.matmul(x_fp8, w_fp8.t()) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y, self.emulate) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 467b33ec..f445f1e3 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -164,7 +164,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): ) -class float8_linear(torch.autograd.Function): +class _float8_linear(torch.autograd.Function): """Custom autograd function for computing torch.nn.Linear on Float8Tensor. This is needed for a couple reasons, we want to have fine grained control over the @@ -238,3 +238,27 @@ def backward(ctx, go_fp8: torch.Tensor): empty_grads = None, None, None, None, None, None, None, None, None return dL_dX, dL_dW, *empty_grads + + +# Need to allow_in_graph because: +# (1) the forward returns a plain tensor +# (2) the backward accepts a Float8Tensor subclass +# dynamo has no good way to be told what the type of +# the grad_out is today, so it (incorrectly) assumes it is also a plain tensor. +@torch._dynamo.allow_in_graph +def float8_linear( + x_fp8: torch.Tensor, + original_weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_amax_buffer: Optional[torch.Tensor], + emulate: bool, + recompute_float8_weight: bool, +): + return _float8_linear.apply( + x_fp8, + original_weight, + weight_scale, + weight_amax_buffer, + emulate, + recompute_float8_weight, + ) From 2ad61841bcefeeced8db319161f8c81c3539e0c9 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 17 Jan 2024 14:18:51 -0800 Subject: [PATCH 05/12] add recompute option --- float8_experimental/float8_linear.py | 6 +++++- float8_experimental/float8_linear_utils.py | 18 +++++++++++------- float8_experimental/float8_ops.py | 1 - 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 5171086e..8a35ddc5 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -329,13 +329,16 @@ def forward(self, x): return y @classmethod - def from_float(cls, mod, emulate: bool = False): + def from_float( + cls, mod, emulate: bool = False, recompute_weight_cast: bool = False + ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + recompute_weight_cast (bool): whether to recompute the casted weight for backwards """ # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params @@ -344,6 +347,7 @@ def from_float(cls, mod, emulate: bool = False): new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + new_mod.recompute_weight_cast = recompute_weight_cast # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) return new_mod diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index c6152e8d..0f3eebe3 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy from enum import auto, Enum +from typing import List, Optional, Union import torch import torch.distributed as dist @@ -59,11 +60,12 @@ def _update_history_with_new_amax(new_amax, amax_history): def swap_linear_with_float8_linear( - model, - module, - emulate=False, - skip_fqn_list=None, - cur_fqn="", + model: torch.nn.Module, + module: Union[Float8Linear, Float8DynamicLinear], + emulate: bool = False, + skip_fqn_list: Optional[List[str]] = None, + cur_fqn: str = "", + recompute_weight_cast: bool = False, ): """ Replaces all instances of torch.nn.Linear in the given model with module. @@ -74,17 +76,19 @@ def swap_linear_with_float8_linear( emulate (bool, optional): Whether to emulate the fp8 matmul logic in float32. skip_fqn_list (List[str], optional): If specified, a list of FQNs to skip cur_fqn (str, optional): Current fqn, used to implement skip_fqn_list + recompute_weight_cast (bool, optional): Whether to recompute the weight cast in the backwards pass. """ + args = (module, emulate, skip_fqn_list, cur_fqn, recompute_weight_cast) name_to_child = dict(model.named_children()) for name, child in name_to_child.items(): new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and isinstance( child, torch.nn.Linear ): - new_child = module.from_float(child, emulate) + new_child = module.from_float(child, emulate, recompute_weight_cast) setattr(model, name, new_child) else: - swap_linear_with_float8_linear(child, module, emulate) + swap_linear_with_float8_linear(child, *args) def get_float8_layers(model: torch.nn.Module, fp8_classes=None): diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index f445f1e3..e244724f 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -205,7 +205,6 @@ def forward( w_fp8_t = w_fp8.t() res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t) - res_bits = res_bits.reshape(*x_fp8.shape[:-1], res_bits.size(-1)) return res_bits From 177173a6f60fde16c805a6052376dc2d3c47cef0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 17 Jan 2024 15:05:33 -0800 Subject: [PATCH 06/12] parmetrize compile and base tests --- float8_experimental/float8_linear_utils.py | 10 ++++- test/test_base.py | 36 ++++++++++++++---- test/test_compile.py | 44 ++++++++++++++++++---- 3 files changed, 73 insertions(+), 17 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 0f3eebe3..40ee292f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -24,13 +24,17 @@ class LinearType(Enum): def get_float8_linear( - linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False + linear_type: LinearType, + linear_ref: torch.nn.Linear, + emulate: bool = False, + recompute_weight_cast: bool = False, ): """Returns a Float8Linear module of the given type, initialized from linear_ref. Args: linear_type: The type of Float8Linear to return. linear_ref: The linear module to initialize from. emulate: Whether to emulate the fp8 matmul logic in float32. + recompute_weight_cast: Whether to recompute the weight cast in the backwards pass. """ LINEAR_TYPE_MAP = { LinearType.DELAYED: Float8Linear, @@ -40,7 +44,9 @@ def get_float8_linear( raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}") return LINEAR_TYPE_MAP[linear_type].from_float( - copy.deepcopy(linear_ref), emulate=emulate + copy.deepcopy(linear_ref), + emulate=emulate, + recompute_weight_cast=recompute_weight_cast, ) diff --git a/test/test_base.py b/test/test_base.py index ba1f6662..b3875f5e 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None: class TestFloat8Linear: - def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): - m_fp8 = get_float8_linear(linear_type, m_ref, emulate) + def _test_linear_impl( + self, + x, + m_ref, + linear_type: LinearType, + emulate: bool, + recompute_weight_cast: bool, + ): + m_fp8 = get_float8_linear(linear_type, m_ref, emulate, recompute_weight_cast) for _ in range(2): if linear_requires_sync(linear_type): sync_float8_amax_and_scale_history(m_fp8) @@ -112,7 +119,14 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): @pytest.mark.parametrize("emulate", [True, False]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): + @pytest.mark.parametrize("recompute_weight_cast", [True, False]) + def test_linear_nobias( + self, + x_shape, + linear_type: LinearType, + emulate: bool, + recompute_weight_cast: bool, + ): if not emulate: if not torch.cuda.is_available(): warnings.warn("CUDA not available") @@ -125,7 +139,7 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast) @pytest.mark.parametrize("emulate", [True, False]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @@ -133,8 +147,14 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) + @pytest.mark.parametrize("recompute_weight_cast", [True, False]) def test_linear_bias( - self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype + self, + x_shape, + linear_type: LinearType, + emulate: bool, + linear_dtype: torch.dtype, + recompute_weight_cast: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -148,10 +168,10 @@ def test_linear_bias( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = Float8Linear.from_float(m, emulate) + m = Float8Linear.from_float(m, emulate, recompute_weight_cast) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) @@ -184,7 +204,7 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, False) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) m = Float8Linear.from_float(m, emulate) diff --git a/test/test_compile.py b/test/test_compile.py index 9b88811a..f345c4b6 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -22,6 +22,7 @@ def _test_compile_base( emulate: bool, linear_type: LinearType, dtype: torch.dtype, + recompute_weight_cast: bool, ): random.seed(0) torch.manual_seed(0) @@ -31,7 +32,9 @@ def _test_compile_base( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - m_fp8 = get_float8_linear(linear_type, m_ref, emulate=emulate) + m_fp8 = get_float8_linear( + linear_type, m_ref, emulate=emulate, recompute_weight_cast=recompute_weight_cast + ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -50,30 +53,57 @@ def _test_compile_base( @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("recompute_weight_cast", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") -def test_eager_only(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_eager_only( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + recompute_weight_cast: bool, +): torch._dynamo.reset() - _test_compile_base("eager", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "eager", fullgraph, emulate, linear_type, dtype, recompute_weight_cast + ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("recompute_weight_cast", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") -def test_aot_eager(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_aot_eager( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + recompute_weight_cast: bool, +): torch._dynamo.reset() - _test_compile_base("aot_eager", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "aot_eager", fullgraph, emulate, linear_type, dtype, recompute_weight_cast + ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) +@pytest.mark.parametrize("recompute_weight_cast", [False, True]) @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_inductor( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + recompute_weight_cast: bool, +): torch._dynamo.reset() - _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "inductor", fullgraph, emulate, linear_type, dtype, recompute_weight_cast + ) if __name__ == "__main__": From 2ffcbe99cdc217cab0df4f92b4870259af9d9a37 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 17 Jan 2024 16:25:01 -0800 Subject: [PATCH 07/12] add test to fsdp --- float8_experimental/float8_ops.py | 2 +- test/test_fsdp.py | 55 +++++++++++++++++++++++++------ test/test_fsdp.sh | 35 +++++++++++++++----- test/test_fsdp_compile.py | 27 ++++++++++----- 4 files changed, 91 insertions(+), 28 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index e244724f..1d7aaa49 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -189,7 +189,7 @@ def forward( emulate=emulate, ) if recompute_float8_weight: - # This should be set to True when using traditional fsdp to avoid saving + # This should be set to True when using traditional fsdp to avoid # saving the unsharded weight for backwards ctx.save_for_backward( x_fp8, original_weight, weight_scale, weight_amax_buffer diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 864412d8..8aea77c5 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -61,7 +61,9 @@ def cleanup(): dist.destroy_process_group() -def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): +def get_model( + K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False +): m = nn.Sequential( nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), @@ -69,7 +71,12 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): nn.ReLU(), ) if is_fp8: - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + swap_linear_with_float8_linear( + m, + Float8Linear, + emulate=emulate, + recompute_weight_cast=recompute_weight_cast, + ) return m @@ -81,10 +88,15 @@ def fsdp_main(rank, world_size, args): # TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile. # We can investigate and fix it later. - is_fp8, emulate, base_dtype, compile, fullgraph = args - model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to( - rank - ) + is_fp8, emulate, base_dtype, compile, fullgraph, recompute_weight_cast = args + model = get_model( + K, + N, + is_fp8=is_fp8, + emulate=emulate, + base_dtype=base_dtype, + recompute_weight_cast=recompute_weight_cast, + ).to(rank) model.load_state_dict(torch.load(sd_in_fname)) # To compile FSDP, we need use_orig_params to True model = FSDP(model, use_orig_params=True) @@ -148,7 +160,13 @@ def forward_backward(model): cleanup() -def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False): +def run( + mode: str, + is_fp8: bool, + compile_fsdp: bool = False, + fullgraph: bool = False, + recompute_weight_cast: bool = False, +): print(f"Mode: {mode}".center(100, "-")) base_dtype = torch.bfloat16 if not os.path.exists(data_dir): @@ -169,7 +187,12 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F # generate reference input ref_input = torch.randn(B, M, K).cuda().to(base_dtype) model = get_model( - K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype + K, + N, + is_fp8=is_fp8, + emulate=emulate, + base_dtype=base_dtype, + recompute_weight_cast=recompute_weight_cast, ).cuda() torch.save(ref_input, input_fname) torch.save(model.state_dict(), sd_in_fname) @@ -177,7 +200,12 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F elif mode == "single_gpu": ref_input = torch.load(input_fname).to(base_dtype) model = get_model( - K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype + K, + N, + is_fp8=is_fp8, + emulate=emulate, + base_dtype=base_dtype, + recompute_weight_cast=recompute_weight_cast, ).cuda() model.load_state_dict(torch.load(sd_in_fname)) optimizer = torch.optim.SGD(model.parameters(), lr=lr) @@ -199,7 +227,14 @@ def forward_backward(): elif mode == "fsdp": WORLD_SIZE = torch.cuda.device_count() # We only compile for fsdp, and compare the numerics with signle-gpu no-compile - args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph) + args = ( + is_fp8, + emulate, + base_dtype, + compile_fsdp, + fullgraph, + recompute_weight_cast, + ) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) elif mode == "analyze": diff --git a/test/test_fsdp.sh b/test/test_fsdp.sh index 624b3969..44770f29 100755 --- a/test/test_fsdp.sh +++ b/test/test_fsdp.sh @@ -4,14 +4,18 @@ set -e launch() { - echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH" + echo "Launching test with the following configuration:" + echo "IS_FP8: $IS_FP8" + echo "compile_fsdp: $COMPILE" + echo "fullgraph: $FULLGRAPH" + echo "recompute_weight_cast: $RECOMPUTE" # generate the test data - python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE echo "Success: ✅" # generate single GPU model output and updated state dict - python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE echo "Success: ✅" # generate FSDP model output and updated state dict @@ -20,19 +24,32 @@ launch() { # the NCCL_NET setting is to work around transient issues on a # specific host (`devgpu001.nha2`) NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \ - --mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + --mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE # compare the outputs and state dicts and verify equivalence - python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE echo "Success: ✅" echo "✅ All Tests Passed ✅" } -# IS_FP8, COMPILE, FULLGRAPH -for i in False,False,False True,False,False True,True,False +# Loop over different combinations of settings +for i in False,False,False,False \ + True,False,False,False \ + True,True,False,False \ + True,False,False,True \ + True,True,False,True do - IFS=","; set -- $i; - IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3 + # Split the string into variables + IFS="," + set -- $i + + # Assign each variable to a more descriptive name + IS_FP8=$1 + COMPILE=$2 + FULLGRAPH=$3 + RECOMPUTE=$4 + + # Launch the test with the current settings launch done diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 9dc77434..384b1ee0 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -48,12 +48,16 @@ def cleanup(): dist.destroy_process_group() -def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): +def get_model( + K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False +): m = nn.Sequential( nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), ) - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + swap_linear_with_float8_linear( + m, Float8Linear, emulate=emulate, recompute_weight_cast=recompute_weight_cast + ) return m @@ -63,7 +67,7 @@ def fsdp_main(rank, world_size, args): setup(rank, world_size) torch.cuda.set_device(rank) - (emulate,) = args + (emulate, recompute_weight_cast) = args # composability of torch.compile + FSDP + autocast + Float8Linear # as fo 2023-12-30 @@ -81,9 +85,14 @@ def fsdp_main(rank, world_size, args): # things work e2e. Note that FSDP does not support full-graph compile # regardless of float8. - model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to( - rank - ) + model = get_model( + K, + N, + is_fp8=True, + emulate=emulate, + base_dtype=torch.bfloat16, + recompute_weight_cast=recompute_weight_cast, + ).to(rank) # To compile FSDP, we need use_orig_params to True model = FSDP(model, use_orig_params=True) @@ -102,7 +111,8 @@ def fsdp_main(rank, world_size, args): sync_float8_func(model) optimizer.step() - print("done!") + if rank == 0: + print("Success: ✅") cleanup() @@ -119,7 +129,8 @@ def run(): emulate = True WORLD_SIZE = torch.cuda.device_count() - args = (emulate,) + recompute_weight_cast = True + args = (emulate, recompute_weight_cast) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) From 48f21f6b9bef8b356fc9a47cbc0259c3585bb6e0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 17 Jan 2024 17:38:15 -0800 Subject: [PATCH 08/12] update linear bench --- benchmarks/bench_linear_float8.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 4736cd29..265fa359 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -56,6 +56,7 @@ class Experiment: dtype: torch.dtype compiled: bool = False float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn + recompute_weight_cast: bool = False # 3 Times since we are calculating forward backward @property @@ -95,9 +96,14 @@ def main( } input_bias = False ref_dtypes = [torch.bfloat16, torch.float16] + recompute_weight_casts = [True, False] experiment_list: List[Experiment] = [] - for idx, (dtype, (name, (K, N))) in enumerate( - tqdm(list(product(ref_dtypes, name_to_shapes_70b.items()))) + for idx, (dtype, (name, (K, N)), recompute_weight_cast) in enumerate( + tqdm( + list( + product(ref_dtypes, name_to_shapes_70b.items(), recompute_weight_casts) + ) + ) ): if n_limit is not None and idx >= n_limit: break @@ -106,7 +112,9 @@ def main( ) linear_float8 = Float8Linear.from_float( - copy.deepcopy(linear_ref), emulate=False + copy.deepcopy(linear_ref), + emulate=False, + recompute_weight_cast=recompute_weight_cast, ) bsz, seq_len = 4, 4096 @@ -155,6 +163,7 @@ def wrapper(*args, **kwargs): float8_time, dtype, compile, + recompute_weight_cast=recompute_weight_cast, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) @@ -169,6 +178,7 @@ def wrapper(*args, **kwargs): "ref_dtype", "compiled", "fp8_dtype", + "recompute_weight_cast", "ref_time_sec", "pt_fp8_time_sec", "ref_tops_sec", @@ -187,6 +197,7 @@ def wrapper(*args, **kwargs): experiment.dtype, experiment.compiled, experiment.float_8_dtype, + experiment.recompute_weight_cast, experiment.ref_time_sec, experiment.float8_time_sec, experiment.ref_tops_sec, @@ -214,6 +225,7 @@ def wrapper(*args, **kwargs): "shape", "ref_dtype", "compiled", + "recompute_weight_cast", "ref_time_sec", "pt_fp8_time_sec", "pt_fp8_speedup", From 20da1c090183709603bfcc1efbbdf25f05ff4a63 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 18 Jan 2024 19:25:42 -0800 Subject: [PATCH 09/12] performance boooooooooooost --- float8_experimental/float8_ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 1d7aaa49..8093c271 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -223,17 +223,19 @@ def backward(ctx, go_fp8: torch.Tensor): x_fp8, w_fp8 = ctx.saved_tensors # calculate dL/dX - go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) + go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1)) w_fp8_t_c_t = w_fp8.t().contiguous().t() dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) - dL_dX = dL_dX.reshape(*go_fp8.shape[:-1], dL_dX.size(-1)) + dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1)) # calculate dL/dW - x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8.size(-1)).t().contiguous() + x_fp8_reshaped_t_c = x_fp8.view(-1, x_fp8.size(-1)).t().contiguous() go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t) - dL_dW = dL_dW.t() + # The contiguous call is not needed for correctness, but allows for a faster backward + # pass in conjunction with compile for both single-gpu and fsdp. + dL_dW = dL_dW.t().contiguous() empty_grads = None, None, None, None, None, None, None, None, None return dL_dX, dL_dW, *empty_grads From d03d16b68aff005b70e3351e8dadb50c896b4bfd Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 19 Jan 2024 10:40:46 -0800 Subject: [PATCH 10/12] do less in the backwards --- float8_experimental/float8_ops.py | 18 ++++++------------ float8_experimental/float8_tensor.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 8093c271..d9754ca9 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -8,7 +8,7 @@ import torch from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, re_construct_float8_weight from float8_experimental.float8_utils import is_row_major from torch.utils._pytree import tree_map @@ -191,9 +191,7 @@ def forward( if recompute_float8_weight: # This should be set to True when using traditional fsdp to avoid # saving the unsharded weight for backwards - ctx.save_for_backward( - x_fp8, original_weight, weight_scale, weight_amax_buffer - ) + ctx.save_for_backward(x_fp8, original_weight, weight_scale) else: # Does this interact properly with activation checkpointing? ctx.save_for_backward(x_fp8, w_fp8) @@ -211,19 +209,15 @@ def forward( @staticmethod def backward(ctx, go_fp8: torch.Tensor): if ctx.recompute_float8_weight: - x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors - w_fp8 = Float8Tensor.to_float8( - original_weight, - weight_scale, - torch.float8_e4m3fn, - weight_amax_buffer, - emulate=ctx.emulate, + x_fp8, original_weight, weight_scale = ctx.saved_tensors + w_fp8 = re_construct_float8_weight( + original_weight, weight_scale, torch.float8_e4m3fn, emulate=ctx.emulate ) else: x_fp8, w_fp8 = ctx.saved_tensors # calculate dL/dX - go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1)) + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) w_fp8_t_c_t = w_fp8.t().contiguous().t() dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1)) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 4450fce8..ff879814 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -42,6 +42,26 @@ def backward(ctx, g): return g, None, None, None, None +@torch._dynamo.allow_in_graph +def re_construct_float8_weight( + tensor: torch.Tensor, scale: torch.Tensor, float8_dtype, emulate: bool = False +): + """In the backwards of float8_linear we don't need to fill the amax buffer + for the weight tensor since that was done during the forward and we just need to + recast the orignal precision tensor using the scale from the forward + + Args: + tensor: the tensor to convert + scale: the scale to use to convert the tensor, from the forward + float8_dtype: the float8 dtype to use + emulate: if true using fp32 emulation for the matmuls, helpful + if you don't have access to h100 hardware. + """ + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) + + @torch._dynamo.allow_in_graph class FromFloat8ConstrFunc(torch.autograd.Function): """ From d2da1ad28fa8d5d8a8fd3e8fe8d079fc25052e6f Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 19 Jan 2024 14:03:38 -0800 Subject: [PATCH 11/12] update profile --- benchmarks/profile_linear_float8.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index b62020b1..4b4e92e8 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -87,7 +87,9 @@ class LinearParams: torch_compile: Optional[bool] = False -def main(profile_path: Path, compile: bool, linear_type: str): +def main( + profile_path: Path, compile: bool, linear_type: str, recompute_weight_cast: bool +): profile_path = Path(profile_path) assert profile_path.is_dir(), f"Path {profile_path} must be a directory" params = LinearParams( @@ -110,7 +112,9 @@ def main(profile_path: Path, compile: bool, linear_type: str): dtype=params.ref_dtype, ) linear_type = LinearType[linear_type.upper()] - linear_float8 = get_float8_linear(linear_type, linear_ref) + linear_float8 = get_float8_linear( + linear_type, linear_ref, recompute_weight_cast=recompute_weight_cast + ) input_tensor = torch.randn( params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True From 941d2f3ae3ac81ed5f4937067087e7474dc15c90 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 19 Jan 2024 23:24:17 -0800 Subject: [PATCH 12/12] use addmm for bias (hopefully), fix autocasting --- float8_experimental/float8_dynamic_linear.py | 28 ++++++--- float8_experimental/float8_linear.py | 27 +++++---- float8_experimental/float8_ops.py | 61 +++++++++----------- float8_experimental/float8_utils.py | 30 +++++++++- test/test_base.py | 19 +++--- 5 files changed, 97 insertions(+), 68 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index eb23bcc1..de26b576 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -11,7 +11,11 @@ from float8_experimental.float8_ops import float8_linear from float8_experimental.float8_tensor import Float8Tensor -from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated +from float8_experimental.float8_utils import ( + get_maybe_autocast_inputs, + tensor_to_scale, + to_fp8_saturated, +) @torch._dynamo.allow_in_graph @@ -48,14 +52,16 @@ class Float8DynamicLinear(torch.nn.Linear): """ def forward(self, x): - x_fp8 = self.cast_to_float8(x) - # w_fp8 = self.cast_to_float8(self.weight) - - # y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - weight_scale = tensor_to_scale(self.weight, torch.float8_e4m3fn) + # Tried to do this with @custom_fwd/bwd but it didn't work + temp_x, temp_weight, temp_bias = get_maybe_autocast_inputs( + x, self.weight, self.bias + ) + x_fp8 = self.cast_to_float8(temp_x) + weight_scale = tensor_to_scale(temp_weight, torch.float8_e4m3fn) y = float8_linear( x_fp8, - self.weight, + temp_weight, + None, # bias weight_scale, None, self.emulate, @@ -63,7 +69,13 @@ def forward(self, x): ) # Cast gradY to float8_e5m2 during backward y = self.cast_to_float8e5m2_bw(y) - y = y + self.bias if self.bias is not None else y + + # TODO We should use addmm above but this fails the single fsdp test: + # FAILED: _orig_mod.0.fp8_amax_w, 0.2197265625, 0.21875 + # Not immediately clear why the bias being fused in would only effect the numerics + # for the weight.... + if temp_bias is not None: + y = y + temp_bias return y diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 8a35ddc5..44c86b36 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -27,6 +27,7 @@ amax_history_to_scale, E4M3_MAX_POS, E5M2_MAX_POS, + get_maybe_autocast_inputs, tensor_to_amax, to_fp8_saturated, ) @@ -199,14 +200,6 @@ def convert_amax_buffer_to_float32(self): def cast_x_to_float8( self, x: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: - # Duplicate the autocast logic for F.linear, so that the output - # of our module has the right original precision - if torch.is_autocast_enabled(): - # For now, hardcode to GPU's autocast dtype - # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() - x = x.to(autocast_dtype) - scale_fn_name = self.recipe.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( x, @@ -303,27 +296,33 @@ class Float8Linear(Float8LinearMixin, torch.nn.Linear): """ def forward(self, x): + temp_x, temp_weight, temp_bias = get_maybe_autocast_inputs( + x, self.weight, self.bias + ) self.float8_pre_forward(x) - x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized) - # w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + x_fp8 = self.cast_x_to_float8(temp_x, self.is_amax_initialized) self._maybe_init_amaxes_scales_weight(self.weight, self.is_amax_initialized) y = float8_linear( x_fp8, - self.weight, + temp_weight, + None, # bias self.fp8_scale_w, self.fp8_amax_w, self.emulate, self.recompute_weight_cast, ) - # y = torch.matmul(x_fp8, w_fp8.t()) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y, self.emulate) - if self.bias is not None: - y = y + self.bias.to(y.dtype) + # TODO We should use addmm above but this fails the single fsdp test: + # FAILED: _orig_mod.0.fp8_amax_w, 0.2197265625, 0.21875 + # Not immediately clear why the bias being fused in would only effect the numerics + # for the weight.... + if temp_bias is not None: + y = y + temp_bias self.float8_post_forward() return y diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index d9754ca9..72ba6afc 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -75,23 +75,33 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): return a_data, a_scale, b_data, b_scale -def float8_mm_helper(a: Float8Tensor, b: Float8Tensor) -> torch.Tensor: +def float8_addmm_helper( + a: Float8Tensor, b: Float8Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: """This is a helper function for float8_mm Args: a: The first matrix multiplication term. b: The second matrix multiplication term. + bias: The bias term. Returns: torch.Tensor: The result of the matrix multiplication. """ a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype + if bias is not None: + assert ( + bias.dtype == output_dtype + ), f"bias dtype {bias.dtype} != output_dtype {output_dtype}" if a._emulate: assert a._emulate == b._emulate - return torch.ops.aten.mm_float8_emulated( + out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype )[0] + if bias is not None: + return out + bias + return out tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None + a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias ) return tensor_out @@ -101,18 +111,7 @@ def float8_mm(aten_op, args, kwargs=None): assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) a = args[0] b = args[1] - return float8_mm_helper(a, b) - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) - output_dtype = a._orig_dtype - if a._emulate: - assert a._emulate == b._emulate - return torch.ops.aten.mm_float8_emulated( - a._data, a._scale, b._data, b._scale, output_dtype - )[0] - tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None - ) - return tensor_out + return float8_addmm_helper(a, b, None) @implements([aten.addmm.default]) @@ -125,19 +124,7 @@ def float8_addmm(aten_op, args, kwargs=None): bias = args[0] a = args[1] b = args[2] - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) - output_dtype = a._orig_dtype - assert bias.dtype == output_dtype, "bias dtype must match output dtype" - if a._emulate: - assert a._emulate == b._emulate - out = torch.ops.aten.mm_float8_emulated( - a._data, a._scale, b._data, b._scale, output_dtype - )[0] - return out + bias - tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias - ) - return tensor_out + return float8_addmm_helper(a, b, bias) @implements([aten.is_same_size.default]) @@ -174,8 +161,9 @@ class _float8_linear(torch.autograd.Function): @staticmethod def forward( ctx, - x_fp8: torch.Tensor, + x_fp8: Float8Tensor, original_weight: torch.Tensor, + bias: Optional[torch.Tensor], weight_scale: torch.Tensor, weight_amax_buffer: Optional[torch.Tensor], emulate: bool, @@ -202,7 +190,7 @@ def forward( w_fp8_t = w_fp8.t() - res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t) + res_bits = float8_addmm_helper(x_fp8_reshaped, w_fp8_t, bias) res_bits = res_bits.reshape(*x_fp8.shape[:-1], res_bits.size(-1)) return res_bits @@ -219,20 +207,25 @@ def backward(ctx, go_fp8: torch.Tensor): # calculate dL/dX go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) w_fp8_t_c_t = w_fp8.t().contiguous().t() - dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) + dL_dX = float8_addmm_helper(go_fp8_reshaped, w_fp8_t_c_t, None) dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1)) # calculate dL/dW x_fp8_reshaped_t_c = x_fp8.view(-1, x_fp8.size(-1)).t().contiguous() go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() - dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t) + dL_dW = float8_addmm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t, None) # The contiguous call is not needed for correctness, but allows for a faster backward # pass in conjunction with compile for both single-gpu and fsdp. dL_dW = dL_dW.t().contiguous() + if ctx.needs_input_grad[2]: + dL_dBias = go_fp8.to_original_precision() + else: + dL_dBias = None + empty_grads = None, None, None, None, None, None, None, None, None - return dL_dX, dL_dW, *empty_grads + return dL_dX, dL_dW, dL_dBias, *empty_grads # Need to allow_in_graph because: @@ -244,6 +237,7 @@ def backward(ctx, go_fp8: torch.Tensor): def float8_linear( x_fp8: torch.Tensor, original_weight: torch.Tensor, + bias: Optional[torch.Tensor], weight_scale: torch.Tensor, weight_amax_buffer: Optional[torch.Tensor], emulate: bool, @@ -252,6 +246,7 @@ def float8_linear( return _float8_linear.apply( x_fp8, original_weight, + bias, weight_scale, weight_amax_buffer, emulate, diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 4e65ef99..581a40fd 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from typing import Optional import torch import torch.distributed as dist @@ -94,3 +94,31 @@ def compute_error(x, y): def is_row_major(stride): assert len(stride) == 2, "is_row_major only supports 2D tensors" return stride[0] > stride[1] and stride[1] == 1 + + +def get_maybe_autocast_inputs( + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +): + """If autocast is enabled, cast the inputs to the autocast dtype. Otherwise, return the inputs as is. + + Why do we need this? I tried to implement the autograd function # custom_fwd and custom_bwd + decorators, from the amp docs but that didn't work. + Args: + x: The input tensor. + weight: The weight tensor. + bias: The bias tensor. + + """ + if torch.is_autocast_enabled(): + autocast_dtype = torch.get_autocast_gpu_dtype() + temp_x = x.to(autocast_dtype) if x.dtype != autocast_dtype else x + temp_weight = ( + weight.to(autocast_dtype) if weight.dtype != autocast_dtype else weight + ) + if bias is not None and bias.dtype != autocast_dtype: + temp_bias = bias.to(autocast_dtype) + else: + temp_bias = bias + return temp_x, temp_weight, temp_bias + else: + return x, weight, bias diff --git a/test/test_base.py b/test/test_base.py index b3875f5e..b965fdb2 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -170,18 +170,20 @@ def test_linear_bias( m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = Float8Linear.from_float(m, emulate, recompute_weight_cast) + m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + m = get_float8_linear(linear_type, m_ref, emulate, recompute_weight_cast) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" @@ -192,20 +194,13 @@ def test_linear_bias( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): + def test_type_cast(self, linear_dtype: torch.dtype): emulate = ( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) ) - x_shape = (16, 16) - - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate, False) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) m = Float8Linear.from_float(m, emulate)