From 465383748b0f35847cbebbd024bd082cc8e23ab5 Mon Sep 17 00:00:00 2001 From: Christine Poerschke Date: Thu, 7 Nov 2024 14:12:24 +0000 Subject: [PATCH] DummyEmbeddingModel: add a parameter and matching test --- .../llm/embedding/SolrEmbeddingModel.java | 20 ++++++++--- .../test-files/modelExamples/dummy-model.json | 5 ++- .../llm/embedding/DummyEmbeddingModel.java | 24 ++++++++++--- .../embedding/DummyEmbeddingModelTest.java | 36 +++++++++++++++++++ 4 files changed, 75 insertions(+), 10 deletions(-) create mode 100644 solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java diff --git a/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java b/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java index 4d07b38197b..c84079d72ea 100644 --- a/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java +++ b/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java @@ -19,7 +19,9 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import java.lang.reflect.Method; import java.time.Duration; +import java.util.ArrayList; import java.util.Map; import java.util.Objects; import org.apache.lucene.util.Accountable; @@ -78,10 +80,20 @@ public static SolrEmbeddingModel getInstance( .invoke(builder, ((Long) params.get(paramName)).intValue()); break; default: - builder - .getClass() - .getMethod(paramName, String.class) - .invoke(builder, params.get(paramName)); + ArrayList methods = new ArrayList<>(); + for (var method : builder.getClass().getMethods()) { + if (paramName.equals(method.getName()) && method.getParameterCount() == 1) { + methods.add(method); + } + } + if (methods.size() == 1) { + methods.get(0).invoke(builder, params.get(paramName)); + } else { + builder + .getClass() + .getMethod(paramName, String.class) + .invoke(builder, params.get(paramName)); + } } } } diff --git a/solr/modules/llm/src/test-files/modelExamples/dummy-model.json b/solr/modules/llm/src/test-files/modelExamples/dummy-model.json index 3de81640d10..e20053d1cfb 100644 --- a/solr/modules/llm/src/test-files/modelExamples/dummy-model.json +++ b/solr/modules/llm/src/test-files/modelExamples/dummy-model.json @@ -1,4 +1,7 @@ { "class": "org.apache.solr.llm.embedding.DummyEmbeddingModel", - "name": "dummy-1" + "name": "dummy-1", + "params": { + "embedding": 1234 + } } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java index 98ffa9e6d97..f88a6de3527 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java @@ -23,17 +23,24 @@ import java.util.List; public class DummyEmbeddingModel implements EmbeddingModel { - public DummyEmbeddingModel() {} + final float[] embedding; + + public DummyEmbeddingModel(int embedding) { + this.embedding = + new float[] { + (embedding / 1000) % 10, (embedding / 100) % 10, (embedding / 10) % 10, embedding % 10 + }; + } @Override public Response embed(String text) { - Embedding dummy = new Embedding(new float[] {1.0f, 2.0f, 3.0f, 4.0f}); + Embedding dummy = new Embedding(this.embedding); return new Response(dummy); } @Override public Response embed(TextSegment textSegment) { - Embedding dummy = new Embedding(new float[] {1.0f, 2.0f, 3.0f, 4.0f}); + Embedding dummy = new Embedding(this.embedding); return new Response(dummy); } @@ -44,7 +51,7 @@ public Response> embedAll(List textSegments) { @Override public int dimension() { - return 4; + return embedding.length; } public static DummyEmbeddingModelBuilder builder() { @@ -52,10 +59,17 @@ public static DummyEmbeddingModelBuilder builder() { } public static class DummyEmbeddingModelBuilder { + private int embedding = 0; + public DummyEmbeddingModelBuilder() {} + public DummyEmbeddingModelBuilder embedding(Long embedding) { + this.embedding = embedding.intValue(); + return this; + } + public DummyEmbeddingModel build() { - return new DummyEmbeddingModel(); + return new DummyEmbeddingModel(this.embedding); } } } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java new file mode 100644 index 00000000000..22f53c5713e --- /dev/null +++ b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.llm.embedding; + +import org.apache.solr.SolrTestCase; +import org.junit.Test; + +public class DummyEmbeddingModelTest extends SolrTestCase { + + @Test + public void constructAndEmbed() throws Exception { + assertEquals( + "[1.0, 2.0, 3.0, 4.0]", + new DummyEmbeddingModel(1234).embed("hello").content().vectorAsList().toString()); + assertEquals( + "[8.0, 7.0, 6.0, 5.0]", + new DummyEmbeddingModel(98765).embed("world").content().vectorAsList().toString()); + assertEquals( + "[0.0, 0.0, 4.0, 2.0]", + new DummyEmbeddingModel(42).embed("answer").content().vectorAsList().toString()); + } +}