From 6004c23256bfe621ad9c8aacb745411818811d95 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 May 2024 13:47:19 -0700 Subject: [PATCH] [not for land] enumerate breakages with module hooks + compile Summary: This PR rewrites Float8DynamicLinear to use module hooks, as we think long term this is more composable with other PyTorch features. For now there is no plan to land this, this is just reproducing / sharing what breaks when we try this today. Test Plan: ``` // note: all tests pass without this PR // eager mode is fine > pytest -s test/test_base.py | with-proxy gh gist create https://gist.github.com/vkuzo/aded224af91092c8326becc855b125c9 // compile has some errors in aot_eager backend > pytest -s test/test_compile.py | with-proxy gh gist create https://gist.github.com/vkuzo/cab55b11a2c3cee0d1ff94169131b171 // dtensor + float8 has numeric issues > ./test/test_dtensor.sh | with-proxy gh gist create https://gist.github.com/vkuzo/d1035200db22f2e3357438824cd3594f ``` Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index caeb31c..2ebdeaf 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -52,6 +52,13 @@ def backward(ctx, gradY): ) return fp8_tensor, None +def forward_pre_hook(mod, x): + x = cast_to_float8_e4m3fn(x[0], mod.forward_config) + return x + +def forward_post_hook(mod, x, y): + y = cast_to_float8_e5m2_bw(y, mod.backward_config) + return y class Float8DynamicLinear(torch.nn.Linear): """ @@ -62,14 +69,14 @@ class Float8DynamicLinear(torch.nn.Linear): def __init__(self, **super_kwargs): super().__init__(**super_kwargs) - def forward(self, x): - x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config) + def forward(self, x_fp8): + # x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config) if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config) y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - y = cast_to_float8_e5m2_bw(y, self.backward_config) + # y = cast_to_float8_e5m2_bw(y, self.backward_config) return y @classmethod @@ -97,6 +104,8 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": else: new_mod.weight = mod.weight new_mod.bias = mod.bias + new_mod.register_forward_pre_hook(forward_pre_hook) + new_mod.register_forward_hook(forward_post_hook) return new_mod