diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml new file mode 100644 index 0000000..85236a9 --- /dev/null +++ b/.github/workflows/ruff.yaml @@ -0,0 +1,27 @@ +name: Ruff Linting and Formatting + +on: + pull_request: + branches: [ main ] + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install uv + run: pip install uv + - name: Create venv + run: uv venv + - name: Install dependencies + run: | + uv pip install . + uv pip install .[dev] + - name: Run Ruff linter + run: uv run ruff check . + - name: Run Ruff formatter + run: uv run ruff format . --check \ No newline at end of file diff --git a/timescale_vector/client.py b/timescale_vector/client.py index a4e118b..9e56453 100644 --- a/timescale_vector/client.py +++ b/timescale_vector/client.py @@ -24,13 +24,18 @@ import json import math import random +import re import uuid from collections.abc import Callable, Iterable +from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Any, Union import asyncpg import numpy as np +import pgvector.psycopg2 +import psycopg2.extras +import psycopg2.pool from pgvector.asyncpg import register_vector @@ -409,7 +414,7 @@ class Predicates: "@>": "@>", # array contains } - PredicateValue = Union[str, int, float, datetime, list, tuple] + PredicateValue = str | int | float | datetime | list | tuple def __init__( self, @@ -423,14 +428,17 @@ def __init__( operator: str = "AND", ): """ - Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~). + Predicates class defines predicates on the object metadata. + Predicates can be combined using logical operators (&, |, and ~). Parameters ---------- clauses - Predicate clauses. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value). + Predicate clauses. Can be either another Predicates object + or a tuple of the form (field, operator, value) or (field, value). Operator - Logical operator to use when combining the clauses. Can be one of 'AND', 'OR', 'NOT'. Defaults to 'AND'. + Logical operator to use when combining the clauses. + Can be one of 'AND', 'OR', 'NOT'. Defaults to 'AND'. """ if operator not in self.logical_operators: raise ValueError(f"invalid operator: {operator}") @@ -458,7 +466,8 @@ def add_clause( Parameters ---------- clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str] - Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value). + Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) + or (field, value). """ if isinstance(clause[0], str): if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)): @@ -527,7 +536,7 @@ def build_query(self, params: list) -> tuple[str, list]: where_conditions.append(f"uuid_timestamp(id) {operator} {param_name}") params.append(value) - elif operator == "@>" and (isinstance(value, list) or isinstance(value, tuple)): + elif operator == "@>" and (isinstance(value, list | tuple)): if len(value) == 0: raise ValueError("Invalid value. Empty lists and empty tuples are not supported.") json_value = json.dumps(value) @@ -645,7 +654,10 @@ def get_upsert_query(self): ------- str: The upsert query. """ - return f"INSERT INTO {self._quoted_table_name()} (id, metadata, contents, embedding) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING" + return ( + f"INSERT INTO {self._quoted_table_name()} (id, metadata, contents, embedding) " + f"VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING" + ) def get_approx_count_query(self): """ @@ -700,10 +712,10 @@ def get_create_query(self): IMMUTABLE PARALLEL SAFE RETURNS NULL ON NULL INPUT; - SELECT create_hypertable('{self._quoted_table_name()}', - 'id', - if_not_exists=> true, - time_partitioning_func=>'public.uuid_timestamp', + SELECT create_hypertable('{self._quoted_table_name()}', + 'id', + if_not_exists=> true, + time_partitioning_func=>'public.uuid_timestamp', chunk_time_interval => '{str(self.time_partition_interval.total_seconds())} seconds'::interval); """ return """ @@ -759,10 +771,14 @@ def drop_table_query(self): def default_max_db_connection_query(self): """ - Generates a query to get the default max db connections. This uses a heuristic to determine the max connections based on the max_connections setting in postgres + Generates a query to get the default max db connections. This uses a heuristic to determine the max connections + based on the max_connections setting in postgres and the number of currently used connections. This heuristic leaves 4 connections in reserve. """ - return "SELECT greatest(1, ((SELECT setting::int FROM pg_settings WHERE name='max_connections')-(SELECT count(*) FROM pg_stat_activity) - 4)::int)" + return ( + "SELECT greatest(1, ((SELECT setting::int FROM pg_settings " + "WHERE name='max_connections')-(SELECT count(*) FROM pg_stat_activity) - 4)::int)" + ) def create_embedding_index_query(self, index: BaseIndex, num_records_callback: Callable[[], int]) -> str: """ @@ -793,8 +809,8 @@ def create_embedding_index_query(self, index: BaseIndex, num_records_callback: C def _where_clause_for_filter( self, params: list, filter: dict[str, str] | list[dict[str, str]] | None ) -> tuple[str, list]: - if filter == None: - return ("TRUE", params) + if filter is None: + return "TRUE", params if isinstance(filter, dict): where = f"metadata @> ${len(params)+1}" @@ -802,14 +818,14 @@ def _where_clause_for_filter( params = params + [json_object] elif isinstance(filter, list): any_params = [] - for idx, filter_dict in enumerate(filter, start=len(params) + 1): + for _idx, filter_dict in enumerate(filter, start=len(params) + 1): any_params.append(json.dumps(filter_dict)) where = f"metadata @> ANY(${len(params) + 1}::jsonb[])" params = params + [any_params] else: raise ValueError(f"Unknown filter type: {type(filter)}") - return (where, params) + return where, params def search_query( self, @@ -834,18 +850,21 @@ def search_query( distance = "-1.0" order_by_clause = "" - if self.infer_filters: - if uuid_time_filter is None and isinstance(filter, dict): - if "__start_date" in filter or "__end_date" in filter: - start_date = UUIDTimeRange._parse_datetime(filter.get("__start_date")) - end_date = UUIDTimeRange._parse_datetime(filter.get("__end_date")) + if ( + self.infer_filters + and uuid_time_filter is None + and isinstance(filter, dict) + and ("__start_date" in filter or "__end_date" in filter) + ): + start_date = UUIDTimeRange._parse_datetime(filter.get("__start_date")) + end_date = UUIDTimeRange._parse_datetime(filter.get("__end_date")) - uuid_time_filter = UUIDTimeRange(start_date, end_date) + uuid_time_filter = UUIDTimeRange(start_date, end_date) - if start_date is not None: - del filter["__start_date"] - if end_date is not None: - del filter["__end_date"] + if start_date is not None: + del filter["__start_date"] + if end_date is not None: + del filter["__end_date"] where_clauses = [] if filter is not None: @@ -863,17 +882,14 @@ def search_query( (where_time, params) = uuid_time_filter.build_query(params) where_clauses.append(where_time) - if len(where_clauses) > 0: - where = " AND ".join(where_clauses) - else: - where = "TRUE" + where = " AND ".join(where_clauses) if len(where_clauses) > 0 else "TRUE" query = f""" SELECT id, metadata, contents, embedding, {distance} as distance FROM {self._quoted_table_name()} - WHERE + WHERE {where} {order_by_clause} LIMIT {limit} @@ -952,8 +968,8 @@ async def connect(self): ------- asyncpg.Connection: The established database connection. """ - if self.pool == None: - if self.max_db_connections == None: + if self.pool is None: + if self.max_db_connections is None: self.max_db_connections = await self._default_max_db_connections() async def init(conn): @@ -970,7 +986,7 @@ async def init(conn): return self.pool.acquire() async def close(self): - if self.pool != None: + if self.pool is not None: await self.pool.close() async def table_is_empty(self): @@ -984,7 +1000,7 @@ async def table_is_empty(self): query = self.builder.get_row_exists_query() async with await self.connect() as pool: rec = await pool.fetchrow(query) - return rec == None + return rec is None def munge_record(self, records) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]: metadata_is_dict = isinstance(records[0][1], dict) @@ -993,10 +1009,11 @@ def munge_record(self, records) -> Iterable[tuple[uuid.UUID, str, str, list[floa return records + @staticmethod def _convert_record_meta_to_json(item): if not isinstance(item[1], dict): raise ValueError("Cannot mix dictionary and string metadata fields in the same upsert") - return (item[0], json.dumps(item[1]), item[2], item[3]) + return item[0], json.dumps(item[1]), item[2], item[3] async def upsert(self, records): """ @@ -1025,7 +1042,8 @@ async def create_tables(self): None """ query = self.builder.get_create_query() - # don't use a connection pool for this because the vector extension may not be installed yet and if it's not installed, register_vector will fail. + # don't use a connection pool for this because the vector extension may not be installed yet + # and if it's not installed, register_vector will fail. conn = await asyncpg.connect(dsn=self.service_url) await conn.execute(query) await conn.close() @@ -1136,9 +1154,11 @@ async def search( limit The number of nearest neighbors to retrieve. filter - A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). + A filter for metadata. Should be specified as a key-value object or a list of key-value objects + (where any objects in the list are matched). predicates - A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, |, and ~). + A Predicates object to filter the results. Predicates support more complex queries than the filter + parameter. Predicates can be combined using logical operators (&, |, and ~). uuid_time_filter A UUIDTimeRange object to filter the results by time using the id column. query_params @@ -1149,27 +1169,17 @@ async def search( """ (query, params) = self.builder.search_query(query_embedding, limit, filter, predicates, uuid_time_filter) if query_params is not None: - async with await self.connect() as pool: - async with pool.transaction(): - # Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588 - statements = query_params.get_statements() - for statement in statements: - await pool.execute(statement) - return await pool.fetch(query, *params) + async with await self.connect() as pool, pool.transaction(): + # Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588 + statements = query_params.get_statements() + for statement in statements: + await pool.execute(statement) + return await pool.fetch(query, *params) else: async with await self.connect() as pool: return await pool.fetch(query, *params) -import re -from contextlib import contextmanager - -import numpy as np -import pgvector.psycopg2 -import psycopg2.extras -import psycopg2.pool - - class Sync: translated_queries: dict[str, str] = {} @@ -1225,10 +1235,6 @@ def __init__( def default_max_db_connections(self): """ Gets a default value for the number of max db connections to use. - - Returns - ------- - None """ query = self.builder.default_max_db_connection_query() conn = psycopg2.connect(dsn=self.service_url) @@ -1244,8 +1250,8 @@ def connect(self): Establishes a connection to a PostgreSQL database using psycopg2 and allows it's use in a context manager. """ - if self.pool == None: - if self.max_db_connections == None: + if self.pool is None: + if self.max_db_connections is None: self.max_db_connections = self.default_max_db_connections() self.pool = psycopg2.pool.SimpleConnectionPool( @@ -1264,7 +1270,7 @@ def connect(self): self.pool.putconn(connection) def close(self): - if self.pool != None: + if self.pool is not None: self.pool.closeall() def _translate_to_pyformat(self, query_string, params): @@ -1273,7 +1279,7 @@ def _translate_to_pyformat(self, query_string, params): Args: query_string (str): The query string with parameters. - params (list): List of parameter values. + params (list|None): List of parameter values. Returns: str: The query string with translated pyformat parameters. @@ -1281,7 +1287,7 @@ def _translate_to_pyformat(self, query_string, params): """ translated_params = {} - if params != None: + if params is not None: for idx, param in enumerate(params): translated_params[str(idx + 1)] = param @@ -1293,10 +1299,7 @@ def _translate_to_pyformat(self, query_string, params): for dollar_param in dollar_params: # Extract the number after the $ param_number = int(dollar_param[1:]) - if params != None: - pyformat_param = "%s" if param_number == 0 else f"%({param_number})s" - else: - pyformat_param = "%s" + pyformat_param = ("%s" if param_number == 0 else f"%({param_number})s") if params is not None else "%s" translated_string = translated_string.replace(dollar_param, pyformat_param) self.translated_queries[query_string] = translated_string @@ -1311,11 +1314,10 @@ def table_is_empty(self): bool: True if the table is empty, False otherwise. """ query = self.builder.get_row_exists_query() - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query) - rec = cur.fetchone() - return rec == None + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + rec = cur.fetchone() + return rec is None def munge_record(self, records) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]: metadata_is_dict = isinstance(records[0][1], dict) @@ -1324,10 +1326,11 @@ def munge_record(self, records) -> Iterable[tuple[uuid.UUID, str, str, list[floa return records + @staticmethod def _convert_record_meta_to_json(item): if not isinstance(item[1], dict): raise ValueError("Cannot mix dictionary and string metadata fields in the same upsert") - return (item[0], json.dumps(item[1]), item[2], item[3]) + return item[0], json.dumps(item[1]), item[2], item[3] def upsert(self, records): """ @@ -1345,9 +1348,8 @@ def upsert(self, records): records = self.munge_record(records) query = self.builder.get_upsert_query() query, _ = self._translate_to_pyformat(query, None) - with self.connect() as conn: - with conn.cursor() as cur: - cur.executemany(query, records) + with self.connect() as conn, conn.cursor() as cur: + cur.executemany(query, records) def create_tables(self): """ @@ -1358,7 +1360,8 @@ def create_tables(self): None """ query = self.builder.get_create_query() - # don't use a connection pool for this because the vector extension may not be installed yet and if it's not installed, register_vector will fail. + # don't use a connection pool for this because the vector extension may not be installed yet + # and if it's not installed, register_vector will fail. conn = psycopg2.connect(dsn=self.service_url) with conn.cursor() as cur: cur.execute(query) @@ -1376,9 +1379,8 @@ def delete_all(self, drop_index=True): if drop_index: self.drop_embedding_index() query = self.builder.delete_all_query() - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) def delete_by_ids(self, ids: list[uuid.UUID] | list[str]): """ @@ -1391,9 +1393,8 @@ def delete_by_ids(self, ids: list[uuid.UUID] | list[str]): """ (query, params) = self.builder.delete_by_ids_query(ids) query, params = self._translate_to_pyformat(query, params) - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query, params) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query, params) def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]): """ @@ -1401,9 +1402,8 @@ def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]): """ (query, params) = self.builder.delete_by_metadata_query(filter) query, params = self._translate_to_pyformat(query, params) - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query, params) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query, params) def drop_table(self): """ @@ -1414,9 +1414,8 @@ def drop_table(self): None """ query = self.builder.drop_table_query() - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) def _get_approx_count(self): """ @@ -1427,11 +1426,10 @@ def _get_approx_count(self): int: Approximate count of records. """ query = self.builder.get_approx_count_query() - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query) - rec = cur.fetchone() - return rec[0] + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) + rec = cur.fetchone() + return rec[0] def drop_embedding_index(self): """ @@ -1442,9 +1440,8 @@ def drop_embedding_index(self): None """ query = self.builder.drop_embedding_index_query() - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) def create_embedding_index(self, index: BaseIndex): """ @@ -1460,9 +1457,8 @@ def create_embedding_index(self, index: BaseIndex): None """ query = self.builder.create_embedding_index_query(index, lambda: self._get_approx_count()) - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query) + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query) def search( self, @@ -1483,18 +1479,17 @@ def search( limit The number of nearest neighbors to retrieve. filter - A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). + A filter for metadata. Should be specified as a key-value object or a list of key-value objects + (where any objects in the list are matched). predicates - A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, |, and ~). + A Predicates object to filter the results. Predicates support more complex queries + than the filter parameter. Predicates can be combined using logical operators (&, |, and ~). Returns -------- List: List of similar records. """ - if query_embedding is not None: - query_embedding_np = np.array(query_embedding) - else: - query_embedding_np = None + query_embedding_np = np.array(query_embedding) if query_embedding is not None else None (query, params) = self.builder.search_query(query_embedding_np, limit, filter, predicates, uuid_time_filter) query, params = self._translate_to_pyformat(query, params) @@ -1503,7 +1498,6 @@ def search( prefix = "; ".join(query_params.get_statements()) query = f"{prefix}; {query}" - with self.connect() as conn: - with conn.cursor() as cur: - cur.execute(query, params) - return cur.fetchall() + with self.connect() as conn, conn.cursor() as cur: + cur.execute(query, params) + return cur.fetchall() diff --git a/timescale_vector/pgvectorizer.py b/timescale_vector/pgvectorizer.py index 6ef3d3a..8f2675d 100644 --- a/timescale_vector/pgvectorizer.py +++ b/timescale_vector/pgvectorizer.py @@ -42,76 +42,80 @@ def __init__( self.trigger_name_fn = client.QueryBuilder._quote_ident(trigger_name_fn) def register(self): - with psycopg2.connect(self.service_url) as conn: - with conn.cursor() as cursor: - cursor.execute(f""" - SELECT to_regclass('{self.schema_name}.{self.work_queue_table_name}') is not null; - """) - table_exists = cursor.fetchone()[0] - if table_exists: - return - - cursor.execute(f""" - CREATE TABLE {self.schema_name}.{self.work_queue_table_name} ( - id int - ); - - CREATE INDEX ON {self.schema_name}.{self.work_queue_table_name}(id); - - CREATE OR REPLACE FUNCTION {self.schema_name}.{self.trigger_name_fn}() RETURNS TRIGGER LANGUAGE PLPGSQL AS $$ - BEGIN - IF (TG_OP = 'DELETE') THEN - INSERT INTO {self.work_queue_table_name} - VALUES (OLD.{self.id_column_name}); - ELSE - INSERT INTO {self.work_queue_table_name} - VALUES (NEW.{self.id_column_name}); - END IF; - RETURN NULL; - END; - $$; - - CREATE TRIGGER {self.trigger_name} - AFTER INSERT OR UPDATE OR DELETE - ON {self.schema_name}.{self.table_name} - FOR EACH ROW EXECUTE PROCEDURE {self.schema_name}.{self.trigger_name_fn}(); - - INSERT INTO {self.schema_name}.{self.work_queue_table_name} SELECT {self.id_column_name} FROM {self.schema_name}.{self.table_name}; - """) + with psycopg2.connect(self.service_url) as conn, conn.cursor() as cursor: + cursor.execute(f""" + SELECT to_regclass('{self.schema_name}.{self.work_queue_table_name}') is not null; + """) + table_exists = cursor.fetchone()[0] + if table_exists: + return + + cursor.execute(f""" + CREATE TABLE {self.schema_name}.{self.work_queue_table_name} ( + id int + ); + + CREATE INDEX ON {self.schema_name}.{self.work_queue_table_name}(id); + + CREATE OR REPLACE FUNCTION {self.schema_name}.{self.trigger_name_fn}() + RETURNS TRIGGER LANGUAGE PLPGSQL AS $$ + BEGIN + IF (TG_OP = 'DELETE') THEN + INSERT INTO {self.work_queue_table_name} + VALUES (OLD.{self.id_column_name}); + ELSE + INSERT INTO {self.work_queue_table_name} + VALUES (NEW.{self.id_column_name}); + END IF; + RETURN NULL; + END; + $$; + + CREATE TRIGGER {self.trigger_name} + AFTER INSERT OR UPDATE OR DELETE + ON {self.schema_name}.{self.table_name} + FOR EACH ROW EXECUTE PROCEDURE {self.schema_name}.{self.trigger_name_fn}(); + + INSERT INTO {self.schema_name}.{self.work_queue_table_name} SELECT {self.id_column_name} + FROM {self.schema_name}.{self.table_name}; + """) def process(self, embed_and_write_cb, batch_size: int = 10, autoregister=True): if autoregister: self.register() - with psycopg2.connect(self.service_url) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: - cursor.execute(f""" - SELECT to_regclass('{self.schema_name}.{self.work_queue_table_name}')::oid; - """) - table_oid = cursor.fetchone()[0] - - cursor.execute(f""" - WITH selected_rows AS ( - SELECT id - FROM {self.schema_name}.{self.work_queue_table_name} - LIMIT {int(batch_size)} - FOR UPDATE SKIP LOCKED - ), - locked_items AS ( - SELECT id, pg_try_advisory_xact_lock({int(table_oid)}, id) AS locked - FROM (SELECT DISTINCT id FROM selected_rows ORDER BY id) as ids - ), - deleted_rows AS ( - DELETE FROM {self.schema_name}.{self.work_queue_table_name} - WHERE id IN (SELECT id FROM locked_items WHERE locked = true ORDER BY id) - ) - SELECT locked_items.id as locked_id, {self.table_name}.* - FROM locked_items - LEFT JOIN {self.schema_name}.{self.table_name} ON {self.table_name}.{self.id_column_name} = locked_items.id - WHERE locked = true - ORDER BY locked_items.id - """) - res = cursor.fetchall() - if len(res) > 0: - embed_and_write_cb(res, self) - return len(res) + with ( + psycopg2.connect(self.service_url) as conn, + conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor, + ): + cursor.execute(f""" + SELECT to_regclass('{self.schema_name}.{self.work_queue_table_name}')::oid; + """) + table_oid = cursor.fetchone()[0] + + cursor.execute(f""" + WITH selected_rows AS ( + SELECT id + FROM {self.schema_name}.{self.work_queue_table_name} + LIMIT {int(batch_size)} + FOR UPDATE SKIP LOCKED + ), + locked_items AS ( + SELECT id, pg_try_advisory_xact_lock({int(table_oid)}, id) AS locked + FROM (SELECT DISTINCT id FROM selected_rows ORDER BY id) as ids + ), + deleted_rows AS ( + DELETE FROM {self.schema_name}.{self.work_queue_table_name} + WHERE id IN (SELECT id FROM locked_items WHERE locked = true ORDER BY id) + ) + SELECT locked_items.id as locked_id, {self.table_name}.* + FROM locked_items + LEFT JOIN {self.schema_name}.{self.table_name} + ON {self.table_name}.{self.id_column_name} = locked_items.id + WHERE locked = true + ORDER BY locked_items.id + """) + res = cursor.fetchall() + if len(res) > 0: + embed_and_write_cb(res, self) + return len(res)