From c0c0eeb5cbe302494a9adf12ab5ab543e5ed1104 Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Fri, 26 Apr 2024 07:26:16 -0700 Subject: [PATCH] Batched inference output. --- .../org/linqs/psl/config/RuntimeOptions.java | 7 +++++++ .../main/java/org/linqs/psl/runtime/Runtime.java | 16 +++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java b/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java index c0c37fd82..30e994a54 100644 --- a/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java +++ b/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java @@ -94,6 +94,13 @@ public class RuntimeOptions { "Use the specified InferenceApplication when running inference." ); + public static final Option INFERENCE_OUTPUT_BATCHED_RESULTS = new Option( + "runtime.inference.output.batched.results", + false, + "Whether to output the inferred atoms after inference and organize output by batches." + + " This is useful if the neural component is batching." + ); + public static final Option INFERENCE_OUTPUT_RESULTS = new Option( "runtime.inference.output.results", true, diff --git a/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java b/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java index 2b1195a9e..079217444 100644 --- a/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java +++ b/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java @@ -431,6 +431,8 @@ protected void runInferenceInternal(RuntimeConfig config, Model model, RuntimeRe // Run inference. boolean runInference = true; + int batch = 0; + String outputDir = RuntimeOptions.INFERENCE_OUTPUT_RESULTS_DIR.getString(); while (runInference) { DeepPredicate.predictAllDeepPredicates(); @@ -439,13 +441,19 @@ protected void runInferenceInternal(RuntimeConfig config, Model model, RuntimeRe log.info("Inference complete."); if (RuntimeOptions.INFERENCE_OUTPUT_RESULTS.getBoolean()) { - String outputDir = RuntimeOptions.INFERENCE_OUTPUT_RESULTS_DIR.getString(); if (outputDir == null) { log.info("Writing inferred predicates to stdout."); targetDatabase.outputRandomVariableAtoms(); } else { - log.info("Writing inferred predicates to directory: " + outputDir); - targetDatabase.outputRandomVariableAtoms(outputDir); + if (RuntimeOptions.INFERENCE_OUTPUT_BATCHED_RESULTS.getBoolean()) { + String batchOutputDir = FileUtils.makePath(outputDir, String.format("batch_%d", batch)); + + log.info("Writing inferred predicates to directory: " + batchOutputDir); + targetDatabase.outputRandomVariableAtoms(batchOutputDir); + } else { + log.info("Writing inferred predicates to directory: " + outputDir); + targetDatabase.outputRandomVariableAtoms(outputDir); + } } } @@ -453,6 +461,8 @@ protected void runInferenceInternal(RuntimeConfig config, Model model, RuntimeRe DeepPredicate.nextBatchAllDeepPredicates(); runInference = !DeepPredicate.isEpochCompleteAllDeepPredicates(); + + batch++; } DeepPredicate.epochEndAllDeepPredicates();