Skip to content

Commit

Permalink
Merge branch 'main' into tweak-et-dashboard
Browse files Browse the repository at this point in the history
  • Loading branch information
huydhn committed Jan 31, 2025
2 parents a63570f + 063abf8 commit 390d469
Show file tree
Hide file tree
Showing 110 changed files with 1,220 additions and 476 deletions.
11 changes: 11 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,17 @@ init_command = [
]
is_formatter = true

[[linter]]
code = 'SQL_PARAMS'
include_patterns = ['torchci/clickhouse_queries/**/params.json']
exclude_patterns = []
command = [
'python3',
'tools/linter/adapters/sql_params_linter.py',
'@{{PATHSFILE}}',
]
is_formatter = false

[[linter]]
code = 'RUSTFMT'
include_patterns = ['**/*.rs']
Expand Down
129 changes: 129 additions & 0 deletions tools/linter/adapters/sql_params_linter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import concurrent.futures
import json
import logging
import os
import re
import subprocess
import time
from enum import Enum
from typing import List, NamedTuple, Optional, Pattern


LINTER_CODE = "SQL_PARAMS"


class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"


class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]


RESULTS_RE: Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
(?P<char>\d+):
\s(?P<message>.*)
\s(?P<code>\[.*\])
$
"""
)


def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
return subprocess.run(
args,
capture_output=True,
)
finally:
end_time = time.monotonic()
logging.debug("took %dms", (end_time - start_time) * 1000)


def check_file(
filename: str,
) -> List[LintMessage]:
with open(filename, "rb") as f:
data = json.load(f)

message = []
if "params" not in data:
message.append("The file does not contain a 'params' key.")
elif not isinstance(data["params"], dict):
message.append("The 'params' key is not a dictionary.")
if "tests" not in data:
message.append("The file does not contain a 'tests' key.")
elif not isinstance(data["tests"], list):
message.append("The 'tests' key is not a list.")
if len(message) > 0:
return [
LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="lint",
replacement=None,
original=None,
description="; ".join(message),
)
]
return []


def main() -> None:
parser = argparse.ArgumentParser(
description=f"A simple linter for params.json files for sql queries",
fromfile_prefix_chars="@",
)
parser.add_argument(
"filenames",
nargs="+",
help="paths to lint",
)

args = parser.parse_args()

with concurrent.futures.ThreadPoolExecutor(
max_workers=os.cpu_count(),
thread_name_prefix="Thread",
) as executor:
futures = {
executor.submit(
check_file,
filename,
): filename
for filename in args.filenames
}
for future in concurrent.futures.as_completed(futures):
try:
for lint_message in future.result():
print(json.dumps(lint_message._asdict()), flush=True)
except Exception:
logging.critical('Failed at "%s".', futures[future])
raise


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion tools/torchci/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def query_clickhouse_saved(queryName: str, inputParams: Dict[str, Any]) -> Any:
with open(path / "query.sql") as f:
queryText = f.read()
with open(path / "params.json") as f:
paramsText = json.load(f)
paramsText = json.load(f).get("params", {})

queryParams = {name: inputParams[name] for name in paramsText}
return query_clickhouse(queryText, queryParams)
Expand Down
207 changes: 207 additions & 0 deletions tools/torchci/clickhouse_query_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Check query results and performance. Note that query performance is not stable
and can vary significantly between runs.
"""

import argparse
import json
import subprocess
import time
from datetime import datetime, timedelta, timezone
from functools import cache
from typing import Optional

from prettytable import PrettyTable
from torchci.clickhouse import get_clickhouse_client, query_clickhouse
from torchci.utils import REPO_ROOT
from tqdm import tqdm # type: ignore[import]


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Queue alert for torchci")
parser.add_argument("--query", type=str, help="Query name", required=True)
parser.add_argument(
"--head",
type=str,
help="Sha for the query to compare or get evaluations for",
required=True,
)
parser.add_argument("--base", type=str, help="Base sha for comparison")
parser.add_argument(
"--perf", action="store_true", help="Run performance analysis/comparison"
)
parser.add_argument(
"--results",
action="store_true",
help="Run results comparison. Requires --base",
)
parser.add_argument(
"--times",
type=int,
help="Number of times to run the query. Only relevant if --perf is used",
default=10,
)
parser.add_argument(
"--strict-results",
action="store_true",
help="Only relevant if --results is used. If set, it will sort the query results before comparing",
)
args = parser.parse_args()
return args


