@@ -60,10 +60,10 @@ public:
60
60
std::stringstream out;
61
61
size_t size = fShape .size ();
62
62
auto length_str = ConvertDimShapeToLength (fShape );
63
- int axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis ;
63
+ size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis ;
64
64
65
65
// Check if this is the special case where memory is contiguous.
66
- if (axis == static_cast < int >( size - 1 ) ) {
66
+ if (axis == size - 1 ) {
67
67
std::string axis_size = fShape [axis].GetVal ();
68
68
std::string num_rows;
69
69
if (IsInteger (length_str) && IsInteger (axis_size)) {
@@ -73,24 +73,24 @@ public:
73
73
}
74
74
75
75
out << " \n " << SP << " //------ SOFTMAX - " << size << " " << length_str << " " << axis << " \n " ;
76
- out << SP << " for (size_t i = 0; i < " << num_rows << " ; ++i) {\n " ;
76
+ out << SP << " for (int i = 0; i < " << num_rows << " ; ++i) {\n " ;
77
77
out << SP << SP << " size_t offset = i * " << axis_size << " ;\n " ;
78
78
out << SP << SP << fType << " const * x_ptr = &tensor_" << fNX << " [offset];\n " ;
79
79
out << SP << SP << fType << " * y_ptr = &tensor_" << fNY << " [offset];\n " ;
80
80
81
81
out << SP << SP << fType << " vmax = x_ptr[0];\n " ;
82
- out << SP << SP << " for (size_t j = 1; j < " << axis_size << " ; ++j) {\n " ;
82
+ out << SP << SP << " for (int j = 1; j < " << axis_size << " ; ++j) {\n " ;
83
83
out << SP << SP << SP << " if (x_ptr[j] > vmax) vmax = x_ptr[j];\n " ;
84
84
out << SP << SP << " }\n " ;
85
85
86
86
out << SP << SP << fType << " sum = 0.0;\n " ;
87
- out << SP << SP << " for (size_t j = 0; j < " << axis_size << " ; ++j) {\n " ;
87
+ out << SP << SP << " for (int j = 0; j < " << axis_size << " ; ++j) {\n " ;
88
88
out << SP << SP << SP << " y_ptr[j] = std::exp(x_ptr[j] - vmax);\n " ;
89
89
out << SP << SP << SP << " sum += y_ptr[j];\n " ;
90
90
out << SP << SP << " }\n " ;
91
91
92
92
out << SP << SP << fType << " inv_sum = 1.0f / sum;\n " ;
93
- out << SP << SP << " for (size_t j = 0; j < " << axis_size << " ; ++j) {\n " ;
93
+ out << SP << SP << " for (int j = 0; j < " << axis_size << " ; ++j) {\n " ;
94
94
out << SP << SP << SP << " y_ptr[j] *= inv_sum;\n " ;
95
95
out << SP << SP << " }\n " ;
96
96
out << SP << " }\n " ;
@@ -100,7 +100,7 @@ public:
100
100
size_t k = 0 ;
101
101
std::vector<std::string> l (size);
102
102
for (size_t i = 0 ; i < size; i++) {
103
- if (static_cast < int >(i) != axis) {
103
+ if (i != axis) {
104
104
for (size_t j = 0 ; j < k; j++) out << SP;
105
105
l[i] = std::string (" i" ) + std::to_string (i);
106
106
out << " for (int " << l[i] << " = 0; " << l[i] << " < " << fShape [i] << " ; " << l[i] << " ++) {\n " ;
@@ -113,7 +113,7 @@ public:
113
113
out << " size_t index = " ;
114
114
bool first = true ;
115
115
for (size_t i = 0 ; i < size; i++) {
116
- if (static_cast < int >(i) == axis) continue ;
116
+ if (i == axis) continue ;
117
117
if (!first) out << " + " ;
118
118
if (stride[i].GetVal () != " 1" )
119
119
out << stride[i] << " *" ;
0 commit comments