Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jerrinot's improvement - everyone is hunting ILP? let's do less! #652

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 21 additions & 68 deletions src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class CalculateAverage_jerrinot {
// todo: with hyper-threading enable we would be better of with availableProcessors / 2;
// todo: validate the testing env. params.
private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors();
// private static final int THREAD_COUNT = 4;
// private static final int THREAD_COUNT = 1;

private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;

Expand Down Expand Up @@ -82,7 +82,7 @@ static void calculate() throws Exception {
final File file = new File(MEASUREMENTS_TXT);
final long length = file.length();
// final int chunkCount = Runtime.getRuntime().availableProcessors();
int chunkPerThread = 3;
int chunkPerThread = 2;
final int chunkCount = THREAD_COUNT * chunkPerThread;
final var chunkStartOffsets = new long[chunkCount + 1];
try (var raf = new RandomAccessFile(file, "r")) {
Expand All @@ -107,10 +107,8 @@ static void calculate() throws Exception {
long endA = chunkStartOffsets[i * chunkPerThread + 1];
long startB = chunkStartOffsets[i * chunkPerThread + 1];
long endB = chunkStartOffsets[i * chunkPerThread + 2];
long startC = chunkStartOffsets[i * chunkPerThread + 2];
long endC = chunkStartOffsets[i * chunkPerThread + 3];

Processor processor = new Processor(startA, endA, startB, endB, startC, endC);
Processor processor = new Processor(startA, endA, startB, endB);
processors[i] = processor;
Thread thread = new Thread(processor);
threads[i] = thread;
Expand All @@ -122,9 +120,7 @@ static void calculate() throws Exception {
long endA = chunkStartOffsets[ownIndex * chunkPerThread + 1];
long startB = chunkStartOffsets[ownIndex * chunkPerThread + 1];
long endB = chunkStartOffsets[ownIndex * chunkPerThread + 2];
long startC = chunkStartOffsets[ownIndex * chunkPerThread + 2];
long endC = chunkStartOffsets[ownIndex * chunkPerThread + 3];
Processor processor = new Processor(startA, endA, startB, endB, startC, endC);
Processor processor = new Processor(startA, endA, startB, endB);
processor.run();

var accumulator = new TreeMap<String, Processor.StationStats>();
Expand Down Expand Up @@ -219,8 +215,6 @@ private static class Processor implements Runnable {
private long endA;
private long cursorB;
private long endB;
private long cursorC;
private long endC;
private HashMap<String, StationStats> stats = new HashMap<>(1000);

// private long maxClusterLen;
Expand Down Expand Up @@ -287,19 +281,16 @@ void accumulateStatus(TreeMap<String, StationStats> accumulator) {
}
}

Processor(long startA, long endA, long startB, long endB, long startC, long endC) {
Processor(long startA, long endA, long startB, long endB) {
this.cursorA = startA;
this.cursorB = startB;
this.cursorC = startC;
this.endA = endA;
this.endB = endB;
this.endC = endC;
}

private void doTail(long fastMAp) {
doOne(cursorA, endA);
doOne(cursorB, endB);
doOne(cursorC, endC);

transferToHeap(fastMAp);
// UNSAFE.freeMemory(fastMap);
Expand Down Expand Up @@ -402,71 +393,55 @@ public void run() {
UNSAFE.setMemory(fastMap, FAST_MAP_SIZE_BYTES, (byte) 0);
UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0);

while (cursorA < endA && cursorB < endB && cursorC < endC) {
while (cursorA < endA && cursorB < endB) {
long currentWordA = UNSAFE.getLong(cursorA);
long currentWordB = UNSAFE.getLong(cursorB);
long currentWordC = UNSAFE.getLong(cursorC);

long candidateWordA = UNSAFE.getLong(cursorA + 8);
long candidateWordB = UNSAFE.getLong(cursorB + 8);

long startA = cursorA;
long startB = cursorB;
long startC = cursorC;

long maskA = getDelimiterMask(currentWordA);
long maskB = getDelimiterMask(currentWordB);
long maskC = getDelimiterMask(currentWordC);

long maskComplementA = -maskA;
long maskComplementB = -maskB;
long maskComplementC = -maskC;
long newExpA = (1L << (Long.numberOfTrailingZeros(maskA) - 1)) >> 63;
long newExpB = (1L << (Long.numberOfTrailingZeros(maskB) - 1)) >> 63;

long maskWithDelimiterA = (maskA ^ (maskA - 1));
long maskWithDelimiterB = (maskB ^ (maskB - 1));
long maskWithDelimiterC = (maskC ^ (maskC - 1));

long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1);
long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1);
long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1);

cursorA += isMaskZeroA << 3;
cursorB += isMaskZeroB << 3;
cursorC += isMaskZeroC << 3;
cursorA += newExpA & 8;
cursorB += newExpB & 8;

long nextWordA = UNSAFE.getLong(cursorA);
long nextWordB = UNSAFE.getLong(cursorB);
long nextWordC = UNSAFE.getLong(cursorC);
long nextWordA = (newExpA & candidateWordA) | (~newExpA & currentWordA);
long nextWordB = (newExpB & candidateWordB) | (~newExpB & currentWordB);

long firstWordMaskA = maskWithDelimiterA >>> 8;
long firstWordMaskB = maskWithDelimiterB >>> 8;
long firstWordMaskC = maskWithDelimiterC >>> 8;

long firstWordFinalMaskA = (newExpA | firstWordMaskA);
long firstWordFinalMaskB = (newExpB | firstWordMaskB);

long nextMaskA = getDelimiterMask(nextWordA);
long nextMaskB = getDelimiterMask(nextWordB);
long nextMaskC = getDelimiterMask(nextWordC);

boolean slowA = nextMaskA == 0;
boolean slowB = nextMaskB == 0;
boolean slowC = nextMaskC == 0;
boolean slowSome = (slowA || slowB || slowC);

long extA = -isMaskZeroA;
long extB = -isMaskZeroB;
long extC = -isMaskZeroC;
boolean slowSome = (slowA || slowB);

long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA;
long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB;
long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC;
long maskedFirstWordA = firstWordFinalMaskA & currentWordA;
long maskedFirstWordB = firstWordFinalMaskB & currentWordB;

int hashA = hash(maskedFirstWordA);
int hashB = hash(maskedFirstWordB);
int hashC = hash(maskedFirstWordC);

currentWordA = nextWordA;
currentWordB = nextWordB;
currentWordC = nextWordC;

maskA = nextMaskA;
maskB = nextMaskB;
maskC = nextMaskC;
if (slowSome) {
while (maskA == 0) {
cursorA += 8;
Expand All @@ -479,48 +454,34 @@ public void run() {
currentWordB = UNSAFE.getLong(cursorB);
maskB = getDelimiterMask(currentWordB);
}
while (maskC == 0) {
cursorC += 8;
currentWordC = UNSAFE.getLong(cursorC);
maskC = getDelimiterMask(currentWordC);
}
}

final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
final int delimiterByteC = Long.numberOfTrailingZeros(maskC);

final long semicolonA = cursorA + (delimiterByteA >> 3);
final long semicolonB = cursorB + (delimiterByteB >> 3);
final long semicolonC = cursorC + (delimiterByteC >> 3);

long digitStartA = semicolonA + 1;
long digitStartB = semicolonB + 1;
long digitStartC = semicolonC + 1;

long temperatureWordA = UNSAFE.getLong(digitStartA);
long temperatureWordB = UNSAFE.getLong(digitStartB);
long temperatureWordC = UNSAFE.getLong(digitStartC);

long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8;
long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8;
long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8;

final long maskedLastWordA = currentWordA & lastWordMaskA;
final long maskedLastWordB = currentWordB & lastWordMaskB;
final long maskedLastWordC = currentWordC & lastWordMaskC;

int lenA = (int) (semicolonA - startA);
int lenB = (int) (semicolonB - startB);
int lenC = (int) (semicolonC - startC);

int mapIndexA = hashA & MAP_MASK;
int mapIndexB = hashB & MAP_MASK;
int mapIndexC = hashC & MAP_MASK;

long baseEntryPtrA;
long baseEntryPtrB;
long baseEntryPtrC;

if (slowSome) {
if (slowA) {
Expand All @@ -537,22 +498,14 @@ public void run() {
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB, fastMap);
}

if (slowC) {
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC);
}
else {
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC, fastMap);
}
}
else {
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA, fastMap);
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB, fastMap);
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC, fastMap);
}

cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA);
cursorB = parseAndStoreTemperature(digitStartB, baseEntryPtrB, temperatureWordB);
cursorC = parseAndStoreTemperature(digitStartC, baseEntryPtrC, temperatureWordC);
}
doTail(fastMap);
// System.out.println("Longest chain: " + longestChain);
Expand Down
Loading