-
Notifications
You must be signed in to change notification settings - Fork 70
Add SdpaFwdOp::evaluate on meta device
#5086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 9c4e1cb Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
jjsjann123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. minor nitpick about the shape propagation.
csrc/ir/nodes.cpp
Outdated
| _scaled_dot_product_flash_attention_meta(const at::Tensor& query) { | ||
| // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) | ||
| // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) | ||
| at::Tensor q_t = query.transpose(1, 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nitpick. pytorch kernel implementation uses transposes because they are feeding this to a kernel. I think we can directly retrieve the shape and call explicit empty or empty_strided, which is probably easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha! Good catch! I spent a lot of time simplifying PyTorch's implementation, but didn't realize that it was just naively return query.
|
!test |
In preparing for the fix of #4888, I am working on #5082, which requires the use of
Expr::evaluateon meta tensors to infer shape and strides of fusion segments selected to be scheduled by theExprEvalscheduler. As a consequence of this change, all theExpr::evaluatefunctions should support meta device, and the returned output tensor's shape and stride must match that on device type CUDA.According to https://docs.pytorch.org/docs/stable/meta.html
It is generally safe to assume that we can use device type meta to infer shapes and strides of device type CUDA. But unfortunately, not all operators implement meta device, and
at::_scaled_dot_product_flash_attentionis such an example.In this PR, I am adding my own
at::_scaled_dot_product_flash_attentionimplementation on meta devices.