Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning results from other tables/database with /jobs/search #142

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions run_local.sh
Original file line number Diff line number Diff line change
@@ -1,32 +1,78 @@
#!/usr/bin/env bash

### Runs a local uvicorn server with the default configuration


# Run a local uvicorn server with the default configuration
set -euo pipefail
IFS=$'\n\t'

tmp_dir=$(mktemp -d)
echo "Using temp dir: ${tmp_dir}"
mkdir -p "${tmp_dir}/signing-key" "${tmp_dir}/cs_store/"

ssh-keygen -P "" -t rsa -b 4096 -m PEM -f "${tmp_dir}/signing-key/rs256.key"
signing_key="${tmp_dir}/signing-key/rsa256.key"
ssh-keygen -P "" -t rsa -b 4096 -m PEM -f "${signing_key}"

dirac internal generate-cs "${tmp_dir}/cs_store/initialRepo" --vo=diracAdmin --user-group=admin --idp-url=runlocal.diracx.invalid
# Make a fake CS
dirac internal generate-cs "${tmp_dir}/cs_store/initialRepo" \
--vo=diracAdmin --user-group=admin --idp-url=runlocal.diracx.invalid
dirac internal add-user "${tmp_dir}/cs_store/initialRepo" \
--vo=diracAdmin --user-group=admin \
--sub=75212b23-14c2-47be-9374-eb0113b0575e \
--preferred-username=localuser

export DIRACX_CONFIG_BACKEND_URL="git+file://${tmp_dir}/cs_store/initialRepo"
export DIRACX_DB_URL_AUTHDB="sqlite+aiosqlite:///:memory:"
export DIRACX_DB_URL_JOBDB="sqlite+aiosqlite:///:memory:"
export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:"
export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${tmp_dir}/signing-key/rs256.key"
export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:"
export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${signing_key}"
export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]'
export DIRACX_SANDBOX_STORE_BUCKET_NAME=sandboxes
export DIRACX_SANDBOX_STORE_AUTO_CREATE_BUCKET=true
export DIRACX_SANDBOX_STORE_S3_CLIENT_KWARGS='{"endpoint_url": "http://localhost:3000", "aws_access_key_id": "console", "aws_secret_access_key": "console123"}'

moto_server -p3000 &
moto_pid=$!
uvicorn --factory diracx.routers:create_app --reload &
diracx_pid=$!

uvicorn --factory diracx.routers:create_app --reload
success=0
for i in {1..10}; do
if curl --silent --head http://localhost:8000 > /dev/null; then
success=1
break
fi
sleep 1
done

echo ""
echo ""
echo ""
if [ $success -eq 0 ]; then
echo "Failed to start DiracX"
else
echo "DiracX is running on http://localhost:8000"
fi
echo "DiracX is running on http://localhost:8000"
echo "To interact with DiracX you can:"
echo ""
echo "1. Use the CLI:"
echo ""
echo " export DIRACX_URL=http://localhost:8000"
echo " tests/make-token-local.py ${signing_key}"
echo ""
echo "2. Usisng swagger: http://localhost:8000/api/docs"

function cleanup(){
trap - SIGTERM;
echo "Cleaning up";
trap - SIGTERM
kill $moto_pid
kill $diracx_pid
echo "Waiting for proccesses to exit"
wait $moto_pid $diracx_pid
echo "Cleaning up"
rm -rf "${tmp_dir}"
}

trap "cleanup" EXIT

while true; do
sleep 1
done
4 changes: 3 additions & 1 deletion src/diracx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from importlib.metadata import PackageNotFoundError, version

logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s"
# level=logging.DEBUG, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s"
level=logging.WARNING,
format="%(asctime)s | %(name)s | %(levelname)s | %(message)s",
)

try:
Expand Down
39 changes: 27 additions & 12 deletions src/diracx/db/sql/jobs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
)


