Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Jina embeddings integration #2172

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(
http_client_proxies: Optional[Union[Dict, str]] = None,
# VertexAI specific
vertex_credentials_json: Optional[str] = None,
# Jina specific
jina_base_url: Optional[str] = None,
):
"""
Initializes a configuration class instance for the Embeddings.
Expand All @@ -47,6 +49,8 @@ def __init__(
:type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional
:param jina_base_url: Base URL for the Jina API, defaults to None
:type jina_base_url: Optional[str], optional
"""

self.model = model
Expand All @@ -68,3 +72,6 @@ def __init__(

# VertexAI specific
self.vertex_credentials_json = vertex_credentials_json

# Jina specific
self.jina_base_url = jina_base_url
52 changes: 52 additions & 0 deletions mem0/embeddings/jina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from typing import Optional

import requests

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase


class JinaEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

self.config.model = self.config.model or "jina-embeddings-v3"
self.config.embedding_dims = self.config.embedding_dims or 768

api_key = self.config.api_key or os.getenv("JINA_API_KEY")
if not api_key:
raise ValueError("Jina API key is required. Set it in config or JINA_API_KEY environment variable.")

base_url = self.config.jina_base_url or os.getenv("JINA_API_BASE", "https://api.jina.ai")

self.base_url = f"{base_url}/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}

def embed(self, text):
"""
Get the embedding for the given text using Jina AI.

Args:
text (str): The text to embed.

Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")

data = {
"model": self.config.model,
"input": [{"text": text}]
}

if self.config.model_kwargs:
data.update(self.config.model_kwargs)

response = requests.post(self.base_url, headers=self.headers, json=data)
response.raise_for_status()

return response.json()["data"][0]["embedding"]
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class EmbedderFactory:
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
"together": "mem0.embeddings.together.TogetherEmbedding",
"jina": "mem0.embeddings.jina.JinaEmbedding",
}

@classmethod
Expand Down
170 changes: 170 additions & 0 deletions tests/embeddings/test_jina_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from unittest.mock import Mock, patch

import pytest

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.jina import JinaEmbedding


@pytest.fixture
def mock_jina_client(monkeypatch):
# Clear any existing env var
monkeypatch.delenv("JINA_API_KEY", raising=False)
with patch("mem0.embeddings.jina.requests") as mock_req:
yield mock_req


def test_embed_default_model(mock_jina_client, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "default_key") # Set a default key
config = BaseEmbedderConfig()
embedder = JinaEmbedding(config)
mock_response = Mock()
mock_response.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
mock_jina_client.post.return_value = mock_response

result = embedder.embed("Test embedding")

mock_jina_client.post.assert_called_once_with(
"https://api.jina.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer default_key" # Use the default key
},
json={
"model": "jina-embeddings-v3",
"input": [{"text": "Test embedding"}]
}
)
assert result == [0.1, 0.2, 0.3]


def test_embed_custom_model(mock_jina_client, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "test_key")
config = BaseEmbedderConfig(model="jina-embeddings-v3", embedding_dims=1024)
embedder = JinaEmbedding(config)

mock_response = Mock()
mock_response.json.return_value = {"data": [{"embedding": [0.4, 0.5, 0.6]}]}
mock_jina_client.post.return_value = mock_response

result = embedder.embed("Test embedding")

mock_jina_client.post.assert_called_once_with(
"https://api.jina.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer test_key"
},
json={
"model": "jina-embeddings-v3",
"input": [{"text": "Test embedding"}]
}
)
assert result == [0.4, 0.5, 0.6]


def test_embed_removes_newlines(mock_jina_client, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "test_key")
config = BaseEmbedderConfig()
embedder = JinaEmbedding(config)

mock_response = Mock()
mock_response.json.return_value = {"data": [{"embedding": [0.7, 0.8, 0.9]}]}
mock_jina_client.post.return_value = mock_response

result = embedder.embed("Hello\nworld")

mock_jina_client.post.assert_called_once_with(
"https://api.jina.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer test_key"
},
json={
"model": "jina-embeddings-v3",
"input": [{"text": "Hello world"}]
}
)
assert result == [0.7, 0.8, 0.9]


def test_embed_with_model_kwargs(mock_jina_client, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "test_key")
config = BaseEmbedderConfig(model_kwargs={"dimensions": 512, "normalized": True})
embedder = JinaEmbedding(config)

mock_response = Mock()
mock_response.json.return_value = {"data": [{"embedding": [1.0, 1.1, 1.2]}]}
mock_jina_client.post.return_value = mock_response

result = embedder.embed("Test with kwargs")

mock_jina_client.post.assert_called_once_with(
"https://api.jina.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer test_key"
},
json={
"model": "jina-embeddings-v3",
"input": [{"text": "Test with kwargs"}],
"dimensions": 512,
"normalized": True
}
)
assert result == [1.0, 1.1, 1.2]


def test_embed_without_api_key_env_var(mock_jina_client):
config = BaseEmbedderConfig(api_key="test_key")
embedder = JinaEmbedding(config)

mock_response = Mock()
mock_response.json.return_value = {"data": [{"embedding": [1.3, 1.4, 1.5]}]}
mock_jina_client.post.return_value = mock_response

result = embedder.embed("Testing API key")

mock_jina_client.post.assert_called_once_with(
"https://api.jina.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer test_key"
},
json={
"model": "jina-embeddings-v3",
"input": [{"text": "Testing API key"}]
}
)
assert result == [1.3, 1.4, 1.5]


def test_embed_uses_environment_api_key(mock_jina_client, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "env_key")
config = BaseEmbedderConfig()
embedder = JinaEmbedding(config)

mock_response = Mock()
mock_response.json.return_value = {"data": [{"embedding": [1.6, 1.7, 1.8]}]}
mock_jina_client.post.return_value = mock_response

result = embedder.embed("Environment key test")

mock_jina_client.post.assert_called_once_with(
"https://api.jina.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer env_key"
},
json={
"model": "jina-embeddings-v3",
"input": [{"text": "Environment key test"}]
}
)
assert result == [1.6, 1.7, 1.8]


def test_raises_error_without_api_key():
config = BaseEmbedderConfig()
with pytest.raises(ValueError, match="Jina API key is required"):
JinaEmbedding(config)