diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 4736cd29..265fa359 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -56,6 +56,7 @@ class Experiment: dtype: torch.dtype compiled: bool = False float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn + recompute_weight_cast: bool = False # 3 Times since we are calculating forward backward @property @@ -95,9 +96,14 @@ def main( } input_bias = False ref_dtypes = [torch.bfloat16, torch.float16] + recompute_weight_casts = [True, False] experiment_list: List[Experiment] = [] - for idx, (dtype, (name, (K, N))) in enumerate( - tqdm(list(product(ref_dtypes, name_to_shapes_70b.items()))) + for idx, (dtype, (name, (K, N)), recompute_weight_cast) in enumerate( + tqdm( + list( + product(ref_dtypes, name_to_shapes_70b.items(), recompute_weight_casts) + ) + ) ): if n_limit is not None and idx >= n_limit: break @@ -106,7 +112,9 @@ def main( ) linear_float8 = Float8Linear.from_float( - copy.deepcopy(linear_ref), emulate=False + copy.deepcopy(linear_ref), + emulate=False, + recompute_weight_cast=recompute_weight_cast, ) bsz, seq_len = 4, 4096 @@ -155,6 +163,7 @@ def wrapper(*args, **kwargs): float8_time, dtype, compile, + recompute_weight_cast=recompute_weight_cast, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) @@ -169,6 +178,7 @@ def wrapper(*args, **kwargs): "ref_dtype", "compiled", "fp8_dtype", + "recompute_weight_cast", "ref_time_sec", "pt_fp8_time_sec", "ref_tops_sec", @@ -187,6 +197,7 @@ def wrapper(*args, **kwargs): experiment.dtype, experiment.compiled, experiment.float_8_dtype, + experiment.recompute_weight_cast, experiment.ref_time_sec, experiment.float8_time_sec, experiment.ref_tops_sec, @@ -214,6 +225,7 @@ def wrapper(*args, **kwargs): "shape", "ref_dtype", "compiled", + "recompute_weight_cast", "ref_time_sec", "pt_fp8_time_sec", "pt_fp8_speedup", diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index b62020b1..4b4e92e8 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -87,7 +87,9 @@ class LinearParams: torch_compile: Optional[bool] = False -def main(profile_path: Path, compile: bool, linear_type: str): +def main( + profile_path: Path, compile: bool, linear_type: str, recompute_weight_cast: bool +): profile_path = Path(profile_path) assert profile_path.is_dir(), f"Path {profile_path} must be a directory" params = LinearParams( @@ -110,7 +112,9 @@ def main(profile_path: Path, compile: bool, linear_type: str): dtype=params.ref_dtype, ) linear_type = LinearType[linear_type.upper()] - linear_float8 = get_float8_linear(linear_type, linear_ref) + linear_float8 = get_float8_linear( + linear_type, linear_ref, recompute_weight_cast=recompute_weight_cast + ) input_tensor = torch.randn( params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 58e352da..de26b576 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -8,9 +8,14 @@ """ import torch +from float8_experimental.float8_ops import float8_linear from float8_experimental.float8_tensor import Float8Tensor -from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated +from float8_experimental.float8_utils import ( + get_maybe_autocast_inputs, + tensor_to_scale, + to_fp8_saturated, +) @torch._dynamo.allow_in_graph @@ -47,14 +52,31 @@ class Float8DynamicLinear(torch.nn.Linear): """ def forward(self, x): - x_fp8 = self.cast_to_float8(x) - w_fp8 = self.cast_to_float8(self.weight) - - y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - + # Tried to do this with @custom_fwd/bwd but it didn't work + temp_x, temp_weight, temp_bias = get_maybe_autocast_inputs( + x, self.weight, self.bias + ) + x_fp8 = self.cast_to_float8(temp_x) + weight_scale = tensor_to_scale(temp_weight, torch.float8_e4m3fn) + y = float8_linear( + x_fp8, + temp_weight, + None, # bias + weight_scale, + None, + self.emulate, + self.recompute_weight_cast, + ) # Cast gradY to float8_e5m2 during backward y = self.cast_to_float8e5m2_bw(y) + # TODO We should use addmm above but this fails the single fsdp test: + # FAILED: _orig_mod.0.fp8_amax_w, 0.2197265625, 0.21875 + # Not immediately clear why the bias being fused in would only effect the numerics + # for the weight.... + if temp_bias is not None: + y = y + temp_bias + return y def cast_to_float8(self, inpt_tensor): @@ -67,17 +89,22 @@ def cast_to_float8e5m2_bw(self, gradY): return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) @classmethod - def from_float(cls, mod, emulate: bool = False): + def from_float( + cls, mod, emulate: bool = False, recompute_weight_cast: bool = False + ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + recompute_weight_cast (bool): whether to recompute the weight cast on every + backwards pass """ with torch.device("meta"): new_mod = cls(mod.in_features, mod.out_features, bias=False) new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + new_mod.recompute_weight_cast = recompute_weight_cast return new_mod diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 59285852..44c86b36 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -19,6 +19,7 @@ import float8_experimental.config as config import torch +from float8_experimental.float8_ops import float8_linear from float8_experimental.float8_tensor import Float8Tensor @@ -26,6 +27,7 @@ amax_history_to_scale, E4M3_MAX_POS, E5M2_MAX_POS, + get_maybe_autocast_inputs, tensor_to_amax, to_fp8_saturated, ) @@ -172,6 +174,13 @@ def __init__(self, *args, **kwargs): # and torch.compile, this option can disable them self.enable_pre_and_post_forward = config.enable_pre_and_post_forward + # This flag is used to modify what gets saved for backwards. Its default value + # is False, this saves the casted weight for backwards. Note that this typically increases memory usage + # Because both the weight parameter and the casted weight are saved on device. If set to true + # this will only save the weight parameter and during the backwards pass it will re-cast this weight to fp8. + # For traditional FSDP this should be set to True in order to not save the un-sharded weight for backwards. + self.recompute_weight_cast = False + def register_always_float32_buffer( self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True ) -> None: @@ -191,14 +200,6 @@ def convert_amax_buffer_to_float32(self): def cast_x_to_float8( self, x: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: - # Duplicate the autocast logic for F.linear, so that the output - # of our module has the right original precision - if torch.is_autocast_enabled(): - # For now, hardcode to GPU's autocast dtype - # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() - x = x.to(autocast_dtype) - scale_fn_name = self.recipe.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( x, @@ -214,6 +215,20 @@ def cast_x_to_float8( ) return x_fp8 + def _maybe_init_amaxes_scales_weight( + self, w: torch.Tensor, is_amax_initialized: bool + ): + scale_fn_name = self.recipe.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + w, + self.fp8_amax_w, + self.fp8_amax_history_w, + self.fp8_scale_w, + scale_fn_name, + torch.float8_e4m3fn, + is_amax_initialized, + ) + def cast_w_to_float8( self, w: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: @@ -281,30 +296,48 @@ class Float8Linear(Float8LinearMixin, torch.nn.Linear): """ def forward(self, x): + temp_x, temp_weight, temp_bias = get_maybe_autocast_inputs( + x, self.weight, self.bias + ) self.float8_pre_forward(x) - x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + x_fp8 = self.cast_x_to_float8(temp_x, self.is_amax_initialized) + self._maybe_init_amaxes_scales_weight(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + y = float8_linear( + x_fp8, + temp_weight, + None, # bias + self.fp8_scale_w, + self.fp8_amax_w, + self.emulate, + self.recompute_weight_cast, + ) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y, self.emulate) - if self.bias is not None: - y = y + self.bias.to(y.dtype) + # TODO We should use addmm above but this fails the single fsdp test: + # FAILED: _orig_mod.0.fp8_amax_w, 0.2197265625, 0.21875 + # Not immediately clear why the bias being fused in would only effect the numerics + # for the weight.... + if temp_bias is not None: + y = y + temp_bias self.float8_post_forward() return y @classmethod - def from_float(cls, mod, emulate: bool = False): + def from_float( + cls, mod, emulate: bool = False, recompute_weight_cast: bool = False + ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + recompute_weight_cast (bool): whether to recompute the casted weight for backwards """ # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params @@ -313,6 +346,7 @@ def from_float(cls, mod, emulate: bool = False): new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + new_mod.recompute_weight_cast = recompute_weight_cast # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) return new_mod diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index c6152e8d..40ee292f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy from enum import auto, Enum +from typing import List, Optional, Union import torch import torch.distributed as dist @@ -23,13 +24,17 @@ class LinearType(Enum): def get_float8_linear( - linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False + linear_type: LinearType, + linear_ref: torch.nn.Linear, + emulate: bool = False, + recompute_weight_cast: bool = False, ): """Returns a Float8Linear module of the given type, initialized from linear_ref. Args: linear_type: The type of Float8Linear to return. linear_ref: The linear module to initialize from. emulate: Whether to emulate the fp8 matmul logic in float32. + recompute_weight_cast: Whether to recompute the weight cast in the backwards pass. """ LINEAR_TYPE_MAP = { LinearType.DELAYED: Float8Linear, @@ -39,7 +44,9 @@ def get_float8_linear( raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}") return LINEAR_TYPE_MAP[linear_type].from_float( - copy.deepcopy(linear_ref), emulate=emulate + copy.deepcopy(linear_ref), + emulate=emulate, + recompute_weight_cast=recompute_weight_cast, ) @@ -59,11 +66,12 @@ def _update_history_with_new_amax(new_amax, amax_history): def swap_linear_with_float8_linear( - model, - module, - emulate=False, - skip_fqn_list=None, - cur_fqn="", + model: torch.nn.Module, + module: Union[Float8Linear, Float8DynamicLinear], + emulate: bool = False, + skip_fqn_list: Optional[List[str]] = None, + cur_fqn: str = "", + recompute_weight_cast: bool = False, ): """ Replaces all instances of torch.nn.Linear in the given model with module. @@ -74,17 +82,19 @@ def swap_linear_with_float8_linear( emulate (bool, optional): Whether to emulate the fp8 matmul logic in float32. skip_fqn_list (List[str], optional): If specified, a list of FQNs to skip cur_fqn (str, optional): Current fqn, used to implement skip_fqn_list + recompute_weight_cast (bool, optional): Whether to recompute the weight cast in the backwards pass. """ + args = (module, emulate, skip_fqn_list, cur_fqn, recompute_weight_cast) name_to_child = dict(model.named_children()) for name, child in name_to_child.items(): new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and isinstance( child, torch.nn.Linear ): - new_child = module.from_float(child, emulate) + new_child = module.from_float(child, emulate, recompute_weight_cast) setattr(model, name, new_child) else: - swap_linear_with_float8_linear(child, module, emulate) + swap_linear_with_float8_linear(child, *args) def get_float8_layers(model: torch.nn.Module, fp8_classes=None): diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 392358f2..72ba6afc 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -3,12 +3,12 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, re_construct_float8_weight from float8_experimental.float8_utils import is_row_major from torch.utils._pytree import tree_map @@ -75,24 +75,45 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): return a_data, a_scale, b_data, b_scale -@implements([aten.mm.default]) -def float8_mm(aten_op, args, kwargs=None): - assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) - a = args[0] - b = args[1] +def float8_addmm_helper( + a: Float8Tensor, b: Float8Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + """This is a helper function for float8_mm + Args: + a: The first matrix multiplication term. + b: The second matrix multiplication term. + bias: The bias term. + Returns: + torch.Tensor: The result of the matrix multiplication. + """ a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype + if bias is not None: + assert ( + bias.dtype == output_dtype + ), f"bias dtype {bias.dtype} != output_dtype {output_dtype}" if a._emulate: assert a._emulate == b._emulate - return torch.ops.aten.mm_float8_emulated( + out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype )[0] + if bias is not None: + return out + bias + return out tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None + a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias ) return tensor_out +@implements([aten.mm.default]) +def float8_mm(aten_op, args, kwargs=None): + assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) + a = args[0] + b = args[1] + return float8_addmm_helper(a, b, None) + + @implements([aten.addmm.default]) def float8_addmm(aten_op, args, kwargs=None): assert ( @@ -103,19 +124,7 @@ def float8_addmm(aten_op, args, kwargs=None): bias = args[0] a = args[1] b = args[2] - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) - output_dtype = a._orig_dtype - assert bias.dtype == output_dtype, "bias dtype must match output dtype" - if a._emulate: - assert a._emulate == b._emulate - out = torch.ops.aten.mm_float8_emulated( - a._data, a._scale, b._data, b._scale, output_dtype - )[0] - return out + bias - tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias - ) - return tensor_out + return float8_addmm_helper(a, b, bias) @implements([aten.is_same_size.default]) @@ -140,3 +149,106 @@ def autocast_to_copy(aten_op, args, kwargs=None): return Float8Tensor( args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate ) + + +class _float8_linear(torch.autograd.Function): + """Custom autograd function for computing torch.nn.Linear on Float8Tensor. + + This is needed for a couple reasons, we want to have fine grained control over the + recomputation of casted values for backward. + """ + + @staticmethod + def forward( + ctx, + x_fp8: Float8Tensor, + original_weight: torch.Tensor, + bias: Optional[torch.Tensor], + weight_scale: torch.Tensor, + weight_amax_buffer: Optional[torch.Tensor], + emulate: bool, + recompute_float8_weight: bool, + ): + w_fp8 = Float8Tensor.to_float8( + original_weight, + weight_scale, + torch.float8_e4m3fn, + weight_amax_buffer, + emulate=emulate, + ) + if recompute_float8_weight: + # This should be set to True when using traditional fsdp to avoid + # saving the unsharded weight for backwards + ctx.save_for_backward(x_fp8, original_weight, weight_scale) + else: + # Does this interact properly with activation checkpointing? + ctx.save_for_backward(x_fp8, w_fp8) + + ctx.recompute_float8_weight = recompute_float8_weight + ctx.emulate = emulate + x_fp8_reshaped = x_fp8.reshape(-1, x_fp8.size(-1)) + + w_fp8_t = w_fp8.t() + + res_bits = float8_addmm_helper(x_fp8_reshaped, w_fp8_t, bias) + res_bits = res_bits.reshape(*x_fp8.shape[:-1], res_bits.size(-1)) + return res_bits + + @staticmethod + def backward(ctx, go_fp8: torch.Tensor): + if ctx.recompute_float8_weight: + x_fp8, original_weight, weight_scale = ctx.saved_tensors + w_fp8 = re_construct_float8_weight( + original_weight, weight_scale, torch.float8_e4m3fn, emulate=ctx.emulate + ) + else: + x_fp8, w_fp8 = ctx.saved_tensors + + # calculate dL/dX + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) + w_fp8_t_c_t = w_fp8.t().contiguous().t() + dL_dX = float8_addmm_helper(go_fp8_reshaped, w_fp8_t_c_t, None) + dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1)) + + # calculate dL/dW + x_fp8_reshaped_t_c = x_fp8.view(-1, x_fp8.size(-1)).t().contiguous() + go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() + + dL_dW = float8_addmm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t, None) + # The contiguous call is not needed for correctness, but allows for a faster backward + # pass in conjunction with compile for both single-gpu and fsdp. + dL_dW = dL_dW.t().contiguous() + + if ctx.needs_input_grad[2]: + dL_dBias = go_fp8.to_original_precision() + else: + dL_dBias = None + + empty_grads = None, None, None, None, None, None, None, None, None + return dL_dX, dL_dW, dL_dBias, *empty_grads + + +# Need to allow_in_graph because: +# (1) the forward returns a plain tensor +# (2) the backward accepts a Float8Tensor subclass +# dynamo has no good way to be told what the type of +# the grad_out is today, so it (incorrectly) assumes it is also a plain tensor. +@torch._dynamo.allow_in_graph +def float8_linear( + x_fp8: torch.Tensor, + original_weight: torch.Tensor, + bias: Optional[torch.Tensor], + weight_scale: torch.Tensor, + weight_amax_buffer: Optional[torch.Tensor], + emulate: bool, + recompute_float8_weight: bool, +): + return _float8_linear.apply( + x_fp8, + original_weight, + bias, + weight_scale, + weight_amax_buffer, + emulate, + recompute_float8_weight, + ) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 4450fce8..ff879814 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -42,6 +42,26 @@ def backward(ctx, g): return g, None, None, None, None +@torch._dynamo.allow_in_graph +def re_construct_float8_weight( + tensor: torch.Tensor, scale: torch.Tensor, float8_dtype, emulate: bool = False +): + """In the backwards of float8_linear we don't need to fill the amax buffer + for the weight tensor since that was done during the forward and we just need to + recast the orignal precision tensor using the scale from the forward + + Args: + tensor: the tensor to convert + scale: the scale to use to convert the tensor, from the forward + float8_dtype: the float8 dtype to use + emulate: if true using fp32 emulation for the matmuls, helpful + if you don't have access to h100 hardware. + """ + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) + + @torch._dynamo.allow_in_graph class FromFloat8ConstrFunc(torch.autograd.Function): """ diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 4e65ef99..581a40fd 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from typing import Optional import torch import torch.distributed as dist @@ -94,3 +94,31 @@ def compute_error(x, y): def is_row_major(stride): assert len(stride) == 2, "is_row_major only supports 2D tensors" return stride[0] > stride[1] and stride[1] == 1 + + +def get_maybe_autocast_inputs( + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +): + """If autocast is enabled, cast the inputs to the autocast dtype. Otherwise, return the inputs as is. + + Why do we need this? I tried to implement the autograd function # custom_fwd and custom_bwd + decorators, from the amp docs but that didn't work. + Args: + x: The input tensor. + weight: The weight tensor. + bias: The bias tensor. + + """ + if torch.is_autocast_enabled(): + autocast_dtype = torch.get_autocast_gpu_dtype() + temp_x = x.to(autocast_dtype) if x.dtype != autocast_dtype else x + temp_weight = ( + weight.to(autocast_dtype) if weight.dtype != autocast_dtype else weight + ) + if bias is not None and bias.dtype != autocast_dtype: + temp_bias = bias.to(autocast_dtype) + else: + temp_bias = bias + return temp_x, temp_weight, temp_bias + else: + return x, weight, bias diff --git a/test/test_base.py b/test/test_base.py index ba1f6662..b965fdb2 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None: class TestFloat8Linear: - def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): - m_fp8 = get_float8_linear(linear_type, m_ref, emulate) + def _test_linear_impl( + self, + x, + m_ref, + linear_type: LinearType, + emulate: bool, + recompute_weight_cast: bool, + ): + m_fp8 = get_float8_linear(linear_type, m_ref, emulate, recompute_weight_cast) for _ in range(2): if linear_requires_sync(linear_type): sync_float8_amax_and_scale_history(m_fp8) @@ -112,7 +119,14 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): @pytest.mark.parametrize("emulate", [True, False]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): + @pytest.mark.parametrize("recompute_weight_cast", [True, False]) + def test_linear_nobias( + self, + x_shape, + linear_type: LinearType, + emulate: bool, + recompute_weight_cast: bool, + ): if not emulate: if not torch.cuda.is_available(): warnings.warn("CUDA not available") @@ -125,7 +139,7 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast) @pytest.mark.parametrize("emulate", [True, False]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @@ -133,8 +147,14 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) + @pytest.mark.parametrize("recompute_weight_cast", [True, False]) def test_linear_bias( - self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype + self, + x_shape, + linear_type: LinearType, + emulate: bool, + linear_dtype: torch.dtype, + recompute_weight_cast: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -148,20 +168,22 @@ def test_linear_bias( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = Float8Linear.from_float(m, emulate) + m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + m = get_float8_linear(linear_type, m_ref, emulate, recompute_weight_cast) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" @@ -172,20 +194,13 @@ def test_linear_bias( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): + def test_type_cast(self, linear_dtype: torch.dtype): emulate = ( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) ) - x_shape = (16, 16) - - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) m = Float8Linear.from_float(m, emulate) diff --git a/test/test_compile.py b/test/test_compile.py index 9b88811a..f345c4b6 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -22,6 +22,7 @@ def _test_compile_base( emulate: bool, linear_type: LinearType, dtype: torch.dtype, + recompute_weight_cast: bool, ): random.seed(0) torch.manual_seed(0) @@ -31,7 +32,9 @@ def _test_compile_base( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - m_fp8 = get_float8_linear(linear_type, m_ref, emulate=emulate) + m_fp8 = get_float8_linear( + linear_type, m_ref, emulate=emulate, recompute_weight_cast=recompute_weight_cast + ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -50,30 +53,57 @@ def _test_compile_base( @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("recompute_weight_cast", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") -def test_eager_only(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_eager_only( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + recompute_weight_cast: bool, +): torch._dynamo.reset() - _test_compile_base("eager", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "eager", fullgraph, emulate, linear_type, dtype, recompute_weight_cast + ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("recompute_weight_cast", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") -def test_aot_eager(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_aot_eager( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + recompute_weight_cast: bool, +): torch._dynamo.reset() - _test_compile_base("aot_eager", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "aot_eager", fullgraph, emulate, linear_type, dtype, recompute_weight_cast + ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) +@pytest.mark.parametrize("recompute_weight_cast", [False, True]) @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_inductor( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + recompute_weight_cast: bool, +): torch._dynamo.reset() - _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "inductor", fullgraph, emulate, linear_type, dtype, recompute_weight_cast + ) if __name__ == "__main__": diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 864412d8..8aea77c5 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -61,7 +61,9 @@ def cleanup(): dist.destroy_process_group() -def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): +def get_model( + K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False +): m = nn.Sequential( nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), @@ -69,7 +71,12 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): nn.ReLU(), ) if is_fp8: - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + swap_linear_with_float8_linear( + m, + Float8Linear, + emulate=emulate, + recompute_weight_cast=recompute_weight_cast, + ) return m @@ -81,10 +88,15 @@ def fsdp_main(rank, world_size, args): # TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile. # We can investigate and fix it later. - is_fp8, emulate, base_dtype, compile, fullgraph = args - model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to( - rank - ) + is_fp8, emulate, base_dtype, compile, fullgraph, recompute_weight_cast = args + model = get_model( + K, + N, + is_fp8=is_fp8, + emulate=emulate, + base_dtype=base_dtype, + recompute_weight_cast=recompute_weight_cast, + ).to(rank) model.load_state_dict(torch.load(sd_in_fname)) # To compile FSDP, we need use_orig_params to True model = FSDP(model, use_orig_params=True) @@ -148,7 +160,13 @@ def forward_backward(model): cleanup() -def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False): +def run( + mode: str, + is_fp8: bool, + compile_fsdp: bool = False, + fullgraph: bool = False, + recompute_weight_cast: bool = False, +): print(f"Mode: {mode}".center(100, "-")) base_dtype = torch.bfloat16 if not os.path.exists(data_dir): @@ -169,7 +187,12 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F # generate reference input ref_input = torch.randn(B, M, K).cuda().to(base_dtype) model = get_model( - K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype + K, + N, + is_fp8=is_fp8, + emulate=emulate, + base_dtype=base_dtype, + recompute_weight_cast=recompute_weight_cast, ).cuda() torch.save(ref_input, input_fname) torch.save(model.state_dict(), sd_in_fname) @@ -177,7 +200,12 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F elif mode == "single_gpu": ref_input = torch.load(input_fname).to(base_dtype) model = get_model( - K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype + K, + N, + is_fp8=is_fp8, + emulate=emulate, + base_dtype=base_dtype, + recompute_weight_cast=recompute_weight_cast, ).cuda() model.load_state_dict(torch.load(sd_in_fname)) optimizer = torch.optim.SGD(model.parameters(), lr=lr) @@ -199,7 +227,14 @@ def forward_backward(): elif mode == "fsdp": WORLD_SIZE = torch.cuda.device_count() # We only compile for fsdp, and compare the numerics with signle-gpu no-compile - args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph) + args = ( + is_fp8, + emulate, + base_dtype, + compile_fsdp, + fullgraph, + recompute_weight_cast, + ) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) elif mode == "analyze": diff --git a/test/test_fsdp.sh b/test/test_fsdp.sh index 624b3969..44770f29 100755 --- a/test/test_fsdp.sh +++ b/test/test_fsdp.sh @@ -4,14 +4,18 @@ set -e launch() { - echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH" + echo "Launching test with the following configuration:" + echo "IS_FP8: $IS_FP8" + echo "compile_fsdp: $COMPILE" + echo "fullgraph: $FULLGRAPH" + echo "recompute_weight_cast: $RECOMPUTE" # generate the test data - python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE echo "Success: ✅" # generate single GPU model output and updated state dict - python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE echo "Success: ✅" # generate FSDP model output and updated state dict @@ -20,19 +24,32 @@ launch() { # the NCCL_NET setting is to work around transient issues on a # specific host (`devgpu001.nha2`) NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \ - --mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + --mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE # compare the outputs and state dicts and verify equivalence - python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE echo "Success: ✅" echo "✅ All Tests Passed ✅" } -# IS_FP8, COMPILE, FULLGRAPH -for i in False,False,False True,False,False True,True,False +# Loop over different combinations of settings +for i in False,False,False,False \ + True,False,False,False \ + True,True,False,False \ + True,False,False,True \ + True,True,False,True do - IFS=","; set -- $i; - IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3 + # Split the string into variables + IFS="," + set -- $i + + # Assign each variable to a more descriptive name + IS_FP8=$1 + COMPILE=$2 + FULLGRAPH=$3 + RECOMPUTE=$4 + + # Launch the test with the current settings launch done diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 9dc77434..384b1ee0 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -48,12 +48,16 @@ def cleanup(): dist.destroy_process_group() -def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): +def get_model( + K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False +): m = nn.Sequential( nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), ) - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + swap_linear_with_float8_linear( + m, Float8Linear, emulate=emulate, recompute_weight_cast=recompute_weight_cast + ) return m @@ -63,7 +67,7 @@ def fsdp_main(rank, world_size, args): setup(rank, world_size) torch.cuda.set_device(rank) - (emulate,) = args + (emulate, recompute_weight_cast) = args # composability of torch.compile + FSDP + autocast + Float8Linear # as fo 2023-12-30 @@ -81,9 +85,14 @@ def fsdp_main(rank, world_size, args): # things work e2e. Note that FSDP does not support full-graph compile # regardless of float8. - model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to( - rank - ) + model = get_model( + K, + N, + is_fp8=True, + emulate=emulate, + base_dtype=torch.bfloat16, + recompute_weight_cast=recompute_weight_cast, + ).to(rank) # To compile FSDP, we need use_orig_params to True model = FSDP(model, use_orig_params=True) @@ -102,7 +111,8 @@ def fsdp_main(rank, world_size, args): sync_float8_func(model) optimizer.step() - print("done!") + if rank == 0: + print("Success: ✅") cleanup() @@ -119,7 +129,8 @@ def run(): emulate = True WORLD_SIZE = torch.cuda.device_count() - args = (emulate,) + recompute_weight_cast = True + args = (emulate, recompute_weight_cast) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)