diff --git a/src/rcapi/api/convertor.py b/src/rcapi/api/convertor.py index f2f8256..d3e5131 100644 --- a/src/rcapi/api/convertor.py +++ b/src/rcapi/api/convertor.py @@ -8,15 +8,14 @@ import base64 import traceback import os.path -import ramanchada2 as rc2 + from pynanomapper.clients.datamodel_simple import StudyRaman from numcompress import compress -from rcapi.services.convertor_service import empty_figure,dict2figure, solr2image, knnquery, recursive_copy,read_spectrum_native -#from rcapi.services.kc import inject_api_key_h5pyd, inject_api_key_into_httpx, get_api_key +from rcapi.services.convertor_service import empty_figure,dict2figure, solr2image, recursive_copy,read_spectrum_native +from rcapi.services.kc import get_token import h5py, h5pyd -from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION -import tempfile -import shutil +from rcapi.services.solr_query import SOLR_ROOT,SOLR_COLLECTION + router = APIRouter() @@ -30,8 +29,8 @@ async def convert_get( dataset: Optional[str] = "raw", w: Optional[int] = 300, h: Optional[int] = 200, - extra: Optional[str] = None - #api_key: Optional[str] = Depends(get_api_key) + extra: Optional[str] = None, + token: Optional[str] = Depends(get_token) ) : if not domain: #tr.set_error("missing domain") @@ -62,7 +61,7 @@ async def convert_get( elif what in ["thumbnail","b64png","image"]: #solr query #async with inject_api_key_into_httpx(api_key): try: - fig,etag = await solr2image(solr_url, domain, figsize, extra) + fig,etag = await solr2image(solr_url, domain, figsize, extra,token) # Check if ETag matches the client's If-None-Match header _headers = {} if etag is not None: @@ -86,7 +85,7 @@ async def convert_get( if what == "h5": try: with io.BytesIO() as tmpfile: - with h5pyd.File(domain,mode="r") as fin: + with h5pyd.File(domain,mode="r",api_key=token) as fin: with h5py.File(tmpfile,"w") as fout: recursive_copy(fin,fout) tmpfile.seek(0) @@ -105,8 +104,8 @@ async def convert_get( @router.post("/download") async def convert_post( what: Literal["knnquery", "b64png" ] = Query("knnquery") , - files: list[UploadFile] = File(...) - #api_key: Optional[str] = Depends(get_api_key) + files: list[UploadFile] = File(...), + token: Optional[str] = Depends(get_token) ): logging.info("convert_file function called") logging.info(f"Received parameter 'what': {what}") diff --git a/src/rcapi/api/query.py b/src/rcapi/api/query.py index 06cbadf..a4dc629 100644 --- a/src/rcapi/api/query.py +++ b/src/rcapi/api/query.py @@ -4,6 +4,7 @@ from rcapi.services import query_service import traceback from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION, solr_query_get +from rcapi.services.kc import get_token router = APIRouter() @@ -17,6 +18,7 @@ async def get_query( page : Optional[int] = 0, pagesize : Optional[int] = 10, img: Optional[Literal["embedded", "original", "thumbnail"]] = "thumbnail", vector_field : Optional[str] = None, + token: Optional[str] = Depends(get_token) ): solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION) @@ -36,7 +38,8 @@ async def get_query( page=page, pagesize=pagesize, img=img, - vector_field=SOLR_VECTOR if vector_field is None else vector_field + vector_field=SOLR_VECTOR if vector_field is None else vector_field, + token=token ) return results except Exception as err: @@ -48,12 +51,13 @@ async def get_query( @router.get("/query/field", ) async def get_field( request: Request, - name: str = "publicname_s" + name: str = "publicname_s", + token: Optional[str] = Depends(get_token) ): solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION) try: params= {"q" : "*", "rows" : 0, "facet.field": name, "facet" : "true"} - rs = await solr_query_get(solr_url, params) + rs = await solr_query_get(solr_url, params,token) result = [] # Extract the facet field values facet_field_values = rs.json()["facet_counts"]["facet_fields"][name] diff --git a/src/rcapi/services/convertor_service.py b/src/rcapi/services/convertor_service.py index f177b93..2b16a90 100644 --- a/src/rcapi/services/convertor_service.py +++ b/src/rcapi/services/convertor_service.py @@ -94,14 +94,14 @@ def knnquery(domain,dataset="raw"): def generate_etag(content: str) -> str: return hashlib.md5(content.encode()).hexdigest() -async def solr2image(solr_url: str,domain : str,figsize=(6,4),extraprm =None) -> Tuple[Figure, str]: +async def solr2image(solr_url: str,domain : str,figsize=(6,4),extraprm =None,token : str = None) -> Tuple[Figure, str]: rs = None try: query="textValue_s:{}{}{}".format('"',domain,'"') params = {"q": query, "fq" : ["type_s:study"], "fl" : "name_s,textValue_s,reference_s,reference_owner_s,{},updated_s,_version_".format(SOLR_VECTOR)} - rs = await solr_query_get(solr_url, params) + rs = await solr_query_get(solr_url, params, token = token) if rs is not None and rs.status_code == 200: response_json = rs.json() if "response" in response_json: diff --git a/src/rcapi/services/kc.py b/src/rcapi/services/kc.py index b6d0f31..306480f 100644 --- a/src/rcapi/services/kc.py +++ b/src/rcapi/services/kc.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) -# Dependency to extract API key from Bearer token -def get_api_key(authorization: Optional[str] = Header(None)): +# Dependency to extract Bearer token +def get_token(authorization: Optional[str] = Header(None)): if authorization is None: return None elif authorization.startswith("Bearer "): diff --git a/src/rcapi/services/query_service.py b/src/rcapi/services/query_service.py index 07da4e3..11489a3 100644 --- a/src/rcapi/services/query_service.py +++ b/src/rcapi/services/query_service.py @@ -15,7 +15,8 @@ async def process(request: Request, page: Optional[int] = 0, pagesize: Optional[int] = 10, img: Optional[Literal["embedded", "original", "thumbnail"]] = "thumbnail", - vector_field="spectrum_p1024"): + vector_field="spectrum_p1024", + token=None): query_fields = "id,name_s,textValue_s" embedded_images = img=="embedded" @@ -35,7 +36,7 @@ async def process(request: Request, "reference_s:{}".format(q_reference),"reference_owner_s:{}".format(q_provider)], "fields" : query_fields} try: - response = await solr_query_post(solr_url,query_params,solr_params) + response = await solr_query_post(solr_url,query_params,solr_params,token) response_data = response.json() return parse_solr_response(response_data,get_baseurl(request),embedded_images,thumbnail,vector_field=None) except Exception as err: @@ -56,7 +57,7 @@ async def process(request: Request, "reference_owner_s:{}".format(q_provider) ], "fields" : query_fields} try: - response = await solr_query_post(solr_url,query_params,solr_params) + response = await solr_query_post(solr_url,query_params,solr_params,token) response_data = response.json() return parse_solr_response(response_data,request.base_url,embedded_images,thumbnail,vector_field) except Exception as err: diff --git a/src/rcapi/services/solr_query.py b/src/rcapi/services/solr_query.py index f83c3de..b46237c 100644 --- a/src/rcapi/services/solr_query.py +++ b/src/rcapi/services/solr_query.py @@ -9,25 +9,33 @@ SOLR_COLLECTION = config.SOLR_COLLECTION -async def solr_query_post(solr_url,query_params = None,post_param = None): +async def solr_query_post(solr_url,query_params = None,post_param = None, token = None): async with httpx.AsyncClient() as client: try: + headers = {} + if token: + headers['Authorization'] = f'Bearer {token}' # Add token to headers response = await client.post( solr_url, json=post_param, - params = query_params + params = query_params, + headers=headers # Pass headers ) response.raise_for_status() # Check for HTTP errors return response except httpx.HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service") -async def solr_query_get(solr_url,params = None): +async def solr_query_get(solr_url,params = None, token = None): async with httpx.AsyncClient() as client: try: + headers = {} + if token: + headers['Authorization'] = f'Bearer {token}' # Add token to headers response = await client.get( solr_url, - params = params + params = params, + headers=headers # Pass headers ) response.raise_for_status() # Check for HTTP errors return response