@cache
def get_base_query(query: str, sha: str) -> str:
return subprocess.check_output(
["git", "show", f"{sha}:torchci/clickhouse_queries/{query}/query.sql"]
).decode("utf-8")


EXECUTION_METRICS = """
SELECT
round(avg(query_duration_ms)) AS realTimeMSAvg,
avg(memory_usage) as memoryBytesAvg
FROM
clusterAllReplicas(default, system.query_log)
where
has({query_ids: Array(String)}, query_id)
and type = 'QueryFinish'
"""


def get_avg_stats(query_ids: list) -> tuple:
metrics = query_clickhouse(EXECUTION_METRICS, {"query_ids": query_ids})
return metrics[0]["realTimeMSAvg"], metrics[0]["memoryBytesAvg"]


def get_query_ids(query: str, params: dict, times: int) -> list[str]:
def _get_query_id(query: str, params: dict) -> Optional[str]:
try:
res = get_clickhouse_client().query(query, params)
return res.query_id
except Exception as e:
print(f"Error: {e}")
return None

return [
x for _ in tqdm(range(times)) if (x := _get_query_id(query, params)) is not None
]


@cache
def get_query(query: str, sha: str) -> tuple:
def _get_file(file_path: str) -> str:
return subprocess.check_output(["git", "show", f"{sha}:{file_path}"]).decode(
"utf-8"
)

tests = json.loads(_get_file(f"torchci/clickhouse_queries/{query}/params.json"))[
"tests"
]
query = _get_file(f"torchci/clickhouse_queries/{query}/query.sql")
for test in tests:
for key, value in test.items():
if isinstance(value, dict):
# special syntax for time values
test[key] = (
datetime.now(timezone.utc) + timedelta(days=value["from_now"])
).strftime("%Y-%m-%d %H:%M:%S")
return query, tests


def perf_compare(args: argparse.Namespace) -> None:
query, tests = get_query(args.query, args.head)

print(
f"Gathering perf stats for: {args.query}\nNum tests: {len(tests)}\nNum times: {args.times}"
)

query_ids = []
for i, test in enumerate(tests):
new = get_query_ids(query, test, args.times)

base = None
if args.base:
base_query, _ = get_query(args.query, args.base)
base = get_query_ids(base_query, test, args.times)
query_ids.append((new, base))

# Split up the query execution and the stats collection because the stats
# table needs time to populate. Also sleep for 10 seconds to the table more
# time to populate
time.sleep(20)
table = PrettyTable()
if args.base:
table.field_names = [
"Test",
"Avg Time",
"Base Time",
"Time Change",
"% Time Change",
"Avg Mem",
"Base Mem",
"Mem Change",
"% Mem Change",
]
else:
table.field_names = ["Test", "Avg Time", "Avg Mem"]
for i, (new, base) in enumerate(query_ids):
avg_time, avg_bytes = get_avg_stats(new)
if base:
old_avg_time, old_avg_bytes = get_avg_stats(base)
table.add_row(
[
i,
avg_time,
old_avg_time,
avg_time - old_avg_time,
round(100 * (avg_time - old_avg_time) / old_avg_time),
avg_bytes,
old_avg_bytes,
avg_bytes - old_avg_bytes,
round(100 * (avg_bytes - old_avg_bytes) / old_avg_bytes),
]
)
else:
table.add_row([i, avg_time, avg_bytes])
print(table)


def results_compare(args: argparse.Namespace) -> None:
if not args.base:
print("Base sha is required for results comparison")
return
query, tests = get_query(args.query, args.head)
base_query, _ = get_query(args.query, args.base)
print(
f"Comparing results for query: {args.query}\nNum tests: {len(tests)}\nHead: {args.head} Base: {args.base}"
)
for i, test in enumerate(tests):
new_results = query_clickhouse(query, test)
base_results = query_clickhouse(base_query, test)
if args.strict_results:
new_results = sorted(
new_results, key=lambda x: json.dumps(x, sort_keys=True)
)
base_results = sorted(
base_results, key=lambda x: json.dumps(x, sort_keys=True)
)
if new_results != base_results:
print(f"Results for test {i} differ")
print(f"Test: {json.dumps(test, indent=2)}")
print(f"New: {new_results}")
print(f"Base: {base_results}")
print()
else:
print(f"Results for test {i} match")


if __name__ == "__main__":
args = parse_args()
if not args.perf and not args.results:
print("Please specify --perf or --results")
exit(1)
if args.perf:
perf_compare(args)
if args.results:
results_compare(args)
Loading

0 comments on commit 390d469

Please sign in to comment.