-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add files for Foundry client and server * Fix Docker requirements * add swagger3.json * place swagger correctly, incorporate feedback * minor * reorganize Dockerfile for faster iteration * dispatch based on HTTP method * minor * docstr * linting * update api.py * simplify * WIP * WIP * WIP * api tests passing * fix * minor * Add missing dev dependencies * Fix formatting * Test blob storage protocol also for Docker image * simplify submission spec * linting * auto-generate swagger-file * add make target to build docker image on ACR * Simplify test and use right Docker image * Rework interaction between client and host * Improve URL parsing * Only require one communication channel * Simplify CI pipeline * Fix url parsing * Add some types * Add test that uses a real blob storage container * Add Docker test for an actual container * Remove deps from notebooks * Add outline of demo * Let the demo depend on env vars * Fix plotting * Remove random data * Increment version * Run demo * Silence warnings * Read the acknowledgement only once * enable aurora logger * Produce fancy visualisation for the notebook * Improve the writing a little * Expand on error message * Expand on error message * Remove abstract Foundry client * Replace split with partition * Prevent completely empty files * Remove image name from docs * Move interactive plotting code to aurora.foundry.demo * Simplify write call * Fix titles in panes * Move H2 out of first cell * Expose size of interactive plot * Improve sentence * Minor improvements * Fix docstrings * Move comment out of docstring * Fix docs * mlflow packaging * minor * load mlflow artifacts from right folder * fix logging * Fix formatting * Fix tests for MLflow workflow * Add missing copyright notice * Fix minor issues * Fix incorrect links, outdated API response, and clarify * Fix pyproject.toml * Update dev deps * Update Foundry demo notebook --------- Co-authored-by: Hannes Schulz <[email protected]> Co-authored-by: Hannes Schulz <[email protected]> Co-authored-by: Maik Riechert <[email protected]>
- Loading branch information
1 parent
060ffbe
commit 00352e2
Showing
27 changed files
with
1,701 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,8 @@ htmlcov | |
.DS_Store | ||
*.swp | ||
.envrc | ||
|
||
checkpoints/ | ||
mlflow_tmp/ | ||
aurora_mlflow_pyfunc/ | ||
mlruns/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" | ||
|
||
from aurora.foundry.client.api import SubmissionError, submit | ||
from aurora.foundry.client.foundry import FoundryClient | ||
from aurora.foundry.common.channel import BlobStorageChannel | ||
|
||
__all__ = [ | ||
"BlobStorageChannel", | ||
"FoundryClient", | ||
"submit", | ||
"SubmissionError", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license. | ||
This is the API that the end user uses to submit jobs to the model running on Azure AI Foundry. | ||
""" | ||
|
||
import logging | ||
from typing import Generator | ||
|
||
from pydantic import BaseModel | ||
|
||
from aurora import Batch | ||
from aurora.foundry.client.foundry import FoundryClient | ||
from aurora.foundry.common.channel import CommunicationChannel, iterate_prediction_files | ||
from aurora.foundry.common.model import models | ||
|
||
__all__ = ["SubmissionError", "submit"] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CreationInfo(BaseModel): | ||
task_id: str | ||
|
||
|
||
class TaskInfo(BaseModel): | ||
task_id: str | ||
completed: bool | ||
progress_percentage: int | ||
success: bool | None | ||
submitted: bool | ||
status: str | ||
|
||
|
||
class SubmissionError(Exception): | ||
"""The submission could not be completed for some reason.""" | ||
|
||
|
||
def submit( | ||
batch: Batch, | ||
model_name: str, | ||
num_steps: int, | ||
channel: CommunicationChannel, | ||
foundry_client: FoundryClient, | ||
) -> Generator[Batch, None, None]: | ||
"""Submit a request to Azure AI Foundry and retrieve the predictions. | ||
Args: | ||
batch (:class:`aurora.Batch`): Initial condition. | ||
model_name (str): Name of the model. This name must be available in | ||
:mod:`aurora.foundry.common.model`. | ||
num_steps (int): Number of prediction steps. | ||
channel (:class:`aurora.foundry.common.channel.CommunicationChannel`): Channel to use for | ||
sending and receiving data. | ||
foundry_client (:class:`aurora.foundry.client.foundry.FoundryClient`): Client to | ||
communicate with Azure Foundry AI. | ||
Yields: | ||
:class:`aurora.Batch`: Predictions. | ||
""" | ||
if model_name not in models: | ||
raise KeyError(f"Model `{model_name}` is not a valid model.") | ||
|
||
# Create a task at the endpoint. | ||
task = { | ||
"model_name": model_name, | ||
"num_steps": num_steps, | ||
"data_folder_uri": channel.to_spec(), | ||
} | ||
response = foundry_client.submit_task(task) | ||
try: | ||
submission_info = CreationInfo(**response) | ||
except Exception as e: | ||
raise SubmissionError("Failed to create task.") from e | ||
task_id = submission_info.task_id | ||
logger.info(f"Created task `{task_id}` at endpoint.") | ||
|
||
# Send the initial condition over. | ||
logger.info("Uploading initial condition.") | ||
channel.send(batch, task_id, "input.nc") | ||
|
||
previous_status: str = "No status" | ||
previous_progress: int = 0 | ||
ack_read: bool = False | ||
|
||
while True: | ||
# Check on the progress of the task. The first progress check will trigger the task to be | ||
# submitted. | ||
response = foundry_client.get_progress(task_id) | ||
task_info = TaskInfo(**response) | ||
|
||
if task_info.submitted and not ack_read: | ||
# If the task has been submitted, we must be able to read the acknowledgement of the | ||
# initial condition. | ||
try: | ||
channel.read(task_id, "input.nc.ack", timeout=120) | ||
ack_read = True # Read the acknowledgement only once. | ||
except TimeoutError as e: | ||
raise SubmissionError( | ||
"Could not read acknowledgement of initial condition. " | ||
"This acknowledgement should be availabe, " | ||
"since the task has been successfully submitted. " | ||
"Something might have gone wrong in the communication " | ||
"between the client and the server. " | ||
"Please check the logs and your SAS token should you be using one." | ||
) from e | ||
|
||
if task_info.status != previous_status: | ||
logger.info(f"Task status update: {task_info.status}") | ||
previous_status = task_info.status | ||
|
||
if task_info.progress_percentage > previous_progress: | ||
logger.info(f"Task progress update: {task_info.progress_percentage}%.") | ||
previous_progress = task_info.progress_percentage | ||
|
||
if task_info.completed: | ||
if task_info.success: | ||
logger.info("Task has been successfully completed!") | ||
break | ||
else: | ||
raise SubmissionError(f"Task failed: {task_info.status}") | ||
|
||
logger.info("Retrieving predictions.") | ||
for prediction_name in iterate_prediction_files("prediction.nc", num_steps): | ||
yield channel.receive(task_id, prediction_name) | ||
logger.info("All predictions have been retrieved.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" | ||
|
||
import json | ||
import logging | ||
|
||
import requests | ||
|
||
__all__ = ["FoundryClient"] | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FoundryClient: | ||
def __init__(self, endpoint: str, token: str) -> None: | ||
"""Initialise. | ||
Args: | ||
endpoint (str): URL to the endpoint. | ||
token (str): Authorisation token. | ||
""" | ||
self.endpoint = endpoint | ||
self.token = token | ||
|
||
def _req( | ||
self, | ||
data: dict | None = None, | ||
) -> requests.Response: | ||
wrapped = {"data": json.dumps(data)} | ||
return requests.request( | ||
"POST", | ||
self.endpoint, | ||
headers={ | ||
"Authorization": f"Bearer {self.token}", | ||
"Content-Type": "application/json", | ||
}, | ||
json={"input_data": wrapped}, | ||
) | ||
|
||
def _unwrap(self, response: requests.Response) -> dict: | ||
if not response.ok: | ||
logger.error(response.text) | ||
response.raise_for_status() | ||
response_json = response.json() | ||
return response_json | ||
|
||
def submit_task(self, data: dict) -> dict: | ||
"""Send `data` to the scoring path. | ||
Args: | ||
data (dict): Data to send. | ||
Returns: | ||
dict: Submission information. | ||
""" | ||
answer = self._req({"type": "submission", "msg": data}) | ||
return self._unwrap(answer) | ||
|
||
def get_progress(self, task_id: str) -> dict: | ||
"""Get the progress of the task. | ||
Args: | ||
task_id (str): Task ID to get progress info for. | ||
Returns: | ||
dict: Progress information. | ||
""" | ||
answer = self._req({"type": "task_info", "msg": {"task_id": task_id}}) | ||
return self._unwrap(answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" |
Oops, something went wrong.