From e7d9e63651d33256c0dc62ed18f38bbab72b0332 Mon Sep 17 00:00:00 2001 From: Christine Poerschke Date: Fri, 8 Nov 2024 11:05:56 +0000 Subject: [PATCH] DummyEmbeddingModel: support float values and dimensions other than 4 --- .../src/test-files/modelExamples/dummy-model.json | 2 +- .../solr/llm/embedding/DummyEmbeddingModel.java | 12 ++++++------ .../solr/llm/embedding/DummyEmbeddingModelTest.java | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/solr/modules/llm/src/test-files/modelExamples/dummy-model.json b/solr/modules/llm/src/test-files/modelExamples/dummy-model.json index 8e850643456..527b86db5f2 100644 --- a/solr/modules/llm/src/test-files/modelExamples/dummy-model.json +++ b/solr/modules/llm/src/test-files/modelExamples/dummy-model.json @@ -2,6 +2,6 @@ "class": "org.apache.solr.llm.embedding.DummyEmbeddingModel", "name": "dummy-1", "params": { - "embedding": [1,2,3,4] + "embedding": [1.0, 2.0, 3.0, 4.0] } } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java index d2effb4393a..1076f5d0bbf 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java @@ -26,8 +26,8 @@ public class DummyEmbeddingModel implements EmbeddingModel { final float[] embedding; - public DummyEmbeddingModel(int[] embedding) { - this.embedding = new float[] {embedding[0], embedding[1], embedding[2], embedding[3]}; + public DummyEmbeddingModel(float[] embedding) { + this.embedding = embedding; } @Override @@ -57,14 +57,14 @@ public static DummyEmbeddingModelBuilder builder() { } public static class DummyEmbeddingModelBuilder { - private int[] builderEmbeddings; + private float[] builderEmbeddings; public DummyEmbeddingModelBuilder() {} - public DummyEmbeddingModelBuilder embedding(ArrayList embeddings) { - this.builderEmbeddings = new int[embeddings.size()]; + public DummyEmbeddingModelBuilder embedding(ArrayList embeddings) { + this.builderEmbeddings = new float[embeddings.size()]; for (int i = 0; i < embeddings.size(); i++) { - this.builderEmbeddings[i] = embeddings.get(i).intValue(); + this.builderEmbeddings[i] = embeddings.get(i).floatValue(); } return this; } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java index d020d50d430..3aca97d7e42 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java @@ -25,21 +25,21 @@ public class DummyEmbeddingModelTest extends SolrTestCase { public void constructAndEmbed() throws Exception { assertEquals( "[1.0, 2.0, 3.0, 4.0]", - new DummyEmbeddingModel(new int[] {1, 2, 3, 4}) + new DummyEmbeddingModel(new float[] {1, 2, 3, 4}) .embed("hello") .content() .vectorAsList() .toString()); assertEquals( "[8.0, 7.0, 6.0, 5.0]", - new DummyEmbeddingModel(new int[] {8, 7, 6, 5}) + new DummyEmbeddingModel(new float[] {8, 7, 6, 5}) .embed("world") .content() .vectorAsList() .toString()); assertEquals( "[0.0, 0.0, 4.0, 2.0]", - new DummyEmbeddingModel(new int[] {0, 0, 4, 2}) + new DummyEmbeddingModel(new float[] {0, 0, 4, 2}) .embed("answer") .content() .vectorAsList()