def _get_columns(table, parameters):
columns = [x for x in table.columns]
if parameters:
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
raise InvalidQueryError(
f"Unrecognised parameters requested {unrecognised_parameters}"
)
columns = [c for c in columns if c.name in parameters]
return columns
def _get_columns(tables, parameters):
assert parameters, "TODO: Not needed when JobDB.summary is updated"
columns = {}
# Iterate in reverse order so we prefer using the first possible table.
# i.e. if tables = [Jobs, JobJDLs] we should prefer getting JobID from
# Jobs.JobID instead of JobJDLs.JobID
for table in tables[::-1]:
# We prefer getting columns from
columns |= {c.name: c for c in table.columns}
if unrecognised := set(parameters) - set(columns):
raise InvalidQueryError(f"Unrecognised parameters requested {unrecognised}")
return [columns[c] for c in parameters]


class JobDB(BaseSQLDB):
Expand All @@ -41,7 +44,8 @@ class JobDB(BaseSQLDB):
jdl2DBParameters = ["JobName", "JobType", "JobGroup"]

async def summary(self, group_by, search) -> list[dict[str, str | int]]:
columns = _get_columns(Jobs.__table__, group_by)
columns = _get_columns([Jobs.__table__], group_by)
# TODO: We probably now need a join

stmt = select(*columns, func.count(Jobs.JobID).label("count"))
stmt = apply_search_filters(Jobs.__table__, stmt, search)
Expand All @@ -55,11 +59,22 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]:
]

async def search(
self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None
self,
parameters,
search,
sorts,
*,
distinct: bool = False,
per_page: int = 100,
page: int | None = None,
) -> list[dict[str, Any]]:
tables = [Jobs.__table__, JobJDLs.__table__]
# TODO: We probably now need a join
# Find which columns to select
columns = _get_columns(Jobs.__table__, parameters)
columns = _get_columns(tables, parameters)
stmt = select(*columns)
if distinct:
stmt = stmt.distinct()

stmt = apply_search_filters(Jobs.__table__, stmt, search)

Expand Down
2 changes: 2 additions & 0 deletions src/diracx/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from diracx.core.config import Config as _Config
from diracx.core.config import ConfigSource
from diracx.core.properties import SecurityProperty
from diracx.db.os import JobParametersDB as _JobParametersDB
from diracx.db.sql import AuthDB as _AuthDB
from diracx.db.sql import JobDB as _JobDB
from diracx.db.sql import JobLoggingDB as _JobLoggingDB
Expand All @@ -36,6 +37,7 @@ def add_settings_annotation(cls: T) -> T:
SandboxMetadataDB = Annotated[
_SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction)
]
JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)]

# Miscellaneous
Config = Annotated[_Config, Depends(ConfigSource.create)]
Expand Down
103 changes: 99 additions & 4 deletions src/diracx/routers/job_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class JobSearchParams(BaseModel):
parameters: list[str] | None = None
search: list[SearchSpec] = []
sort: list[SortSpec] = []
distinct: bool = False

@root_validator
def validate_fields(cls, v):
Expand Down Expand Up @@ -355,6 +356,7 @@ async def get_job_status_history_bulk(
async def search(
config: Annotated[Config, Depends(ConfigSource.create)],
job_db: JobDB,
# job_parameters_db: JobParametersDB,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
page: int = 0,
per_page: int = 100,
Expand All @@ -377,10 +379,103 @@ async def search(
"value": user_info.sub,
}
)
# TODO: Pagination
return await job_db.search(
body.parameters, body.search, body.sort, page=page, per_page=per_page
)

# TODO: Put a property to BaseSQLDB so this can be JobDB.tables.Jobs
default_params = [c.name for c in job_db.metadata.tables["Jobs"].columns]
# By default only the contents of the Jobs table is returned
sql_columns = set(default_params)
for table_name in ["JobJDLs"]:
sql_columns |= {c.name for c in job_db.metadata.tables[table_name].columns}
# TODO: Support opensearch
os_fields = set() # set(job_parameters_db.fields)

# Ensure all of the requested parameters are known
# TODO: Do we need to be able to support arbitrary job parameters?
response_params: list[str] = body.parameters or default_params
if unrecognised_params := set(response_params) - (sql_columns | os_fields):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Requested unknown parameters: {unrecognised_params}",
)

