Skip to content

Commit

Permalink
Inlined the hash function, runs locally in 2.4sec now, hopefully endi…
Browse files Browse the repository at this point in the history
…an issues fix
  • Loading branch information
royvanrijn committed Jan 4, 2024
1 parent 09c4e60 commit d741175
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion calculate_average_royvanrijn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#


sdk use java 21.0.1-graal
# Added for fun, doesn't seem to be making a difference...
if [ -f "target/calculate_average_royvanrijn.jsa" ]; then
JAVA_OPTS="-XX:SharedArchiveFile=target/calculate_average_royvanrijn.jsa -Xshare:on"
Expand Down
88 changes: 50 additions & 38 deletions src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
* Improved String skip: 3250 ms
* Segmenting files: 3150 ms (based on spullara's code)
* Not using SWAR for EOL: 2850 ms
* Inlining hash calculation: 2450 ms
*
* Best performing JVM on MacBook M2 Pro: 21.0.1-graal
* `sdk use java 21.0.1-graal`
Expand Down Expand Up @@ -88,7 +89,7 @@ private double round(double value) {
}
}

public static final void main(String[] args) throws Exception {
public static void main(String[] args) throws Exception {
new CalculateAverage_royvanrijn().run();
}

Expand All @@ -100,18 +101,18 @@ private void run() throws Exception {
try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) {
var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start());
var buffer = new byte[64];
var pointerAndHash = new int[2];

// Force little endian:
bb.order(ByteOrder.LITTLE_ENDIAN);

final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN);
BitTwiddledMap measurements = new BitTwiddledMap();

int startPointer;
int limit = bb.limit();
while ((startPointer = bb.position()) < limit) {

// SWAR is faster for ';'
int separatorPointer = findNextSWAR(bb, SEPARATOR_PATTERN, startPointer + 3, limit);
findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, startPointer, limit, pointerAndHash, bufferIsBigEndian);
int separatorPointer = pointerAndHash[0];

// Simple is faster for '\n' (just three options)
int endPointer;
Expand All @@ -133,17 +134,16 @@ else if (bb.get(separatorPointer + 5) == '\n') {
final int nameLength = separatorPointer - startPointer;
final int valueLength = endPointer - separatorPointer - 1;
final int measured = branchlessParseInt(buffer, nameLength + 1, valueLength);
measurements.getOrCreate(buffer, nameLength).updateWith(measured);
measurements.getOrCreate(buffer, nameLength, pointerAndHash[1]).updateWith(measured);
}
return measurements;
}
catch (IOException e) {
throw new RuntimeException(e);
}
}).parallel().flatMap(v -> v.values.stream())
.collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, (m1, m2) -> m1.updateWith(m2), TreeMap::new));
.collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, Measurement::updateWith, TreeMap::new));

// Seems to perform better than actually using a TreeMap:
System.out.println(results);
}

