Skip to content

Commit 4d743c7

Browse files
Olialmoneta
authored andcommitted
comments solved
1 parent fb33746 commit 4d743c7

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tmva/sofie/inc/TMVA/ROperator_Softmax.hxx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ public:
6060
std::stringstream out;
6161
size_t size = fShape.size();
6262
auto length_str = ConvertDimShapeToLength(fShape);
63-
int axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
63+
size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
6464

6565
// Check if this is the special case where memory is contiguous.
66-
if (axis == static_cast<int>(size - 1)) {
66+
if (axis == size - 1) {
6767
std::string axis_size = fShape[axis].GetVal();
6868
std::string num_rows;
6969
if (IsInteger(length_str) && IsInteger(axis_size)) {
@@ -73,24 +73,24 @@ public:
7373
}
7474

7575
out << "\n" << SP << "//------ SOFTMAX - " << size << " " << length_str << " " << axis << "\n";
76-
out << SP << "for (size_t i = 0; i < " << num_rows << "; ++i) {\n";
76+
out << SP << "for (int i = 0; i < " << num_rows << "; ++i) {\n";
7777
out << SP << SP << "size_t offset = i * " << axis_size << ";\n";
7878
out << SP << SP << fType << " const * x_ptr = &tensor_" << fNX << "[offset];\n";
7979
out << SP << SP << fType << " * y_ptr = &tensor_" << fNY << "[offset];\n";
8080

8181
out << SP << SP << fType << " vmax = x_ptr[0];\n";
82-
out << SP << SP << "for (size_t j = 1; j < " << axis_size << "; ++j) {\n";
82+
out << SP << SP << "for (int j = 1; j < " << axis_size << "; ++j) {\n";
8383
out << SP << SP << SP << "if (x_ptr[j] > vmax) vmax = x_ptr[j];\n";
8484
out << SP << SP << "}\n";
8585

8686
out << SP << SP << fType << " sum = 0.0;\n";
87-
out << SP << SP << "for (size_t j = 0; j < " << axis_size << "; ++j) {\n";
87+
out << SP << SP << "for (int j = 0; j < " << axis_size << "; ++j) {\n";
8888
out << SP << SP << SP << "y_ptr[j] = std::exp(x_ptr[j] - vmax);\n";
8989
out << SP << SP << SP << "sum += y_ptr[j];\n";
9090
out << SP << SP << "}\n";
9191

9292
out << SP << SP << fType << " inv_sum = 1.0f / sum;\n";
93-
out << SP << SP << "for (size_t j = 0; j < " << axis_size << "; ++j) {\n";
93+
out << SP << SP << "for (int j = 0; j < " << axis_size << "; ++j) {\n";
9494
out << SP << SP << SP << "y_ptr[j] *= inv_sum;\n";
9595
out << SP << SP << "}\n";
9696
out << SP << "}\n";
@@ -100,7 +100,7 @@ public:
100100
size_t k = 0;
101101
std::vector<std::string> l(size);
102102
for (size_t i = 0; i < size; i++) {
103-
if (static_cast<int>(i) != axis) {
103+
if (i != axis) {
104104
for (size_t j = 0; j < k; j++) out << SP;
105105
l[i] = std::string("i") + std::to_string(i);
106106
out << "for (int " << l[i] << " = 0; " << l[i] << " < " << fShape[i] << "; " << l[i] << "++) {\n";
@@ -113,7 +113,7 @@ public:
113113
out << "size_t index = ";
114114
bool first = true;
115115
for (size_t i = 0; i < size; i++) {
116-
if (static_cast<int>(i) == axis) continue;
116+
if (i == axis) continue;
117117
if (!first) out << " + ";
118118
if (stride[i].GetVal() != "1")
119119
out << stride[i] << "*";

0 commit comments

Comments
 (0)