Skip to content

Commit

Permalink
pass bearer tokens to solr and hsds
Browse files Browse the repository at this point in the history
  • Loading branch information
vedina committed Sep 19, 2024
1 parent 9a19a1e commit 46e23ca
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 26 deletions.
23 changes: 11 additions & 12 deletions src/rcapi/api/convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import base64
import traceback
import os.path
import ramanchada2 as rc2

from pynanomapper.clients.datamodel_simple import StudyRaman
from numcompress import compress
from rcapi.services.convertor_service import empty_figure,dict2figure, solr2image, knnquery, recursive_copy,read_spectrum_native
#from rcapi.services.kc import inject_api_key_h5pyd, inject_api_key_into_httpx, get_api_key
from rcapi.services.convertor_service import empty_figure,dict2figure, solr2image, recursive_copy,read_spectrum_native
from rcapi.services.kc import get_token
import h5py, h5pyd
from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION
import tempfile
import shutil
from rcapi.services.solr_query import SOLR_ROOT,SOLR_COLLECTION



router = APIRouter()
Expand All @@ -30,8 +29,8 @@ async def convert_get(
dataset: Optional[str] = "raw",
w: Optional[int] = 300,
h: Optional[int] = 200,
extra: Optional[str] = None
#api_key: Optional[str] = Depends(get_api_key)
extra: Optional[str] = None,
token: Optional[str] = Depends(get_token)
) :
if not domain:
#tr.set_error("missing domain")
Expand Down Expand Up @@ -62,7 +61,7 @@ async def convert_get(
elif what in ["thumbnail","b64png","image"]: #solr query
#async with inject_api_key_into_httpx(api_key):
try:
fig,etag = await solr2image(solr_url, domain, figsize, extra)
fig,etag = await solr2image(solr_url, domain, figsize, extra,token)
# Check if ETag matches the client's If-None-Match header
_headers = {}
if etag is not None:
Expand All @@ -86,7 +85,7 @@ async def convert_get(
if what == "h5":
try:
with io.BytesIO() as tmpfile:
with h5pyd.File(domain,mode="r") as fin:
with h5pyd.File(domain,mode="r",api_key=token) as fin:
with h5py.File(tmpfile,"w") as fout:
recursive_copy(fin,fout)
tmpfile.seek(0)
Expand All @@ -105,8 +104,8 @@ async def convert_get(
@router.post("/download")
async def convert_post(
what: Literal["knnquery", "b64png" ] = Query("knnquery") ,
files: list[UploadFile] = File(...)
#api_key: Optional[str] = Depends(get_api_key)
files: list[UploadFile] = File(...),
token: Optional[str] = Depends(get_token)
):
logging.info("convert_file function called")
logging.info(f"Received parameter 'what': {what}")
Expand Down
10 changes: 7 additions & 3 deletions src/rcapi/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rcapi.services import query_service
import traceback
from rcapi.services.solr_query import SOLR_ROOT,SOLR_VECTOR,SOLR_COLLECTION, solr_query_get
from rcapi.services.kc import get_token

router = APIRouter()

Expand All @@ -17,6 +18,7 @@ async def get_query(
page : Optional[int] = 0, pagesize : Optional[int] = 10,
img: Optional[Literal["embedded", "original", "thumbnail"]] = "thumbnail",
vector_field : Optional[str] = None,
token: Optional[str] = Depends(get_token)
):
solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION)

Expand All @@ -36,7 +38,8 @@ async def get_query(
page=page,
pagesize=pagesize,
img=img,
vector_field=SOLR_VECTOR if vector_field is None else vector_field
vector_field=SOLR_VECTOR if vector_field is None else vector_field,
token=token
)
return results
except Exception as err:
Expand All @@ -48,12 +51,13 @@ async def get_query(
@router.get("/query/field", )
async def get_field(
request: Request,
name: str = "publicname_s"
name: str = "publicname_s",
token: Optional[str] = Depends(get_token)
):
solr_url = "{}{}/select".format(SOLR_ROOT,SOLR_COLLECTION)
try:
params= {"q" : "*", "rows" : 0, "facet.field": name, "facet" : "true"}
rs = await solr_query_get(solr_url, params)
rs = await solr_query_get(solr_url, params,token)
result = []
# Extract the facet field values
facet_field_values = rs.json()["facet_counts"]["facet_fields"][name]
Expand Down
4 changes: 2 additions & 2 deletions src/rcapi/services/convertor_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ def knnquery(domain,dataset="raw"):
def generate_etag(content: str) -> str:
return hashlib.md5(content.encode()).hexdigest()

async def solr2image(solr_url: str,domain : str,figsize=(6,4),extraprm =None) -> Tuple[Figure, str]:
async def solr2image(solr_url: str,domain : str,figsize=(6,4),extraprm =None,token : str = None) -> Tuple[Figure, str]:
rs = None
try:

query="textValue_s:{}{}{}".format('"',domain,'"')
params = {"q": query, "fq" : ["type_s:study"], "fl" : "name_s,textValue_s,reference_s,reference_owner_s,{},updated_s,_version_".format(SOLR_VECTOR)}

rs = await solr_query_get(solr_url, params)
rs = await solr_query_get(solr_url, params, token = token)
if rs is not None and rs.status_code == 200:
response_json = rs.json()
if "response" in response_json:
Expand Down
4 changes: 2 additions & 2 deletions src/rcapi/services/kc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

logger = logging.getLogger(__name__)

# Dependency to extract API key from Bearer token
def get_api_key(authorization: Optional[str] = Header(None)):
# Dependency to extract Bearer token
def get_token(authorization: Optional[str] = Header(None)):
if authorization is None:
return None
elif authorization.startswith("Bearer "):
Expand Down
7 changes: 4 additions & 3 deletions src/rcapi/services/query_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ async def process(request: Request,
page: Optional[int] = 0,
pagesize: Optional[int] = 10,
img: Optional[Literal["embedded", "original", "thumbnail"]] = "thumbnail",
vector_field="spectrum_p1024"):
vector_field="spectrum_p1024",
token=None):

query_fields = "id,name_s,textValue_s"
embedded_images = img=="embedded"
Expand All @@ -35,7 +36,7 @@ async def process(request: Request,
"reference_s:{}".format(q_reference),"reference_owner_s:{}".format(q_provider)],
"fields" : query_fields}
try:
response = await solr_query_post(solr_url,query_params,solr_params)
response = await solr_query_post(solr_url,query_params,solr_params,token)
response_data = response.json()
return parse_solr_response(response_data,get_baseurl(request),embedded_images,thumbnail,vector_field=None)
except Exception as err:
Expand All @@ -56,7 +57,7 @@ async def process(request: Request,
"reference_owner_s:{}".format(q_provider) ],
"fields" : query_fields}
try:
response = await solr_query_post(solr_url,query_params,solr_params)
response = await solr_query_post(solr_url,query_params,solr_params,token)
response_data = response.json()
return parse_solr_response(response_data,request.base_url,embedded_images,thumbnail,vector_field)
except Exception as err:
Expand Down
16 changes: 12 additions & 4 deletions src/rcapi/services/solr_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,33 @@
SOLR_COLLECTION = config.SOLR_COLLECTION


async def solr_query_post(solr_url,query_params = None,post_param = None):
async def solr_query_post(solr_url,query_params = None,post_param = None, token = None):
async with httpx.AsyncClient() as client:
try:
headers = {}
if token:
headers['Authorization'] = f'Bearer {token}' # Add token to headers
response = await client.post(
solr_url,
json=post_param,
params = query_params
params = query_params,
headers=headers # Pass headers
)
response.raise_for_status() # Check for HTTP errors
return response
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail="Error fetching data from external service")

async def solr_query_get(solr_url,params = None):
async def solr_query_get(solr_url,params = None, token = None):
async with httpx.AsyncClient() as client:
try:
headers = {}
if token:
headers['Authorization'] = f'Bearer {token}' # Add token to headers
response = await client.get(
solr_url,
params = params
params = params,
headers=headers # Pass headers
)
response.raise_for_status() # Check for HTTP errors
return response
Expand Down

0 comments on commit 46e23ca

Please sign in to comment.