Expand All @@ -152,42 +152,64 @@ else if (bb.get(separatorPointer + 5) == '\n') {
*/
private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');

private int findNextSWAR(ByteBuffer bb, long pattern, int start, int limit) {
/**
* Already looping the longs here, lets shoehorn in making a hash
*/
private void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output,
final boolean bufferBigEndian) {
int hash = 1;
int i;
for (i = start; i <= limit - 8; i += 8) {
// checks 8x8 bytes
long word = bb.getLong(i);
int index = firstAnyPattern(word, pattern);
int index = firstAnyPattern(word, pattern, bufferBigEndian);
if (index < Long.BYTES) {
return i + index;
final long mask = ((1L << (index * 8)) - 1);
final long partialHash = (bufferBigEndian ? Long.reverseBytes(word) : word) & mask;
hash = 31 * hash + (int) (partialHash >>> 32);
hash = 31 * hash + (int) partialHash;
output[0] = (i + index);
output[1] = hash;
return;
}
hash = 31 * hash + (int) (word >>> 32);
hash = 31 * hash + (int) word;
}
// Handle remaining bytes
long partialHash = 0;
for (; i < limit; i++) {
if (bb.get(i) == (byte) pattern) {
return i;
byte read;
if ((read = bb.get(i)) == (byte) pattern) {
hash = 31 * hash + (int) (partialHash >>> 32);
hash = 31 * hash + (int) partialHash;
output[0] = i;
output[1] = hash;
return;
}
partialHash = partialHash << 8 | read;
}
return limit; // delimiter not found
output[0] = limit; // delimiter not found
output[1] = hash;
}

private static long compilePattern(byte value) {
private static long compilePattern(final byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
}

private static int firstAnyPattern(long word, long pattern) {
private static int firstAnyPattern(final long word, final long pattern, final boolean bufferBigEndian) {
final long match = word ^ pattern;
long mask = match - 0x0101010101010101L;
mask &= ~match;
mask &= 0x8080808080808080L;
return Long.numberOfTrailingZeros(mask) >>> 3;
return (bufferBigEndian ? Long.numberOfLeadingZeros(mask) : Long.numberOfTrailingZeros(mask)) >>> 3;
}

record FileSegment(long start, long end) {
}

/** Using this way to segment the file is much prettier, from spullara */
private static List<FileSegment> getFileSegments(File file) throws IOException {
private static List<FileSegment> getFileSegments(final File file) throws IOException {
final int numberOfSegments = Runtime.getRuntime().availableProcessors();
final long fileSize = file.length();
final long segmentSize = fileSize / numberOfSegments;
Expand All @@ -205,7 +227,7 @@ private static List<FileSegment> getFileSegments(File file) throws IOException {
return segments;
}

private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException {
private static long findSegment(final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) throws IOException {
if (i != skipSegment) {
raf.seek(location);
while (location < fileSize) {
Expand All @@ -226,7 +248,7 @@ private static long findSegment(int i, int skipSegment, RandomAccessFile raf, lo
* @param input
* @return int value x10
*/
private static int branchlessParseInt(final byte[] input, int start, int length) {
private static int branchlessParseInt(final byte[] input, final int start, final int length) {
// 0 if positive, 1 if negative
final int negative = ~(input[start] >> 4) & 1;
// 0 if nr length is 3, 1 if length is 4
Expand Down Expand Up @@ -265,16 +287,14 @@ class BitTwiddledMap {
BitTwiddledMap() {
// Optimized fill with -1, fastest method:
int len = indices.length;
if (len > 0) {
indices[0] = -1;
}
indices[0] = -1;
// Value of i will be [1, 2, 4, 8, 16, 32, ..., len]
for (int i = 1; i < len; i += i) {
System.arraycopy(indices, 0, indices, i, i);
}
}

private List<Entry> values = new ArrayList<>(512);
private final List<Entry> values = new ArrayList<>(512);

record Entry(int hash, byte[] key, Measurement measurement) {
@Override
Expand All @@ -288,12 +308,13 @@ public String toString() {
* @param key
* @return
*/
public Measurement getOrCreate(byte[] key, int length) {
int inHash;
int index = (SIZE - 1) & (inHash = hashCode(key, length));
public Measurement getOrCreate(byte[] key, int length, int calculatedHash) {

int index = (SIZE - 1) & calculatedHash;
int valueIndex;
Entry retrievedEntry = null;
while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != inHash) {
while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != calculatedHash) {

index = (index + 1) % SIZE;
}
if (valueIndex >= 0) {
Expand All @@ -306,18 +327,9 @@ public Measurement getOrCreate(byte[] key, int length) {
byte[] actualKey = new byte[length];
System.arraycopy(key, 0, actualKey, 0, length);

Entry toAdd = new Entry(inHash, actualKey, new Measurement());
Entry toAdd = new Entry(calculatedHash, actualKey, new Measurement());
values.add(toAdd);
return toAdd.measurement;
}

private static int hashCode(byte[] a, int length) {
int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + a[i];
}
return result;
}
}

}

0 comments on commit d741175

Please sign in to comment.