From 1923b0b58297c0799644446e72e0a60fe07b08a3 Mon Sep 17 00:00:00 2001 From: Richard Zowalla Date: Tue, 8 Oct 2024 11:26:58 +0200 Subject: [PATCH] OPENNLP-1618 - AbstractDL does not release Ort Resources --- .../src/main/java/opennlp/dl/AbstractDL.java | 18 +- .../dl/doccat/DocumentCategorizerDLEval.java | 273 +++++++++--------- .../dl/namefinder/NameFinderDLEval.java | 117 ++++---- .../dl/vectors/SentenceVectorsDLEval.java | 17 +- 4 files changed, 228 insertions(+), 197 deletions(-) diff --git a/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java b/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java index 7d891bea4..b1b0bf67a 100644 --- a/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java +++ b/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java @@ -34,7 +34,7 @@ /** * Base class for OpenNLP deep-learning classes using ONNX Runtime. */ -public abstract class AbstractDL { +public abstract class AbstractDL implements AutoCloseable { public static final String INPUT_IDS = "input_ids"; public static final String ATTENTION_MASK = "attention_mask"; @@ -50,7 +50,6 @@ public abstract class AbstractDL { * * @param vocabFile The vocabulary file. * @return A map of vocabulary words to integer IDs. - * * @throws IOException Thrown if the vocabulary file cannot be opened or read. */ public Map loadVocab(final File vocabFile) throws IOException { @@ -66,4 +65,19 @@ public Map loadVocab(final File vocabFile) throws IOException { return vocab; } + /** + * Closes this resource, relinquishing any underlying resources. + * + * @throws Exception If it failed to close. + */ + @Override + public void close() throws Exception { + if (session != null) { + session.close(); + } + if (env != null) { + env.close(); + } + } + } diff --git a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java index 506b0bc51..6f86e8f97 100644 --- a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java +++ b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java @@ -18,7 +18,6 @@ package opennlp.dl.doccat; import java.io.File; -import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -26,7 +25,6 @@ import java.util.Map; import java.util.Set; -import ai.onnxruntime.OrtException; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -42,97 +40,99 @@ public class DocumentCategorizerDLEval extends AbstractDLTest { private static final Logger logger = LoggerFactory.getLogger(DocumentCategorizerDLEval.class); final String text = "We try hard to identify the sources and licenses of all media such as text," + - " images or sounds used in our encyclopedia articles. Still, we cannot guarantee that all " + - "media are used or marked correctly: for example, if an image description page states " + - "that an image was in the public domain, you should still check yourself whether that claim " + - "appears correct and decide for yourself whether your use of the image would be fine under " + - "the laws applicable to you. Wikipedia is primarily subject to U.S. law; re-users outside " + - "the U.S. should be aware that they are subject to the laws of their country, which almost " + - "certainly are different. Images published under the GFDL or one of the Creative Commons " + - "Licenses are unlikely to pose problems, as these are specific licenses with precise terms " + - "worldwide. Public domain images may need to be re-evaluated by a re-user because it depends " + - "on each country's copyright laws what is in the public domain there. There is no guarantee " + - "that something in the public domain in the U.S. was also in the public domain in your country."; + " images or sounds used in our encyclopedia articles. Still, we cannot guarantee that all " + + "media are used or marked correctly: for example, if an image description page states " + + "that an image was in the public domain, you should still check yourself whether that claim " + + "appears correct and decide for yourself whether your use of the image would be fine under " + + "the laws applicable to you. Wikipedia is primarily subject to U.S. law; re-users outside " + + "the U.S. should be aware that they are subject to the laws of their country, which almost " + + "certainly are different. Images published under the GFDL or one of the Creative Commons " + + "Licenses are unlikely to pose problems, as these are specific licenses with precise terms " + + "worldwide. Public domain images may need to be re-evaluated by a re-user because it depends " + + "on each country's copyright laws what is in the public domain there. There is no guarantee " + + "that something in the public domain in the U.S. was also in the public domain in your country."; @Test - public void categorize() throws IOException, OrtException { + public void categorize() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, getCategories(), - new AverageClassificationScoringStrategy(), - new InferenceOptions()); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, getCategories(), + new AverageClassificationScoringStrategy(), + new InferenceOptions())) { - final double[] result = documentCategorizerDL.categorize(new String[]{text}); + final double[] result = documentCategorizerDL.categorize(new String[] {text}); - // Sort the result for easier comparison. - final double[] sortedResult = Arrays.stream(result) - .boxed() - .sorted(Collections.reverseOrder()).mapToDouble(Double::doubleValue).toArray(); + // Sort the result for easier comparison. + final double[] sortedResult = Arrays.stream(result) + .boxed() + .sorted(Collections.reverseOrder()).mapToDouble(Double::doubleValue).toArray(); - final double[] expected = new double[] - {0.3391093313694, - 0.2611352801322937, - 0.24420668184757233, - 0.11939861625432968, - 0.03615010157227516}; + final double[] expected = new double[] + {0.3391093313694, + 0.2611352801322937, + 0.24420668184757233, + 0.11939861625432968, + 0.03615010157227516}; - logger.debug("Actual: {}", Arrays.toString(sortedResult)); - logger.debug("Expected: {}", Arrays.toString(expected)); + logger.debug("Actual: {}", Arrays.toString(sortedResult)); + logger.debug("Expected: {}", Arrays.toString(expected)); - Assertions.assertArrayEquals(expected, sortedResult, 0.000001); - Assertions.assertEquals(5, result.length); + Assertions.assertArrayEquals(expected, sortedResult, 0.000001); + Assertions.assertEquals(5, result.length); - final String category = documentCategorizerDL.getBestCategory(result); - Assertions.assertEquals("bad", category); + final String category = documentCategorizerDL.getBestCategory(result); + Assertions.assertEquals("bad", category); + } } @Test - public void categorizeWithAutomaticLabels() throws IOException, OrtException { + public void categorizeWithAutomaticLabels() throws Exception { final File model = new File(getOpennlpDataDir(), - "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"); + "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"); final File vocab = new File(getOpennlpDataDir(), - "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"); + "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"); final File config = new File(getOpennlpDataDir(), - "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.json"); + "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.json"); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, config, - new AverageClassificationScoringStrategy(), - new InferenceOptions()); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, config, + new AverageClassificationScoringStrategy(), + new InferenceOptions())) { - final double[] result = documentCategorizerDL.categorize(new String[]{text}); + final double[] result = documentCategorizerDL.categorize(new String[] {text}); - // Sort the result for easier comparison. - final double[] sortedResult = Arrays.stream(result) - .boxed() - .sorted(Collections.reverseOrder()).mapToDouble(Double::doubleValue).toArray(); + // Sort the result for easier comparison. + final double[] sortedResult = Arrays.stream(result) + .boxed() + .sorted(Collections.reverseOrder()).mapToDouble(Double::doubleValue).toArray(); - final double[] expected = new double[] - {0.3391093313694, - 0.2611352801322937, - 0.24420668184757233, - 0.11939861625432968, - 0.03615010157227516}; + final double[] expected = new double[] + {0.3391093313694, + 0.2611352801322937, + 0.24420668184757233, + 0.11939861625432968, + 0.03615010157227516}; - logger.debug("Actual: {}", Arrays.toString(sortedResult)); - logger.debug("Expected: {}", Arrays.toString(expected)); + logger.debug("Actual: {}", Arrays.toString(sortedResult)); + logger.debug("Expected: {}", Arrays.toString(expected)); - Assertions.assertArrayEquals(expected, sortedResult, 0.000001); - Assertions.assertEquals(5, result.length); + Assertions.assertArrayEquals(expected, sortedResult, 0.000001); + Assertions.assertEquals(5, result.length); - final String category = documentCategorizerDL.getBestCategory(result); - Assertions.assertEquals("2 stars", category); + final String category = documentCategorizerDL.getBestCategory(result); + Assertions.assertEquals("2 stars", category); + } } - @Disabled("This test will should only be run if a GPU device is present.") + @Disabled("This test should only be run if a GPU device is present.") @Test public void categorizeWithGpu() throws Exception { @@ -145,26 +145,27 @@ public void categorizeWithGpu() throws Exception { inferenceOptions.setGpu(true); inferenceOptions.setGpuDeviceId(0); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, getCategories(), - new AverageClassificationScoringStrategy(), - new InferenceOptions()); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, getCategories(), + new AverageClassificationScoringStrategy(), + new InferenceOptions())) { - final double[] result = documentCategorizerDL.categorize(new String[]{"I am happy"}); - logger.debug(Arrays.toString(result)); + final double[] result = documentCategorizerDL.categorize(new String[] {"I am happy"}); + logger.debug(Arrays.toString(result)); - final double[] expected = new double[] - {0.007819971069693565, - 0.006593209225684404, - 0.04995147883892059, - 0.3003573715686798, - 0.6352779865264893}; + final double[] expected = new double[] + {0.007819971069693565, + 0.006593209225684404, + 0.04995147883892059, + 0.3003573715686798, + 0.6352779865264893}; - Assertions.assertArrayEquals(expected, result, 0.000001); - Assertions.assertEquals(5, result.length); + Assertions.assertArrayEquals(expected, result, 0.000001); + Assertions.assertEquals(5, result.length); - final String category = documentCategorizerDL.getBestCategory(result); - Assertions.assertEquals("very good", category); + final String category = documentCategorizerDL.getBestCategory(result); + Assertions.assertEquals("very good", category); + } } @@ -183,21 +184,21 @@ public void categorizeWithInferenceOptions() throws Exception { categories.put(0, "negative"); categories.put(1, "positive"); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, categories, - new AverageClassificationScoringStrategy(), - inferenceOptions); - - final double[] result = documentCategorizerDL.categorize(new String[]{"I am angry"}); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, categories, + new AverageClassificationScoringStrategy(), + inferenceOptions)) { - final double[] expected = new double[]{0.8851314783096313, 0.11486853659152985}; + final double[] result = documentCategorizerDL.categorize(new String[] {"I am angry"}); - Assertions.assertArrayEquals(expected, result, 0.000001); - Assertions.assertEquals(2, result.length); + final double[] expected = new double[] {0.8851314783096313, 0.11486853659152985}; - final String category = documentCategorizerDL.getBestCategory(result); - Assertions.assertEquals("negative", category); + Assertions.assertArrayEquals(expected, result, 0.000001); + Assertions.assertEquals(2, result.length); + final String category = documentCategorizerDL.getBestCategory(result); + Assertions.assertEquals("negative", category); + } } @Test @@ -208,85 +209,89 @@ public void scoreMap() throws Exception { final File vocab = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, getCategories(), - new AverageClassificationScoringStrategy(), - new InferenceOptions()); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, getCategories(), + new AverageClassificationScoringStrategy(), + new InferenceOptions())) { - final Map result = documentCategorizerDL.scoreMap(new String[]{"I am happy"}); + final Map result = documentCategorizerDL.scoreMap(new String[] {"I am happy"}); - Assertions.assertEquals(0.6352779865264893, result.get("very good"), 0.000001); - Assertions.assertEquals(0.3003573715686798, result.get("good"), 0.000001); - Assertions.assertEquals(0.04995147883892059, result.get("neutral"), 0.000001); - Assertions.assertEquals(0.006593209225684404, result.get("bad"), 0.000001); - Assertions.assertEquals(0.007819971069693565, result.get("very bad"), 0.000001); + Assertions.assertEquals(0.6352779865264893, result.get("very good"), 0.000001); + Assertions.assertEquals(0.3003573715686798, result.get("good"), 0.000001); + Assertions.assertEquals(0.04995147883892059, result.get("neutral"), 0.000001); + Assertions.assertEquals(0.006593209225684404, result.get("bad"), 0.000001); + Assertions.assertEquals(0.007819971069693565, result.get("very bad"), 0.000001); + } } @Test - public void sortedScoreMap() throws IOException, OrtException { + public void sortedScoreMap() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, getCategories(), - new AverageClassificationScoringStrategy(), - new InferenceOptions()); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, getCategories(), + new AverageClassificationScoringStrategy(), + new InferenceOptions())) { - final Map> result = documentCategorizerDL.sortedScoreMap(new String[]{"I am happy"}); + final Map> result = + documentCategorizerDL.sortedScoreMap(new String[] {"I am happy"}); - Assertions.assertNotNull(result, "Result must not be NULL."); - Assertions.assertEquals(5, result.size()); + Assertions.assertNotNull(result, "Result must not be NULL."); + Assertions.assertEquals(5, result.size()); - final Iterator>> it = result.entrySet().iterator(); + final Iterator>> it = result.entrySet().iterator(); - // we assume a sorted map here, so lets check in sorted order (lower values first). - Map.Entry> e = it.next(); - Assertions.assertEquals(0.006593209225684404, e.getKey(), 0.000001); - Assertions.assertEquals(e.getValue().size(), 1); + // we assume a sorted map here, so lets check in sorted order (lower values first). + Map.Entry> e = it.next(); + Assertions.assertEquals(0.006593209225684404, e.getKey(), 0.000001); + Assertions.assertEquals(e.getValue().size(), 1); - e = it.next(); - Assertions.assertEquals(0.007819971069693565, e.getKey(), 0.000001); - Assertions.assertEquals(e.getValue().size(), 1); + e = it.next(); + Assertions.assertEquals(0.007819971069693565, e.getKey(), 0.000001); + Assertions.assertEquals(e.getValue().size(), 1); - e = it.next(); - Assertions.assertEquals(0.04995147883892059, e.getKey(), 0.000001); - Assertions.assertEquals(e.getValue().size(), 1); + e = it.next(); + Assertions.assertEquals(0.04995147883892059, e.getKey(), 0.000001); + Assertions.assertEquals(e.getValue().size(), 1); - e = it.next(); - Assertions.assertEquals(0.3003573715686798, e.getKey(), 0.000001); - Assertions.assertEquals(e.getValue().size(), 1); + e = it.next(); + Assertions.assertEquals(0.3003573715686798, e.getKey(), 0.000001); + Assertions.assertEquals(e.getValue().size(), 1); - e = it.next(); - Assertions.assertEquals(0.6352779865264893, e.getKey(), 0.000001); - Assertions.assertEquals(e.getValue().size(), 1); + e = it.next(); + Assertions.assertEquals(0.6352779865264893, e.getKey(), 0.000001); + Assertions.assertEquals(e.getValue().size(), 1); + } } @Test - public void doccat() throws IOException, OrtException { + public void doccat() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"); - final DocumentCategorizerDL documentCategorizerDL = - new DocumentCategorizerDL(model, vocab, getCategories(), - new AverageClassificationScoringStrategy(), - new InferenceOptions()); + try (final DocumentCategorizerDL documentCategorizerDL = + new DocumentCategorizerDL(model, vocab, getCategories(), + new AverageClassificationScoringStrategy(), + new InferenceOptions())) { - final int index = documentCategorizerDL.getIndex("bad"); - Assertions.assertEquals(1, index); + final int index = documentCategorizerDL.getIndex("bad"); + Assertions.assertEquals(1, index); - final String category = documentCategorizerDL.getCategory(3); - Assertions.assertEquals("good", category); + final String category = documentCategorizerDL.getCategory(3); + Assertions.assertEquals("good", category); - final int number = documentCategorizerDL.getNumberOfCategories(); - Assertions.assertEquals(5, number); + final int number = documentCategorizerDL.getNumberOfCategories(); + Assertions.assertEquals(5, number); + } } diff --git a/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java b/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java index 513765dc2..f8febbd0a 100644 --- a/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java +++ b/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java @@ -36,7 +36,7 @@ public class NameFinderDLEval extends AbstractDLTest { private static final Logger logger = LoggerFactory.getLogger(NameFinderDLEval.class); - private final SentenceDetector sentenceDetector ; + private final SentenceDetector sentenceDetector; public NameFinderDLEval() throws IOException { this.sentenceDetector = new SentenceDetectorME("en"); @@ -54,18 +54,21 @@ public void tokenNameFinder1Test() throws Exception { final String[] tokens = new String[] {"George", "Washington", "was", "president", "of", "the", "United", "States", "."}; - final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); - final Span[] spans = nameFinderDL.find(tokens); + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + final Span[] spans = nameFinderDL.find(tokens); - for (Span span : spans) { - logger.debug(span.toString()); - } + for (Span span : spans) { + logger.debug(span.toString()); + } - Assertions.assertEquals(1, spans.length); - Assertions.assertEquals(0, spans[0].getStart()); - Assertions.assertEquals(17, spans[0].getEnd()); - Assertions.assertEquals(8.251646041870117, spans[0].getProb(), 0.00001); - Assertions.assertEquals("George Washington", spans[0].getCoveredText(String.join(" ", tokens))); + Assertions.assertEquals(1, spans.length); + Assertions.assertEquals(0, spans[0].getStart()); + Assertions.assertEquals(17, spans[0].getEnd()); + Assertions.assertEquals(8.251646041870117, spans[0].getProb(), 0.00001); + Assertions.assertEquals("George Washington", + spans[0].getCoveredText(String.join(" ", tokens))); + } } @@ -78,19 +81,20 @@ public void tokenNameFinder2Test() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); - final String[] tokens = new String[]{"His", "name", "was", "George", "Washington"}; + final String[] tokens = new String[] {"His", "name", "was", "George", "Washington"}; - final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); - final Span[] spans = nameFinderDL.find(tokens); - - for (Span span : spans) { - logger.debug(span.toString()); - } + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + final Span[] spans = nameFinderDL.find(tokens); - Assertions.assertEquals(1, spans.length); - Assertions.assertEquals(13, spans[0].getStart()); - Assertions.assertEquals(30, spans[0].getEnd()); + for (Span span : spans) { + logger.debug(span.toString()); + } + Assertions.assertEquals(1, spans.length); + Assertions.assertEquals(13, spans[0].getStart()); + Assertions.assertEquals(30, spans[0].getEnd()); + } } @Test @@ -102,19 +106,20 @@ public void tokenNameFinder3Test() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); - final String[] tokens = new String[]{"His", "name", "was", "George"}; - - final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); - final Span[] spans = nameFinderDL.find(tokens); + final String[] tokens = new String[] {"His", "name", "was", "George"}; - for (Span span : spans) { - logger.debug(span.toString()); - } + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + final Span[] spans = nameFinderDL.find(tokens); - Assertions.assertEquals(1, spans.length); - Assertions.assertEquals(13, spans[0].getStart()); - Assertions.assertEquals(19, spans[0].getEnd()); + for (Span span : spans) { + logger.debug(span.toString()); + } + Assertions.assertEquals(1, spans.length); + Assertions.assertEquals(13, spans[0].getStart()); + Assertions.assertEquals(19, spans[0].getEnd()); + } } @Test @@ -126,13 +131,14 @@ public void tokenNameFinderNoInputTest() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); - final String[] tokens = new String[]{}; - - final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); - final Span[] spans = nameFinderDL.find(tokens); + final String[] tokens = new String[] {}; - Assertions.assertEquals(0, spans.length); + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + final Span[] spans = nameFinderDL.find(tokens); + Assertions.assertEquals(0, spans.length); + } } @Test @@ -144,12 +150,14 @@ public void tokenNameFinderNoEntitiesTest() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); - final String[] tokens = new String[]{"I", "went", "to", "the", "park"}; + final String[] tokens = new String[] {"I", "went", "to", "the", "park"}; - final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); - final Span[] spans = nameFinderDL.find(tokens); + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + final Span[] spans = nameFinderDL.find(tokens); - Assertions.assertEquals(0, spans.length); + Assertions.assertEquals(0, spans.length); + } } @@ -162,21 +170,24 @@ public void tokenNameFinderMultipleEntitiesTest() throws Exception { final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); - final String[] tokens = new String[]{"George", "Washington", "and", "Abraham", "Lincoln", + final String[] tokens = new String[] {"George", "Washington", "and", "Abraham", "Lincoln", "were", "presidents"}; - final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); - final Span[] spans = nameFinderDL.find(tokens); + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + final Span[] spans = nameFinderDL.find(tokens); - for (Span span : spans) { - logger.debug(span.toString()); - } + for (Span span : spans) { + logger.debug(span.toString()); + } + + Assertions.assertEquals(2, spans.length); + Assertions.assertEquals(0, spans[0].getStart()); + Assertions.assertEquals(17, spans[0].getEnd()); + Assertions.assertEquals(22, spans[1].getStart()); + Assertions.assertEquals(37, spans[1].getEnd()); - Assertions.assertEquals(2, spans.length); - Assertions.assertEquals(0, spans[0].getStart()); - Assertions.assertEquals(17, spans[0].getEnd()); - Assertions.assertEquals(22, spans[1].getStart()); - Assertions.assertEquals(37, spans[1].getEnd()); + } } @@ -190,7 +201,9 @@ public void invalidModel() { final File model = new File("invalid.onnx"); final File vocab = new File("vocab.txt"); - new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector); + try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, getIds2Labels(), + sentenceDetector)) { + } }); } diff --git a/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java b/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java index f63fa3f40..b7c92e1c7 100644 --- a/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java +++ b/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java @@ -18,9 +18,7 @@ package opennlp.dl.vectors; import java.io.File; -import java.io.IOException; -import ai.onnxruntime.OrtException; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -29,21 +27,22 @@ public class SentenceVectorsDLEval extends AbstractDLTest { @Test - public void generateVectorsTest() throws IOException, OrtException { + public void generateVectorsTest() throws Exception { final File MODEL_FILE_NAME = new File(getOpennlpDataDir(), "onnx/sentence-transformers/model.onnx"); final File VOCAB_FILE_NAME = new File(getOpennlpDataDir(), "onnx/sentence-transformers/vocab.txt"); final String sentence = "george washington was president"; - final SentenceVectorsDL sv = new SentenceVectorsDL(MODEL_FILE_NAME, VOCAB_FILE_NAME); + try (final SentenceVectorsDL sv = new SentenceVectorsDL(MODEL_FILE_NAME, VOCAB_FILE_NAME)) { - final float[] vectors = sv.getVectors(sentence); + final float[] vectors = sv.getVectors(sentence); - Assertions.assertEquals(vectors[0], 0.39994872, 0.00001); - Assertions.assertEquals(vectors[1], -0.055101186, 0.00001); - Assertions.assertEquals(vectors[2], 0.2817594, 0.00001); - Assertions.assertEquals(vectors.length, 384); + Assertions.assertEquals(vectors[0], 0.39994872, 0.00001); + Assertions.assertEquals(vectors[1], -0.055101186, 0.00001); + Assertions.assertEquals(vectors[2], 0.2817594, 0.00001); + Assertions.assertEquals(vectors.length, 384); + } }