diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 0000000..ec86d3d --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,24 @@ +name: Type Checking + +on: + pull_request: + branches: [ main ] + +jobs: + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install uv + run: pip install uv + - name: Create venv + run: uv venv + - name: Install dependencies + run: | + uv sync + - name: Run Mypy + run: uv run mypy . \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 071432c..ef5c0c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,18 +23,6 @@ dependencies = [ "numpy>=1,<2", ] -[project.optional-dependencies] -dev = [ - "ruff>=0.6.9", - "pyright>=1.1.384", - "pytest>=8.3.3", - "langchain>=0.3.3", - "langchain-openai>=0.2.2", - "langchain-community>=0.3.2", - "pandas>=2.2.3", - "pytest-asyncio>=0.24.0", -] - [project.urls] repository = "https://github.com/timescale/python-vector" documentation = "https://timescale.github.io/python-vector" @@ -51,36 +39,11 @@ addopts = [ "--import-mode=importlib", ] -[tool.pyright] -typeCheckingMode = "strict" -reportImplicitOverride = true -exclude = [ - "**/.bzr", - "**/.direnv", - "**/.eggs", - "**/.git", - "**/.git-rewrite", - "**/.hg", - "**/.ipynb_checkpoints", - "**/.mypy_cache", - "**/.nox", - "**/.pants.d", - "**/.pyenv", - "**/.pytest_cache", - "**/.pytype", - "**/.ruff_cache", - "**/.svn", - "**/.tox", - "**/.venv", - "**/.vscode", - "**/__pypackages__", - "**/_build", - "**/buck-out", - "**/dist", - "**/node_modules", - "**/site-packages", - "**/venv", -] + +[tool.mypy] +strict = true +ignore_missing_imports = true +namespace_packages = true [tool.ruff] line-length = 120 @@ -137,4 +100,17 @@ select = [ "W291", "PIE", "Q" -] \ No newline at end of file +] + +[tool.uv] +dev-dependencies = [ + "mypy>=1.12.0", + "types-psycopg2>=2.9.21.20240819", + "ruff>=0.6.9", + "pytest>=8.3.3", + "langchain>=0.3.3", + "langchain-openai>=0.2.2", + "langchain-community>=0.3.2", + "pandas>=2.2.3", + "pytest-asyncio>=0.24.0", +] diff --git a/tests/async_client_test.py b/tests/async_client_test.py index 98f1085..11939b4 100644 --- a/tests/async_client_test.py +++ b/tests/async_client_test.py @@ -306,7 +306,7 @@ async def test_vector(service_url: str, schema: str) -> None: assert not await vec.table_is_empty() # check all the possible ways to specify a date range - async def search_date(start_date, end_date, expected): + async def search_date(start_date: datetime | str | None, end_date: datetime | str | None, expected: int) -> None: # using uuid_time_filter rec = await vec.search( [1.0, 2.0], diff --git a/tests/pg_vectorizer_test.py b/tests/pg_vectorizer_test.py index fc4f5ab..4484f65 100644 --- a/tests/pg_vectorizer_test.py +++ b/tests/pg_vectorizer_test.py @@ -1,4 +1,5 @@ from datetime import timedelta +from typing import Any import psycopg2 import pytest @@ -11,7 +12,7 @@ from timescale_vector.pgvectorizer import Vectorize -def get_document(blog): +def get_document(blog: dict[str, Any]) -> list[Document]: text_splitter = CharacterTextSplitter( chunk_size=1000, chunk_overlap=200, @@ -56,7 +57,7 @@ def test_pg_vectorizer(service_url: str) -> None: VALUES ('first', 'mat', 'first_post', 'personal', '2021-01-01'); """) - def embed_and_write(blog_instances, vectorizer): + def embed_and_write(blog_instances: list[Any], vectorizer: Vectorize) -> None: TABLE_NAME = vectorizer.table_name_unquoted + "_embedding" embedding = OpenAIEmbeddings() vector_store = TimescaleVector( diff --git a/tests/sync_client_test.py b/tests/sync_client_test.py index 19f9f4e..43a0e30 100644 --- a/tests/sync_client_test.py +++ b/tests/sync_client_test.py @@ -136,15 +136,15 @@ def test_sync_client(service_url: str, schema: str) -> None: rec = vec.search([1.0, 2.0], filter={"key_1": "val_1", "key_2": "val_2"}) assert rec[0][SEARCH_RESULT_CONTENTS_IDX] == "the brown fox" - assert rec[0]["contents"] == "the brown fox" + assert rec[0]["contents"] == "the brown fox" # type: ignore assert rec[0][SEARCH_RESULT_METADATA_IDX] == { "key_1": "val_1", "key_2": "val_2", } - assert rec[0]["metadata"] == {"key_1": "val_1", "key_2": "val_2"} + assert rec[0]["metadata"] == {"key_1": "val_1", "key_2": "val_2"} # type: ignore assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict) assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556 - assert rec[0]["distance"] == 0.0009438353921149556 + assert rec[0]["distance"] == 0.0009438353921149556 # type: ignore rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates("key", "==", "val2")) assert len(rec) == 1 @@ -218,7 +218,7 @@ def test_sync_client(service_url: str, schema: str) -> None: ] ) - def search_date(start_date, end_date, expected): + def search_date(start_date: datetime | str | None, end_date: datetime | str | None, expected: int) -> None: # using uuid_time_filter rec = vec.search( [1.0, 2.0], diff --git a/timescale_vector/client.py b/timescale_vector/client.py deleted file mode 100644 index 6e1f8e2..0000000 --- a/timescale_vector/client.py +++ /dev/null @@ -1,1503 +0,0 @@ -__all__ = [ - "SEARCH_RESULT_ID_IDX", - "SEARCH_RESULT_METADATA_IDX", - "SEARCH_RESULT_CONTENTS_IDX", - "SEARCH_RESULT_EMBEDDING_IDX", - "SEARCH_RESULT_DISTANCE_IDX", - "uuid_from_time", - "BaseIndex", - "IvfflatIndex", - "HNSWIndex", - "DiskAnnIndex", - "QueryParams", - "DiskAnnIndexParams", - "IvfflatIndexParams", - "HNSWIndexParams", - "UUIDTimeRange", - "Predicates", - "QueryBuilder", - "Async", - "Sync", -] - -import calendar -import json -import math -import random -import re -import uuid -from collections.abc import Callable, Iterable -from contextlib import contextmanager -from datetime import datetime, timedelta, timezone -from typing import Any, Literal, Union - -import asyncpg -import numpy as np -import pgvector.psycopg2 -import psycopg2.extras -import psycopg2.pool -from pgvector.asyncpg import register_vector - - -# copied from Cassandra: https://docs.datastax.com/en/drivers/python/3.2/_modules/cassandra/util.html#uuid_from_time -def uuid_from_time( - time_arg: float | datetime | None = None, node: Any = None, clock_seq: int | None = None -) -> uuid.UUID: - """ - Converts a datetime or timestamp to a type 1 `uuid.UUID`. - - Parameters - ---------- - time_arg - The time to use for the timestamp portion of the UUID. - This can either be a `datetime` object or a timestamp in seconds - (as returned from `time.time()`). - node - Bytes for the UUID (up to 48 bits). If not specified, this - field is randomized. - clock_seq - Clock sequence field for the UUID (up to 14 bits). If not specified, - a random sequence is generated. - - Returns - ------- - uuid.UUID: For the given time, node, and clock sequence - """ - if time_arg is None: - return uuid.uuid1(node, clock_seq) - if hasattr(time_arg, "utctimetuple"): - # this is different from the Cassandra version, - # we assume that a naive datetime is in system time and convert it to UTC - # we do this because naive datetimes are interpreted as timestamps (without timezone) in postgres - time_arg_dt: datetime = time_arg - if time_arg_dt.tzinfo is None: - time_arg_dt = time_arg_dt.astimezone(timezone.utc) - seconds = int(calendar.timegm(time_arg_dt.utctimetuple())) - microseconds = (seconds * 1e6) + time_arg_dt.time().microsecond - else: - microseconds = int(float(time_arg) * 1e6) - - # 0x01b21dd213814000 is the number of 100-ns intervals between the - # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. - intervals = int(microseconds * 10) + 0x01B21DD213814000 - - time_low = intervals & 0xFFFFFFFF - time_mid = (intervals >> 32) & 0xFFFF - time_hi_version = (intervals >> 48) & 0x0FFF - - if clock_seq is None: - clock_seq = random.getrandbits(14) - else: - if clock_seq > 0x3FFF: - raise ValueError("clock_seq is out of range (need a 14-bit value)") - - clock_seq_low = clock_seq & 0xFF - clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3F) - - if node is None: - node = random.getrandbits(48) - - return uuid.UUID( - fields=( - time_low, - time_mid, - time_hi_version, - clock_seq_hi_variant, - clock_seq_low, - node, - ), - version=1, - ) - - -class BaseIndex: - def get_index_method(self, distance_type: str) -> str: - index_method = "invalid" - if distance_type == "<->": - index_method = "vector_l2_ops" - elif distance_type == "<#>": - index_method = "vector_ip_ops" - elif distance_type == "<=>": - index_method = "vector_cosine_ops" - else: - raise ValueError(f"Unknown distance type {distance_type}") - return index_method - - def create_index_query( - self, - table_name_quoted: str, - column_name_quoted: str, - index_name_quoted: str, - distance_type: str, - num_records_callback: Callable[[], int], - ) -> str: - raise NotImplementedError() - - -class IvfflatIndex(BaseIndex): - def __init__(self, num_records: int | None = None, num_lists: int | None = None) -> None: - """ - Pgvector's ivfflat index. - """ - self.num_records: int | None = num_records - self.num_lists: int | None = num_lists - - def get_num_records(self, num_record_callback: Callable[[], int]) -> int: - if self.num_records is not None: - return self.num_records - return num_record_callback() - - def get_num_lists(self, num_records_callback: Callable[[], int]) -> int: - if self.num_lists is not None: - return self.num_lists - - num_records = self.get_num_records(num_records_callback) - num_lists = num_records / 1000 - if num_lists < 10: - num_lists = 10 - if num_records > 1000000: - num_lists = math.sqrt(num_records) - return int(num_lists) - - def create_index_query( - self, - table_name_quoted: str, - column_name_quoted: str, - index_name_quoted: str, - distance_type: str, - num_records_callback: Callable[[], int], - ) -> str: - index_method = self.get_index_method(distance_type) - num_lists = self.get_num_lists(num_records_callback) - - return ( - f"CREATE INDEX {index_name_quoted} ON {table_name_quoted}" - f"USING ivfflat ({column_name_quoted} {index_method}) WITH (lists = {num_lists});" - ) - - -class HNSWIndex(BaseIndex): - def __init__(self, m: int | None = None, ef_construction: int | None = None) -> None: - """ - Pgvector's hnsw index. - """ - self.m: int | None = m - self.ef_construction: int | None = ef_construction - - def create_index_query( - self, - table_name_quoted: str, - column_name_quoted: str, - index_name_quoted: str, - distance_type: str, - _num_records_callback: Callable[[], int], - ) -> str: - index_method = self.get_index_method(distance_type) - - with_clauses: list[str] = [] - if self.m is not None: - with_clauses.append(f"m = {self.m}") - if self.ef_construction is not None: - with_clauses.append(f"ef_construction = {self.ef_construction}") - - with_clause = "" - if len(with_clauses) > 0: - with_clause = "WITH (" + ", ".join(with_clauses) + ")" - - return ( - f"CREATE INDEX {index_name_quoted} ON {table_name_quoted}" - f"USING hnsw ({column_name_quoted} {index_method}) {with_clause};" - ) - - -class DiskAnnIndex(BaseIndex): - def __init__( - self, - search_list_size: int | None = None, - num_neighbors: int | None = None, - max_alpha: float | None = None, - storage_layout: str | None = None, - num_dimensions: int | None = None, - num_bits_per_dimension: int | None = None, - ) -> None: - """ - Timescale's vector index. - """ - self.search_list_size: int | None = search_list_size - self.num_neighbors: int | None = num_neighbors - self.max_alpha: float | None = max_alpha - self.storage_layout: str | None = storage_layout - self.num_dimensions: int | None = num_dimensions - self.num_bits_per_dimension: int | None = num_bits_per_dimension - - def create_index_query( - self, - table_name_quoted: str, - column_name_quoted: str, - index_name_quoted: str, - distance_type: str, - _num_records_callback: Callable[[], int], - ) -> str: - if distance_type != "<=>": - raise ValueError( - f"Timescale's vector index only supports cosine distance, but distance_type was {distance_type}" - ) - - with_clauses: list[str] = [] - if self.search_list_size is not None: - with_clauses.append(f"search_list_size = {self.search_list_size}") - if self.num_neighbors is not None: - with_clauses.append(f"num_neighbors = {self.num_neighbors}") - if self.max_alpha is not None: - with_clauses.append(f"max_alpha = {self.max_alpha}") - if self.storage_layout is not None: - with_clauses.append(f"storage_layout = {self.storage_layout}") - if self.num_dimensions is not None: - with_clauses.append(f"num_dimensions = {self.num_dimensions}") - if self.num_bits_per_dimension is not None: - with_clauses.append(f"num_bits_per_dimension = {self.num_bits_per_dimension}") - - with_clause = "" - if len(with_clauses) > 0: - with_clause = "WITH (" + ", ".join(with_clauses) + ")" - - return ( - f"CREATE INDEX {index_name_quoted} ON {table_name_quoted}" - f"USING diskann ({column_name_quoted}) {with_clause};" - ) - - -class QueryParams: - def __init__(self, params: dict[str, Any]) -> None: - self.params: dict[str, Any] = params - - def get_statements(self) -> list[str]: - return ["SET LOCAL " + key + " = " + str(value) for key, value in self.params.items()] - - -class DiskAnnIndexParams(QueryParams): - def __init__(self, search_list_size: int | None = None, rescore: int | None = None) -> None: - params: dict[str, Any] = {} - if search_list_size is not None: - params["diskann.query_search_list_size"] = search_list_size - if rescore is not None: - params["diskann.query_rescore"] = rescore - super().__init__(params) - - -class IvfflatIndexParams(QueryParams): - def __init__(self, probes: int) -> None: - super().__init__({"ivfflat.probes": probes}) - - -class HNSWIndexParams(QueryParams): - def __init__(self, ef_search: int) -> None: - super().__init__({"hnsw.ef_search": ef_search}) - - -SEARCH_RESULT_ID_IDX = 0 -SEARCH_RESULT_METADATA_IDX = 1 -SEARCH_RESULT_CONTENTS_IDX = 2 -SEARCH_RESULT_EMBEDDING_IDX = 3 -SEARCH_RESULT_DISTANCE_IDX = 4 - - -class UUIDTimeRange: - @staticmethod - def _parse_datetime(input_datetime: datetime | str | None) -> datetime | None: - """ - Parse a datetime object or string representation of a datetime. - - Args: - input_datetime (datetime or str): Input datetime or string. - - Returns: - datetime: Parsed datetime object. - - Raises: - ValueError: If the input cannot be parsed as a datetime. - """ - if input_datetime is None or input_datetime == "None": - return None - - if isinstance(input_datetime, datetime): - # If input is already a datetime object, return it as is - return input_datetime - - if isinstance(input_datetime, str): - try: - # Attempt to parse the input string into a datetime - return datetime.fromisoformat(input_datetime) - except ValueError: - raise ValueError(f"Invalid datetime string format: {input_datetime}") from None - - raise ValueError("Input must be a datetime object or string") - - def __init__( - self, - start_date: datetime | str | None = None, - end_date: datetime | str | None = None, - time_delta: timedelta | None = None, - start_inclusive: bool = True, - end_inclusive: bool = False, - ): - """ - A UUIDTimeRange is a time range predicate on the UUID Version 1 timestamps. - - Note that naive datetime objects are interpreted as local time on the python client side - and converted to UTC before being sent to the database. - """ - start_date = UUIDTimeRange._parse_datetime(start_date) - end_date = UUIDTimeRange._parse_datetime(end_date) - - if start_date is not None and end_date is not None and start_date > end_date: - raise Exception("start_date must be before end_date") - - if start_date is None and end_date is None: - raise Exception("start_date and end_date cannot both be None") - - if start_date is not None and start_date.tzinfo is None: - start_date = start_date.astimezone(timezone.utc) - - if end_date is not None and end_date.tzinfo is None: - end_date = end_date.astimezone(timezone.utc) - - if time_delta is not None: - if end_date is None: - end_date = start_date + time_delta - elif start_date is None: - start_date = end_date - time_delta - else: - raise Exception("time_delta, start_date and end_date cannot all be specified at the same time") - - self.start_date: datetime | None = start_date - self.end_date: datetime | None = end_date - self.start_inclusive: bool = start_inclusive - self.end_inclusive: bool = end_inclusive - - def __str__(self) -> str: - start_str = f"[{self.start_date}" if self.start_inclusive else f"({self.start_date}" - end_str = f"{self.end_date}]" if self.end_inclusive else f"{self.end_date})" - - return f"UUIDTimeRange {start_str}, {end_str}" - - def build_query(self, params: list[Any]) -> tuple[str, list[Any]]: - column = "uuid_timestamp(id)" - queries: list[str] = [] - if self.start_date is not None: - if self.start_inclusive: - queries.append(f"{column} >= ${len(params)+1}") - else: - queries.append(f"{column} > ${len(params)+1}") - params.append(self.start_date) - if self.end_date is not None: - if self.end_inclusive: - queries.append(f"{column} <= ${len(params)+1}") - else: - queries.append(f"{column} < ${len(params)+1}") - params.append(self.end_date) - return " AND ".join(queries), params - - -class Predicates: - logical_operators: dict[str, str] = { - "AND": "AND", - "OR": "OR", - "NOT": "NOT", - } - - operators_mapping: dict[str, str] = { - "=": "=", - "==": "=", - ">=": ">=", - ">": ">", - "<=": "<=", - "<": "<", - "!=": "<>", - "@>": "@>", # array contains - } - - PredicateValue = str | int | float | datetime | list[Any] | tuple[Any, ...] - - def __init__( - self, - *clauses: Union[ - "Predicates", - tuple[str, PredicateValue], - tuple[str, str, PredicateValue], - str, - PredicateValue, - ], - operator: Literal["AND", "OR", "NOT"] = "AND", - ): - """ - Predicates class defines predicates on the object metadata. - Predicates can be combined using logical operators (&, |, and ~). - - Parameters - ---------- - clauses - Predicate clauses. Can be either another Predicates object - or a tuple of the form (field, operator, value) or (field, value). - Operator - Logical operator to use when combining the clauses. - Can be one of 'AND', 'OR', 'NOT'. Defaults to 'AND'. - """ - if operator not in self.logical_operators: - raise ValueError(f"invalid operator: {operator}") - self.operator: str = operator - if isinstance(clauses[0], str): - if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)): - raise ValueError(f"Invalid clause format: {clauses}") - self.clauses: list[ - Predicates - | tuple[str, Predicates.PredicateValue] - | tuple[str, str, Predicates.PredicateValue] - | str - | Predicates.PredicateValue - ] = [clauses] - else: - self.clauses = list(clauses) - - def add_clause( - self, - *clause: Union[ - "Predicates", - tuple[str, PredicateValue], - tuple[str, str, PredicateValue], - str, - PredicateValue, - ], - ) -> None: - """ - Add a clause to the predicates object. - - Parameters - ---------- - clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str] - Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) - or (field, value). - """ - if isinstance(clause[0], str): - if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)): - raise ValueError(f"Invalid clause format: {clause}") - self.clauses.append(clause) - else: - self.clauses.extend(list(clause)) - - def __and__(self, other: "Predicates") -> "Predicates": - new_predicates = Predicates(self, other, operator="AND") - return new_predicates - - def __or__(self, other: "Predicates") -> "Predicates": - new_predicates = Predicates(self, other, operator="OR") - return new_predicates - - def __invert__(self) -> "Predicates": - new_predicates = Predicates(self, operator="NOT") - return new_predicates - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Predicates): - return False - - return self.operator == other.operator and self.clauses == other.clauses - - def __repr__(self) -> str: - if self.operator: - return f"{self.operator}({', '.join(repr(clause) for clause in self.clauses)})" - else: - return repr(self.clauses) - - def build_query(self, params: list[Any]) -> tuple[str, list[Any]]: - """ - Build the SQL query string and parameters for the predicates object. - """ - if not self.clauses: - return "", [] - - where_conditions: list[str] = [] - - for clause in self.clauses: - if isinstance(clause, Predicates): - child_where_clause, params = clause.build_query(params) - where_conditions.append(f"({child_where_clause})") - elif isinstance(clause, tuple): - if len(clause) == 2: - field, value = clause - operator = "=" # Default operator - elif len(clause) == 3: - field, operator, value = clause - if operator not in self.operators_mapping: - raise ValueError(f"Invalid operator: {operator}") - operator = self.operators_mapping[operator] - else: - raise ValueError("Invalid clause format") - - index = len(params) + 1 - param_name = f"${index}" - - if field == "__uuid_timestamp": - # convert str to timestamp in the database, it's better at it than python - if isinstance(value, str): - where_conditions.append(f"uuid_timestamp(id) {operator} ({param_name}::text)::timestamptz") - else: - where_conditions.append(f"uuid_timestamp(id) {operator} {param_name}") - params.append(value) - - elif operator == "@>" and isinstance(value, list | tuple): - if len(value) == 0: - raise ValueError("Invalid value. Empty lists and empty tuples are not supported.") - json_value = json.dumps(value) - where_conditions.append(f"metadata @> jsonb_build_object('{field}', {param_name}::jsonb)") - params.append(json_value) - - else: - field_cast = "" - if isinstance(value, int): - field_cast = "::int" - elif isinstance(value, float): - field_cast = "::numeric" - elif isinstance(value, datetime): - field_cast = "::timestamptz" - where_conditions.append(f"(metadata->>'{field}'){field_cast} {operator} {param_name}") - params.append(value) - - if self.operator == "NOT": - or_clauses = " OR ".join(where_conditions) - # use IS DISTINCT FROM to treat all-null clauses as False and pass the filter - where_clause = f"TRUE IS DISTINCT FROM ({or_clauses})" - else: - where_clause = (" " + self.operator + " ").join(where_conditions) - return where_clause, params - - -class QueryBuilder: - def __init__( - self, - table_name: str, - num_dimensions: int, - distance_type: str, - id_type: str, - time_partition_interval: timedelta | None, - infer_filters: bool, - schema_name: str | None, - ) -> None: - """ - Initializes a base Vector object to generate queries for vector clients. - - Parameters - ---------- - table_name - The name of the table. - num_dimensions - The number of dimensions for the embedding vector. - distance_type - The distance type for indexing. - id_type - The type of the id column. Can be either 'UUID' or 'TEXT'. - time_partition_interval - The time interval for partitioning the table (optional). - infer_filters - Whether to infer start and end times from the special __start_date and __end_date filters. - schema_name - The schema name for the table (optional, uses the database's default schema if not specified). - """ - self.table_name: str = table_name - self.schema_name: str | None = schema_name - self.num_dimensions: int = num_dimensions - if distance_type == "cosine" or distance_type == "<=>": - self.distance_type: str = "<=>" - elif distance_type == "euclidean" or distance_type == "<->" or distance_type == "l2": - self.distance_type = "<->" - else: - raise ValueError(f"unrecognized distance_type {distance_type}") - - if id_type.lower() != "uuid" and id_type.lower() != "text": - raise ValueError(f"unrecognized id_type {id_type}") - - if time_partition_interval is not None and id_type.lower() != "uuid": - raise ValueError("time partitioning is only supported for uuid id_type") - - self.id_type: str = id_type.lower() - self.time_partition_interval: timedelta | None = time_partition_interval - self.infer_filters: bool = infer_filters - - @staticmethod - def _quote_ident(ident: str) -> str: - """ - Quotes an identifier to prevent SQL injection. - - Parameters - ---------- - ident - The identifier to be quoted. - - Returns - ------- - str: The quoted identifier. - """ - return '"{}"'.format(ident.replace('"', '""')) - - def _quoted_table_name(self) -> str: - if self.schema_name is not None: - return self._quote_ident(self.schema_name) + "." + self._quote_ident(self.table_name) - else: - return self._quote_ident(self.table_name) - - def get_row_exists_query(self) -> str: - """ - Generates a query to check if any rows exist in the table. - - Returns - ------- - str: The query to check for row existence. - """ - return f"SELECT 1 FROM {self._quoted_table_name()} LIMIT 1" - - def get_upsert_query(self) -> str: - """ - Generates an upsert query. - - Returns - ------- - str: The upsert query. - """ - return ( - f"INSERT INTO {self._quoted_table_name()} (id, metadata, contents, embedding) " - f"VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING" - ) - - def get_approx_count_query(self) -> str: - """ - Generate a query to find the approximate count of records in the table. - - Returns - ------- - str: the query. - """ - # todo optimize with approx - return f"SELECT COUNT(*) as cnt FROM {self._quoted_table_name()}" - - def get_create_query(self) -> str: - """ - Generates a query to create the tables, indexes, and extensions needed to store the vector data. - - Returns - ------- - str: The create table query. - """ - hypertable_sql = "" - if self.time_partition_interval is not None: - hypertable_sql = f""" - CREATE EXTENSION IF NOT EXISTS timescaledb; - - CREATE OR REPLACE FUNCTION public.uuid_timestamp(uuid UUID) RETURNS TIMESTAMPTZ AS $$ - DECLARE - bytes bytea; - BEGIN - bytes := uuid_send(uuid); - if (get_byte(bytes, 6) >> 4)::int2 != 1 then - RAISE EXCEPTION 'UUID version is not 1'; - end if; - RETURN to_timestamp( - ( - ( - (get_byte(bytes, 0)::bigint << 24) | - (get_byte(bytes, 1)::bigint << 16) | - (get_byte(bytes, 2)::bigint << 8) | - (get_byte(bytes, 3)::bigint << 0) - ) + ( - ((get_byte(bytes, 4)::bigint << 8 | - get_byte(bytes, 5)::bigint)) << 32 - ) + ( - (((get_byte(bytes, 6)::bigint & 15) << 8 | get_byte(bytes, 7)::bigint) & 4095) << 48 - ) - 122192928000000000 - ) / 10000 / 1000::double precision - ); - END - $$ LANGUAGE plpgsql - IMMUTABLE PARALLEL SAFE - RETURNS NULL ON NULL INPUT; - - SELECT create_hypertable('{self._quoted_table_name()}', - 'id', - if_not_exists=> true, - time_partitioning_func=>'public.uuid_timestamp', - chunk_time_interval => '{str(self.time_partition_interval.total_seconds())} seconds'::interval); - """ - return f""" -CREATE EXTENSION IF NOT EXISTS vector; -CREATE EXTENSION IF NOT EXISTS vectorscale; - - -CREATE TABLE IF NOT EXISTS {self._quoted_table_name()} ( - id {self.id_type} PRIMARY KEY, - metadata JSONB, - contents TEXT, - embedding VECTOR({self.num_dimensions}) -); - -CREATE INDEX IF NOT EXISTS {self._quote_ident(self.table_name + "_meta_idx")} ON {self._quoted_table_name()} -USING GIN(metadata jsonb_path_ops); - -{hypertable_sql} -""" - - def _get_embedding_index_name_quoted(self) -> str: - return self._quote_ident(self.table_name + "_embedding_idx") - - def _get_schema_qualified_embedding_index_name_quoted(self) -> str: - if self.schema_name is not None: - return self._quote_ident(self.schema_name) + "." + self._get_embedding_index_name_quoted() - else: - return self._get_embedding_index_name_quoted() - - def drop_embedding_index_query(self) -> str: - return f"DROP INDEX IF EXISTS {self._get_schema_qualified_embedding_index_name_quoted()};" - - def delete_all_query(self) -> str: - return f"TRUNCATE {self._quoted_table_name()};" - - def delete_by_ids_query(self, ids: list[uuid.UUID] | list[str]) -> tuple[str, list[Any]]: - query = f"DELETE FROM {self._quoted_table_name()} WHERE id = ANY($1::{self.id_type}[]);" - return (query, [ids]) - - def delete_by_metadata_query(self, filter: dict[str, str] | list[dict[str, str]]) -> tuple[str, list[Any]]: - params: list[Any] = [] - (where, params) = self._where_clause_for_filter(params, filter) - query = f"DELETE FROM {self._quoted_table_name()} WHERE {where};" - return (query, params) - - def drop_table_query(self) -> str: - return f"DROP TABLE IF EXISTS {self._quoted_table_name()};" - - def default_max_db_connection_query(self) -> str: - """ - Generates a query to get the default max db connections. This uses a heuristic to determine the max connections - based on the max_connections setting in postgres - and the number of currently used connections. This heuristic leaves 4 connections in reserve. - """ - return ( - "SELECT greatest(1, ((SELECT setting::int FROM pg_settings " - "WHERE name='max_connections')-(SELECT count(*) FROM pg_stat_activity) - 4)::int)" - ) - - def create_embedding_index_query(self, index: BaseIndex, num_records_callback: Callable[[], int]) -> str: - """ - Generates an embedding index creation query. - - Parameters - ---------- - index - The index to create. - num_records_callback - A callback function to get the number of records in the table. - - Returns - ------- - str: The index creation query. - """ - column_name = "embedding" - index_name_quoted = self._get_embedding_index_name_quoted() - query = index.create_index_query( - self._quoted_table_name(), - self._quote_ident(column_name), - index_name_quoted, - self.distance_type, - num_records_callback, - ) - return query - - def _where_clause_for_filter( - self, params: list[Any], filter: dict[str, str] | list[dict[str, str]] | None - ) -> tuple[str, list[Any]]: - if filter is None: - return "TRUE", params - - if isinstance(filter, dict): - where = f"metadata @> ${len(params)+1}" - json_object = json.dumps(filter) - params = params + [json_object] - elif isinstance(filter, list): - any_params = [] - for _idx, filter_dict in enumerate(filter, start=len(params) + 1): - any_params.append(json.dumps(filter_dict)) - where = f"metadata @> ANY(${len(params) + 1}::jsonb[])" - params = params + [any_params] - else: - raise ValueError(f"Unknown filter type: {type(filter)}") - - return where, params - - def search_query( - self, - query_embedding: list[float] | np.ndarray | None, - limit: int = 10, - filter: dict[str, str] | list[dict[str, str]] | None = None, - predicates: Predicates | None = None, - uuid_time_filter: UUIDTimeRange | None = None, - ) -> tuple[str, list[Any]]: - """ - Generates a similarity query. - - Returns: - Tuple[str, List]: A tuple containing the query and parameters. - """ - params: list[Any] = [] - if query_embedding is not None: - distance = f"embedding {self.distance_type} ${len(params)+1}" - params = params + [query_embedding] - order_by_clause = f"ORDER BY {distance} ASC" - else: - distance = "-1.0" - order_by_clause = "" - - if ( - self.infer_filters - and uuid_time_filter is None - and isinstance(filter, dict) - and ("__start_date" in filter or "__end_date" in filter) - ): - start_date = UUIDTimeRange._parse_datetime(filter.get("__start_date")) - end_date = UUIDTimeRange._parse_datetime(filter.get("__end_date")) - - uuid_time_filter = UUIDTimeRange(start_date, end_date) - - if start_date is not None: - del filter["__start_date"] - if end_date is not None: - del filter["__end_date"] - - where_clauses = [] - if filter is not None: - (where_filter, params) = self._where_clause_for_filter(params, filter) - where_clauses.append(where_filter) - - if predicates is not None: - (where_predicates, params) = predicates.build_query(params) - where_clauses.append(where_predicates) - - if uuid_time_filter is not None: - (where_time, params) = uuid_time_filter.build_query(params) - where_clauses.append(where_time) - - where = " AND ".join(where_clauses) if len(where_clauses) > 0 else "TRUE" - - query = f""" - SELECT - id, metadata, contents, embedding, {distance} as distance - FROM - {self._quoted_table_name()} - WHERE - {where} - {order_by_clause} - LIMIT {limit} - """ - return query, params - - -class Async(QueryBuilder): - def __init__( - self, - service_url: str, - table_name: str, - num_dimensions: int, - distance_type: str = "cosine", - id_type: Literal["UUID"] | Literal["TEXT"] = "UUID", - time_partition_interval: timedelta | None = None, - max_db_connections: int | None = None, - infer_filters: bool = True, - schema_name: str | None = None, - ) -> None: - """ - Initializes a async client for storing vector data. - - Parameters - ---------- - service_url - The connection string for the database. - table_name - The name of the table. - num_dimensions - The number of dimensions for the embedding vector. - distance_type - The distance type for indexing. - id_type - The type of the id column. Can be either 'UUID' or 'TEXT'. - time_partition_interval - The time interval for partitioning the table (optional). - infer_filters - Whether to infer start and end times from the special __start_date and __end_date filters. - schema_name - The schema name for the table (optional, uses the database's default schema if not specified). - """ - self.builder = QueryBuilder( - table_name, - num_dimensions, - distance_type, - id_type, - time_partition_interval, - infer_filters, - schema_name, - ) - self.service_url: str = service_url - self.pool: asyncpg.Pool | None = None - self.max_db_connections: int | None = max_db_connections - self.time_partition_interval: timedelta | None = time_partition_interval - - async def _default_max_db_connections(self) -> int: - """ - Gets a default value for the number of max db connections to use. - - Returns - ------- - None - """ - query = self.builder.default_max_db_connection_query() - conn = await asyncpg.connect(dsn=self.service_url) - num_connections = await conn.fetchval(query) - await conn.close() - return num_connections - - async def connect(self) -> asyncpg.pool.PoolAcquireContext: - """ - Establishes a connection to a PostgreSQL database using asyncpg. - - Returns - ------- - asyncpg.Connection: The established database connection. - """ - if self.pool is None: - if self.max_db_connections is None: - self.max_db_connections = await self._default_max_db_connections() - - async def init(conn: asyncpg.Connection) -> None: - await register_vector(conn) - # decode to a dict, but accept a string as input in upsert - await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog") - - self.pool = await asyncpg.create_pool( - dsn=self.service_url, - init=init, - min_size=1, - max_size=self.max_db_connections, - ) - return self.pool.acquire() - - async def close(self) -> None: - if self.pool is not None: - await self.pool.close() - - async def table_is_empty(self) -> bool: - """ - Checks if the table is empty. - - Returns - ------- - bool: True if the table is empty, False otherwise. - """ - query = self.builder.get_row_exists_query() - async with await self.connect() as pool: - rec = await pool.fetchrow(query) - return rec is None - - def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]: - metadata_is_dict = isinstance(records[0][1], dict) - if metadata_is_dict: - records = map(lambda item: Async._convert_record_meta_to_json(item), records) - - return records - - @staticmethod - def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]: - if not isinstance(item[1], dict): - raise ValueError("Cannot mix dictionary and string metadata fields in the same upsert") - return item[0], json.dumps(item[1]), item[2], item[3] - - async def upsert(self, records: list[tuple[Any, ...]]) -> None: - """ - Performs upsert operation for multiple records. - - Parameters - ---------- - records - List of records to upsert. Each record is a tuple of the form (id, metadata, contents, embedding). - - Returns - ------- - None - """ - records = self.munge_record(records) - query = self.builder.get_upsert_query() - async with await self.connect() as pool: - await pool.executemany(query, records) - - async def create_tables(self) -> None: - """ - Creates necessary tables. - - Returns - ------- - None - """ - query = self.builder.get_create_query() - # don't use a connection pool for this because the vector extension may not be installed yet - # and if it's not installed, register_vector will fail. - conn = await asyncpg.connect(dsn=self.service_url) - await conn.execute(query) - await conn.close() - - async def delete_all(self, drop_index: bool = True) -> None: - """ - Deletes all data. Also drops the index if `drop_index` is true. - - Returns - ------- - None - """ - if drop_index: - await self.drop_embedding_index() - query = self.builder.delete_all_query() - async with await self.connect() as pool: - await pool.execute(query) - - async def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> list[asyncpg.Record]: - """ - Delete records by id. - """ - (query, params) = self.builder.delete_by_ids_query(ids) - async with await self.connect() as pool: - return await pool.fetch(query, *params) - - async def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]) -> list[asyncpg.Record]: - """ - Delete records by metadata filters. - """ - (query, params) = self.builder.delete_by_metadata_query(filter) - async with await self.connect() as pool: - return await pool.fetch(query, *params) - - async def drop_table(self) -> None: - """ - Drops the table - - Returns - ------- - None - """ - query = self.builder.drop_table_query() - async with await self.connect() as pool: - await pool.execute(query) - - async def _get_approx_count(self) -> int: - """ - Retrieves an approximate count of records in the table. - - Returns - ------- - int: Approximate count of records. - """ - query = self.builder.get_approx_count_query() - async with await self.connect() as pool: - rec = await pool.fetchrow(query) - return rec[0] - - async def drop_embedding_index(self) -> None: - """ - Drop any index on the emedding - - Returns - ------- - None - """ - query = self.builder.drop_embedding_index_query() - async with await self.connect() as pool: - await pool.execute(query) - - async def create_embedding_index(self, index: BaseIndex) -> None: - """ - Creates an index for the table. - - Parameters - ---------- - index - The index to create. - - Returns - ------- - None - """ - # todo: can we make geting the records lazy? - num_records = await self._get_approx_count() - query = self.builder.create_embedding_index_query(index, lambda: num_records) - - async with await self.connect() as pool: - await pool.execute(query) - - async def search( - self, - query_embedding: list[float] | None = None, - limit: int = 10, - filter: dict[str, str] | list[dict[str, str]] | None = None, - predicates: Predicates | None = None, - uuid_time_filter: UUIDTimeRange | None = None, - query_params: QueryParams | None = None, - ) -> list[asyncpg.Record]: - """ - Retrieves similar records using a similarity query. - - Parameters - ---------- - query_embedding - The query embedding vector. - limit - The number of nearest neighbors to retrieve. - filter - A filter for metadata. Should be specified as a key-value object or a list of key-value objects - (where any objects in the list are matched). - predicates - A Predicates object to filter the results. Predicates support more complex queries than the filter - parameter. Predicates can be combined using logical operators (&, |, and ~). - uuid_time_filter - A UUIDTimeRange object to filter the results by time using the id column. - query_params - - Returns - ------- - List: List of similar records. - """ - (query, params) = self.builder.search_query(query_embedding, limit, filter, predicates, uuid_time_filter) - if query_params is not None: - async with await self.connect() as pool, pool.transaction(): - # Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588 - statements = query_params.get_statements() - for statement in statements: - await pool.execute(statement) - return await pool.fetch(query, *params) - else: - async with await self.connect() as pool: - return await pool.fetch(query, *params) - - -class Sync: - translated_queries: dict[str, str] = {} - - def __init__( - self, - service_url: str, - table_name: str, - num_dimensions: int, - distance_type: str = "cosine", - id_type: Literal["UUID"] | Literal["TEXT"] = "UUID", - time_partition_interval: timedelta | None = None, - max_db_connections: int | None = None, - infer_filters: bool = True, - schema_name: str | None = None, - ) -> None: - """ - Initializes a sync client for storing vector data. - - Parameters - ---------- - service_url - The connection string for the database. - table_name - The name of the table. - num_dimensions - The number of dimensions for the embedding vector. - distance_type - The distance type for indexing. - id_type - The type of the primary id column. Can be either 'UUID' or 'TEXT'. - time_partition_interval - The time interval for partitioning the table (optional). - infer_filters - Whether to infer start and end times from the special __start_date and __end_date filters. - schema_name - The schema name for the table (optional, uses the database's default schema if not specified). - """ - self.builder = QueryBuilder( - table_name, - num_dimensions, - distance_type, - id_type, - time_partition_interval, - infer_filters, - schema_name, - ) - self.service_url: str = service_url - self.pool: psycopg2.pool.SimpleConnectionPool | None = None - self.max_db_connections: int | None = max_db_connections - self.time_partition_interval: timedelta | None = time_partition_interval - psycopg2.extras.register_uuid() - - def default_max_db_connections(self) -> int: - """ - Gets a default value for the number of max db connections to use. - """ - query = self.builder.default_max_db_connection_query() - conn = psycopg2.connect(dsn=self.service_url) - with conn.cursor() as cur: - cur.execute(query) - num_connections = cur.fetchone() - conn.close() - return num_connections[0] - - @contextmanager - def connect(self) -> Iterable[psycopg2.extensions.connection]: - """ - Establishes a connection to a PostgreSQL database using psycopg2 and allows it's - use in a context manager. - """ - if self.pool is None: - if self.max_db_connections is None: - self.max_db_connections = self.default_max_db_connections() - - self.pool = psycopg2.pool.SimpleConnectionPool( - 1, - self.max_db_connections, - dsn=self.service_url, - cursor_factory=psycopg2.extras.DictCursor, - ) - - connection = self.pool.getconn() - pgvector.psycopg2.register_vector(connection) - try: - yield connection - connection.commit() - finally: - self.pool.putconn(connection) - - def close(self) -> None: - if self.pool is not None: - self.pool.closeall() - - def _translate_to_pyformat(self, query_string: str, params: list[Any] | None) -> tuple[str, dict[str, Any]]: - """ - Translates dollar sign number parameters and list parameters to pyformat strings. - - Args: - query_string (str): The query string with parameters. - params (list|None): List of parameter values. - - Returns: - str: The query string with translated pyformat parameters. - dict: A dictionary mapping parameter numbers to their values. - """ - - translated_params: dict[str, Any] = {} - if params is not None: - for idx, param in enumerate(params): - translated_params[str(idx + 1)] = param - - if query_string in self.translated_queries: - return self.translated_queries[query_string], translated_params - - dollar_params = re.findall(r"\$[0-9]+", query_string) - translated_string = query_string - for dollar_param in dollar_params: - # Extract the number after the $ - param_number = int(dollar_param[1:]) - pyformat_param = ("%s" if param_number == 0 else f"%({param_number})s") if params is not None else "%s" - translated_string = translated_string.replace(dollar_param, pyformat_param) - - self.translated_queries[query_string] = translated_string - return self.translated_queries[query_string], translated_params - - def table_is_empty(self) -> bool: - """ - Checks if the table is empty. - - Returns - ------- - bool: True if the table is empty, False otherwise. - """ - query = self.builder.get_row_exists_query() - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query) - rec = cur.fetchone() - return rec is None - - def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]: - metadata_is_dict = isinstance(records[0][1], dict) - if metadata_is_dict: - records = map(lambda item: Sync._convert_record_meta_to_json(item), records) - - return records - - @staticmethod - def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]: - if not isinstance(item[1], dict): - raise ValueError("Cannot mix dictionary and string metadata fields in the same upsert") - return item[0], json.dumps(item[1]), item[2], item[3] - - def upsert(self, records: list[tuple[Any, ...]]) -> None: - """ - Performs upsert operation for multiple records. - - Parameters - ---------- - records - Records to upsert. - - Returns - ------- - None - """ - records = self.munge_record(records) - query = self.builder.get_upsert_query() - query, _ = self._translate_to_pyformat(query, None) - with self.connect() as conn, conn.cursor() as cur: - cur.executemany(query, records) - - def create_tables(self) -> None: - """ - Creates necessary tables. - - Returns - ------- - None - """ - query = self.builder.get_create_query() - # don't use a connection pool for this because the vector extension may not be installed yet - # and if it's not installed, register_vector will fail. - conn = psycopg2.connect(dsn=self.service_url) - with conn.cursor() as cur: - cur.execute(query) - conn.commit() - conn.close() - - def delete_all(self, drop_index: bool = True) -> None: - """ - Deletes all data. Also drops the index if `drop_index` is true. - - Returns - ------- - None - """ - if drop_index: - self.drop_embedding_index() - query = self.builder.delete_all_query() - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query) - - def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> None: - """ - Delete records by id. - - Parameters - ---------- - ids - List of ids to delete. - """ - (query, params) = self.builder.delete_by_ids_query(ids) - query, params = self._translate_to_pyformat(query, params) - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query, params) - - def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]) -> None: - """ - Delete records by metadata filters. - """ - (query, params) = self.builder.delete_by_metadata_query(filter) - query, params = self._translate_to_pyformat(query, params) - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query, params) - - def drop_table(self) -> None: - """ - Drops the table - - Returns - ------- - None - """ - query = self.builder.drop_table_query() - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query) - - def _get_approx_count(self) -> int: - """ - Retrieves an approximate count of records in the table. - - Returns - ------- - int: Approximate count of records. - """ - query = self.builder.get_approx_count_query() - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query) - rec = cur.fetchone() - return rec[0] - - def drop_embedding_index(self) -> None: - """ - Drop any index on the emedding - - Returns - -------- - None - """ - query = self.builder.drop_embedding_index_query() - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query) - - def create_embedding_index(self, index: BaseIndex) -> None: - """ - Creates an index on the embedding for the table. - - Parameters - ---------- - index - The index to create. - - Returns - -------- - None - """ - query = self.builder.create_embedding_index_query(index, lambda: self._get_approx_count()) - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query) - - def search( - self, - query_embedding: list[float] | None = None, - limit: int = 10, - filter: dict[str, str] | list[dict[str, str]] | None = None, - predicates: Predicates | None = None, - uuid_time_filter: UUIDTimeRange | None = None, - query_params: QueryParams | None = None, - ) -> list[dict[str, Any]]: - """ - Retrieves similar records using a similarity query. - - Parameters - ---------- - query_embedding - The query embedding vector. - limit - The number of nearest neighbors to retrieve. - filter - A filter for metadata. Should be specified as a key-value object or a list of key-value objects - (where any objects in the list are matched). - predicates - A Predicates object to filter the results. Predicates support more complex queries - than the filter parameter. Predicates can be combined using logical operators (&, |, and ~). - - Returns - -------- - List: List of similar records. - """ - query_embedding_np = np.array(query_embedding) if query_embedding is not None else None - - (query, params) = self.builder.search_query(query_embedding_np, limit, filter, predicates, uuid_time_filter) - query, params = self._translate_to_pyformat(query, params) - - if query_params is not None: - prefix = "; ".join(query_params.get_statements()) - query = f"{prefix}; {query}" - - with self.connect() as conn, conn.cursor() as cur: - cur.execute(query, params) - return cur.fetchall() diff --git a/timescale_vector/client/__init__.py b/timescale_vector/client/__init__.py new file mode 100644 index 0000000..1758782 --- /dev/null +++ b/timescale_vector/client/__init__.py @@ -0,0 +1,44 @@ +__all__ = [ + "SEARCH_RESULT_ID_IDX", + "SEARCH_RESULT_METADATA_IDX", + "SEARCH_RESULT_CONTENTS_IDX", + "SEARCH_RESULT_EMBEDDING_IDX", + "SEARCH_RESULT_DISTANCE_IDX", + "uuid_from_time", + "BaseIndex", + "IvfflatIndex", + "HNSWIndex", + "DiskAnnIndex", + "QueryParams", + "DiskAnnIndexParams", + "IvfflatIndexParams", + "HNSWIndexParams", + "UUIDTimeRange", + "Predicates", + "QueryBuilder", + "Async", + "Sync", +] + +from timescale_vector.client.async_client import Async +from timescale_vector.client.index import ( + BaseIndex, + DiskAnnIndex, + DiskAnnIndexParams, + HNSWIndex, + HNSWIndexParams, + IvfflatIndex, + IvfflatIndexParams, + QueryParams, +) +from timescale_vector.client.predicates import Predicates +from timescale_vector.client.query_builder import QueryBuilder +from timescale_vector.client.sync_client import Sync +from timescale_vector.client.utils import uuid_from_time +from timescale_vector.client.uuid_time_range import UUIDTimeRange + +SEARCH_RESULT_ID_IDX = 0 +SEARCH_RESULT_METADATA_IDX = 1 +SEARCH_RESULT_CONTENTS_IDX = 2 +SEARCH_RESULT_EMBEDDING_IDX = 3 +SEARCH_RESULT_DISTANCE_IDX = 4 diff --git a/timescale_vector/client/async_client.py b/timescale_vector/client/async_client.py new file mode 100644 index 0000000..fb8d51d --- /dev/null +++ b/timescale_vector/client/async_client.py @@ -0,0 +1,300 @@ +import json +import uuid +from collections.abc import Iterable, Mapping +from datetime import datetime, timedelta +from typing import Any, Literal + +from asyncpg import Connection, Pool, Record, connect, create_pool +from asyncpg.pool import PoolAcquireContext +from pgvector.asyncpg import register_vector + +from timescale_vector.client.index import BaseIndex, QueryParams +from timescale_vector.client.predicates import Predicates +from timescale_vector.client.query_builder import QueryBuilder +from timescale_vector.client.uuid_time_range import UUIDTimeRange + + +class Async(QueryBuilder): + def __init__( + self, + service_url: str, + table_name: str, + num_dimensions: int, + distance_type: str = "cosine", + id_type: Literal["UUID"] | Literal["TEXT"] = "UUID", + time_partition_interval: timedelta | None = None, + max_db_connections: int | None = None, + infer_filters: bool = True, + schema_name: str | None = None, + ) -> None: + """ + Initializes a async client for storing vector data. + + Parameters + ---------- + service_url + The connection string for the database. + table_name + The name of the table. + num_dimensions + The number of dimensions for the embedding vector. + distance_type + The distance type for indexing. + id_type + The type of the id column. Can be either 'UUID' or 'TEXT'. + time_partition_interval + The time interval for partitioning the table (optional). + infer_filters + Whether to infer start and end times from the special __start_date and __end_date filters. + schema_name + The schema name for the table (optional, uses the database's default schema if not specified). + """ + self.builder = QueryBuilder( + table_name, + num_dimensions, + distance_type, + id_type, + time_partition_interval, + infer_filters, + schema_name, + ) + self.service_url: str = service_url + self.pool: Pool | None = None + self.max_db_connections: int | None = max_db_connections + self.time_partition_interval: timedelta | None = time_partition_interval + + async def _default_max_db_connections(self) -> int: + """ + Gets a default value for the number of max db connections to use. + + Returns + ------- + int + """ + query = self.builder.default_max_db_connection_query() + conn: Connection = await connect(dsn=self.service_url) + num_connections = await conn.fetchval(query) + await conn.close() + if num_connections is None: + return 10 + return num_connections # type: ignore + + async def connect(self) -> PoolAcquireContext: + """ + Establishes a connection to a PostgreSQL database using asyncpg. + + Returns + ------- + asyncpg.Connection: The established database connection. + """ + if self.pool is None: + if self.max_db_connections is None: + self.max_db_connections = await self._default_max_db_connections() + + async def init(conn: Connection) -> None: + await register_vector(conn) + # decode to a dict, but accept a string as input in upsert + await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog") + + self.pool = await create_pool( + dsn=self.service_url, + init=init, + min_size=1, + max_size=self.max_db_connections, + ) + + return self.pool.acquire() + + async def close(self) -> None: + if self.pool is not None: + await self.pool.close() + + async def table_is_empty(self) -> bool: + """ + Checks if the table is empty. + + Returns + ------- + bool: True if the table is empty, False otherwise. + """ + query = self.builder.get_row_exists_query() + async with await self.connect() as pool: + rec = await pool.fetchrow(query) + return rec is None + + def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]: + metadata_is_dict = isinstance(records[0][1], dict) + if metadata_is_dict: + munged_records = map(lambda item: Async._convert_record_meta_to_json(item), records) + + return munged_records if metadata_is_dict else records + + @staticmethod + def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]: + if not isinstance(item[1], dict): + raise ValueError("Cannot mix dictionary and string metadata fields in the same upsert") + return item[0], json.dumps(item[1]), item[2], item[3] + + async def upsert(self, records: list[tuple[Any, ...]]) -> None: + """ + Performs upsert operation for multiple records. + + Parameters + ---------- + records + List of records to upsert. Each record is a tuple of the form (id, metadata, contents, embedding). + + Returns + ------- + None + """ + munged_records = self.munge_record(records) + query = self.builder.get_upsert_query() + async with await self.connect() as pool: + await pool.executemany(query, munged_records) + + async def create_tables(self) -> None: + """ + Creates necessary tables. + + Returns + ------- + None + """ + query = self.builder.get_create_query() + # don't use a connection pool for this because the vector extension may not be installed yet + # and if it's not installed, register_vector will fail. + conn = await connect(dsn=self.service_url) + await conn.execute(query) + await conn.close() + + async def delete_all(self, drop_index: bool = True) -> None: + """ + Deletes all data. Also drops the index if `drop_index` is true. + + Returns + ------- + None + """ + if drop_index: + await self.drop_embedding_index() + query = self.builder.delete_all_query() + async with await self.connect() as pool: + await pool.execute(query) + + async def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> list[Record]: + """ + Delete records by id. + """ + (query, params) = self.builder.delete_by_ids_query(ids) + async with await self.connect() as pool: + return await pool.fetch(query, *params) # type: ignore + + async def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]) -> list[Record]: + """ + Delete records by metadata filters. + """ + (query, params) = self.builder.delete_by_metadata_query(filter) + async with await self.connect() as pool: + return await pool.fetch(query, *params) # type: ignore + + async def drop_table(self) -> None: + """ + Drops the table + + Returns + ------- + None + """ + query = self.builder.drop_table_query() + async with await self.connect() as pool: + await pool.execute(query) + + async def _get_approx_count(self) -> int: + """ + Retrieves an approximate count of records in the table. + + Returns + ------- + int: Approximate count of records. + """ + query = self.builder.get_approx_count_query() + async with await self.connect() as pool: + rec = await pool.fetchrow(query) + return rec[0] if rec is not None else 0 + + async def drop_embedding_index(self) -> None: + """ + Drop any index on the emedding + + Returns + ------- + None + """ + query = self.builder.drop_embedding_index_query() + async with await self.connect() as pool: + await pool.execute(query) + + async def create_embedding_index(self, index: BaseIndex) -> None: + """ + Creates an index for the table. + + Parameters + ---------- + index + The index to create. + + Returns + ------- + None + """ + # todo: can we make geting the records lazy? + num_records = await self._get_approx_count() + query = self.builder.create_embedding_index_query(index, lambda: num_records) + + async with await self.connect() as pool: + await pool.execute(query) + + async def search( + self, + query_embedding: list[float] | None = None, + limit: int = 10, + filter: Mapping[str, datetime | str] | list[dict[str, str]] | None = None, + predicates: Predicates | None = None, + uuid_time_filter: UUIDTimeRange | None = None, + query_params: QueryParams | None = None, + ) -> list[Record]: + """ + Retrieves similar records using a similarity query. + + Parameters + ---------- + query_embedding + The query embedding vector. + limit + The number of nearest neighbors to retrieve. + filter + A filter for metadata. Should be specified as a key-value object or a list of key-value objects + (where any objects in the list are matched). + predicates + A Predicates object to filter the results. Predicates support more complex queries than the filter + parameter. Predicates can be combined using logical operators (&, |, and ~). + uuid_time_filter + A UUIDTimeRange object to filter the results by time using the id column. + query_params + + Returns + ------- + List: List of similar records. + """ + (query, params) = self.builder.search_query(query_embedding, limit, filter, predicates, uuid_time_filter) + if query_params is not None: + async with await self.connect() as pool, pool.transaction(): + # Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588 + statements = query_params.get_statements() + for statement in statements: + await pool.execute(statement) + return await pool.fetch(query, *params) # type: ignore + else: + async with await self.connect() as pool: + return await pool.fetch(query, *params) # type: ignore diff --git a/timescale_vector/client/index.py b/timescale_vector/client/index.py new file mode 100644 index 0000000..c199301 --- /dev/null +++ b/timescale_vector/client/index.py @@ -0,0 +1,192 @@ +import math +from collections.abc import Callable +from typing import Any + +from typing_extensions import override + + +class BaseIndex: + def get_index_method(self, distance_type: str) -> str: + index_method = "invalid" + if distance_type == "<->": + index_method = "vector_l2_ops" + elif distance_type == "<#>": + index_method = "vector_ip_ops" + elif distance_type == "<=>": + index_method = "vector_cosine_ops" + else: + raise ValueError(f"Unknown distance type {distance_type}") + return index_method + + def create_index_query( + self, + table_name_quoted: str, + column_name_quoted: str, + index_name_quoted: str, + distance_type: str, + num_records_callback: Callable[[], int], + ) -> str: + raise NotImplementedError() + + +class IvfflatIndex(BaseIndex): + def __init__(self, num_records: int | None = None, num_lists: int | None = None) -> None: + """ + Pgvector's ivfflat index. + """ + self.num_records: int | None = num_records + self.num_lists: int | None = num_lists + + def get_num_records(self, num_record_callback: Callable[[], int]) -> int: + if self.num_records is not None: + return self.num_records + return num_record_callback() + + def get_num_lists(self, num_records_callback: Callable[[], int]) -> int: + if self.num_lists is not None: + return self.num_lists + + num_records = self.get_num_records(num_records_callback) + num_lists = num_records / 1000 + if num_lists < 10: + num_lists = 10 + if num_records > 1000000: + num_lists = math.sqrt(num_records) + return int(num_lists) + + def create_index_query( + self, + table_name_quoted: str, + column_name_quoted: str, + index_name_quoted: str, + distance_type: str, + num_records_callback: Callable[[], int], + ) -> str: + index_method = self.get_index_method(distance_type) + num_lists = self.get_num_lists(num_records_callback) + + return ( + f"CREATE INDEX {index_name_quoted} ON {table_name_quoted}" + f"USING ivfflat ({column_name_quoted} {index_method}) WITH (lists = {num_lists});" + ) + + +class HNSWIndex(BaseIndex): + def __init__(self, m: int | None = None, ef_construction: int | None = None) -> None: + """ + Pgvector's hnsw index. + """ + self.m: int | None = m + self.ef_construction: int | None = ef_construction + + @override + def create_index_query( + self, + table_name_quoted: str, + column_name_quoted: str, + index_name_quoted: str, + distance_type: str, + num_records_callback: Callable[[], int], + ) -> str: + index_method = self.get_index_method(distance_type) + + with_clauses: list[str] = [] + if self.m is not None: + with_clauses.append(f"m = {self.m}") + if self.ef_construction is not None: + with_clauses.append(f"ef_construction = {self.ef_construction}") + + with_clause = "" + if len(with_clauses) > 0: + with_clause = "WITH (" + ", ".join(with_clauses) + ")" + + return ( + f"CREATE INDEX {index_name_quoted} ON {table_name_quoted}" + f"USING hnsw ({column_name_quoted} {index_method}) {with_clause};" + ) + + +class DiskAnnIndex(BaseIndex): + def __init__( + self, + search_list_size: int | None = None, + num_neighbors: int | None = None, + max_alpha: float | None = None, + storage_layout: str | None = None, + num_dimensions: int | None = None, + num_bits_per_dimension: int | None = None, + ) -> None: + """ + Timescale's vector index. + """ + self.search_list_size: int | None = search_list_size + self.num_neighbors: int | None = num_neighbors + self.max_alpha: float | None = max_alpha + self.storage_layout: str | None = storage_layout + self.num_dimensions: int | None = num_dimensions + self.num_bits_per_dimension: int | None = num_bits_per_dimension + + @override + def create_index_query( + self, + table_name_quoted: str, + column_name_quoted: str, + index_name_quoted: str, + distance_type: str, + num_records_callback: Callable[[], int], + ) -> str: + if distance_type != "<=>": + raise ValueError( + f"Timescale's vector index only supports cosine distance, but distance_type was {distance_type}" + ) + + with_clauses: list[str] = [] + if self.search_list_size is not None: + with_clauses.append(f"search_list_size = {self.search_list_size}") + if self.num_neighbors is not None: + with_clauses.append(f"num_neighbors = {self.num_neighbors}") + if self.max_alpha is not None: + with_clauses.append(f"max_alpha = {self.max_alpha}") + if self.storage_layout is not None: + with_clauses.append(f"storage_layout = {self.storage_layout}") + if self.num_dimensions is not None: + with_clauses.append(f"num_dimensions = {self.num_dimensions}") + if self.num_bits_per_dimension is not None: + with_clauses.append(f"num_bits_per_dimension = {self.num_bits_per_dimension}") + + with_clause = "" + if len(with_clauses) > 0: + with_clause = "WITH (" + ", ".join(with_clauses) + ")" + + return ( + f"CREATE INDEX {index_name_quoted} ON {table_name_quoted}" + f"USING diskann ({column_name_quoted}) {with_clause};" + ) + + +class QueryParams: + def __init__(self, params: dict[str, Any]) -> None: + self.params: dict[str, Any] = params + + def get_statements(self) -> list[str]: + return ["SET LOCAL " + key + " = " + str(value) for key, value in self.params.items()] + + +class DiskAnnIndexParams(QueryParams): + def __init__(self, search_list_size: int | None = None, rescore: int | None = None) -> None: + params: dict[str, Any] = {} + if search_list_size is not None: + params["diskann.query_search_list_size"] = search_list_size + if rescore is not None: + params["diskann.query_rescore"] = rescore + super().__init__(params) + + +class IvfflatIndexParams(QueryParams): + def __init__(self, probes: int) -> None: + super().__init__({"ivfflat.probes": probes}) + + +class HNSWIndexParams(QueryParams): + def __init__(self, ef_search: int) -> None: + super().__init__({"hnsw.ef_search": ef_search}) diff --git a/timescale_vector/client/predicates.py b/timescale_vector/client/predicates.py new file mode 100644 index 0000000..6d9f688 --- /dev/null +++ b/timescale_vector/client/predicates.py @@ -0,0 +1,176 @@ +import json +from datetime import datetime +from typing import Any, Literal, Union + + +class Predicates: + logical_operators: dict[str, str] = { + "AND": "AND", + "OR": "OR", + "NOT": "NOT", + } + + operators_mapping: dict[str, str] = { + "=": "=", + "==": "=", + ">=": ">=", + ">": ">", + "<=": "<=", + "<": "<", + "!=": "<>", + "@>": "@>", # array contains + } + + PredicateValue = str | int | float | datetime | list | tuple # type: ignore + + def __init__( + self, + *clauses: Union[ + "Predicates", + tuple[str, PredicateValue], + tuple[str, str, PredicateValue], + str, + PredicateValue, + ], + operator: Literal["AND", "OR", "NOT"] = "AND", + ): + """ + Predicates class defines predicates on the object metadata. + Predicates can be combined using logical operators (&, |, and ~). + + Parameters + ---------- + clauses + Predicate clauses. Can be either another Predicates object + or a tuple of the form (field, operator, value) or (field, value). + Operator + Logical operator to use when combining the clauses. + Can be one of 'AND', 'OR', 'NOT'. Defaults to 'AND'. + """ + if operator not in self.logical_operators: + raise ValueError(f"invalid operator: {operator}") + self.operator: str = operator + if isinstance(clauses[0], str): + if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)): + raise ValueError(f"Invalid clause format: {clauses}") + self.clauses: list[ + Predicates + | tuple[str, Predicates.PredicateValue] + | tuple[str, str, Predicates.PredicateValue] + | str + | Predicates.PredicateValue + ] = [clauses] + else: + self.clauses = list(clauses) + + def add_clause( + self, + *clause: Union[ + "Predicates", + tuple[str, PredicateValue], + tuple[str, str, PredicateValue], + str, + PredicateValue, + ], + ) -> None: + """ + Add a clause to the predicates object. + + Parameters + ---------- + clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str] + Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) + or (field, value). + """ + if isinstance(clause[0], str): + if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)): + raise ValueError(f"Invalid clause format: {clause}") + self.clauses.append(clause) + else: + self.clauses.extend(list(clause)) + + def __and__(self, other: "Predicates") -> "Predicates": + new_predicates = Predicates(self, other, operator="AND") + return new_predicates + + def __or__(self, other: "Predicates") -> "Predicates": + new_predicates = Predicates(self, other, operator="OR") + return new_predicates + + def __invert__(self) -> "Predicates": + new_predicates = Predicates(self, operator="NOT") + return new_predicates + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Predicates): + return False + + return self.operator == other.operator and self.clauses == other.clauses + + def __repr__(self) -> str: + if self.operator: + return f"{self.operator}({', '.join(repr(clause) for clause in self.clauses)})" + else: + return repr(self.clauses) + + def build_query(self, params: list[Any]) -> tuple[str, list[Any]]: + """ + Build the SQL query string and parameters for the predicates object. + """ + if not self.clauses: + return "", [] + + where_conditions: list[str] = [] + + for clause in self.clauses: + if isinstance(clause, Predicates): + child_where_clause, params = clause.build_query(params) + where_conditions.append(f"({child_where_clause})") + elif isinstance(clause, tuple): + if len(clause) == 2: + field, value = clause + operator = "=" # Default operator + elif len(clause) == 3: + field, operator, value = clause + if operator not in self.operators_mapping: + raise ValueError(f"Invalid operator: {operator}") + operator = self.operators_mapping[operator] + else: + raise ValueError("Invalid clause format") + + index = len(params) + 1 + param_name = f"${index}" + + if field == "__uuid_timestamp": + # convert str to timestamp in the database, it's better at it than python + if isinstance(value, str): + where_conditions.append(f"uuid_timestamp(id) {operator} ({param_name}::text)::timestamptz") + else: + where_conditions.append(f"uuid_timestamp(id) {operator} {param_name}") + params.append(value) + + elif operator == "@>" and isinstance(value, list | tuple): + if len(value) == 0: + raise ValueError("Invalid value. Empty lists and empty tuples are not supported.") + json_value = json.dumps(value) + where_conditions.append(f"metadata @> jsonb_build_object('{field}', {param_name}::jsonb)") + params.append(json_value) + + else: + field_cast = "" + if isinstance(value, int): + field_cast = "::int" + elif isinstance(value, float): + field_cast = "::numeric" + elif isinstance(value, datetime): + field_cast = "::timestamptz" + where_conditions.append(f"(metadata->>'{field}'){field_cast} {operator} {param_name}") + params.append(value) + + if self.operator == "NOT": + or_clauses = " OR ".join(where_conditions) + # use IS DISTINCT FROM to treat all-null clauses as False and pass the filter + where_clause = f"TRUE IS DISTINCT FROM ({or_clauses})" + else: + where_clause = (" " + self.operator + " ").join(where_conditions) + return where_clause, params diff --git a/timescale_vector/client/query_builder.py b/timescale_vector/client/query_builder.py new file mode 100644 index 0000000..38a2c76 --- /dev/null +++ b/timescale_vector/client/query_builder.py @@ -0,0 +1,338 @@ +import json +import uuid +from collections.abc import Callable, Mapping +from datetime import datetime, timedelta +from typing import Any + +import numpy as np + +from timescale_vector.client.index import BaseIndex +from timescale_vector.client.predicates import Predicates +from timescale_vector.client.uuid_time_range import UUIDTimeRange + + +class QueryBuilder: + def __init__( + self, + table_name: str, + num_dimensions: int, + distance_type: str, + id_type: str, + time_partition_interval: timedelta | None, + infer_filters: bool, + schema_name: str | None, + ) -> None: + """ + Initializes a base Vector object to generate queries for vector clients. + + Parameters + ---------- + table_name + The name of the table. + num_dimensions + The number of dimensions for the embedding vector. + distance_type + The distance type for indexing. + id_type + The type of the id column. Can be either 'UUID' or 'TEXT'. + time_partition_interval + The time interval for partitioning the table (optional). + infer_filters + Whether to infer start and end times from the special __start_date and __end_date filters. + schema_name + The schema name for the table (optional, uses the database's default schema if not specified). + """ + self.table_name: str = table_name + self.schema_name: str | None = schema_name + self.num_dimensions: int = num_dimensions + if distance_type == "cosine" or distance_type == "<=>": + self.distance_type: str = "<=>" + elif distance_type == "euclidean" or distance_type == "<->" or distance_type == "l2": + self.distance_type = "<->" + else: + raise ValueError(f"unrecognized distance_type {distance_type}") + + if id_type.lower() != "uuid" and id_type.lower() != "text": + raise ValueError(f"unrecognized id_type {id_type}") + + if time_partition_interval is not None and id_type.lower() != "uuid": + raise ValueError("time partitioning is only supported for uuid id_type") + + self.id_type: str = id_type.lower() + self.time_partition_interval: timedelta | None = time_partition_interval + self.infer_filters: bool = infer_filters + + @staticmethod + def _quote_ident(ident: str) -> str: + """ + Quotes an identifier to prevent SQL injection. + + Parameters + ---------- + ident + The identifier to be quoted. + + Returns + ------- + str: The quoted identifier. + """ + return '"{}"'.format(ident.replace('"', '""')) + + def _quoted_table_name(self) -> str: + if self.schema_name is not None: + return self._quote_ident(self.schema_name) + "." + self._quote_ident(self.table_name) + else: + return self._quote_ident(self.table_name) + + def get_row_exists_query(self) -> str: + """ + Generates a query to check if any rows exist in the table. + + Returns + ------- + str: The query to check for row existence. + """ + return f"SELECT 1 FROM {self._quoted_table_name()} LIMIT 1" + + def get_upsert_query(self) -> str: + """ + Generates an upsert query. + + Returns + ------- + str: The upsert query. + """ + return ( + f"INSERT INTO {self._quoted_table_name()} (id, metadata, contents, embedding) " + f"VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING" + ) + + def get_approx_count_query(self) -> str: + """ + Generate a query to find the approximate count of records in the table. + + Returns + ------- + str: the query. + """ + # todo optimize with approx + return f"SELECT COUNT(*) as cnt FROM {self._quoted_table_name()}" + + def get_create_query(self) -> str: + """ + Generates a query to create the tables, indexes, and extensions needed to store the vector data. + + Returns + ------- + str: The create table query. + """ + hypertable_sql = "" + if self.time_partition_interval is not None: + hypertable_sql = f""" + CREATE EXTENSION IF NOT EXISTS timescaledb; + + CREATE OR REPLACE FUNCTION public.uuid_timestamp(uuid UUID) RETURNS TIMESTAMPTZ AS $$ + DECLARE + bytes bytea; + BEGIN + bytes := uuid_send(uuid); + if (get_byte(bytes, 6) >> 4)::int2 != 1 then + RAISE EXCEPTION 'UUID version is not 1'; + end if; + RETURN to_timestamp( + ( + ( + (get_byte(bytes, 0)::bigint << 24) | + (get_byte(bytes, 1)::bigint << 16) | + (get_byte(bytes, 2)::bigint << 8) | + (get_byte(bytes, 3)::bigint << 0) + ) + ( + ((get_byte(bytes, 4)::bigint << 8 | + get_byte(bytes, 5)::bigint)) << 32 + ) + ( + (((get_byte(bytes, 6)::bigint & 15) << 8 | get_byte(bytes, 7)::bigint) & 4095) << 48 + ) - 122192928000000000 + ) / 10000 / 1000::double precision + ); + END + $$ LANGUAGE plpgsql + IMMUTABLE PARALLEL SAFE + RETURNS NULL ON NULL INPUT; + + SELECT create_hypertable('{self._quoted_table_name()}', + 'id', + if_not_exists=> true, + time_partitioning_func=>'public.uuid_timestamp', + chunk_time_interval => '{str(self.time_partition_interval.total_seconds())} seconds'::interval); + """ + return f""" +CREATE EXTENSION IF NOT EXISTS vector; +CREATE EXTENSION IF NOT EXISTS vectorscale; + + +CREATE TABLE IF NOT EXISTS {self._quoted_table_name()} ( + id {self.id_type} PRIMARY KEY, + metadata JSONB, + contents TEXT, + embedding VECTOR({self.num_dimensions}) +); + +CREATE INDEX IF NOT EXISTS {self._quote_ident(self.table_name + "_meta_idx")} ON {self._quoted_table_name()} +USING GIN(metadata jsonb_path_ops); + +{hypertable_sql} +""" + + def _get_embedding_index_name_quoted(self) -> str: + return self._quote_ident(self.table_name + "_embedding_idx") + + def _get_schema_qualified_embedding_index_name_quoted(self) -> str: + if self.schema_name is not None: + return self._quote_ident(self.schema_name) + "." + self._get_embedding_index_name_quoted() + else: + return self._get_embedding_index_name_quoted() + + def drop_embedding_index_query(self) -> str: + return f"DROP INDEX IF EXISTS {self._get_schema_qualified_embedding_index_name_quoted()};" + + def delete_all_query(self) -> str: + return f"TRUNCATE {self._quoted_table_name()};" + + def delete_by_ids_query(self, ids: list[uuid.UUID] | list[str]) -> tuple[str, list[Any]]: + query = f"DELETE FROM {self._quoted_table_name()} WHERE id = ANY($1::{self.id_type}[]);" + return (query, [ids]) + + def delete_by_metadata_query( + self, filter_conditions: dict[str, str] | list[dict[str, str]] + ) -> tuple[str, list[Any]]: + params: list[Any] = [] + (where, params) = self._where_clause_for_filter(params, filter_conditions) + query = f"DELETE FROM {self._quoted_table_name()} WHERE {where};" + return (query, params) + + def drop_table_query(self) -> str: + return f"DROP TABLE IF EXISTS {self._quoted_table_name()};" + + def default_max_db_connection_query(self) -> str: + """ + Generates a query to get the default max db connections. This uses a heuristic to determine the max connections + based on the max_connections setting in postgres + and the number of currently used connections. This heuristic leaves 4 connections in reserve. + """ + return ( + "SELECT greatest(1, ((SELECT setting::int FROM pg_settings " + "WHERE name='max_connections')-(SELECT count(*) FROM pg_stat_activity) - 4)::int)" + ) + + def create_embedding_index_query(self, index: BaseIndex, num_records_callback: Callable[[], int]) -> str: + """ + Generates an embedding index creation query. + + Parameters + ---------- + index + The index to create. + num_records_callback + A callback function to get the number of records in the table. + + Returns + ------- + str: The index creation query. + """ + column_name = "embedding" + index_name_quoted = self._get_embedding_index_name_quoted() + query = index.create_index_query( + self._quoted_table_name(), + self._quote_ident(column_name), + index_name_quoted, + self.distance_type, + num_records_callback, + ) + return query + + def _where_clause_for_filter( + self, params: list[Any], filter: Mapping[str, datetime | str] | list[dict[str, str]] | None + ) -> tuple[str, list[Any]]: + if filter is None: + return "TRUE", params + + if isinstance(filter, dict): + where = f"metadata @> ${len(params)+1}" + json_object = json.dumps(filter) + params = params + [json_object] + elif isinstance(filter, list): + any_params = [] + for _idx, filter_dict in enumerate(filter, start=len(params) + 1): + any_params.append(json.dumps(filter_dict)) + where = f"metadata @> ANY(${len(params) + 1}::jsonb[])" + params = params + [any_params] + else: + raise ValueError(f"Unknown filter type: {type(filter)}") + + return where, params + + def search_query( + self, + query_embedding: list[float] | np.ndarray[Any, Any] | None, + limit: int = 10, + filter: Mapping[str, datetime | str] | list[dict[str, str]] | None = None, + predicates: Predicates | None = None, + uuid_time_filter: UUIDTimeRange | None = None, + ) -> tuple[str, list[Any]]: + """ + Generates a similarity query. + + Returns: + Tuple[str, List]: A tuple containing the query and parameters. + """ + params: list[Any] = [] + if query_embedding is not None: + distance = f"embedding {self.distance_type} ${len(params)+1}" + params = params + [query_embedding] + order_by_clause = f"ORDER BY {distance} ASC" + else: + distance = "-1.0" + order_by_clause = "" + + if ( + self.infer_filters + and uuid_time_filter is None + and isinstance(filter, dict) + and ("__start_date" in filter or "__end_date" in filter) + ): + start_date = UUIDTimeRange._parse_datetime(filter.get("__start_date")) + end_date = UUIDTimeRange._parse_datetime(filter.get("__end_date")) + + uuid_time_filter = UUIDTimeRange(start_date, end_date) + + if start_date is not None: + del filter["__start_date"] + if end_date is not None: + del filter["__end_date"] + + where_clauses = [] + if filter is not None: + (where_filter, params) = self._where_clause_for_filter(params, filter) + where_clauses.append(where_filter) + + if predicates is not None: + (where_predicates, params) = predicates.build_query(params) + where_clauses.append(where_predicates) + + if uuid_time_filter is not None: + (where_time, params) = uuid_time_filter.build_query(params) + where_clauses.append(where_time) + + where = " AND ".join(where_clauses) if len(where_clauses) > 0 else "TRUE" + + query = f""" + SELECT + id, metadata, contents, embedding, {distance} as distance + FROM + {self._quoted_table_name()} + WHERE + {where} + {order_by_clause} + LIMIT {limit} + """ + return query, params diff --git a/timescale_vector/client/sync_client.py b/timescale_vector/client/sync_client.py new file mode 100644 index 0000000..59511e9 --- /dev/null +++ b/timescale_vector/client/sync_client.py @@ -0,0 +1,341 @@ +import json +import re +import uuid +from collections.abc import Iterable, Iterator, Mapping +from contextlib import contextmanager +from datetime import datetime, timedelta +from typing import Any, Literal + +import numpy as np +import pgvector.psycopg2 +import psycopg2.extras +import psycopg2.pool +from numpy import ndarray + +from timescale_vector.client.index import BaseIndex, QueryParams +from timescale_vector.client.predicates import Predicates +from timescale_vector.client.query_builder import QueryBuilder +from timescale_vector.client.uuid_time_range import UUIDTimeRange + + +class Sync: + translated_queries: dict[str, str] = {} + + def __init__( + self, + service_url: str, + table_name: str, + num_dimensions: int, + distance_type: str = "cosine", + id_type: Literal["UUID"] | Literal["TEXT"] = "UUID", + time_partition_interval: timedelta | None = None, + max_db_connections: int | None = None, + infer_filters: bool = True, + schema_name: str | None = None, + ) -> None: + """ + Initializes a sync client for storing vector data. + + Parameters + ---------- + service_url + The connection string for the database. + table_name + The name of the table. + num_dimensions + The number of dimensions for the embedding vector. + distance_type + The distance type for indexing. + id_type + The type of the primary id column. Can be either 'UUID' or 'TEXT'. + time_partition_interval + The time interval for partitioning the table (optional). + infer_filters + Whether to infer start and end times from the special __start_date and __end_date filters. + schema_name + The schema name for the table (optional, uses the database's default schema if not specified). + """ + self.builder = QueryBuilder( + table_name, + num_dimensions, + distance_type, + id_type, + time_partition_interval, + infer_filters, + schema_name, + ) + self.service_url: str = service_url + self.pool: psycopg2.pool.SimpleConnectionPool | None = None + self.max_db_connections: int | None = max_db_connections + self.time_partition_interval: timedelta | None = time_partition_interval + psycopg2.extras.register_uuid() + + def default_max_db_connections(self) -> int: + """ + Gets a default value for the number of max db connections to use. + """ + query = self.builder.default_max_db_connection_query() + conn = psycopg2.connect(dsn=self.service_url) + with conn.cursor() as cur: + cur.execute(query) + num_connections = cur.fetchone() + conn.close() + return num_connections[0] # type: ignore + + @contextmanager + def connect(self) -> Iterator[psycopg2.extensions.connection]: + """ + Establishes a connection to a PostgreSQL database using psycopg2 and allows it's + use in a context manager. + """ + if self.pool is None: + if self.max_db_connections is None: + self.max_db_connections = self.default_max_db_connections() + + self.pool = psycopg2.pool.SimpleConnectionPool( + 1, + self.max_db_connections, + dsn=self.service_url, + cursor_factory=psycopg2.extras.DictCursor, + ) + + connection = self.pool.getconn() + pgvector.psycopg2.register_vector(connection) + try: + yield connection + connection.commit() + finally: + self.pool.putconn(connection) + + def close(self) -> None: + if self.pool is not None: + self.pool.closeall() + + def _translate_to_pyformat(self, query_string: str, params: list[Any] | None) -> tuple[str, dict[str, Any]]: + """ + Translates dollar sign number parameters and list parameters to pyformat strings. + + Args: + query_string (str): The query string with parameters. + params (list|None): List of parameter values. + + Returns: + str: The query string with translated pyformat parameters. + dict: A dictionary mapping parameter numbers to their values. + """ + + translated_params: dict[str, Any] = {} + if params is not None: + for idx, param in enumerate(params): + translated_params[str(idx + 1)] = param + + if query_string in self.translated_queries: + return self.translated_queries[query_string], translated_params + + dollar_params = re.findall(r"\$[0-9]+", query_string) + translated_string = query_string + for dollar_param in dollar_params: + # Extract the number after the $ + param_number = int(dollar_param[1:]) + pyformat_param = ("%s" if param_number == 0 else f"%({param_number})s") if params is not None else "%s" + translated_string = translated_string.replace(dollar_param, pyformat_param) + + self.translated_queries[query_string] = translated_string + return self.translated_queries[query_string], translated_params + + def table_is_empty(self) -> bool: + """ + Checks if the table is empty. + + Returns + ------- + bool: True if the table is empty, False otherwise. + """ + query = self.builder.get_row_exists_query() + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + rec = cur.fetchone() + return rec is None + + def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]: + metadata_is_dict = isinstance(records[0][1], dict) + if metadata_is_dict: + munged_records = map(lambda item: Sync._convert_record_meta_to_json(item), records) + + return munged_records if metadata_is_dict else records + + @staticmethod + def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]: + if not isinstance(item[1], dict): + raise ValueError("Cannot mix dictionary and string metadata fields in the same upsert") + return item[0], json.dumps(item[1]), item[2], item[3] + + def upsert(self, records: list[tuple[Any, ...]]) -> None: + """ + Performs upsert operation for multiple records. + + Parameters + ---------- + records + Records to upsert. + + Returns + ------- + None + """ + munged_records = self.munge_record(records) + query = self.builder.get_upsert_query() + query, _ = self._translate_to_pyformat(query, None) + with self.connect() as conn, conn.cursor() as cur: + cur.executemany(query, munged_records) + + def create_tables(self) -> None: + """ + Creates necessary tables. + + Returns + ------- + None + """ + query = self.builder.get_create_query() + # don't use a connection pool for this because the vector extension may not be installed yet + # and if it's not installed, register_vector will fail. + conn = psycopg2.connect(dsn=self.service_url) + with conn.cursor() as cur: + cur.execute(query) + conn.commit() + conn.close() + + def delete_all(self, drop_index: bool = True) -> None: + """ + Deletes all data. Also drops the index if `drop_index` is true. + + Returns + ------- + None + """ + if drop_index: + self.drop_embedding_index() + query = self.builder.delete_all_query() + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + + def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> None: + """ + Delete records by id. + + Parameters + ---------- + ids + List of ids to delete. + """ + (query, params) = self.builder.delete_by_ids_query(ids) + translated_query, translated_params = self._translate_to_pyformat(query, params) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(translated_query, translated_params) + + def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]) -> None: + """ + Delete records by metadata filters. + """ + (query, params) = self.builder.delete_by_metadata_query(filter) + translated_query, translated_params = self._translate_to_pyformat(query, params) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(translated_query, translated_params) + + def drop_table(self) -> None: + """ + Drops the table + + Returns + ------- + None + """ + query = self.builder.drop_table_query() + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + + def _get_approx_count(self) -> int: + """ + Retrieves an approximate count of records in the table. + + Returns + ------- + int: Approximate count of records. + """ + query = self.builder.get_approx_count_query() + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + rec = cur.fetchone() + return rec[0] if rec is not None else 0 + + def drop_embedding_index(self) -> None: + """ + Drop any index on the emedding + + Returns + -------- + None + """ + query = self.builder.drop_embedding_index_query() + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + + def create_embedding_index(self, index: BaseIndex) -> None: + """ + Creates an index on the embedding for the table. + + Parameters + ---------- + index + The index to create. + + Returns + -------- + None + """ + query = self.builder.create_embedding_index_query(index, lambda: self._get_approx_count()) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + + def search( + self, + query_embedding: ndarray[Any, Any] | list[float] | None = None, + limit: int = 10, + filter: Mapping[str, datetime | str] | list[dict[str, str]] | None = None, + predicates: Predicates | None = None, + uuid_time_filter: UUIDTimeRange | None = None, + query_params: QueryParams | None = None, + ) -> list[tuple[Any, ...]]: + """ + Retrieves similar records using a similarity query. + + Parameters + ---------- + query_embedding + The query embedding vector. + limit + The number of nearest neighbors to retrieve. + filter + A filter for metadata. Should be specified as a key-value object or a list of key-value objects + (where any objects in the list are matched). + predicates + A Predicates object to filter the results. Predicates support more complex queries + than the filter parameter. Predicates can be combined using logical operators (&, |, and ~). + + Returns + -------- + List: List of similar records. + """ + query_embedding_np = np.array(query_embedding) if query_embedding is not None else None + + (query, params) = self.builder.search_query(query_embedding_np, limit, filter, predicates, uuid_time_filter) + translated_query, translated_params = self._translate_to_pyformat(query, params) + + if query_params is not None: + prefix = "; ".join(query_params.get_statements()) + translated_query = f"{prefix}; {translated_query}" + + with self.connect() as conn, conn.cursor() as cur: + cur.execute(translated_query, translated_params) + return cur.fetchall() diff --git a/timescale_vector/client/utils.py b/timescale_vector/client/utils.py new file mode 100644 index 0000000..7ca7a71 --- /dev/null +++ b/timescale_vector/client/utils.py @@ -0,0 +1,76 @@ +import calendar +import random +import uuid +from datetime import datetime, timezone +from typing import Any + + +# copied from Cassandra: https://docs.datastax.com/en/drivers/python/3.2/_modules/cassandra/util.html#uuid_from_time +def uuid_from_time( + time_arg: float | datetime | None = None, node: Any = None, clock_seq: int | None = None +) -> uuid.UUID: + """ + Converts a datetime or timestamp to a type 1 `uuid.UUID`. + + Parameters + ---------- + time_arg + The time to use for the timestamp portion of the UUID. + This can either be a `datetime` object or a timestamp in seconds + (as returned from `time.time()`). + node + Bytes for the UUID (up to 48 bits). If not specified, this + field is randomized. + clock_seq + Clock sequence field for the UUID (up to 14 bits). If not specified, + a random sequence is generated. + + Returns + ------- + uuid.UUID: For the given time, node, and clock sequence + """ + if time_arg is None: + return uuid.uuid1(node, clock_seq) + if hasattr(time_arg, "utctimetuple"): + # this is different from the Cassandra version, + # we assume that a naive datetime is in system time and convert it to UTC + # we do this because naive datetimes are interpreted as timestamps (without timezone) in postgres + time_arg_dt: datetime = time_arg # type: ignore + if time_arg_dt.tzinfo is None: + time_arg_dt = time_arg_dt.astimezone(timezone.utc) + seconds = int(calendar.timegm(time_arg_dt.utctimetuple())) + microseconds = (seconds * 1e6) + time_arg_dt.time().microsecond + else: + microseconds = int(float(time_arg) * 1e6) + + # 0x01b21dd213814000 is the number of 100-ns intervals between the + # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. + intervals = int(microseconds * 10) + 0x01B21DD213814000 + + time_low = intervals & 0xFFFFFFFF + time_mid = (intervals >> 32) & 0xFFFF + time_hi_version = (intervals >> 48) & 0x0FFF + + if clock_seq is None: + clock_seq = random.getrandbits(14) + else: + if clock_seq > 0x3FFF: + raise ValueError("clock_seq is out of range (need a 14-bit value)") + + clock_seq_low = clock_seq & 0xFF + clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3F) + + if node is None: + node = random.getrandbits(48) + + return uuid.UUID( + fields=( + time_low, + time_mid, + time_hi_version, + clock_seq_hi_variant, + clock_seq_low, + node, + ), + version=1, + ) diff --git a/timescale_vector/client/uuid_time_range.py b/timescale_vector/client/uuid_time_range.py new file mode 100644 index 0000000..2c5a5ab --- /dev/null +++ b/timescale_vector/client/uuid_time_range.py @@ -0,0 +1,99 @@ +from datetime import datetime, timedelta, timezone +from typing import Any + + +class UUIDTimeRange: + @staticmethod + def _parse_datetime(input_datetime: datetime | str | None | Any) -> datetime | None: + """ + Parse a datetime object or string representation of a datetime. + + Args: + input_datetime (datetime or str): Input datetime or string. + + Returns: + datetime: Parsed datetime object. + + Raises: + ValueError: If the input cannot be parsed as a datetime. + """ + if input_datetime is None or input_datetime == "None": + return None + + if isinstance(input_datetime, datetime): + # If input is already a datetime object, return it as is + return input_datetime + + if isinstance(input_datetime, str): + try: + # Attempt to parse the input string into a datetime + return datetime.fromisoformat(input_datetime) + except ValueError: + raise ValueError(f"Invalid datetime string format: {input_datetime}") from None + + raise ValueError("Input must be a datetime object or string") + + def __init__( + self, + start_date: datetime | str | None = None, + end_date: datetime | str | None = None, + time_delta: timedelta | None = None, + start_inclusive: bool = True, + end_inclusive: bool = False, + ): + """ + A UUIDTimeRange is a time range predicate on the UUID Version 1 timestamps. + + Note that naive datetime objects are interpreted as local time on the python client side + and converted to UTC before being sent to the database. + """ + start_date = UUIDTimeRange._parse_datetime(start_date) + end_date = UUIDTimeRange._parse_datetime(end_date) + + if start_date is not None and end_date is not None and start_date > end_date: + raise Exception("start_date must be before end_date") + + if start_date is None and end_date is None: + raise Exception("start_date and end_date cannot both be None") + + if start_date is not None and start_date.tzinfo is None: + start_date = start_date.astimezone(timezone.utc) + + if end_date is not None and end_date.tzinfo is None: + end_date = end_date.astimezone(timezone.utc) + + if time_delta is not None: + if end_date is None and start_date is not None: + end_date = start_date + time_delta + elif start_date is None and end_date is not None: + start_date = end_date - time_delta + else: + raise Exception("time_delta, start_date and end_date cannot all be specified at the same time") + + self.start_date: datetime | None = start_date + self.end_date: datetime | None = end_date + self.start_inclusive: bool = start_inclusive + self.end_inclusive: bool = end_inclusive + + def __str__(self) -> str: + start_str = f"[{self.start_date}" if self.start_inclusive else f"({self.start_date}" + end_str = f"{self.end_date}]" if self.end_inclusive else f"{self.end_date})" + + return f"UUIDTimeRange {start_str}, {end_str}" + + def build_query(self, params: list[Any]) -> tuple[str, list[Any]]: + column = "uuid_timestamp(id)" + queries: list[str] = [] + if self.start_date is not None: + if self.start_inclusive: + queries.append(f"{column} >= ${len(params)+1}") + else: + queries.append(f"{column} > ${len(params)+1}") + params.append(self.start_date) + if self.end_date is not None: + if self.end_inclusive: + queries.append(f"{column} <= ${len(params)+1}") + else: + queries.append(f"{column} < ${len(params)+1}") + params.append(self.end_date) + return " AND ".join(queries), params diff --git a/uv.lock b/uv.lock index 3205a96..6e31814 100644 --- a/uv.lock +++ b/uv.lock @@ -728,21 +728,46 @@ wheels = [ ] [[package]] -name = "mypy-extensions" -version = "1.0.0" +name = "mypy" +version = "1.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +dependencies = [ + { name = "mypy-extensions" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f9/70/196a3339459fe22296ac9a883bbd998fcaf0db3e8d9a54cf4f53b722cad4/mypy-1.12.0.tar.gz", hash = "sha256:65a22d87e757ccd95cbbf6f7e181e6caa87128255eb2b6be901bb71b26d8a99d", size = 3149879 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/6d/9751ed6d77b42a5d704224fbadf6f1a18b5ab655c012d17bc8af819a7f06/mypy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4397081e620dc4dc18e2f124d5e1d2c288194c2c08df6bdb1db31c38cd1fe1ed", size = 11017763 }, + { url = "https://files.pythonhosted.org/packages/74/03/5fa6824555460f74873a414c7f42332c219fdfcfbd63b55b2442794b634b/mypy-1.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:684a9c508a283f324804fea3f0effeb7858eb03f85c4402a967d187f64562469", size = 10181032 }, + { url = "https://files.pythonhosted.org/packages/89/56/20d3136d6904c369422423d267c5ceb312487586cdd81e90bf7e237b67e7/mypy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cabe4cda2fa5eca7ac94854c6c37039324baaa428ecbf4de4567279e9810f9e", size = 12587243 }, + { url = "https://files.pythonhosted.org/packages/53/cb/64043dec34fbcecaced207b077b8e5041e263da43003cc6309c90bc5e26e/mypy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:060a07b10e999ac9e7fa249ce2bdcfa9183ca2b70756f3bce9df7a92f78a3c0a", size = 13105170 }, + { url = "https://files.pythonhosted.org/packages/5e/59/e89758d47412ec6bd7a2fd9cae8074b7ffb2acee40456a4efbedd42e2dfd/mypy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:0eff042d7257f39ba4ca06641d110ca7d2ad98c9c1fb52200fe6b1c865d360ff", size = 9633620 }, + { url = "https://files.pythonhosted.org/packages/21/68/9098b11b5c4371793237c7a2c5e9415ece358bed97bc849e9191d38c66b5/mypy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b86de37a0da945f6d48cf110d5206c5ed514b1ca2614d7ad652d4bf099c7de7", size = 10940151 }, + { url = "https://files.pythonhosted.org/packages/7c/11/14a4373e5da6636fc4c8475cabe65084ff640528bc6c4f426d9c992736a9/mypy-1.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20c7c5ce0c1be0b0aea628374e6cf68b420bcc772d85c3c974f675b88e3e6e57", size = 10107645 }, + { url = "https://files.pythonhosted.org/packages/c7/07/b73faeeaadabb5aab23195bfd828392c9a5e21e7b8cdf8369a2546e00ce6/mypy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64ee25f05fc2d3d8474985c58042b6759100a475f8237da1f4faf7fcd7e6309", size = 12504561 }, + { url = "https://files.pythonhosted.org/packages/78/70/c35608364f9cdf97c048f0240be4d06d3baadede2767a5fbf60aad7c64f3/mypy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:faca7ab947c9f457a08dcb8d9a8664fd438080e002b0fa3e41b0535335edcf7f", size = 12983108 }, + { url = "https://files.pythonhosted.org/packages/74/fa/e5b0d4291ed9b94075fe13a0cdd1d9f1ba9d32ea1f8e88aec2ffcd057ac3/mypy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:5bc81701d52cc8767005fdd2a08c19980de9ec61a25dbd2a937dfb1338a826f9", size = 9629293 }, + { url = "https://files.pythonhosted.org/packages/e7/c8/ef6e2a11f0de6cf4359552bf71f07a89f302d27e25bf4c9761649bf1b5a8/mypy-1.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8462655b6694feb1c99e433ea905d46c478041a8b8f0c33f1dab00ae881b2164", size = 11072079 }, + { url = "https://files.pythonhosted.org/packages/61/e7/1f9ba3965c3c445d863290d3f8521a7a726b878784f5ad642e82c038261f/mypy-1.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:923ea66d282d8af9e0f9c21ffc6653643abb95b658c3a8a32dca1eff09c06475", size = 10071930 }, + { url = "https://files.pythonhosted.org/packages/3a/11/c84fb4c3a42ffd460c2a9b27105fbd538ec501e5aa34671fd3d14a1b94ba/mypy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1ebf9e796521f99d61864ed89d1fb2926d9ab6a5fab421e457cd9c7e4dd65aa9", size = 12588227 }, + { url = "https://files.pythonhosted.org/packages/f0/ad/b55d070d2001e47c4c6c7d00b13f8dafb16b74db5a99904a183e3c0a3bd6/mypy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e478601cc3e3fa9d6734d255a59c7a2e5c2934da4378f3dd1e3411ea8a248642", size = 13037186 }, + { url = "https://files.pythonhosted.org/packages/28/c8/5fc9ef8d3ea89490939ecdfea7a84cede31a69534d468c34807941f5a79f/mypy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:c72861b7139a4f738344faa0e150834467521a3fba42dc98264e5aa9507dd601", size = 9727738 }, + { url = "https://files.pythonhosted.org/packages/a6/07/0df1b099a4a725e61782f7d9a34947fc93be688f9dfa011d86e411b2f036/mypy-1.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52b9e1492e47e1790360a43755fa04101a7ac72287b1a53ce817f35899ba0521", size = 11071648 }, + { url = "https://files.pythonhosted.org/packages/9a/60/2a8bdb4f822bcdb0fa4599b83c1ae9f9ab0e10c1bee262dd9c1ff4607b12/mypy-1.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:48d3e37dd7d9403e38fa86c46191de72705166d40b8c9f91a3de77350daa0893", size = 10065760 }, + { url = "https://files.pythonhosted.org/packages/cc/d9/065ec6bd21a0ae14b520574d531dc1aa23fdc30fd276dea25f71945172d2/mypy-1.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f106db5ccb60681b622ac768455743ee0e6a857724d648c9629a9bd2ac3f721", size = 12584005 }, + { url = "https://files.pythonhosted.org/packages/e6/a8/31449fc5698d1a55062614790a885128e3b2a21de0f82a426942a5ae6a00/mypy-1.12.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:233e11b3f73ee1f10efada2e6da0f555b2f3a5316e9d8a4a1224acc10e7181d3", size = 13030941 }, + { url = "https://files.pythonhosted.org/packages/b5/8e/2347814cffccfb52fc02cbe457ae4a3fb5b660c5b361cdf72374266c231b/mypy-1.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ae8959c21abcf9d73aa6c74a313c45c0b5a188752bf37dace564e29f06e9c1b", size = 9734383 }, + { url = "https://files.pythonhosted.org/packages/85/fd/2cc64da1ce9fada64b5d023dfbaf763548429145d08c958c78c02876c7f6/mypy-1.12.0-py3-none-any.whl", hash = "sha256:fd313226af375d52e1e36c383f39bf3836e1f192801116b31b090dfcd3ec5266", size = 2645791 }, ] [[package]] -name = "nodeenv" -version = "1.9.1" +name = "mypy-extensions" +version = "1.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, ] [[package]] @@ -1099,19 +1124,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/19/26bb6bdb9fdad5f0dfce538780814084fb667b4bc37fcb28459c14b8d3b5/pydantic_settings-2.6.0-py3-none-any.whl", hash = "sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0", size = 28578 }, ] -[[package]] -name = "pyright" -version = "1.1.385" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nodeenv" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/29/ca/3238db97766ecfd6b2758fb50727a0b433e7b1bb6be0de090ed08b291fff/pyright-1.1.385.tar.gz", hash = "sha256:1bf042b8f080441534aa02101dea30f8fc2efa8f7b6f1ab05197c21317f5bfa7", size = 21971 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/39/877484412a1079003a7645375b487bd7c422692f4e5b7c2030dea3e83043/pyright-1.1.385-py3-none-any.whl", hash = "sha256:e5b9a1b8d492e13004d822af94d07d235f2c7c158457293b51ab2214c8c5b375", size = 18579 }, -] - [[package]] name = "pytest" version = "8.3.3" @@ -1456,33 +1468,39 @@ dependencies = [ { name = "python-dotenv" }, ] -[package.optional-dependencies] +[package.dev-dependencies] dev = [ { name = "langchain" }, { name = "langchain-community" }, { name = "langchain-openai" }, + { name = "mypy" }, { name = "pandas" }, - { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "ruff" }, + { name = "types-psycopg2" }, ] [package.metadata] requires-dist = [ { name = "asyncpg", specifier = ">=0.29.0" }, - { name = "langchain", marker = "extra == 'dev'", specifier = ">=0.3.3" }, - { name = "langchain-community", marker = "extra == 'dev'", specifier = ">=0.3.2" }, - { name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.2" }, { name = "numpy", specifier = ">=1,<2" }, - { name = "pandas", marker = "extra == 'dev'", specifier = ">=2.2.3" }, { name = "pgvector", specifier = ">=0.3.5" }, { name = "psycopg2", specifier = ">=2.9.9" }, - { name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.384" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.6.9" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "langchain", specifier = ">=0.3.3" }, + { name = "langchain-community", specifier = ">=0.3.2" }, + { name = "langchain-openai", specifier = ">=0.2.2" }, + { name = "mypy", specifier = ">=1.12.0" }, + { name = "pandas", specifier = ">=2.2.3" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, + { name = "ruff", specifier = ">=0.6.9" }, + { name = "types-psycopg2", specifier = ">=2.9.21.20240819" }, ] [[package]] @@ -1506,6 +1524,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/5d/acf5905c36149bbaec41ccf7f2b68814647347b72075ac0b1fe3022fdc73/tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd", size = 78351 }, ] +[[package]] +name = "types-psycopg2" +version = "2.9.21.20240819" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/86/785e0994d69cbc5044f70ff92c35f42b66a295fea5a30a0e4a7c7421e84a/types-psycopg2-2.9.21.20240819.tar.gz", hash = "sha256:4ed6b47464d6374fa64e5e3b234cea0f710e72123a4596d67ab50b7415a84666", size = 21366 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/4b/dad5e0fe9565a7f4cf6aa8a342b0deffaf49bb899814fe046c5f03bc8bee/types_psycopg2-2.9.21.20240819-py3-none-any.whl", hash = "sha256:c9192311c27d7ad561eef705f1b2df1074f2cdcf445a98a6a2fcaaaad43278cf", size = 19973 }, +] + [[package]] name = "typing-extensions" version = "4.12.2"