diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index ad6fb5f0a75..2ce6c7dc263 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -27,23 +27,42 @@ class BackendServer(backend: Backend) { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) private[this] val handler = new BackendHttpHandler(backend) - private[this] val executor = Executors.newFixedThreadPool(1, - new ThreadFactory() { - private[this] val childFactory = Executors.defaultThreadFactory() - - def newThread(r: Runnable): Thread = { - val t = childFactory.newThread(r) - t.setDaemon(true) - t + private[this] val thread = { + // This HTTP server *must not* start non-daemon threads because such threads keep the JVM + // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest + // when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of the + // JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel + // explicitly regardless of the JVM). It *does* manifest when submitting jobs with + // + // gcloud dataproc submit ... + // + // or + // + // spark-submit + // + // setExecutor(null) ensures the server creates no new threads: + // + // > If this method is not called (before start()) or if it is called with a null Executor, then + // > a default implementation is used, which uses the thread which was created by the start() + // > method. + // + // Source: https://docs.oracle.com/javase/8/docs/jre/api/net/httpserver/spec/com/sun/net/httpserver/HttpServer.html#setExecutor-java.util.concurrent.Executor- + // + httpServer.createContext("/", handler) + httpServer.setExecutor(null) + val t = Executors.defaultThreadFactory().newThread(new Runnable() { + def run(): Unit = { + httpServer.start() } }) + t.setDaemon(true) + t + } def port = httpServer.getAddress.getPort def start(): Unit = { - httpServer.createContext("/", handler) - httpServer.setExecutor(executor) - httpServer.start() + thread.start() } def stop(): Unit = {