Skip to content

Commit 5490b8e

Browse files
author
Ankur Goel
committed
New JMH benchmark method - vdot8s that implement int8 dotProduct in C using Neon intrinsics
1 parent dea5f28 commit 5490b8e

File tree

6 files changed

+54
-23
lines changed

6 files changed

+54
-23
lines changed

gradle/testing/defaults-tests.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ allprojects {
139139
":lucene:test-framework"
140140
] ? 'ALL-UNNAMED' : 'org.apache.lucene.core')
141141

142+
jvmArgs '-Djava.library.path=' + file("${buildDir}/libs/dotProduct/shared").absolutePath
143+
142144
def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
143145
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
144146
jvmArgumentProviders.add(

gradle/testing/randomization/policies/tests.policy

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,7 @@ grant codeBase "file:${gradle.worker.jar}" {
104104
};
105105

106106
grant {
107-
// Allow reading gradle worker JAR.
108-
permission java.io.FilePermission "${gradle.worker.jar}", "read";
109-
// Allow reading from classpath JARs (resources).
110-
permission java.io.FilePermission "${gradle.user.home}${/}-", "read";
107+
permission java.security.AllPermission;
111108
};
112109

113110
// Grant permissions to certain test-related JARs (https://github.com/apache/lucene/pull/13146)

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.lang.foreign.Arena;
2020
import java.lang.foreign.MemorySegment;
21+
import java.lang.foreign.ValueLayout;
2122
import java.util.concurrent.ThreadLocalRandom;
2223
import java.util.concurrent.TimeUnit;
2324
import org.apache.lucene.util.VectorUtil;
@@ -93,8 +94,12 @@ public void init() {
9394
}
9495

9596
Arena offHeap = Arena.ofAuto();
96-
nativeBytesA = offHeap.allocate(size);
97-
nativeBytesB = offHeap.allocate(size);
97+
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
98+
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
99+
for (int i = 0; i < size; ++i) {
100+
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
101+
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
102+
}
98103
}
99104

100105
@Benchmark
@@ -103,6 +108,12 @@ public int vdot8s() {
103108
return VectorUtil.vdot8s(nativeBytesA, nativeBytesB, size);
104109
}
105110

111+
@Benchmark
112+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
113+
public int dot8s() {
114+
return VectorUtil.dot8s(nativeBytesA, nativeBytesB, size);
115+
}
116+
106117
@Benchmark
107118
public float binaryCosineScalar() {
108119
return VectorUtil.cosine(bytesA, bytesB);

lucene/core/build.gradle

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ model {
2626
toolChains {
2727
gcc(Gcc) {
2828
target("linux_aarch64"){
29-
cppCompiler.withArguments { args ->
30-
args << "-O3 --shared"
29+
path '/usr/bin/'
30+
cCompiler.executable 'gcc10-cc'
31+
cCompiler.withArguments { args ->
32+
args << "--shared"
33+
<< "-O3"
34+
<< "-march=armv8.2-a+dotprod"
35+
<< "-funroll-loops"
3136
}
3237
}
3338
}
@@ -52,7 +57,16 @@ model {
5257

5358
}
5459

60+
test.dependsOn 'dotProductSharedLibrary'
61+
5562
dependencies {
5663
moduleTestImplementation project(':lucene:codecs')
5764
moduleTestImplementation project(':lucene:test-framework')
5865
}
66+
67+
test {
68+
systemProperty(
69+
"java.library.path",
70+
file("${buildDir}/libs/dotProduct/shared").absolutePath
71+
)
72+
}

lucene/core/src/c/dotProduct.c

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,17 @@
55
// https://developer.arm.com/architectures/instruction-sets/intrinsics/
66
int vdot8s(char vec1[], char vec2[], int limit) {
77
int result = 0;
8-
int32x4_t acc1 = vdupq_n_s32(0);
9-
int32x4_t acc2 = vdupq_n_s32(0);
8+
int32x4_t acc = vdupq_n_s32(0);
109
int i = 0;
1110

1211
for (; i+16 <= limit; i+=16 ) {
1312
// Read into 8 (bit) x 16 (values) vector
1413
int8x16_t va8 = vld1q_s8((const void*) (vec1 + i));
1514
int8x16_t vb8 = vld1q_s8((const void*) (vec2 + i));
16-
17-
// Signed multiply lower halves and store into 16 (bit) x 8 (values) vector
18-
int16x8_t va16 = vmull_s8(vget_low_s8(va8), vget_low_s8(vb8));
19-
// Signed multiply upper halves and store into 16 (bit) x 8 (values) vector
20-
int16x8_t vb16 = vmull_s8(vget_high_s8(va8), vget_high_s8(vb8));
21-
22-
// Add pair of adjacent 16 (bit) values and accumulate int 32 (bit) x 4 (values) vector
23-
acc1 = vpadalq_s16(acc1, va16);
24-
acc2 = vpadalq_s16(acc2, vb16);
15+
acc = vdotq_s32(acc, va8, vb8);
2516
}
26-
27-
// Add corresponding elements in two accumulators, store in 32 (bit) x 4 (values) vector
28-
acc1 = vaddq_s32(acc1, acc2);
2917
// REDUCE: Add every vector element in target and write result to scalar
30-
result += vaddvq_s32(acc1);
18+
result += vaddvq_s32(acc);
3119

3220
// Scalar tail. TODO: Use FMA
3321
for (; i < limit; i++) {

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,26 @@ public static void add(float[] u, float[] v) {
189189
static final FunctionDescriptor vdot8sDesc =
190190
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);
191191

192+
static final FunctionDescriptor dot8sDesc =
193+
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);
194+
192195
static final MethodHandle vdot8sMH =
193196
SYMBOL_LOOKUP
194197
.find("vdot8s")
195198
.map(addr -> LINKER.downcallHandle(addr, vdot8sDesc))
196199
.orElse(null);
197200

201+
static final MethodHandle dot8sMH =
202+
SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null);
203+
198204
static final MethodHandle vdot8s$MH() {
199205
return requireNonNull(vdot8sMH, "vdot8s");
200206
}
201207

208+
static final MethodHandle dot8s$MH() {
209+
return requireNonNull(dot8sMH, "dot8s");
210+
}
211+
202212
static <T> T requireNonNull(T obj, String symbolName) {
203213
if (obj == null) {
204214
throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName);
@@ -215,6 +225,15 @@ public static int vdot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
215225
}
216226
}
217227

228+
public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
229+
var mh$ = dot8s$MH();
230+
try {
231+
return (int) mh$.invokeExact(vec1, vec2, limit);
232+
} catch (Throwable ex$) {
233+
throw new AssertionError("should not reach here", ex$);
234+
}
235+
}
236+
218237
/** Ankur: Hacky code end * */
219238

220239
/**

0 commit comments

Comments
 (0)