diff --git a/run_local.sh b/run_local.sh index 5fe2f2e5..d2319076 100755 --- a/run_local.sh +++ b/run_local.sh @@ -1,32 +1,78 @@ #!/usr/bin/env bash - -### Runs a local uvicorn server with the default configuration - - +# Run a local uvicorn server with the default configuration set -euo pipefail IFS=$'\n\t' tmp_dir=$(mktemp -d) +echo "Using temp dir: ${tmp_dir}" mkdir -p "${tmp_dir}/signing-key" "${tmp_dir}/cs_store/" -ssh-keygen -P "" -t rsa -b 4096 -m PEM -f "${tmp_dir}/signing-key/rs256.key" +signing_key="${tmp_dir}/signing-key/rsa256.key" +ssh-keygen -P "" -t rsa -b 4096 -m PEM -f "${signing_key}" -dirac internal generate-cs "${tmp_dir}/cs_store/initialRepo" --vo=diracAdmin --user-group=admin --idp-url=runlocal.diracx.invalid +# Make a fake CS +dirac internal generate-cs "${tmp_dir}/cs_store/initialRepo" \ + --vo=diracAdmin --user-group=admin --idp-url=runlocal.diracx.invalid +dirac internal add-user "${tmp_dir}/cs_store/initialRepo" \ + --vo=diracAdmin --user-group=admin \ + --sub=75212b23-14c2-47be-9374-eb0113b0575e \ + --preferred-username=localuser export DIRACX_CONFIG_BACKEND_URL="git+file://${tmp_dir}/cs_store/initialRepo" export DIRACX_DB_URL_AUTHDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:" -export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${tmp_dir}/signing-key/rs256.key" +export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:" +export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${signing_key}" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]' +export DIRACX_SANDBOX_STORE_BUCKET_NAME=sandboxes +export DIRACX_SANDBOX_STORE_AUTO_CREATE_BUCKET=true +export DIRACX_SANDBOX_STORE_S3_CLIENT_KWARGS='{"endpoint_url": "http://localhost:3000", "aws_access_key_id": "console", "aws_secret_access_key": "console123"}' +moto_server -p3000 & +moto_pid=$! +uvicorn --factory diracx.routers:create_app --reload & +diracx_pid=$! -uvicorn --factory diracx.routers:create_app --reload +success=0 +for i in {1..10}; do + if curl --silent --head http://localhost:8000 > /dev/null; then + success=1 + break + fi + sleep 1 +done + +echo "" +echo "" +echo "" +if [ $success -eq 0 ]; then + echo "Failed to start DiracX" +else + echo "DiracX is running on http://localhost:8000" +fi +echo "DiracX is running on http://localhost:8000" +echo "To interact with DiracX you can:" +echo "" +echo "1. Use the CLI:" +echo "" +echo " export DIRACX_URL=http://localhost:8000" +echo " tests/make-token-local.py ${signing_key}" +echo "" +echo "2. Usisng swagger: http://localhost:8000/api/docs" function cleanup(){ - trap - SIGTERM; - echo "Cleaning up"; + trap - SIGTERM + kill $moto_pid + kill $diracx_pid + echo "Waiting for proccesses to exit" + wait $moto_pid $diracx_pid + echo "Cleaning up" rm -rf "${tmp_dir}" } trap "cleanup" EXIT + +while true; do + sleep 1 +done diff --git a/src/diracx/__init__.py b/src/diracx/__init__.py index a3755310..e3073033 100644 --- a/src/diracx/__init__.py +++ b/src/diracx/__init__.py @@ -2,7 +2,9 @@ from importlib.metadata import PackageNotFoundError, version logging.basicConfig( - level=logging.DEBUG, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s" + # level=logging.DEBUG, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s" + level=logging.WARNING, + format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", ) try: diff --git a/src/diracx/db/sql/jobs/db.py b/src/diracx/db/sql/jobs/db.py index 864ca761..0af5958d 100644 --- a/src/diracx/db/sql/jobs/db.py +++ b/src/diracx/db/sql/jobs/db.py @@ -21,15 +21,18 @@ ) -def _get_columns(table, parameters): - columns = [x for x in table.columns] - if parameters: - if unrecognised_parameters := set(parameters) - set(table.columns.keys()): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) - columns = [c for c in columns if c.name in parameters] - return columns +def _get_columns(tables, parameters): + assert parameters, "TODO: Not needed when JobDB.summary is updated" + columns = {} + # Iterate in reverse order so we prefer using the first possible table. + # i.e. if tables = [Jobs, JobJDLs] we should prefer getting JobID from + # Jobs.JobID instead of JobJDLs.JobID + for table in tables[::-1]: + # We prefer getting columns from + columns |= {c.name: c for c in table.columns} + if unrecognised := set(parameters) - set(columns): + raise InvalidQueryError(f"Unrecognised parameters requested {unrecognised}") + return [columns[c] for c in parameters] class JobDB(BaseSQLDB): @@ -41,7 +44,8 @@ class JobDB(BaseSQLDB): jdl2DBParameters = ["JobName", "JobType", "JobGroup"] async def summary(self, group_by, search) -> list[dict[str, str | int]]: - columns = _get_columns(Jobs.__table__, group_by) + columns = _get_columns([Jobs.__table__], group_by) + # TODO: We probably now need a join stmt = select(*columns, func.count(Jobs.JobID).label("count")) stmt = apply_search_filters(Jobs.__table__, stmt, search) @@ -55,11 +59,22 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]: ] async def search( - self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None + self, + parameters, + search, + sorts, + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, ) -> list[dict[str, Any]]: + tables = [Jobs.__table__, JobJDLs.__table__] + # TODO: We probably now need a join # Find which columns to select - columns = _get_columns(Jobs.__table__, parameters) + columns = _get_columns(tables, parameters) stmt = select(*columns) + if distinct: + stmt = stmt.distinct() stmt = apply_search_filters(Jobs.__table__, stmt, search) diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index deb55167..e677b422 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -16,6 +16,7 @@ from diracx.core.config import Config as _Config from diracx.core.config import ConfigSource from diracx.core.properties import SecurityProperty +from diracx.db.os import JobParametersDB as _JobParametersDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB @@ -36,6 +37,7 @@ def add_settings_annotation(cls: T) -> T: SandboxMetadataDB = Annotated[ _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) ] +JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index 39592c52..0e071bb8 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -53,6 +53,7 @@ class JobSearchParams(BaseModel): parameters: list[str] | None = None search: list[SearchSpec] = [] sort: list[SortSpec] = [] + distinct: bool = False @root_validator def validate_fields(cls, v): @@ -355,6 +356,7 @@ async def get_job_status_history_bulk( async def search( config: Annotated[Config, Depends(ConfigSource.create)], job_db: JobDB, + # job_parameters_db: JobParametersDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], page: int = 0, per_page: int = 100, @@ -377,10 +379,103 @@ async def search( "value": user_info.sub, } ) - # TODO: Pagination - return await job_db.search( - body.parameters, body.search, body.sort, page=page, per_page=per_page - ) + + # TODO: Put a property to BaseSQLDB so this can be JobDB.tables.Jobs + default_params = [c.name for c in job_db.metadata.tables["Jobs"].columns] + # By default only the contents of the Jobs table is returned + sql_columns = set(default_params) + for table_name in ["JobJDLs"]: + sql_columns |= {c.name for c in job_db.metadata.tables[table_name].columns} + # TODO: Support opensearch + os_fields = set() # set(job_parameters_db.fields) + + # Ensure all of the requested parameters are known + # TODO: Do we need to be able to support arbitrary job parameters? + response_params: list[str] = body.parameters or default_params + if unrecognised_params := set(response_params) - (sql_columns | os_fields): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Requested unknown parameters: {unrecognised_params}", + ) + + # Ensure the search parameters can be satisfied by a single DB technology + search_params = {x["parameter"] for x in body.search} + if unrecognised_params := search_params - (sql_columns | os_fields): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Requested unknown search parameters: {unrecognised_params}", + ) + if not search_params: + sql_search = None + elif search_params.issubset(sql_columns): + # TODO: Limit to indexed columns? + sql_search = True + elif search_params.issubset(os_fields): + sql_search = False + else: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=( + f"Can not search by {search_params - sql_columns} at the " + f"same time as {search_params - os_fields}." + ), + ) + + # Ensure the sort parameters can be satisfied by a single DB technology + sort_parameters = {x["parameter"] for x in body.sort} + if unrecognised_params := sort_parameters - (sql_columns | os_fields): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Requested unknown sort parameters: {unrecognised_params}", + ) + if not sort_parameters: + sql_sort = None + elif sort_parameters.issubset(sql_columns): + sql_sort = True + elif sort_parameters.issubset(os_fields): + sql_sort = False + else: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=( + f"Can not search by {search_params - sql_columns} at the " + f"same time as {search_params - os_fields}." + ), + ) + + # If the request can be satisfied by either DB, prefer SQL + if sql_search is None and sql_sort is None: + sql_search = True + + # Ensure that the search and sort can be done with the same DB technology + sql_search = sql_sort if sql_search is None else sql_search + sql_sort = sql_search if sql_sort is None else sql_sort + if sql_search != sql_sort: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=( + f"Searches by {search_params} can only be sorted by {sql_columns}." + ), + ) + + if sql_search: + results = await job_db.search( + # We don't use a set here as we want to maintain the column ordering + [c for c in response_params if c in sql_columns], + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + if os_params := set(response_params) - sql_columns: + # TODO: Don't forget to maintain column order + raise NotImplementedError( + "TODO: Support querying some parameters from opensearch" + ) + else: + raise NotImplementedError("TODO: Support opensearch") + return results @router.post("/summary") diff --git a/tests/make-token-local.py b/tests/make-token-local.py new file mode 100755 index 00000000..9417b0b6 --- /dev/null +++ b/tests/make-token-local.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +import argparse +import uuid +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from diracx.core.models import TokenResponse +from diracx.core.properties import NORMAL_USER +from diracx.core.utils import write_credentials +from diracx.routers.auth import AuthSettings, create_token + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("token_key", type=Path, help="The key to sign the token with") + args = parser.parse_args() + main(args.token_key.read_text()) + + +def main(token_key): + vo = "diracAdmin" + dirac_group = "admin" + sub = "75212b23-14c2-47be-9374-eb0113b0575e" + preferred_username = "localuser" + dirac_properties = [NORMAL_USER] + settings = AuthSettings(token_key=token_key) + creation_time = datetime.now(tz=timezone.utc) + expires_in = 7 * 24 * 60 * 60 + + access_payload = { + "sub": f"{vo}:{sub}", + "vo": vo, + "aud": settings.token_audience, + "iss": settings.token_issuer, + "dirac_properties": dirac_properties, + "jti": str(uuid.uuid4()), + "preferred_username": preferred_username, + "dirac_group": dirac_group, + "exp": creation_time + timedelta(seconds=expires_in), + } + token = TokenResponse( + access_token=create_token(access_payload, settings), + expires_in=expires_in, + refresh_token=None, + ) + write_credentials(token) + + +if __name__ == "__main__": + parse_args()