Skip to content

Commit

Permalink
expand expire write optimization to more operations (fixes #1320)
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-manes committed Jan 16, 2025
1 parent 9b65365 commit 8022a16
Show file tree
Hide file tree
Showing 6 changed files with 397 additions and 107 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright 2025 Ben Manes. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.benmanes.caffeine.cache;

import static com.github.benmanes.caffeine.cache.BoundedLocalCache.EXPIRE_WRITE_TOLERANCE;

import java.time.Duration;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;

import com.google.common.cache.CacheBuilder;
import com.google.common.testing.FakeTicker;

import site.ycsb.generator.NumberGenerator;
import site.ycsb.generator.ScrambledZipfianGenerator;

/**
* A benchmark for the {@link BoundedLocalCache#EXPIRE_WRITE_TOLERANCE} optimization.
*
* @author [email protected] (Ben Manes)
*/
@State(Scope.Benchmark)
@SuppressWarnings({"CanonicalAnnotationSyntax", "LexicographicalAnnotationAttributeListing"})
public class ExpireWriteToleranceBenchmark {
static final int SIZE = (2 << 14);
static final int MASK = SIZE - 1;
static final int NUM_THREADS = 8;
static final int ITEMS = SIZE / 3;
static final int INITIAL_CAPACITY = 10_000;

@Param({
"Caffeine w/o tolerance",
"Caffeine w/ tolerance",
"ConcurrentHashMap",
"Guava"
})
String mapType;

Map<Integer, Integer> map;
Integer[] ints;

@State(Scope.Thread)
public static class ThreadState {
static final Random random = new Random();
int index = random.nextInt();
}

@Setup
@SuppressWarnings("ReturnValueIgnored")
public void setup() {
if (mapType.equals("ConcurrentHashMap")) {
map = new ConcurrentHashMap<>(INITIAL_CAPACITY);
} else if (mapType.equals("Caffeine w/ tolerance")) {
Cache<Integer, Integer> cache = Caffeine.newBuilder()
.expireAfterWrite(Duration.ofDays(1))
.initialCapacity(INITIAL_CAPACITY)
.ticker(new FakeTicker()::read)
.build();
map = cache.asMap();
} else if (mapType.equals("Caffeine w/o tolerance")) {
Cache<Integer, Integer> cache = Caffeine.newBuilder()
.expireAfterWrite(Duration.ofNanos(EXPIRE_WRITE_TOLERANCE / 2))
.initialCapacity(INITIAL_CAPACITY)
.ticker(new FakeTicker()::read)
.build();
map = cache.asMap();
} else if (mapType.equals("Guava")) {
com.google.common.cache.Cache<Integer, Integer> cache = CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofDays(1))
.initialCapacity(INITIAL_CAPACITY)
.ticker(new FakeTicker())
.build();
map = cache.asMap();
} else {
throw new AssertionError("Unknown computingType: " + mapType);
}

ints = new Integer[SIZE];
NumberGenerator generator = new ScrambledZipfianGenerator(ITEMS);
for (int i = 0; i < SIZE; i++) {
ints[i] = generator.nextValue().intValue();
map.put(ints[i], ints[i]);
}
}

@Benchmark @Threads(NUM_THREADS)
public Integer put(ThreadState threadState) {
var key = ints[threadState.index++ & MASK];
return map.put(key, key);
}

@Benchmark @Threads(NUM_THREADS)
public Integer replace(ThreadState threadState) {
var key = ints[threadState.index++ & MASK];
return map.replace(key, key);
}

@Benchmark @Threads(NUM_THREADS)
public Boolean replaceConditionally(ThreadState threadState) {
var key = ints[threadState.index++ & MASK];
return map.replace(key, key, key);
}

@Benchmark @Threads(NUM_THREADS)
public Integer compute(ThreadState threadState) {
return map.compute(ints[threadState.index++ & MASK], (k, v) -> k);
}

@Benchmark @Threads(NUM_THREADS)
public Integer computeIfPresent(ThreadState threadState) {
return map.computeIfPresent(ints[threadState.index++ & MASK], (k, v) -> k);
}

@Benchmark @Threads(NUM_THREADS)
public Integer merge(ThreadState threadState) {
var key = ints[threadState.index++ & MASK];
return map.merge(key, key, (k, v) -> k);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,19 @@ void setAccessTime(Node<K, V> node, long now) {
}
}

/** Returns if the entry's write time would exceed the minimum expiration reorder threshold. */
boolean exceedsWriteTimeTolerance(Node<K, V> node, long varTime, long now) {
long variableTime = node.getVariableTime();
long tolerance = EXPIRE_WRITE_TOLERANCE;
long writeTime = node.getWriteTime();
return
(expiresAfterWrite()
&& ((expiresAfterWriteNanos() <= tolerance) || (Math.abs(now - writeTime) > tolerance)))
|| (refreshAfterWrite()
&& ((refreshAfterWriteNanos() <= tolerance) || (Math.abs(now - writeTime) > tolerance)))
|| (expiresVariable() && (Math.abs(varTime - variableTime) > tolerance));
}

/**
* Performs the post-processing work required after a write.
*
Expand Down Expand Up @@ -2298,8 +2311,8 @@ public void putAll(Map<? extends K, ? extends V> map) {
if (node == null) {
node = nodeFactory.newNode(key, keyReferenceQueue(),
value, valueReferenceQueue(), newWeight, now);
setVariableTime(node, expireAfterCreate(key, value, expiry, now));
long expirationTime = isComputingAsync(value) ? (now + ASYNC_EXPIRY) : now;
setVariableTime(node, expireAfterCreate(key, value, expiry, now));
setAccessTime(node, expirationTime);
setWriteTime(node, expirationTime);
}
Expand Down Expand Up @@ -2384,11 +2397,10 @@ public void putAll(Map<? extends K, ? extends V> map) {

long expirationTime = isComputingAsync(value) ? (now + ASYNC_EXPIRY) : now;
if (mayUpdate) {
exceedsTolerance =
(expiresAfterWrite() && (now - prior.getWriteTime()) > EXPIRE_WRITE_TOLERANCE)
|| (expiresVariable()
&& Math.abs(varTime - prior.getVariableTime()) > EXPIRE_WRITE_TOLERANCE);
setWriteTime(prior, expirationTime);
exceedsTolerance = exceedsWriteTimeTolerance(prior, varTime, now);
if (expired || exceedsTolerance) {
setWriteTime(prior, isComputingAsync(value) ? (now + ASYNC_EXPIRY) : now);
}

prior.setValue(value, valueReferenceQueue());
prior.setWeight(newWeight);
Expand Down Expand Up @@ -2514,8 +2526,9 @@ public boolean remove(Object key, Object value) {
requireNonNull(key);
requireNonNull(value);

long[] now = new long[1];
var now = new long[1];
var oldWeight = new int[1];
var exceedsTolerance = new boolean[1];
@SuppressWarnings({"unchecked", "Varifier"})
@Nullable K[] nodeKey = (K[]) new Object[1];
@SuppressWarnings({"unchecked", "Varifier"})
Expand All @@ -2538,8 +2551,11 @@ public boolean remove(Object key, Object value) {
n.setWeight(weight);

long expirationTime = isComputingAsync(value) ? (now[0] + ASYNC_EXPIRY) : now[0];
exceedsTolerance[0] = exceedsWriteTimeTolerance(n, varTime, expirationTime);
if (exceedsTolerance[0]) {
setWriteTime(n, expirationTime);
}
setAccessTime(n, expirationTime);
setWriteTime(n, expirationTime);
setVariableTime(n, varTime);

discardRefresh(k);
Expand All @@ -2552,7 +2568,7 @@ public boolean remove(Object key, Object value) {
}

int weightedDifference = (weight - oldWeight[0]);
if (expiresAfterWrite() || (weightedDifference != 0)) {
if (exceedsTolerance[0] || (weightedDifference != 0)) {
afterWrite(new UpdateTask(node, weightedDifference));
} else {
afterRead(node, now[0], /* recordHit= */ false);
Expand All @@ -2573,13 +2589,15 @@ public boolean replace(K key, V oldValue, V newValue, boolean shouldDiscardRefre
requireNonNull(oldValue);
requireNonNull(newValue);

int weight = weigher.weigh(key, newValue);
var now = new long[1];
var oldWeight = new int[1];
var exceedsTolerance = new boolean[1];
@SuppressWarnings({"unchecked", "Varifier"})
@Nullable K[] nodeKey = (K[]) new Object[1];
@SuppressWarnings({"unchecked", "Varifier"})
@Nullable V[] prevValue = (V[]) new Object[1];
int[] oldWeight = new int[1];
long[] now = new long[1];

int weight = weigher.weigh(key, newValue);
Node<K, V> node = data.computeIfPresent(nodeFactory.newLookupKey(key), (k, n) -> {
synchronized (n) {
requireIsAlive(key, n);
Expand All @@ -2597,8 +2615,11 @@ public boolean replace(K key, V oldValue, V newValue, boolean shouldDiscardRefre
n.setWeight(weight);

long expirationTime = isComputingAsync(newValue) ? (now[0] + ASYNC_EXPIRY) : now[0];
exceedsTolerance[0] = exceedsWriteTimeTolerance(n, varTime, expirationTime);
if (exceedsTolerance[0]) {
setWriteTime(n, expirationTime);
}
setAccessTime(n, expirationTime);
setWriteTime(n, expirationTime);
setVariableTime(n, varTime);

if (shouldDiscardRefresh) {
Expand All @@ -2613,7 +2634,7 @@ public boolean replace(K key, V oldValue, V newValue, boolean shouldDiscardRefre
}

int weightedDifference = (weight - oldWeight[0]);
if (expiresAfterWrite() || (weightedDifference != 0)) {
if (exceedsTolerance[0] || (weightedDifference != 0)) {
afterWrite(new UpdateTask(node, weightedDifference));
} else {
afterRead(node, now[0], /* recordHit= */ false);
Expand Down Expand Up @@ -2688,8 +2709,8 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
weight[1] = weigher.weigh(key, newValue[0]);
var created = nodeFactory.newNode(key, keyReferenceQueue(),
newValue[0], valueReferenceQueue(), weight[1], now[0]);
setVariableTime(created, expireAfterCreate(key, newValue[0], expiry(), now[0]));
long expirationTime = isComputingAsync(newValue[0]) ? (now[0] + ASYNC_EXPIRY) : now[0];
setVariableTime(created, expireAfterCreate(key, newValue[0], expiry(), now[0]));
setAccessTime(created, expirationTime);
setWriteTime(created, expirationTime);
return created;
Expand Down Expand Up @@ -2724,15 +2745,11 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
n.setValue(newValue[0], valueReferenceQueue());
n.setWeight(weight[1]);

long expirationTime = isComputingAsync(newValue[0]) ? (now[0] + ASYNC_EXPIRY) : now[0];
setAccessTime(n, expirationTime);
setWriteTime(n, expirationTime);
setVariableTime(n, varTime);
if (isComputingAsync(newValue[0])) {
long expirationTime = now[0] + ASYNC_EXPIRY;
setAccessTime(n, expirationTime);
setWriteTime(n, expirationTime);
} else {
setAccessTime(n, now[0]);
setWriteTime(n, now[0]);
}

discardRefresh(k);
return n;
}
Expand Down Expand Up @@ -2853,6 +2870,7 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {

var weight = new int[2]; // old, new
var cause = new RemovalCause[1];
var exceedsTolerance = new boolean[1];

Node<K, V> node = data.compute(keyRef, (kr, n) -> {
if (n == null) {
Expand Down Expand Up @@ -2925,8 +2943,11 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
n.setWeight(weight[1]);

long expirationTime = isComputingAsync(newValue[0]) ? (now[0] + ASYNC_EXPIRY) : now[0];
exceedsTolerance[0] = exceedsWriteTimeTolerance(n, varTime, expirationTime);
if (((cause[0] != null) && cause[0].wasEvicted()) || exceedsTolerance[0]) {
setWriteTime(n, expirationTime);
}
setAccessTime(n, expirationTime);
setWriteTime(n, expirationTime);
setVariableTime(n, varTime);

discardRefresh(kr);
Expand Down Expand Up @@ -2954,7 +2975,7 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
afterWrite(new AddTask(node, weight[1]));
} else {
int weightedDifference = weight[1] - weight[0];
if (expiresAfterWrite() || (weightedDifference != 0)) {
if (exceedsTolerance[0] || (weightedDifference != 0)) {
afterWrite(new UpdateTask(node, weightedDifference));
} else {
afterRead(node, now[0], /* recordHit= */ false);
Expand Down
Loading

0 comments on commit 8022a16

Please sign in to comment.