diff --git a/calculate_average_royvanrijn.sh b/calculate_average_royvanrijn.sh index ede6451c8f..c1e595ddcf 100755 --- a/calculate_average_royvanrijn.sh +++ b/calculate_average_royvanrijn.sh @@ -15,7 +15,7 @@ # limitations under the License. # - +sdk use java 21.0.1-graal # Added for fun, doesn't seem to be making a difference... if [ -f "target/calculate_average_royvanrijn.jsa" ]; then JAVA_OPTS="-XX:SharedArchiveFile=target/calculate_average_royvanrijn.jsa -Xshare:on" diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java index 5fc38aef31..956290634a 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java @@ -44,6 +44,7 @@ * Improved String skip: 3250 ms * Segmenting files: 3150 ms (based on spullara's code) * Not using SWAR for EOL: 2850 ms + * Inlining hash calculation: 2450 ms * * Best performing JVM on MacBook M2 Pro: 21.0.1-graal * `sdk use java 21.0.1-graal` @@ -88,7 +89,7 @@ private double round(double value) { } } - public static final void main(String[] args) throws Exception { + public static void main(String[] args) throws Exception { new CalculateAverage_royvanrijn().run(); } @@ -100,10 +101,9 @@ private void run() throws Exception { try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) { var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start()); var buffer = new byte[64]; + var pointerAndHash = new int[2]; - // Force little endian: - bb.order(ByteOrder.LITTLE_ENDIAN); - + final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN); BitTwiddledMap measurements = new BitTwiddledMap(); int startPointer; @@ -111,7 +111,8 @@ private void run() throws Exception { while ((startPointer = bb.position()) < limit) { // SWAR is faster for ';' - int separatorPointer = findNextSWAR(bb, SEPARATOR_PATTERN, startPointer + 3, limit); + findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, startPointer, limit, pointerAndHash, bufferIsBigEndian); + int separatorPointer = pointerAndHash[0]; // Simple is faster for '\n' (just three options) int endPointer; @@ -133,7 +134,7 @@ else if (bb.get(separatorPointer + 5) == '\n') { final int nameLength = separatorPointer - startPointer; final int valueLength = endPointer - separatorPointer - 1; final int measured = branchlessParseInt(buffer, nameLength + 1, valueLength); - measurements.getOrCreate(buffer, nameLength).updateWith(measured); + measurements.getOrCreate(buffer, nameLength, pointerAndHash[1]).updateWith(measured); } return measurements; } @@ -141,9 +142,8 @@ else if (bb.get(separatorPointer + 5) == '\n') { throw new RuntimeException(e); } }).parallel().flatMap(v -> v.values.stream()) - .collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, (m1, m2) -> m1.updateWith(m2), TreeMap::new)); + .collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, Measurement::updateWith, TreeMap::new)); - // Seems to perform better than actually using a TreeMap: System.out.println(results); } @@ -152,42 +152,64 @@ else if (bb.get(separatorPointer + 5) == '\n') { */ private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); - private int findNextSWAR(ByteBuffer bb, long pattern, int start, int limit) { + /** + * Already looping the longs here, lets shoehorn in making a hash + */ + private void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output, + final boolean bufferBigEndian) { + int hash = 1; int i; for (i = start; i <= limit - 8; i += 8) { + // checks 8x8 bytes long word = bb.getLong(i); - int index = firstAnyPattern(word, pattern); + int index = firstAnyPattern(word, pattern, bufferBigEndian); if (index < Long.BYTES) { - return i + index; + final long mask = ((1L << (index * 8)) - 1); + final long partialHash = (bufferBigEndian ? Long.reverseBytes(word) : word) & mask; + hash = 31 * hash + (int) (partialHash >>> 32); + hash = 31 * hash + (int) partialHash; + output[0] = (i + index); + output[1] = hash; + return; } + hash = 31 * hash + (int) (word >>> 32); + hash = 31 * hash + (int) word; } // Handle remaining bytes + long partialHash = 0; for (; i < limit; i++) { - if (bb.get(i) == (byte) pattern) { - return i; + byte read; + if ((read = bb.get(i)) == (byte) pattern) { + hash = 31 * hash + (int) (partialHash >>> 32); + hash = 31 * hash + (int) partialHash; + output[0] = i; + output[1] = hash; + return; } + partialHash = partialHash << 8 | read; } - return limit; // delimiter not found + output[0] = limit; // delimiter not found + output[1] = hash; } - private static long compilePattern(byte value) { + private static long compilePattern(final byte value) { return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; } - private static int firstAnyPattern(long word, long pattern) { + private static int firstAnyPattern(final long word, final long pattern, final boolean bufferBigEndian) { final long match = word ^ pattern; long mask = match - 0x0101010101010101L; mask &= ~match; mask &= 0x8080808080808080L; - return Long.numberOfTrailingZeros(mask) >>> 3; + return (bufferBigEndian ? Long.numberOfLeadingZeros(mask) : Long.numberOfTrailingZeros(mask)) >>> 3; } record FileSegment(long start, long end) { } /** Using this way to segment the file is much prettier, from spullara */ - private static List getFileSegments(File file) throws IOException { + private static List getFileSegments(final File file) throws IOException { final int numberOfSegments = Runtime.getRuntime().availableProcessors(); final long fileSize = file.length(); final long segmentSize = fileSize / numberOfSegments; @@ -205,7 +227,7 @@ private static List getFileSegments(File file) throws IOException { return segments; } - private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException { + private static long findSegment(final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) throws IOException { if (i != skipSegment) { raf.seek(location); while (location < fileSize) { @@ -226,7 +248,7 @@ private static long findSegment(int i, int skipSegment, RandomAccessFile raf, lo * @param input * @return int value x10 */ - private static int branchlessParseInt(final byte[] input, int start, int length) { + private static int branchlessParseInt(final byte[] input, final int start, final int length) { // 0 if positive, 1 if negative final int negative = ~(input[start] >> 4) & 1; // 0 if nr length is 3, 1 if length is 4 @@ -265,16 +287,14 @@ class BitTwiddledMap { BitTwiddledMap() { // Optimized fill with -1, fastest method: int len = indices.length; - if (len > 0) { - indices[0] = -1; - } + indices[0] = -1; // Value of i will be [1, 2, 4, 8, 16, 32, ..., len] for (int i = 1; i < len; i += i) { System.arraycopy(indices, 0, indices, i, i); } } - private List values = new ArrayList<>(512); + private final List values = new ArrayList<>(512); record Entry(int hash, byte[] key, Measurement measurement) { @Override @@ -288,12 +308,13 @@ public String toString() { * @param key * @return */ - public Measurement getOrCreate(byte[] key, int length) { - int inHash; - int index = (SIZE - 1) & (inHash = hashCode(key, length)); + public Measurement getOrCreate(byte[] key, int length, int calculatedHash) { + + int index = (SIZE - 1) & calculatedHash; int valueIndex; Entry retrievedEntry = null; - while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != inHash) { + while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != calculatedHash) { + index = (index + 1) % SIZE; } if (valueIndex >= 0) { @@ -306,18 +327,9 @@ public Measurement getOrCreate(byte[] key, int length) { byte[] actualKey = new byte[length]; System.arraycopy(key, 0, actualKey, 0, length); - Entry toAdd = new Entry(inHash, actualKey, new Measurement()); + Entry toAdd = new Entry(calculatedHash, actualKey, new Measurement()); values.add(toAdd); return toAdd.measurement; } - - private static int hashCode(byte[] a, int length) { - int result = 1; - for (int i = 0; i < length; i++) { - result = 31 * result + a[i]; - } - return result; - } } - }