Skip to content

Commit ddba43e

Browse files
committed
initial fma impl for ArrayMath#innerProduct for experiments
1 parent bbfed25 commit ddba43e

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

opennlp-api/src/main/java/opennlp/tools/ml/ArrayMath.java

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,55 @@
2626
*/
2727
public class ArrayMath {
2828

29+
private static final String OS_NAME = System.getProperty("os.name", "Unknown");
30+
private static final String OS_ARCH = System.getProperty("os.arch", "Unknown");
31+
private static final boolean MAC_OS_X = OS_NAME.startsWith("Mac OS X");
32+
33+
private static boolean hasHWVectorFMA() {
34+
// aarch64 has hw fma, but not on silicon
35+
if (OS_ARCH.equals("aarch64") && !MAC_OS_X) {
36+
return true;
37+
}
38+
// intel et al. support it nowadays
39+
if (OS_ARCH.equals("amd64")) {
40+
return true;
41+
}
42+
// otherwise
43+
return false;
44+
}
45+
2946
public static double innerProduct(double[] vecA, double[] vecB) {
3047
if (vecA == null || vecB == null || vecA.length != vecB.length)
3148
return Double.NaN;
3249

33-
double product = 0.0;
34-
for (int i = 0; i < vecA.length; i++) {
35-
product += vecA[i] * vecB[i];
50+
if (hasHWVectorFMA()) {
51+
double product = 0;
52+
int i = 0;
53+
54+
// unroll, in case the arrays are large enough
55+
if (vecA.length > 32) {
56+
double acc1 = 0, acc2 = 0, acc3 = 0, acc4 = 0;
57+
int upperBound = vecA.length & ~(4 - 1);
58+
for (; i < upperBound; i += 4) {
59+
acc1 = StrictMath.fma(vecA[i], vecB[i], acc1);
60+
acc2 = StrictMath.fma(vecA[i + 1], vecB[i + 1], acc2);
61+
acc3 = StrictMath.fma(vecA[i + 2], vecB[i + 2], acc3);
62+
acc4 = StrictMath.fma(vecA[i + 3], vecB[i + 3], acc4);
63+
}
64+
product += acc1 + acc2 + acc3 + acc4;
65+
}
66+
67+
for (; i < vecA.length; i++) {
68+
product = StrictMath.fma(vecA[i], vecB[i], product);
69+
}
70+
return product;
71+
} else {
72+
double product = 0.0;
73+
for (int i = 0; i < vecA.length; i++) {
74+
product += vecA[i] * vecB[i];
75+
}
76+
return product;
3677
}
37-
return product;
3878
}
3979

4080
/**

0 commit comments

Comments
 (0)