Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed asyncpg to psycopg3 #94

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions autoagora/logs_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

import asyncpg
import graphql
import psycopg_pool
from psycopg import sql


class LogsDB:
Expand All @@ -16,7 +17,7 @@ class QueryStats:
avg_time: float
stddev_time: float

def __init__(self, pgpool: asyncpg.Pool) -> None:
def __init__(self, pgpool: psycopg_pool.AsyncConnectionPool) -> None:
self.pgpool = pgpool

def return_query_body(self, query):
Expand All @@ -30,9 +31,11 @@ def return_query_body(self, query):
async def get_most_frequent_queries(
self, subgraph_ipfs_hash: str, min_count: int = 100
):
async with self.pgpool.acquire() as connection:
rows = await connection.fetch(
"""

async with self.pgpool.connection() as connection:
rows = await connection.execute(
sql.SQL(
"""
SELECT
query,
count_id,
Expand All @@ -54,22 +57,21 @@ async def get_most_frequent_queries(
FROM
query_logs
WHERE
subgraph = $1
subgraph = {hash}
AND query_time_ms IS NOT NULL
GROUP BY
qhash
HAVING
Count(id) >= $2
Count(id) >= {min_count}
) as query_logs
ON
qhash = hash
ORDER BY
count_id DESC
""",
subgraph_ipfs_hash,
min_count,
"""
).format(hash=subgraph_ipfs_hash, min_count=str(min_count)),
)

rows = await rows.fetchall()
return [
LogsDB.QueryStats(
query=self.return_query_body(row[0])
Expand Down
22 changes: 13 additions & 9 deletions autoagora/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from typing import Dict, Optional

import asyncpg
import psycopg_pool
from prometheus_async.aio.web import start_http_server

from autoagora.config import args, init_config
Expand Down Expand Up @@ -39,15 +39,19 @@ async def allocated_subgraph_watcher():

# Initialize connection pool to PG database
try:
pgpool = await asyncpg.create_pool(
host=args.postgres_host,
database=args.postgres_database,
user=args.postgres_username,
password=args.postgres_password,
port=args.postgres_port,
min_size=1,
max_size=args.postgres_max_connections,
conn_string = (
f"host={args.postgres_host} "
f"dbname={args.postgres_database} "
f"user={args.postgres_username} "
f'password="{args.postgres_password}" '
f"port={args.postgres_port}"
)

pgpool = psycopg_pool.AsyncConnectionPool(
conn_string, min_size=1, max_size=args.postgres_max_connections, open=False
)
await pgpool.open()
await pgpool.wait()
assert pgpool
except:
logging.exception(
Expand Down
4 changes: 2 additions & 2 deletions autoagora/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from importlib.metadata import version

import asyncpg
import psycopg_pool
from jinja2 import Template

from autoagora.config import args
Expand All @@ -15,7 +15,7 @@
from autoagora.utils.constants import AGORA_ENTRY_TEMPLATE


async def model_builder(subgraph: str, pgpool: asyncpg.Pool) -> str:
async def model_builder(subgraph: str, pgpool: psycopg_pool.AsyncConnectionPool) -> str:
logs_db = LogsDB(pgpool)
most_frequent_queries = await logs_db.get_most_frequent_queries(subgraph)
model = build_template(subgraph, most_frequent_queries)
Expand Down
6 changes: 4 additions & 2 deletions autoagora/price_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime, timedelta, timezone
from typing import Tuple

import asyncpg
import psycopg_pool
from autoagora_agents.agent_factory import AgentFactory
from prometheus_client import Gauge

Expand Down Expand Up @@ -36,7 +36,9 @@


async def price_bandit_loop(
subgraph: str, pgpool: asyncpg.Pool, metrics_endpoints: MetricsEndpoints
subgraph: str,
pgpool: psycopg_pool.AsyncConnectionPool,
metrics_endpoints: MetricsEndpoints,
):
try:
# Instantiate environment.
Expand Down
53 changes: 29 additions & 24 deletions autoagora/price_save_state_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import datetime, timezone
from typing import Optional

import asyncpg
import psycopg_pool
from psycopg import sql


@dataclass
Expand All @@ -16,13 +17,13 @@ class SaveState:


class PriceSaveStateDB:
def __init__(self, pgpool: asyncpg.Pool) -> None:
def __init__(self, pgpool: psycopg_pool.AsyncConnectionPool) -> None:
self.pgpool = pgpool
self._table_created = False

async def _create_table_if_not_exists(self) -> None:
if not self._table_created:
async with self.pgpool.acquire() as connection:
async with self.pgpool.connection() as connection:
await connection.execute( # type: ignore
"""
CREATE TABLE IF NOT EXISTS price_save_state (
Expand All @@ -38,45 +39,49 @@ async def _create_table_if_not_exists(self) -> None:
async def save_state(self, subgraph: str, mean: float, stddev: float):
await self._create_table_if_not_exists()

async with self.pgpool.acquire() as connection:
async with self.pgpool.connection() as connection:
await connection.execute(
"""
sql.SQL(
"""
INSERT INTO price_save_state (subgraph, last_update, mean, stddev)
VALUES($1, $2, $3, $4)
VALUES({subgraph_hash}, {datetime}, {mean}, {stddev})
ON CONFLICT (subgraph)
DO
UPDATE SET
last_update = $2,
mean = $3,
stddev = $4
""",
subgraph,
datetime.now(timezone.utc),
mean,
stddev,
last_update = {datetime},
mean = {mean},
stddev = {stddev}
"""
).format(
subgraph_hash=subgraph,
datetime=str(datetime.now(timezone.utc)),
mean=mean,
stddev=stddev,
)
)

async def load_state(self, subgraph: str) -> Optional[SaveState]:
await self._create_table_if_not_exists()

async with self.pgpool.acquire() as connection:
row = await connection.fetchrow(
"""
async with self.pgpool.connection() as connection:
row = await connection.execute(
sql.SQL(
"""
SELECT
last_update,
mean,
stddev
FROM
price_save_state
WHERE
subgraph = $1
""",
subgraph,
subgraph = {subgraph_hash}
"""
).format(subgraph_hash=subgraph)
)

row = await row.fetchone()
if row:
return SaveState(
last_update=row["last_update"], # type: ignore
mean=row["mean"], # type: ignore
stddev=row["stddev"], # type: ignore
last_update=row[0], # type: ignore
mean=row[1], # type: ignore
stddev=row[2], # type: ignore
)
Loading
Loading