Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process 2 consecutive lines at a time in each thread #651

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions calculate_average_ianopolousfast.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@
#

JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector"
#-Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -XX:-UseTransparentHugePages"

java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast
150 changes: 63 additions & 87 deletions src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
Expand All @@ -39,10 +38,7 @@
* * read chunks in parallel
* * minimise allocation
* * no unsafe
*
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
* average_baseline: 4m48s
* ianopolous: 13.8s
* * process multiple lines in each thread for better ILP
*/
public class CalculateAverage_ianopolousfast {

Expand Down Expand Up @@ -91,11 +87,22 @@ public static boolean matchingStationBytes(long start, long end, MemorySegment b
return true;
}

private static int hashToIndex(long hash, int len) {
// From Thomas Wuerthinger's entry
int hashAsInt = (int) (hash ^ (hash >>> 28));
int finalHash = (hashAsInt ^ (hashAsInt >>> 15));
return (finalHash & (len - 1));
private static final int GOLDEN_RATIO = 0x9E3779B9;
private static final int HASH_LROTATE = 5;

// hash from giovannicuccu
private static int hash(MemorySegment memorySegment, long start, int len) {
int x;
int y;
if (len >= Integer.BYTES) {
x = memorySegment.get(JAVA_INT_UNALIGNED, start);
y = memorySegment.get(JAVA_INT_UNALIGNED, start + len - Integer.BYTES);
}
else {
x = memorySegment.get(JAVA_BYTE, start);
y = memorySegment.get(JAVA_BYTE, start + len - Byte.BYTES);
}
return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO;
}

public static Stat createStation(long start, long end, MemorySegment buffer) {
Expand All @@ -105,8 +112,9 @@ public static Stat createStation(long start, long end, MemorySegment buffer) {
return new Stat(stationBuffer);
}

public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, Stat[] stations) {
int index = hashToIndex(hash, MAX_STATIONS);
public static Stat dedupeStation(long start, long end, MemorySegment buffer, Stat[] stations) {
int hash = hash(buffer, start, (int) (end - start));
int index = hash & (MAX_STATIONS - 1);
Stat match = stations[index];
while (match != null) {
if (matchingStationBytes(start, end, buffer, match))
Expand All @@ -119,37 +127,11 @@ public static Stat dedupeStation(long start, long end, long hash, MemorySegment
return res;
}

static long maskHighBytes(long d, int nbytes) {
return d & (-1L << ((8 - nbytes) * 8));
}

public static Stat parseStation(long lineStart, MemorySegment buffer, Stat[] stations) {
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();

long first8 = buffer.get(LONG_LAYOUT, lineStart);
long second8 = 0;
if (keySize <= 8) {
first8 = maskHighBytes(first8, keySize & 0x07);
}
else if (keySize < 16) {
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
}
else if (keySize == BYTE_SPECIES.vectorByteSize()) {
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
keySize++;
}
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
}
long hash = first8 ^ second8; // todo include later bytes
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
}

public static short getMinus(long d) {
return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1;
}

public static long processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) {
public static void processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) {
long d = buffer.get(LONG_LAYOUT, lineSplit);
// negative is either 0 or -1
short negative = getMinus(d);
Expand All @@ -162,10 +144,9 @@ public static long processTemperature(long lineSplit, int size, MemorySegment bu
100 * (((byte) (d >> 24)) - '0'));
temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
station.add(temperature);
return lineSplit + size + 1;
}

private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stations) {
private static int lineSize(long lineStart, MemorySegment buffer) {
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue();
int index = lineSize;
Expand All @@ -174,33 +155,19 @@ private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stati
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue();
lineSize += index;
}
int keySize = lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6,
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
return lineSize;
}

long first8 = buffer.get(LONG_LAYOUT, lineStart);
long second8 = 0;
if (keySize <= 8) {
first8 = maskHighBytes(first8, keySize & 0x07);
}
else if (keySize < 16) {
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
}
else if (keySize == BYTE_SPECIES.vectorByteSize()) {
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
keySize++;
}
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
}
long hash = first8 ^ second8; // todo include later bytes
Stat station = dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
return processTemperature(lineStart + keySize + 1, lineSize - keySize - 1, buffer, station);
private static int keySize(int lineSize, long lineStart, MemorySegment buffer) {
return lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6,
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
}

public static Stat[] parseStats(long startByte, long endByte, MemorySegment buffer) {
public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) {
// read first partial line
if (startByte > 0) {
if (start1 > 0) {
for (int i = 0; i < MAX_LINE_LENGTH; i++) {
byte b = buffer.get(JAVA_BYTE, startByte++);
byte b = buffer.get(JAVA_BYTE, start1++);
if (b == '\n') {
break;
}
Expand All @@ -213,38 +180,47 @@ public static Stat[] parseStats(long startByte, long endByte, MemorySegment buff
// this allows us to not worry about reading beyond the end
// in the inner loop (reducing branches)
// We need at least the vector lane size bytes back
if (endByte == buffer.byteSize()) {
if (end2 == buffer.byteSize()) {
// reverse at least vector lane width
endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0);
while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
endByte--;
end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0);
while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n')
end2--;

if (endByte > 0)
endByte++;
if (end2 > 0)
end2++;
// copy into a larger buffer to avoid reading off end
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize());
for (long i = endByte; i < buffer.byteSize(); i++)
end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize());
for (long i = end2; i < buffer.byteSize(); i++)
end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i));
int index = 0;
while (endByte + index < buffer.byteSize()) {
Stat station = parseStation(index, end, stations);
int tempSize = 3;
if (end.get(JAVA_BYTE, index + station.namelen + 5) == '\n')
tempSize = 4;
if (end.get(JAVA_BYTE, index + station.namelen + 6) == '\n')
tempSize = 5;
index = (int) processTemperature(index + station.namelen + 1, tempSize, end, station);
while (end2 + index < buffer.byteSize()) {
int lineSize1 = lineSize(index, end);
int semiSearchStart = index + Math.max(0, lineSize1 - 6);
int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart,
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
Stat station1 = dedupeStation(index, index + keySize1, end, stations);
processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1);
index += lineSize1 + 1;
}
}

innerloop(startByte, endByte, buffer, stations);
return stations;
}

private static void innerloop(long startByte, long endByte, MemorySegment buffer, Stat[] stations) {
while (startByte < endByte) {
startByte = parseLine(startByte, buffer, stations);
while (start1 < end2) {
int lineSize1 = lineSize(start1, buffer);
long start2 = start1 + lineSize1 + 1;
int lineSize2 = start2 < end2 ? lineSize(start2, buffer) : 0;
int keySize1 = keySize(lineSize1, start1, buffer);
int keySize2 = keySize(lineSize2, start2, buffer);
Stat station1 = dedupeStation(start1, start1 + keySize1, buffer, stations);
processTemperature(start1 + keySize1 + 1, lineSize1 - keySize1 - 1, buffer, station1);
if (start2 < end2) {
Stat station2 = dedupeStation(start2, start2 + keySize2, buffer, stations);
processTemperature(start2 + keySize2 + 1, lineSize2 - keySize2 - 1, buffer, station2);
start1 = start2 + lineSize2 + 1;
}
else
start1 += lineSize1 + 1;
}
return stations;
}

public static class Stat {
Expand Down
Loading