Skip to content

Commit

Permalink
Add dimensions option for ZhiPuAi embedding model
Browse files Browse the repository at this point in the history
Signed-off-by: Assassinxc <[email protected]>
  • Loading branch information
Assassinxc committed Jan 21, 2025
1 parent 6a53268 commit 85c1420
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,13 @@ private ZhiPuAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeO

return ZhiPuAiEmbeddingOptions.builder()
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.dimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(),
defaultOptions.getDimensions()))
.build();
}

private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String text, EmbeddingOptions requestOptions) {
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel());
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel(), requestOptions.getDimensions());
}

public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ public class ZhiPuAiEmbeddingOptions implements EmbeddingOptions {
* ID of the model to use.
*/
private @JsonProperty("model") String model;
/**
* Dimension value of the model to use.
*/
private @JsonProperty("dimensions") Integer dimensions;
// @formatter:on

public static Builder builder() {
Expand All @@ -54,6 +58,10 @@ public void setModel(String model) {
this.model = model;
}

public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}

@Override
@JsonIgnore
public Integer getDimensions() {
Expand All @@ -73,6 +81,11 @@ public Builder model(String model) {
return this;
}

public Builder dimensions(Integer dimensions) {
this.options.setDimensions(dimensions);
return this;
}

public ZhiPuAiEmbeddingOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,12 @@ public enum EmbeddingModel {
/**
* DIMENSION: 1024
*/
Embedding_2("Embedding-2");
Embedding_2("Embedding-2"),

/**
* DIMENSION: up to 2048
*/
Embedding_3("Embedding-3");

public final String value;

Expand Down Expand Up @@ -956,15 +961,27 @@ public String toString() {
@JsonInclude(Include.NON_NULL)
public record EmbeddingRequest<T>(
@JsonProperty("input") T input,
@JsonProperty("model") String model) {
@JsonProperty("model") String model,
@JsonProperty("dimensions") Integer dimensions) {


/**
* Create an embedding request with the given input. Encoding model is set to 'embedding-2'.
* @param input Input text to embed.
*/
* Create an embedding request with the given input. Encoding model is set to 'embedding-2'.
*
* @param input Input text to embed.
*/
public EmbeddingRequest(T input) {
this(input, DEFAULT_EMBEDDING_MODEL);
this(input, DEFAULT_EMBEDDING_MODEL,null);
}

/**
* Create an embedding request with the given input and model.
*
* @param input
* @param model
*/
public EmbeddingRequest(T input, String model) {
this(input, model,null);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,15 @@ void embeddings() {
assertThat(response.getBody().data().get(0).embedding()).hasSize(1024);
}

@Test
void embeddingsWithDimensions() {
ResponseEntity<EmbeddingList<Embedding>> response = this.zhiPuAiApi
.embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world",
ZhiPuAiApi.EmbeddingModel.Embedding_3.getValue(), 1536));

assertThat(response).isNotNull();
assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1);
assertThat(response.getBody().data().get(0).embedding()).hasSize(1536);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ The prefix `spring.ai.zhipuai.embedding` is property prefix that configures the
| spring.ai.zhipuai.embedding.base-url | Optional overrides the spring.ai.zhipuai.base-url to provide embedding specific url | -
| spring.ai.zhipuai.embedding.api-key | Optional overrides the spring.ai.zhipuai.api-key to provide embedding specific api-key | -
| spring.ai.zhipuai.embedding.options.model | The model to use | embedding-2
| spring.ai.zhipuai.embedding.options.dimensions | The number of dimensions | -
|====

NOTE: You can override the common `spring.ai.zhipuai.base-url` and `spring.ai.zhipuai.api-key` for the `ChatModel` and `EmbeddingModel` implementations.
Expand Down Expand Up @@ -185,7 +186,8 @@ var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY"));
var embeddingModel = new ZhiPuAiEmbeddingModel(api, MetadataMode.EMBED,
ZhiPuAiEmbeddingOptions.builder()
.model("embedding-2")
.model("embedding-3")
.dimensions(1536)
.build());
EmbeddingResponse embeddingResponse = this.embeddingModel
Expand Down

0 comments on commit 85c1420

Please sign in to comment.