Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1868: Azure AI Studio support in Prompts #329

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ celery = {version = "^5.3.6", extras = ["redis"]}
kombu = ">=5.4.0rc2" # Pin version to fix https://github.com/celery/celery/issues/8030. TODO: remove when this fix will be included in celery
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/2547755090341885fedf17181f72290a0b48034a.zip"}
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/c3464d51e546db3a5acf2399d8277db7f76cc79e.zip"}
kafka-python-ng = "^2.2.3"
requests = "^2.32.0"
# Using litellm from forked repo until vertex fix is released: https://github.com/BerriAI/litellm/issues/7904
Expand Down
6 changes: 6 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ async def submit_batch(batch: BatchData):

@app.post("/validate-connection", response_model=Response[ValidateConnectionResponse])
async def validate_connection(request: ValidateConnectionRequest):
# TODO: move this logic to LSE, this is the last place Adala needs to be updated when adding a provider connection
multi_model_provider_test_models = {
"openai": "gpt-4o-mini",
"vertexai": "vertex_ai/gemini-1.5-flash",
Expand Down Expand Up @@ -290,6 +291,9 @@ async def validate_connection(request: ValidateConnectionRequest):
if provider.lower() == "azureopenai":
model = "azure/" + request.deployment_name
model_extra = {"base_url": request.endpoint}
elif provider.lower() == "azureaifoundry":
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
model = "azure_ai/" + request.deployment_name
model_extra = {"base_url": request.endpoint}
elif provider.lower() == "custom":
model = "openai/" + request.deployment_name
model_extra = (
Expand Down Expand Up @@ -364,6 +368,8 @@ async def estimate_cost(
agent = request.agent
provider = request.provider
runtime = agent.get_runtime()
with open("cost_estimate.log", "w") as f:
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
json.dump(request.model_dump(), f, indent=4)

try:
cost_estimates = []
Expand Down
Loading