Skip to content

Commit

Permalink
Improve PCT algorithm and make replay more deterministic (#94)
Browse files Browse the repository at this point in the history
* Improve PCT algorithm.

* Improve PCT algorithm and make replay more deterministic.
  • Loading branch information
aoli-al authored Jan 24, 2025
1 parent 2c16633 commit dab5161
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 35 deletions.
24 changes: 20 additions & 4 deletions core/src/main/kotlin/org/pastalab/fray/core/RunContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.pastalab.fray.core
import java.io.PrintWriter
import java.io.StringWriter
import java.lang.Thread.UncaughtExceptionHandler
import java.time.Instant
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
Expand Down Expand Up @@ -48,7 +49,7 @@ class RunContext(val config: Configuration) {
var mainThreadId: Long = -1
var bugFound: Throwable? = null
var mainExiting = false
var nanoTime = System.nanoTime()
var nanoTime = TimeUnit.SECONDS.toNanos(1577768400)
val hashCodeMapper = ReferencedContextManager<Int>({ config.randomnessProvider.nextInt() })
var forkJoinPool: ForkJoinPool? = null
private val semaphoreManager = ReferencedContextManager {
Expand Down Expand Up @@ -225,6 +226,7 @@ class RunContext(val config: Configuration) {
registeredThreads.clear()
config.scheduleObservers.forEach { it.onExecutionDone() }
hashCodeMapper.done(false)
nanoTime = TimeUnit.SECONDS.toNanos(1577768400)
}

fun shutDown() {
Expand Down Expand Up @@ -990,8 +992,14 @@ class RunContext(val config: Configuration) {
}

val nextThread =
config.scheduler.scheduleNextOperation(enabledOperations, registeredThreads.values.toList())
config.scheduleObservers.forEach { it.onNewSchedule(enabledOperations, nextThread) }
if (enabledOperations.size == 1) {
enabledOperations.first()
} else {
val thread = config.scheduler.scheduleNextOperation(enabledOperations, enabledOperations)
config.scheduleObservers.forEach { it.onNewSchedule(enabledOperations, thread) }
thread
}

currentThreadId = nextThread.thread.id
nextThread.state = ThreadState.Running
runThread(currentThread, nextThread)
Expand Down Expand Up @@ -1029,10 +1037,18 @@ class RunContext(val config: Configuration) {
}

fun nanoTime(): Long {
nanoTime += TimeUnit.MILLISECONDS.convert(10000, TimeUnit.NANOSECONDS)
nanoTime += TimeUnit.MILLISECONDS.toNanos(10000)
return nanoTime
}

fun currentTimeMillis(): Long {
return nanoTime() / 1000000
}

fun instantNow(): Instant {
return Instant.ofEpochMilli(currentTimeMillis())
}

fun getForkJoinPoolCommon(): ForkJoinPool {
if (forkJoinPool == null) {
forkJoinPool = ForkJoinPool()
Expand Down
20 changes: 19 additions & 1 deletion core/src/main/kotlin/org/pastalab/fray/core/RuntimeDelegate.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.pastalab.fray.core

import java.time.Duration
import java.time.Instant
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ForkJoinPool
Expand Down Expand Up @@ -839,7 +840,24 @@ class RuntimeDelegate(val context: RunContext) : org.pastalab.fray.runtime.Deleg
}

override fun onNanoTime(): Long {
return context.nanoTime()
if (checkEntered()) return System.nanoTime()
val value = context.nanoTime()
entered.set(false)
return value
}

override fun onCurrentTimeMillis(): Long {
if (checkEntered()) return System.currentTimeMillis()
val value = context.currentTimeMillis()
entered.set(false)
return value
}

override fun onInstantNow(): Instant {
if (checkEntered()) return Instant.now()
val instant = context.instantNow()
entered.set(false)
return instant
}

override fun onObjectHashCode(t: Any): Int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ data class Configuration(
val noFray: Boolean,
val dummyRun: Boolean,
) {
var scheduleObservers = mutableListOf<ScheduleObserver>()
val scheduleObservers = mutableListOf<ScheduleObserver>()
var nextSavedIndex = 0
var currentIteration = 0
val startTime = TimeSource.Monotonic.markNow()
Expand All @@ -248,7 +248,7 @@ data class Configuration(
if (!isReplay || !Paths.get(report).exists()) {
prepareReportPath(report)
}
if (System.getProperty("fray.recordSchedule", "false").toBoolean()) {
if (System.getProperty("fray.recordSchedule", "true").toBoolean()) {
scheduleObservers.add(ScheduleRecorder())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@ class ScheduleRecorder : ScheduleObserver {
}

override fun onNewSchedule(enabledSchedules: List<ThreadContext>, scheduled: ThreadContext) {
val operation = scheduled.pendingOperation.toString()
var operation = scheduled.pendingOperation.toString()
var count = 0
for (st in Thread.currentThread().stackTrace.drop(1)) {
if (st.className.startsWith("org.pastalab.fray")) {
continue
}
operation += "@${st.className}.${st.methodName},"
count += 1
if (count == 3) break
}
val enabled = enabledSchedules.map { it.index }.toList()
val scheduledIndex = scheduled.index
val recording = ScheduleRecording(scheduledIndex, enabled, operation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ class ScheduleVerifier(val schedules: List<ScheduleRecording>) : ScheduleObserve
val recording = schedules[index]
val scheduledIndex = scheduled.index
val enabled = enabledSchedules.map { it.index }.toList()
val operation = scheduled.pendingOperation.toString()
var operation = scheduled.pendingOperation.toString()
var count = 0
for (st in Thread.currentThread().stackTrace.drop(1)) {
if (st.className.startsWith("org.pastalab.fray")) {
continue
}
operation += "@${st.className}.${st.methodName},"
count += 1
if (count == 3) break
}
if (recording.scheduled != scheduledIndex) {
throw IllegalStateException(
"Scheduled index mismatch: expected ${recording.scheduled}, got $scheduledIndex")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,20 @@ import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
import org.pastalab.fray.core.ThreadContext
import org.pastalab.fray.core.randomness.ControlledRandom
import org.pastalab.fray.core.utils.Utils

@Serializable
class PCTScheduler(val rand: ControlledRandom, val numSwitchPoints: Int, var maxStep: Int) :
Scheduler {
constructor() : this(ControlledRandom(), 3, 0) {}

@Transient var currentStep = 0
@Transient var nextSwitchPoint = 0
@Transient var numSwitchPointLeft = numSwitchPoints
@Transient val threadPriorityQueue = mutableListOf<ThreadContext>()
@Transient val priorityChangePoints = mutableSetOf<Int>()

fun updateNextSwitchPoint() {
numSwitchPointLeft -= 1
if (numSwitchPoints == 0) return
val switchPointProbability =
if (maxStep == 0) {
0.1
} else {
1.0 * numSwitchPointLeft / maxStep
}
nextSwitchPoint += Utils.sampleGeometric(switchPointProbability, rand.nextDouble())
init {
if (maxStep != 0) {
preparePriorityChangePoints()
}
}

override fun scheduleNextOperation(
Expand All @@ -44,15 +36,24 @@ class PCTScheduler(val rand: ControlledRandom, val numSwitchPoints: Int, var max
}
}
val next = threadPriorityQueue.first { threads.contains(it) }
if (currentStep == nextSwitchPoint) {
if (priorityChangePoints.contains(currentStep)) {
threadPriorityQueue.remove(next)
threadPriorityQueue.add(next)
updateNextSwitchPoint()
}
return next
}

override fun nextIteration(): Scheduler {
return PCTScheduler(ControlledRandom(), numSwitchPoints, maxStep.coerceAtLeast(currentStep))
}

private fun preparePriorityChangePoints() {
val listOfInts = (1..maxStep).toMutableList()
for (i in 0 ..< numSwitchPoints) {
val index = rand.nextInt() % (listOfInts.size)
priorityChangePoints.add(listOfInts[index])
listOfInts.removeAt(index)
if (listOfInts.isEmpty()) break
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ApplicationCodeTransformer : ClassFileTransformer {
cv = SleepInstrumenter(cv)
cv = TimeInstrumenter(cv)
cv = SkipMethodInstrumenter(cv)
// cv = ObjectHashCodeInstrumenter(cv, false)
cv = ObjectHashCodeInstrumenter(cv, false)
cv = AtomicGetInstrumenter(cv)
// cv = ToStringInstrumenter(cv)
val classVersionInstrumenter = ClassVersionInstrumenter(cv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.pastalab.fray.instrumentation.base.visitors
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.ASM9
import org.objectweb.asm.Type
import org.objectweb.asm.commons.AdviceAdapter
import org.pastalab.fray.runtime.Runtime

Expand All @@ -25,8 +26,20 @@ class TimeInstrumenter(cv: ClassVisitor) : ClassVisitor(ASM9, cv) {
isInterface: Boolean
) {
if (owner == "java/lang/System" && name == "nanoTime") {
visitMethodInsn(
INVOKESTATIC, Runtime::class.java.name.replace(".", "/"), "onNanoTime", "()J", false)
invokeStatic(
Type.getObjectType(Runtime::class.java.name.replace(".", "/")),
Utils.kFunctionToASMMethod(Runtime::onNanoTime),
)
} else if (owner == "java/lang/System" && name == "currentTimeMillis") {
invokeStatic(
Type.getObjectType(Runtime::class.java.name.replace(".", "/")),
Utils.kFunctionToASMMethod(Runtime::onCurrentTimeMillis),
)
} else if (owner == "java/time/Instant" && name == "now") {
invokeStatic(
Type.getObjectType(Runtime::class.java.name.replace(".", "/")),
Utils.kFunctionToASMMethod(Runtime::onInstantNow),
)
} else {
super.visitMethodInsn(opcode, owner, name, descriptor, isInterface)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.pastalab.fray.test.success.time;

import java.time.Instant;

public class TestTime {
public static void main(String[] args) {
long t1 = System.nanoTime();
long t2 = System.nanoTime();
assert(t2 - t1 == 10000L * 1000000L);

long t3 = System.currentTimeMillis();
long t4 = System.currentTimeMillis();
assert(t4 - t3 == 10000);

Instant t5 = Instant.now();
Instant t6 = Instant.now();
assert(t6.toEpochMilli() - t5.toEpochMilli() == 10000);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
import org.pastalab.fray.core.command.LambdaExecutor;
import org.pastalab.fray.core.command.MethodExecutor;
import org.pastalab.fray.core.randomness.ControlledRandom;
import org.pastalab.fray.core.scheduler.PCTScheduler;
import org.pastalab.fray.core.scheduler.POSScheduler;
import org.pastalab.fray.core.scheduler.RandomScheduler;
import org.pastalab.fray.test.fail.cdl.CountDownLatchDeadlockUnblockMultiThread;
import org.pastalab.fray.test.fail.park.ParkDeadlock;
import org.pastalab.fray.test.fail.rwlock.ReentrantReadWriteLockDeadlock;
import org.pastalab.fray.test.fail.wait.NotifyOrder;
import org.pastalab.fray.test.success.condition.ConditionAwaitTimeoutNotifyInterrupt;
import org.pastalab.fray.test.success.rwlock.ReentrantReadWriteLockDowngradingNoDeadlock;
import org.pastalab.fray.test.success.rwlock.ReentrantReadWriteLockNoDeadlock;
import org.pastalab.fray.test.success.stampedlock.StampedLockTryLockNoDeadlock;
import org.pastalab.fray.test.success.thread.ThreadInterruptionWithoutStart;
import org.pastalab.fray.test.success.time.TestTime;

import java.util.*;

Expand All @@ -43,7 +47,7 @@ private DynamicTest populateTest(String className, boolean testShouldFail) {
"/tmp/report",
10000,
60,
new RandomScheduler(),
new PCTScheduler(),
new ControlledRandom(),
true,
false,
Expand All @@ -70,7 +74,7 @@ public void testOne() throws Throwable {
new ExecutionInfo(
new LambdaExecutor(() -> {
try {
ConditionAwaitTimeoutNotifyInterrupt.main(new String[]{});
TestTime.main(new String[]{});
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -83,7 +87,7 @@ public void testOne() throws Throwable {
"/tmp/report2",
200,
60,
new RandomScheduler(),
new PCTScheduler(),
new ControlledRandom(),
true,
false,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
package org.pastalab.fray.junit.plain

import java.io.File
import kotlin.io.path.Path
import kotlin.io.path.absolutePathString
import kotlin.io.path.exists
import kotlinx.serialization.json.Json
import org.pastalab.fray.core.TestRunner
import org.pastalab.fray.core.command.Configuration
import org.pastalab.fray.core.command.ExecutionInfo
import org.pastalab.fray.core.command.LambdaExecutor
import org.pastalab.fray.core.observers.ScheduleVerifier
import org.pastalab.fray.core.randomness.ControlledRandom
import org.pastalab.fray.core.scheduler.PCTScheduler
import org.pastalab.fray.core.scheduler.Scheduler
import org.pastalab.fray.junit.Common.WORK_DIR

object FrayInTestLauncher {

fun launchFray(runnable: Runnable, scheduler: Scheduler, randomnessProvider: ControlledRandom) {
fun launchFray(
runnable: Runnable,
scheduler: Scheduler,
randomnessProvider: ControlledRandom,
iteration: Int,
timeout: Int,
additionalConfigs: (Configuration) -> Unit = { _ -> }
) {
val config =
Configuration(
ExecutionInfo(
Expand All @@ -25,8 +35,8 @@ object FrayInTestLauncher {
false,
-1),
WORK_DIR.absolutePathString(),
10000,
60,
iteration,
timeout,
scheduler,
randomnessProvider,
true,
Expand All @@ -36,19 +46,26 @@ object FrayInTestLauncher {
false,
false,
)
additionalConfigs(config)
val runner = TestRunner(config)
runner.run()?.let { throw it }
}

fun launchFrayTest(test: Runnable) {
launchFray(test, PCTScheduler(ControlledRandom(), 15, 0), ControlledRandom())
launchFray(test, PCTScheduler(ControlledRandom(), 15, 0), ControlledRandom(), 10000, 120)
}

fun launchFrayReplay(test: Runnable, path: String) {
val randomPath = "${path}/random.json"
val schedulerPath = "${path}/schedule.json"
val randomnessProvider = Json.decodeFromString<ControlledRandom>(File(randomPath).readText())
val scheduler = Json.decodeFromString<Scheduler>(File(schedulerPath).readText())
launchFray(test, scheduler, randomnessProvider)
launchFray(test, scheduler, randomnessProvider, 1, 10000) {
val recording = Path("${path}/recording.json")
if (recording.exists()) {
val verifier = ScheduleVerifier(recording.absolutePathString())
it.scheduleObservers.add(verifier)
}
}
}
}
Loading

0 comments on commit dab5161

Please sign in to comment.