From b505e7942e446e78106ac58a69f6f681f7a8cd12 Mon Sep 17 00:00:00 2001 From: Matt Dziuban Date: Wed, 29 Jan 2025 16:23:58 -0500 Subject: [PATCH] Get process PID via reflection, use it to `kill -15` it before forcibly destroying it. --- .../scala/sys/process/ProcessWithPid.scala | 48 +++++++++++++++++++ src/main/scala/spray/revolver/Actions.scala | 6 +-- .../scala/spray/revolver/AppProcess.scala | 27 ++++++++--- 3 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 src/main/scala/scala/sys/process/ProcessWithPid.scala diff --git a/src/main/scala/scala/sys/process/ProcessWithPid.scala b/src/main/scala/scala/sys/process/ProcessWithPid.scala new file mode 100644 index 0000000..5856b20 --- /dev/null +++ b/src/main/scala/scala/sys/process/ProcessWithPid.scala @@ -0,0 +1,48 @@ +package scala.sys.process + +import java.lang.{Process => JProcess} +import sbt.{Level, Logger} +import scala.util.control.NonFatal + +case class ProcessWithPid(process: Process, pid: Option[Long]) + +object ProcessWithPid { + private def reflectJProcess(p: Process.SimpleProcess): JProcess = { + val field = p.getClass.getDeclaredField("p") + field.setAccessible(true) + field.get(p).asInstanceOf[JProcess] + } + + // Java 9+ has a `Process#pid()` method, but Java 8 and below have a private `Process#pid` field + // We first try to reflect on the method and then fall back to reflecting on the field + private def reflectJProcessPid(p: JProcess): Long = + try { + val method = classOf[JProcess].getMethod("pid") + method.invoke(p) match { + case pid: java.lang.Long => pid + case pid => throw new RuntimeException(s"Expected process PID ($pid) to be a Long, but it was a ${pid.getClass.getName}") + } + } catch { + case e: NoSuchMethodException => + val field = p.getClass.getDeclaredField("pid") + field.setAccessible(true) + field.getLong(p) + } + + def apply(process: Process, log: Logger): ProcessWithPid = + try { + process match { + case p: Process.SimpleProcess => + val jp = reflectJProcess(p) + val pid = reflectJProcessPid(jp) + ProcessWithPid(process, Some(pid)) + + case p => + throw new RuntimeException(s"Expected app process to be a Process.SimpleProcess but it was a ${p.getClass.getName}") + } + } catch { + case NonFatal(e) => + log.log(Level.Warn, s"Failed to determine process PID: $e") + ProcessWithPid(process, None) + } +} diff --git a/src/main/scala/spray/revolver/Actions.scala b/src/main/scala/spray/revolver/Actions.scala index 35ee891..851830a 100644 --- a/src/main/scala/spray/revolver/Actions.scala +++ b/src/main/scala/spray/revolver/Actions.scala @@ -19,7 +19,7 @@ package spray.revolver import sbt.Keys._ import sbt.{Fork, ForkOptions, LoggedOutput, Logger, Path, ProjectRef, State, complete} import java.io.File -import scala.sys.process.Process +import scala.sys.process.ProcessWithPid object Actions { import Utilities._ @@ -129,13 +129,13 @@ object Actions { def formatAppName(projectName: String, projectColor: String, color: String = "[YELLOW]"): String = "[RESET]%s%s[RESET]%s" format (projectColor, projectName, color) - def forkRun(config: ForkOptions, mainClass: String, classpath: Seq[File], options: Seq[String], log: Logger, extraJvmArgs: Seq[String]): Process = { + def forkRun(config: ForkOptions, mainClass: String, classpath: Seq[File], options: Seq[String], log: Logger, extraJvmArgs: Seq[String]): ProcessWithPid = { log.info(options.mkString("Starting " + mainClass + ".main(", ", ", ")")) val scalaOptions = "-classpath" :: Path.makeString(classpath) :: mainClass :: options.toList val newOptions = config .withOutputStrategy(config.outputStrategy getOrElse LoggedOutput(log)) .withRunJVMOptions(config.runJVMOptions ++ extraJvmArgs) - Fork.java.fork(newOptions, scalaOptions) + ProcessWithPid(Fork.java.fork(newOptions, scalaOptions), log) } } diff --git a/src/main/scala/spray/revolver/AppProcess.scala b/src/main/scala/spray/revolver/AppProcess.scala index 5abd677..d5db90d 100644 --- a/src/main/scala/spray/revolver/AppProcess.scala +++ b/src/main/scala/spray/revolver/AppProcess.scala @@ -17,22 +17,38 @@ package spray.revolver import java.lang.{Runtime => JRuntime} +import java.util.concurrent.TimeUnit import sbt.{Logger, ProjectRef} -import scala.sys.process.Process +import scala.sys.process.ProcessWithPid /** * A token which we put into the SBT state to hold the Process of an application running in the background. */ -case class AppProcess(projectRef: ProjectRef, consoleColor: String, log: Logger)(process: Process) { +case class AppProcess(projectRef: ProjectRef, consoleColor: String, log: Logger)(process: ProcessWithPid) { val shutdownHook = createShutdownHook("... killing ...") + private def destroyProcess(): Unit = process.process.destroy() + + private def killProcess(pid: Long): Unit = { + val exited = try { + JRuntime.getRuntime.exec(s"kill -15 $pid").waitFor(10, TimeUnit.SECONDS) + } catch { case e: InterruptedException => true } + + if (!exited) destroyProcess() + } + + private def stopProcess(): Int = { + process.pid.fold(destroyProcess())(killProcess) + process.process.exitValue() + } + def createShutdownHook(msg: => String) = new Thread(new Runnable { def run() { if (isRunning) { log.info(msg) - process.destroy() + stopProcess() } } }) @@ -42,7 +58,7 @@ case class AppProcess(projectRef: ProjectRef, consoleColor: String, log: Logger) val watchThread = { val thread = new Thread(new Runnable { def run() { - val code = process.exitValue() + val code = process.process.exitValue() finishState = Some(code) log.info("... finished with exit code %d" format code) unregisterShutdownHook() @@ -58,8 +74,7 @@ case class AppProcess(projectRef: ProjectRef, consoleColor: String, log: Logger) def stop() { unregisterShutdownHook() - process.destroy() - process.exitValue() + stopProcess() } def registerShutdownHook() {