diff --git a/src/rcapi/api/convertor.py b/src/rcapi/api/convertor.py index 31f3753..26b7f53 100644 --- a/src/rcapi/api/convertor.py +++ b/src/rcapi/api/convertor.py @@ -11,13 +11,13 @@ import ramanchada2 as rc2 from pynanomapper.clients.datamodel_simple import StudyRaman from numcompress import compress -from rcapi.services.convertor_service import empty_figure,dict2figure, image, thumbnail, knnquery, recursive_copy -from rcapi.services.kc import inject_api_key_h5pyd, inject_api_key_into_requests, get_api_key +from rcapi.services.convertor_service import empty_figure,dict2figure, solr2image, knnquery, recursive_copy +from rcapi.services.kc import inject_api_key_h5pyd, inject_api_key_into_httpx, get_api_key import h5py, h5pyd +from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION router = APIRouter() -solr_root = "https://solr-kc.ideaconsult.net/solr/" @router.get("/download", ) async def convert_get( @@ -33,7 +33,7 @@ async def convert_get( #tr.set_error("missing domain") raise HTTPException(status_code=400, detail=str("missing domain")) - solr_url = "{}charisma/select".format(solr_root) + solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION) width = w height = h @@ -47,7 +47,7 @@ async def convert_get( FigureCanvas(fig).print_png(output) return Response(content=output.getvalue(), media_type='image/png') - if what in ["dict","thumbnail"]: #solr query + if what in ["dict","thumbnail","b64png","image"]: #solr query if what == "dict": prm = dict(request.query_params) prm["what"] = None @@ -55,29 +55,22 @@ async def convert_get( output = io.BytesIO() FigureCanvas(fig).print_png(output) return Response(content=output.getvalue(), media_type='image/png') - if what == "thumbnail": - with inject_api_key_into_requests(api_key): - fig = thumbnail(solr_url, domain, figsize, extra) + else: + async with inject_api_key_into_httpx(api_key): + fig = await solr2image(solr_url, domain, figsize, extra) output = io.BytesIO() FigureCanvas(fig).print_png(output) - return Response(content=output.getvalue(), media_type='image/png') - + if what == "b64png": + base64_bytes = base64.b64encode(output.getvalue()) + return Response(content=base64_bytes, media_type='text/plain') + else: + return Response(content=output.getvalue(), media_type='image/png') + - elif what in ["knnquery","h5","b64png","image"]: # h5 query + elif what in ["knnquery","h5"]: # h5 query try: with inject_api_key_h5pyd(api_key): - if what == "b64png": - fig = image(domain, dataset) - output = io.BytesIO() - FigureCanvas(fig).print_png(output) - base64_bytes = base64.b64encode(output.getvalue()) - return Response(content=base64_bytes, media_type='text/plain') - elif what == "image": - fig = image(domain, dataset, figsize, extra) - output = io.BytesIO() - FigureCanvas(fig).print_png(output) - return Response(content=output.getvalue(), media_type='image/png') - elif what == "knnquery": + if what == "knnquery": return knnquery(domain, dataset) elif what == "h5": try: @@ -88,6 +81,7 @@ async def convert_get( tmpfile.seek(0) return Response(content=tmpfile.read(), media_type="application/x-hdf5", headers={"Content-Disposition": "attachment; filename=download.h5"}) except Exception as e: + print(traceback.format_exc()) raise HTTPException(status_code=400, detail=f" error: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f" error: {str(e)}") @@ -121,7 +115,6 @@ async def convert_post( if file_extension != ".cha": x, y, _meta = rc2.spectrum.from_local_file(file=uf.file, f_name=f_name) - print(x) if what == "knnquery": _cdf, pdf = StudyRaman.xy2embedding(x, y, StudyRaman.x4search(dim=2048)) result_json = {"cdf": compress(pdf.tolist(), precision=6)} diff --git a/src/rcapi/api/query.py b/src/rcapi/api/query.py index 5ebc642..a2ac637 100644 --- a/src/rcapi/api/query.py +++ b/src/rcapi/api/query.py @@ -3,10 +3,9 @@ from typing import Optional, Literal from rcapi.services import query_service import traceback -router = APIRouter() +from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION -solr_root = "https://solr-kc.ideaconsult.net/solr/" -solr_vector_field = "spectrum_p1024" +router = APIRouter() @router.get("/query", ) async def get_query( @@ -18,7 +17,7 @@ async def get_query( page : Optional[int] = 0, pagesize : Optional[int] = 10, img: Optional[Literal["embedded", "original", "thumbnail"]] = "thumbnail", ): - solr_url = "{}charisma/select".format(solr_root) + solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION) textQuery = q textQuery = "*" if textQuery is None or textQuery=="" else textQuery @@ -36,7 +35,7 @@ async def get_query( page=page, pagesize=pagesize, img=img, - vector_field=solr_vector_field + vector_field=SOLR_VECTOR ) return results except Exception as err: diff --git a/src/rcapi/services/convertor_service.py b/src/rcapi/services/convertor_service.py index 55ffe68..8000575 100644 --- a/src/rcapi/services/convertor_service.py +++ b/src/rcapi/services/convertor_service.py @@ -10,8 +10,7 @@ from numcompress import compress from pynanomapper.clients.datamodel_simple import StudyRaman import h5py, h5pyd -import requests -from .kc import AuthenticatedRequest +from rcapi.services.solr_query import solr_query_post,solr_query_get,SOLR_ROOT,SOLR_COLLECTION,SOLR_VECTOR import traceback def empty_figure(figsize,title,label): @@ -60,33 +59,6 @@ def dict2figure(pm,figsize): axis.set_title(pm["domain"]) return fig -def image(domain,dataset="raw",figsize=(6,4),extraprm=""): - try: - with h5pyd.File(domain,mode="r") as h5: - x = h5[dataset][0] - y = h5[dataset][1] - try: - _sample = h5["annotation_sample"].attrs["sample"] - except: - _sample = None - try: - _provider = h5["annotation_study"].attrs["provider"] - except: - _provider = None - try: - _wavelength = h5["annotation_study"].attrs["wavelength"] - except: - _wavelength = None - fig = Figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - axis.plot(x, y, color='black') - axis.set_ylabel(h5[dataset].dims[1].label) - axis.set_xlabel(h5[dataset].dims[0].label) - axis.title.set_text("{} {} ({}) {}".format(extraprm,_sample,_provider,_wavelength)) - #domain.split("/")[-1],dataset)) - return fig - except Exception as err: - return empty_figure(figsize,"Error","{}".format(domain.split("/")[-1])) def knnquery(domain,dataset="raw"): try: @@ -116,22 +88,25 @@ def knnquery(domain,dataset="raw"): except Exception as err: raise(err) -def thumbnail(solr_url,domain,figsize=(6,4),extraprm=""): +async def solr2image(solr_url,domain,figsize=(6,4),extraprm=None): rs = None try: query="textValue_s:\"{}\"".format(domain.replace(" ","\ ")) - params = {"q": query, "fq" : ["type_s:study"], "fl" : "name_s,textValue_s,reference_s,reference_owner_s,spectrum_p1024"} - rs = solrquery_get(solr_url, params = params) + params = {"q": query, "fq" : ["type_s:study"], "fl" : "name_s,textValue_s,reference_s,reference_owner_s,{}".format(SOLR_VECTOR)} + rs = await solr_query_get(solr_url, params) if rs.status_code==200: - x = StudyRaman.x4search() + x = None for doc in rs.json()["response"]["docs"]: - y = doc["spectrum_p1024"] + y = doc[SOLR_VECTOR] + if x is None: + x = StudyRaman.x4search(len(y)) fig = Figure(figsize=figsize) axis = fig.add_subplot(1, 1, 1) axis.plot(x, y) axis.set_ylabel("a.u.") - axis.set_xlabel("Raman shift [1/cm]") - axis.title.set_text("{} {} {} ({})".format(extraprm,doc["name_s"],doc["reference_owner_s"],doc["reference_s"])) + axis.set_xlabel("Wavenumber [1/cm]") + axis.title.set_text("{} {} {} ({})".format("" if extraprm is None else extraprm, + doc["name_s"],doc["reference_owner_s"],doc["reference_s"])) return fig else: return empty_figure(figsize,"{} {}".format(rs.status_code,rs.reason),"{}".format(domain.split("/")[-1])) diff --git a/src/rcapi/services/kc.py b/src/rcapi/services/kc.py index d6e9088..3a75476 100644 --- a/src/rcapi/services/kc.py +++ b/src/rcapi/services/kc.py @@ -3,8 +3,8 @@ import logging import threading import h5pyd -import requests -from contextlib import contextmanager +import httpx +from contextlib import contextmanager, asynccontextmanager from typing import Optional # Thread-local storage for API key @@ -23,84 +23,6 @@ def get_api_key(authorization: Optional[str] = Header(None)): return None # raise HTTPException(status_code=401, detail="Invalid or missing Authorization header") -class AuthenticatedRequest: - def __init__(self, get_token): - """ - Initialize the AuthenticatedRequest context manager. - - Args: - get_token (callable): A function to retrieve the current token. - - with AuthenticatedRequest(get_token): - requests.get() - ... - """ - self.get_token = get_token - self.original_post = None - - def __enter__(self): - # Save the original requests.post method - self.original_post = requests.post - # Override requests.post with our modified version - requests.post = self.modified_post - - def __exit__(self, exc_type, exc_value, traceback): - # Restore the original requests.post method - requests.post = self.original_post - - - def modified_post(self, url, **kwargs): - # Retrieve the current token - token = self.get_token() - # Ensure the headers dictionary exists - headers = kwargs.setdefault('headers', {}) - # Add the Authorization header, without overwriting other headers - if 'Authorization' not in headers: - headers['Authorization'] = f'Bearer {token}' - # Make the request with the modified headers - return self.original_post(url, **kwargs) - - -# Context manager to inject the API key into a thread-local session for all requests -@contextmanager -def inject_api_key_into_requests(api_key): - """ - Thread-safe context manager to inject the API key into the thread-local `requests.Session` - used for making HTTP requests. - - Each thread will have its own independent `requests.Session`, ensuring that the - `Authorization: Bearer ` header is correctly applied without affecting other threads. - - Parameters: - ----------- - api_key : str - The API key to inject into the requests headers. - - Yields: - ------- - session : requests.Session - The thread-local session object, with the Authorization header pre-configured. - - Example: - -------- - >>> with inject_api_key_into_requests(api_key="your_api_key") as session: - >>> response = session.get("https://example.com/data") - >>> response = session.post("https://example.com/data", json={"key": "value"}) - >>> # The Authorization header will automatically be included in both requests. - - """ - if not hasattr(thread_local, "session"): - thread_local.session = requests.Session() - - # Add Authorization header to the session for this thread - thread_local.session.headers.update({"Authorization": f"Bearer {api_key}"}) - - try: - yield thread_local.session # Provide the session for this thread's context - finally: - # Optional: Clean up if necessary (e.g., closing the session) - pass - # Context manager to temporarily patch h5pyd.File to inject the api_key @contextmanager @@ -163,3 +85,41 @@ def patched_h5pyd_folder(*args, **kwargs): h5pyd.File = original_h5pyd_file h5pyd.Folder = original_h5pyd_folder +@asynccontextmanager +async def inject_api_key_into_httpx(api_key: str): + """ + Thread-safe async context manager to inject the API key into the thread-local `httpx.AsyncClient` + used for making HTTP requests. + + Each thread will have its own independent `httpx.AsyncClient`, ensuring that the + `Authorization: Bearer ` header is correctly applied without affecting other threads. + + Parameters: + ----------- + api_key : str + The API key to inject into the requests headers. + + Yields: + ------- + client : httpx.AsyncClient + The thread-local `AsyncClient` object, with the Authorization header pre-configured. + + Example: + -------- + >>> async with inject_api_key_into_requests(api_key="your_api_key") as client: + >>> response = await client.get("https://example.com/data") + >>> response = await client.post("https://example.com/data", json={"key": "value"}) + >>> # The Authorization header will automatically be included in both requests. + """ + if not hasattr(thread_local, "client"): + thread_local.client = httpx.AsyncClient() + + # Add Authorization header to the client for this thread + thread_local.client.headers.update({"Authorization": f"Bearer {api_key}"}) + + try: + yield thread_local.client # Provide the client for this thread's context + finally: + # Optionally close the client if required + await thread_local.client.aclose() + del thread_local.client # Clean up the client after use \ No newline at end of file diff --git a/src/rcapi/services/query_service.py b/src/rcapi/services/query_service.py index e6cd320..d7c8e17 100644 --- a/src/rcapi/services/query_service.py +++ b/src/rcapi/services/query_service.py @@ -2,36 +2,8 @@ from typing import Optional, Literal from fastapi import Request, HTTPException from numcompress import decompress -import httpx -import traceback +from rcapi.services.solr_query import solr_query_post,solr_query_get -async def solr_query_post(solr_url,query_params = None,post_param = None): - async with httpx.AsyncClient() as client: - try: - response = await client.post( - solr_url, - json=post_param, - params = query_params - ) - response.raise_for_status() # Check for HTTP errors - return response - except httpx.HTTPStatusError as e: - print(traceback.format_exc()) - raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service") - -async def solr_query_get(solr_url,query_params = None): - async with httpx.AsyncClient() as client: - try: - response = await client.get( - solr_url, - params = query_params - ) - response.raise_for_status() # Check for HTTP errors - return response - except httpx.HTTPStatusError as e: - print(traceback.format_exc()) - raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service") - async def process(request: Request, solr_url: str, q: Optional[str] = "*", diff --git a/src/rcapi/services/solr_query.py b/src/rcapi/services/solr_query.py new file mode 100644 index 0000000..84d6b84 --- /dev/null +++ b/src/rcapi/services/solr_query.py @@ -0,0 +1,34 @@ +import httpx +import traceback +from fastapi import HTTPException + +SOLR_ROOT = "https://solr-kc.ideaconsult.net/solr/" +SOLR_VECTOR = "spectrum_p1024" +SOLR_COLLECTION = "charisma" + +async def solr_query_post(solr_url,query_params = None,post_param = None): + async with httpx.AsyncClient() as client: + try: + response = await client.post( + solr_url, + json=post_param, + params = query_params + ) + response.raise_for_status() # Check for HTTP errors + return response + except httpx.HTTPStatusError as e: + print(traceback.format_exc()) + raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service") + +async def solr_query_get(solr_url,params = None): + async with httpx.AsyncClient() as client: + try: + response = await client.get( + solr_url, + params = params + ) + response.raise_for_status() # Check for HTTP errors + return response + except httpx.HTTPStatusError as e: + print(traceback.format_exc()) + raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service") \ No newline at end of file