Skip to content

Commit

Permalink
Add semaphore to control number of connections that asyncio.gather at…
Browse files Browse the repository at this point in the history
…tempts to instantiate when making inference requests
  • Loading branch information
leothomas committed Nov 21, 2023
1 parent b2d8136 commit 6685499
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
22 changes: 15 additions & 7 deletions cerulean_cloud/cloud_run_orchestrator/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from io import BytesIO
from typing import List, Tuple

import asyncio
import httpx
import morecantile
import numpy as np
Expand Down Expand Up @@ -71,6 +72,7 @@ async def get_base_tile_inference(
self,
tile: morecantile.Tile,
http_client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
rescale=(0, 255),
) -> InferenceResultStack:
"""fetch inference for base tiles"""
Expand All @@ -96,9 +98,10 @@ async def get_base_tile_inference(

inf_stack = [InferenceInput(image=encoded, bounds=TMS.bounds(tile))]
payload = PredictPayload(inf_stack=inf_stack, inf_parms=self.inference_parms)
res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
async with semaphore:
res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
Expand All @@ -108,7 +111,11 @@ async def get_base_tile_inference(
)

async def get_offset_tile_inference(
self, bounds: List[float], http_client: httpx.AsyncClient, rescale=(0, 255)
self,
bounds: List[float],
http_client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
rescale=(0, 255),
) -> InferenceResultStack:
"""fetch inference for offset tiles"""
hw = self.scale * 256
Expand All @@ -133,9 +140,10 @@ async def get_offset_tile_inference(
inf_stack = [InferenceInput(image=encoded, bounds=bounds)]

payload = PredictPayload(inf_stack=inf_stack, inf_parms=self.inference_parms)
res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
async with semaphore:
res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
Expand Down
7 changes: 5 additions & 2 deletions cerulean_cloud/cloud_run_orchestrator/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,15 @@ async def perform_inference(tiles, inference_func, description):
"""
print(f"Inference on {description}!")

semaphore = asyncio.Semaphore(value=25)
async with httpx.AsyncClient(
headers={"Authorization": f"Bearer {os.getenv('API_KEY')}"}
headers={"Authorization": f"Bearer {os.getenv('API_KEY')}"},
timeout=None,
pool_limits=httpx.Limits(max_connections=50),
) as async_http_client:
inferences = await asyncio.gather(
*[
inference_func(tile, async_http_client, rescale=(0, 255))
inference_func(tile, async_http_client, semaphore, rescale=(0, 255))
for tile in tiles
],
return_exceptions=False, # This raises exceptions
Expand Down

0 comments on commit 6685499

Please sign in to comment.