-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[tmva][sofie] Softmax contiguous memory fast path #19743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[tmva][sofie] Softmax contiguous memory fast path #19743
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR,
Just some corrections for the case the shapes are not fully specified as integer values, but parametrised (dynamic)
auto stride = UTILITY::ComputeStrideFromShape(fShape); | ||
size_t size = fShape.size(); | ||
auto length_str = ConvertDimShapeToLength(fShape); | ||
size_t length = std::stoul(length_str); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to convert to an integer, since in some case it might fail : e.g length = " N * 32"
// Check if this is the special case where memory is contiguous. | ||
if (axis == static_cast<int>(size - 1)) { | ||
size_t axis_size = std::stoul(fShape[axis].GetVal()); | ||
size_t num_rows = length / axis_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of length is not an integer , this could be written as following:
std::string axis_size = fShape[axis].GetVal());
std::string num_rows;
if (IsInteger(length) && IsInteger(axis_size)) {
num_rows = std::to_string( std::stoul(length_str)/ std::stoul(fShape[axis].GetVal()));
} else {
num_rows = "(" + length_str + ") /" + fShape[axis].GetVal():
}
c51c2d9
to
5960569
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Just a couple of small comments.
Thank you for this improvement!
auto length = ConvertDimShapeToLength(fShape); | ||
auto stride = UTILITY::ComputeStrideFromShape(fShape); | ||
size_t size = fShape.size(); | ||
auto length_str = ConvertDimShapeToLength(fShape); | ||
int axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we could have "size_t" for axis instead of int
out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n"; | ||
|
||
// Check if this is the special case where memory is contiguous. | ||
if (axis == static_cast<int>(size - 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we define axis as size_t we don;t need the casting here
} | ||
|
||
out << "\n" << SP << "//------ SOFTMAX - " << size << " " << length_str << " " << axis << "\n"; | ||
out << SP << "for (size_t i = 0; i < " << num_rows << "; ++i) {\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could use here int
in the generated code to save memory, as it is done in the general case below
This Pull request:
An optimized code path has been added for the common case where the Softmax axis is the last dimension (axis == size - 1).