Skip to content

Commit

Permalink
DummyEmbeddingModel: add a parameter and matching test
Browse files Browse the repository at this point in the history
  • Loading branch information
cpoerschke committed Nov 7, 2024
1 parent b89e08c commit 4653837
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.util.Accountable;
Expand Down Expand Up @@ -78,10 +80,20 @@ public static SolrEmbeddingModel getInstance(
.invoke(builder, ((Long) params.get(paramName)).intValue());
break;
default:
builder
.getClass()
.getMethod(paramName, String.class)
.invoke(builder, params.get(paramName));
ArrayList<Method> methods = new ArrayList<>();
for (var method : builder.getClass().getMethods()) {
if (paramName.equals(method.getName()) && method.getParameterCount() == 1) {
methods.add(method);
}
}
if (methods.size() == 1) {
methods.get(0).invoke(builder, params.get(paramName));
} else {
builder
.getClass()
.getMethod(paramName, String.class)
.invoke(builder, params.get(paramName));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
{
"class": "org.apache.solr.llm.embedding.DummyEmbeddingModel",
"name": "dummy-1"
"name": "dummy-1",
"params": {
"embedding": 1234
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,24 @@
import java.util.List;

public class DummyEmbeddingModel implements EmbeddingModel {
public DummyEmbeddingModel() {}
final float[] embedding;

public DummyEmbeddingModel(int embedding) {
this.embedding =
new float[] {
(embedding / 1000) % 10, (embedding / 100) % 10, (embedding / 10) % 10, embedding % 10
};
}

@Override
public Response<Embedding> embed(String text) {
Embedding dummy = new Embedding(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Embedding dummy = new Embedding(this.embedding);
return new Response<Embedding>(dummy);
}

@Override
public Response<Embedding> embed(TextSegment textSegment) {
Embedding dummy = new Embedding(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Embedding dummy = new Embedding(this.embedding);
return new Response<Embedding>(dummy);
}

Expand All @@ -44,18 +51,25 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {

@Override
public int dimension() {
return 4;
return embedding.length;
}

public static DummyEmbeddingModelBuilder builder() {
return new DummyEmbeddingModelBuilder();
}

public static class DummyEmbeddingModelBuilder {
private int embedding = 0;

public DummyEmbeddingModelBuilder() {}

public DummyEmbeddingModelBuilder embedding(Long embedding) {
this.embedding = embedding.intValue();
return this;
}

public DummyEmbeddingModel build() {
return new DummyEmbeddingModel();
return new DummyEmbeddingModel(this.embedding);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.llm.embedding;

import org.apache.solr.SolrTestCase;
import org.junit.Test;

public class DummyEmbeddingModelTest extends SolrTestCase {

@Test
public void constructAndEmbed() throws Exception {
assertEquals(
"[1.0, 2.0, 3.0, 4.0]",
new DummyEmbeddingModel(1234).embed("hello").content().vectorAsList().toString());
assertEquals(
"[8.0, 7.0, 6.0, 5.0]",
new DummyEmbeddingModel(98765).embed("world").content().vectorAsList().toString());
assertEquals(
"[0.0, 0.0, 4.0, 2.0]",
new DummyEmbeddingModel(42).embed("answer").content().vectorAsList().toString());
}
}

0 comments on commit 4653837

Please sign in to comment.