-
Notifications
You must be signed in to change notification settings - Fork 66
Use ATen ops on meta tensor to compute output shapes and strides for ExprEval segments #5082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ExprEval segments
!test |
Fixes #4888 |
Review updated until commit 5f89380 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
In preparing for the fix of #4888, I am working on #5082, which requires the use of `Expr::evaluate` on meta tensors to infer shape and strides of fusion segments selected to be scheduled by the `ExprEval` scheduler. As a consequence of this change, all the `Expr::evaluate` functions should support meta device, and the returned output tensor's shape and stride must match that on device type CUDA. According to https://docs.pytorch.org/docs/stable/meta.html > In some cases, not all device types (e.g., CPU and CUDA) have exactly the same output metadata for an operation; we typically prefer representing the CUDA behavior faithfully in this situation. It is generally safe to assume that we can use device type meta to infer shapes and strides of device type CUDA. But unfortunately, not all operators implement meta device, and `at::_scaled_dot_product_flash_attention` is such an example. In this PR, I am adding my own `at::_scaled_dot_product_flash_attention` implementation on meta devices.
!test |
No description provided.