Skip to content

Commit

Permalink
solr query in separate module
Browse files Browse the repository at this point in the history
async context manager for api key injection
download as image working (via solr)
  • Loading branch information
vedina committed Sep 13, 2024
1 parent 251b86d commit b8c3c82
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 174 deletions.
41 changes: 17 additions & 24 deletions src/rcapi/api/convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import ramanchada2 as rc2
from pynanomapper.clients.datamodel_simple import StudyRaman
from numcompress import compress
from rcapi.services.convertor_service import empty_figure,dict2figure, image, thumbnail, knnquery, recursive_copy
from rcapi.services.kc import inject_api_key_h5pyd, inject_api_key_into_requests, get_api_key
from rcapi.services.convertor_service import empty_figure,dict2figure, solr2image, knnquery, recursive_copy
from rcapi.services.kc import inject_api_key_h5pyd, inject_api_key_into_httpx, get_api_key
import h5py, h5pyd
from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION

router = APIRouter()

solr_root = "https://solr-kc.ideaconsult.net/solr/"

@router.get("/download", )
async def convert_get(
Expand All @@ -33,7 +33,7 @@ async def convert_get(
#tr.set_error("missing domain")
raise HTTPException(status_code=400, detail=str("missing domain"))

solr_url = "{}charisma/select".format(solr_root)
solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION)

width = w
height = h
Expand All @@ -47,37 +47,30 @@ async def convert_get(
FigureCanvas(fig).print_png(output)
return Response(content=output.getvalue(), media_type='image/png')

if what in ["dict","thumbnail"]: #solr query
if what in ["dict","thumbnail","b64png","image"]: #solr query
if what == "dict":
prm = dict(request.query_params)
prm["what"] = None
fig = dict2figure(prm, figsize)
output = io.BytesIO()
FigureCanvas(fig).print_png(output)
return Response(content=output.getvalue(), media_type='image/png')
if what == "thumbnail":
with inject_api_key_into_requests(api_key):
fig = thumbnail(solr_url, domain, figsize, extra)
else:
async with inject_api_key_into_httpx(api_key):
fig = await solr2image(solr_url, domain, figsize, extra)
output = io.BytesIO()
FigureCanvas(fig).print_png(output)
return Response(content=output.getvalue(), media_type='image/png')

if what == "b64png":
base64_bytes = base64.b64encode(output.getvalue())
return Response(content=base64_bytes, media_type='text/plain')
else:
return Response(content=output.getvalue(), media_type='image/png')


elif what in ["knnquery","h5","b64png","image"]: # h5 query
elif what in ["knnquery","h5"]: # h5 query
try:
with inject_api_key_h5pyd(api_key):
if what == "b64png":
fig = image(domain, dataset)
output = io.BytesIO()
FigureCanvas(fig).print_png(output)
base64_bytes = base64.b64encode(output.getvalue())
return Response(content=base64_bytes, media_type='text/plain')
elif what == "image":
fig = image(domain, dataset, figsize, extra)
output = io.BytesIO()
FigureCanvas(fig).print_png(output)
return Response(content=output.getvalue(), media_type='image/png')
elif what == "knnquery":
if what == "knnquery":
return knnquery(domain, dataset)
elif what == "h5":
try:
Expand All @@ -88,6 +81,7 @@ async def convert_get(
tmpfile.seek(0)
return Response(content=tmpfile.read(), media_type="application/x-hdf5", headers={"Content-Disposition": "attachment; filename=download.h5"})
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=400, detail=f" error: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f" error: {str(e)}")
Expand Down Expand Up @@ -121,7 +115,6 @@ async def convert_post(

if file_extension != ".cha":
x, y, _meta = rc2.spectrum.from_local_file(file=uf.file, f_name=f_name)
print(x)
if what == "knnquery":
_cdf, pdf = StudyRaman.xy2embedding(x, y, StudyRaman.x4search(dim=2048))
result_json = {"cdf": compress(pdf.tolist(), precision=6)}
Expand Down
9 changes: 4 additions & 5 deletions src/rcapi/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import Optional, Literal
from rcapi.services import query_service
import traceback
router = APIRouter()
from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION

solr_root = "https://solr-kc.ideaconsult.net/solr/"
solr_vector_field = "spectrum_p1024"
router = APIRouter()

@router.get("/query", )
async def get_query(
Expand All @@ -18,7 +17,7 @@ async def get_query(
page : Optional[int] = 0, pagesize : Optional[int] = 10,
img: Optional[Literal["embedded", "original", "thumbnail"]] = "thumbnail",
):
solr_url = "{}charisma/select".format(solr_root)
solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION)

textQuery = q
textQuery = "*" if textQuery is None or textQuery=="" else textQuery
Expand All @@ -36,7 +35,7 @@ async def get_query(
page=page,
pagesize=pagesize,
img=img,
vector_field=solr_vector_field
vector_field=SOLR_VECTOR
)
return results
except Exception as err:
Expand Down
47 changes: 11 additions & 36 deletions src/rcapi/services/convertor_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from numcompress import compress
from pynanomapper.clients.datamodel_simple import StudyRaman
import h5py, h5pyd
import requests
from .kc import AuthenticatedRequest
from rcapi.services.solr_query import solr_query_post,solr_query_get,SOLR_ROOT,SOLR_COLLECTION,SOLR_VECTOR
import traceback

def empty_figure(figsize,title,label):
Expand Down Expand Up @@ -60,33 +59,6 @@ def dict2figure(pm,figsize):
axis.set_title(pm["domain"])
return fig

def image(domain,dataset="raw",figsize=(6,4),extraprm=""):
try:
with h5pyd.File(domain,mode="r") as h5:
x = h5[dataset][0]
y = h5[dataset][1]
try:
_sample = h5["annotation_sample"].attrs["sample"]
except:
_sample = None
try:
_provider = h5["annotation_study"].attrs["provider"]
except:
_provider = None
try:
_wavelength = h5["annotation_study"].attrs["wavelength"]
except:
_wavelength = None
fig = Figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1)
axis.plot(x, y, color='black')
axis.set_ylabel(h5[dataset].dims[1].label)
axis.set_xlabel(h5[dataset].dims[0].label)
axis.title.set_text("{} {} ({}) {}".format(extraprm,_sample,_provider,_wavelength))
#domain.split("/")[-1],dataset))
return fig
except Exception as err:
return empty_figure(figsize,"Error","{}".format(domain.split("/")[-1]))

def knnquery(domain,dataset="raw"):
try:
Expand Down Expand Up @@ -116,22 +88,25 @@ def knnquery(domain,dataset="raw"):
except Exception as err:
raise(err)

def thumbnail(solr_url,domain,figsize=(6,4),extraprm=""):
async def solr2image(solr_url,domain,figsize=(6,4),extraprm=None):
rs = None
try:
query="textValue_s:\"{}\"".format(domain.replace(" ","\ "))
params = {"q": query, "fq" : ["type_s:study"], "fl" : "name_s,textValue_s,reference_s,reference_owner_s,spectrum_p1024"}
rs = solrquery_get(solr_url, params = params)
params = {"q": query, "fq" : ["type_s:study"], "fl" : "name_s,textValue_s,reference_s,reference_owner_s,{}".format(SOLR_VECTOR)}
rs = await solr_query_get(solr_url, params)
if rs.status_code==200:
x = StudyRaman.x4search()
x = None
for doc in rs.json()["response"]["docs"]:
y = doc["spectrum_p1024"]
y = doc[SOLR_VECTOR]
if x is None:
x = StudyRaman.x4search(len(y))
fig = Figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1)
axis.plot(x, y)
axis.set_ylabel("a.u.")
axis.set_xlabel("Raman shift [1/cm]")
axis.title.set_text("{} {} {} ({})".format(extraprm,doc["name_s"],doc["reference_owner_s"],doc["reference_s"]))
axis.set_xlabel("Wavenumber [1/cm]")
axis.title.set_text("{} {} {} ({})".format("" if extraprm is None else extraprm,
doc["name_s"],doc["reference_owner_s"],doc["reference_s"]))
return fig
else:
return empty_figure(figsize,"{} {}".format(rs.status_code,rs.reason),"{}".format(domain.split("/")[-1]))
Expand Down
120 changes: 40 additions & 80 deletions src/rcapi/services/kc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
import threading
import h5pyd
import requests
from contextlib import contextmanager
import httpx
from contextlib import contextmanager, asynccontextmanager
from typing import Optional

# Thread-local storage for API key
Expand All @@ -23,84 +23,6 @@ def get_api_key(authorization: Optional[str] = Header(None)):
return None
# raise HTTPException(status_code=401, detail="Invalid or missing Authorization header")

class AuthenticatedRequest:
def __init__(self, get_token):
"""
Initialize the AuthenticatedRequest context manager.
Args:
get_token (callable): A function to retrieve the current token.
with AuthenticatedRequest(get_token):
requests.get()
...
"""
self.get_token = get_token
self.original_post = None

def __enter__(self):
# Save the original requests.post method
self.original_post = requests.post
# Override requests.post with our modified version
requests.post = self.modified_post

def __exit__(self, exc_type, exc_value, traceback):
# Restore the original requests.post method
requests.post = self.original_post


def modified_post(self, url, **kwargs):
# Retrieve the current token
token = self.get_token()
# Ensure the headers dictionary exists
headers = kwargs.setdefault('headers', {})
# Add the Authorization header, without overwriting other headers
if 'Authorization' not in headers:
headers['Authorization'] = f'Bearer {token}'
# Make the request with the modified headers
return self.original_post(url, **kwargs)


# Context manager to inject the API key into a thread-local session for all requests
@contextmanager
def inject_api_key_into_requests(api_key):
"""
Thread-safe context manager to inject the API key into the thread-local `requests.Session`
used for making HTTP requests.
Each thread will have its own independent `requests.Session`, ensuring that the
`Authorization: Bearer <api_key>` header is correctly applied without affecting other threads.
Parameters:
-----------
api_key : str
The API key to inject into the requests headers.
Yields:
-------
session : requests.Session
The thread-local session object, with the Authorization header pre-configured.
Example:
--------
>>> with inject_api_key_into_requests(api_key="your_api_key") as session:
>>> response = session.get("https://example.com/data")
>>> response = session.post("https://example.com/data", json={"key": "value"})
>>> # The Authorization header will automatically be included in both requests.
"""
if not hasattr(thread_local, "session"):
thread_local.session = requests.Session()

# Add Authorization header to the session for this thread
thread_local.session.headers.update({"Authorization": f"Bearer {api_key}"})

try:
yield thread_local.session # Provide the session for this thread's context
finally:
# Optional: Clean up if necessary (e.g., closing the session)
pass


# Context manager to temporarily patch h5pyd.File to inject the api_key
@contextmanager
Expand Down Expand Up @@ -163,3 +85,41 @@ def patched_h5pyd_folder(*args, **kwargs):
h5pyd.File = original_h5pyd_file
h5pyd.Folder = original_h5pyd_folder

@asynccontextmanager
async def inject_api_key_into_httpx(api_key: str):
"""
Thread-safe async context manager to inject the API key into the thread-local `httpx.AsyncClient`
used for making HTTP requests.
Each thread will have its own independent `httpx.AsyncClient`, ensuring that the
`Authorization: Bearer <api_key>` header is correctly applied without affecting other threads.
Parameters:
-----------
api_key : str
The API key to inject into the requests headers.
Yields:
-------
client : httpx.AsyncClient
The thread-local `AsyncClient` object, with the Authorization header pre-configured.
Example:
--------
>>> async with inject_api_key_into_requests(api_key="your_api_key") as client:
>>> response = await client.get("https://example.com/data")
>>> response = await client.post("https://example.com/data", json={"key": "value"})
>>> # The Authorization header will automatically be included in both requests.
"""
if not hasattr(thread_local, "client"):
thread_local.client = httpx.AsyncClient()

# Add Authorization header to the client for this thread
thread_local.client.headers.update({"Authorization": f"Bearer {api_key}"})

try:
yield thread_local.client # Provide the client for this thread's context
finally:
# Optionally close the client if required
await thread_local.client.aclose()
del thread_local.client # Clean up the client after use
30 changes: 1 addition & 29 deletions src/rcapi/services/query_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,8 @@
from typing import Optional, Literal
from fastapi import Request, HTTPException
from numcompress import decompress
import httpx
import traceback
from rcapi.services.solr_query import solr_query_post,solr_query_get

async def solr_query_post(solr_url,query_params = None,post_param = None):
async with httpx.AsyncClient() as client:
try:
response = await client.post(
solr_url,
json=post_param,
params = query_params
)
response.raise_for_status() # Check for HTTP errors
return response
except httpx.HTTPStatusError as e:
print(traceback.format_exc())
raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service")

async def solr_query_get(solr_url,query_params = None):
async with httpx.AsyncClient() as client:
try:
response = await client.get(
solr_url,
params = query_params
)
response.raise_for_status() # Check for HTTP errors
return response
except httpx.HTTPStatusError as e:
print(traceback.format_exc())
raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service")

async def process(request: Request,
solr_url: str,
q: Optional[str] = "*",
Expand Down
Loading

0 comments on commit b8c3c82

Please sign in to comment.