# Ensure the search parameters can be satisfied by a single DB technology
search_params = {x["parameter"] for x in body.search}
if unrecognised_params := search_params - (sql_columns | os_fields):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Requested unknown search parameters: {unrecognised_params}",
)
if not search_params:
sql_search = None
elif search_params.issubset(sql_columns):
# TODO: Limit to indexed columns?
sql_search = True
elif search_params.issubset(os_fields):
sql_search = False
else:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=(
f"Can not search by {search_params - sql_columns} at the "
f"same time as {search_params - os_fields}."
),
)

# Ensure the sort parameters can be satisfied by a single DB technology
sort_parameters = {x["parameter"] for x in body.sort}
if unrecognised_params := sort_parameters - (sql_columns | os_fields):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Requested unknown sort parameters: {unrecognised_params}",
)
if not sort_parameters:
sql_sort = None
elif sort_parameters.issubset(sql_columns):
sql_sort = True
elif sort_parameters.issubset(os_fields):
sql_sort = False
else:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=(
f"Can not search by {search_params - sql_columns} at the "
f"same time as {search_params - os_fields}."
),
)

# If the request can be satisfied by either DB, prefer SQL
if sql_search is None and sql_sort is None:
sql_search = True

# Ensure that the search and sort can be done with the same DB technology
sql_search = sql_sort if sql_search is None else sql_search
sql_sort = sql_search if sql_sort is None else sql_sort
if sql_search != sql_sort:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=(
f"Searches by {search_params} can only be sorted by {sql_columns}."
),
)

if sql_search:
results = await job_db.search(
# We don't use a set here as we want to maintain the column ordering
[c for c in response_params if c in sql_columns],
body.search,
body.sort,
distinct=body.distinct,
page=page,
per_page=per_page,
)
if os_params := set(response_params) - sql_columns:
# TODO: Don't forget to maintain column order
raise NotImplementedError(
"TODO: Support querying some parameters from opensearch"
)
else:
raise NotImplementedError("TODO: Support opensearch")
return results


@router.post("/summary")
Expand Down
50 changes: 50 additions & 0 deletions tests/make-token-local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python
import argparse
import uuid
from datetime import datetime, timedelta, timezone
from pathlib import Path

from diracx.core.models import TokenResponse
from diracx.core.properties import NORMAL_USER
from diracx.core.utils import write_credentials
from diracx.routers.auth import AuthSettings, create_token


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("token_key", type=Path, help="The key to sign the token with")
args = parser.parse_args()
main(args.token_key.read_text())


def main(token_key):
vo = "diracAdmin"
dirac_group = "admin"
sub = "75212b23-14c2-47be-9374-eb0113b0575e"
preferred_username = "localuser"
dirac_properties = [NORMAL_USER]
settings = AuthSettings(token_key=token_key)
creation_time = datetime.now(tz=timezone.utc)
expires_in = 7 * 24 * 60 * 60

access_payload = {
"sub": f"{vo}:{sub}",
"vo": vo,
"aud": settings.token_audience,
"iss": settings.token_issuer,
"dirac_properties": dirac_properties,
"jti": str(uuid.uuid4()),
"preferred_username": preferred_username,
"dirac_group": dirac_group,
"exp": creation_time + timedelta(seconds=expires_in),
}
token = TokenResponse(
access_token=create_token(access_payload, settings),
expires_in=expires_in,
refresh_token=None,
)
write_credentials(token)


if __name__ == "__main__":
parse_args()