Skip to content

Commit

Permalink
Add information on deploying distillation output model (#3354)
Browse files Browse the repository at this point in the history
* Add information on deploying distillation output model

* Reformat with black

* Fix versioning info

* Add FT model name output method to deploy

* Formatting fixes

* format with black

* add teacher model deployment

* uncomment dependencies for the auto-runner

---------

Co-authored-by: Sharvin Jondhale <[email protected]>
  • Loading branch information
sharvin2187 and Sharvin Jondhale authored Sep 11, 2024
1 parent 28351c9 commit b01a262
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"source": [
"%pip install azure-ai-ml\n",
"%pip install azure-identity\n",
"%pip install azure-core\n",
"%pip install azure-ai-inference\n",
"\n",
"%pip install mlflow\n",
"%pip install azureml-mlflow\n",
Expand All @@ -62,13 +64,18 @@
"\n",
"import base64\n",
"import json\n",
"import os\n",
"import uuid\n",
"\n",
"from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n",
"\n",
"from azure.ai.ml import MLClient, Input\n",
"from azure.ai.inference import ChatCompletionsClient\n",
"from azure.ai.inference.models import SystemMessage, UserMessage\n",
"from azure.ai.ml import Input, MLClient\n",
"from azure.ai.ml.constants import AssetTypes\n",
"from azure.ai.ml.dsl import pipeline\n",
"from azure.ai.ml.entities import Data"
"from azure.ai.ml.entities import Data, ServerlessEndpoint\n",
"from azure.core.credentials import AzureKeyCredential\n",
"from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n",
"from azure.core.exceptions import ResourceNotFoundError"
]
},
{
Expand Down Expand Up @@ -110,6 +117,19 @@
"`DefaultAzureCredential` should be capable of handling most Azure SDK authentication scenarios. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"try:\n",
" credential = DefaultAzureCredential()\n",
" # Check if given credential can get token successfully.\n",
" credential.get_token(\"https://management.azure.com/.default\")\n",
"except Exception as ex:\n",
" # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work\n",
" credential = InteractiveBrowserCredential()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -150,14 +170,7 @@
"source": [
"## Pick a teacher model\n",
"\n",
"We support **Meta-Llama-3.1-405B-Instruct** as the teacher model. \n",
"### First deploy the teacher model in Azure AI Studio\n",
"* Go to Azure AI Studio (ai.azure.com)\n",
"* Select Meta-Llama-3.1-405B-Instruct model from Model catalog.\n",
"* Deploy with \"Pay-as-you-go\"\n",
"* Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.\n",
"\n",
"Update the following cell with the information of the deployment you just created."
"We support **Meta-Llama-3.1-405B-Instruct** as the teacher model. "
]
},
{
Expand All @@ -166,13 +179,23 @@
"metadata": {},
"outputs": [],
"source": [
"# Llama-3-405B Teacher model endpoint name\n",
"# The serverless model name is the name found in ML Studio > Endpoints > Serverless endpoints > Model column\n",
"# We will reuse or create a serverless endpoint\n",
"TEACHER_MODEL_NAME = \"Meta-Llama-3.1-405B-Instruct\"\n",
"TEACHER_MODEL_ENDPOINT_NAME = \"Meta-Llama-3-1-405B-Instruct-vum\"\n",
"\n",
"# The serverless model endpoint name is the name found in ML Studio > Endpoints > Serverless endpoints > Name column\n",
"# The endpoint URL will be resolved from this name by the MLFlow component\n",
"TEACHER_MODEL_ENDPOINT_NAME = \"Meta-Llama-3-1-405B-Instruct-vum\""
"mlclient_azureml_meta = MLClient(credential, registry_name=\"azureml-meta\")\n",
"try:\n",
" ml_client.serverless_endpoints.get(TEACHER_MODEL_ENDPOINT_NAME)\n",
"except ResourceNotFoundError:\n",
" # create the endpoint\n",
" teacher_model_id = (\n",
" \"azureml://registries/azureml-meta/models/Meta-Llama-3.1-405B-Instruct\"\n",
" )\n",
" teacher_endpoint = ServerlessEndpoint(\n",
" name=TEACHER_MODEL_ENDPOINT_NAME,\n",
" model_id=teacher_model_id,\n",
" )\n",
" ml_client.begin_create_or_update(teacher_endpoint).result()"
]
},
{
Expand All @@ -194,7 +217,6 @@
"STUDENT_MODEL_VERSION = 1\n",
"\n",
"# retrieve student model from model registry\n",
"mlclient_azureml_meta = MLClient(credential, registry_name=\"azureml-meta\")\n",
"student_model = mlclient_azureml_meta.models.get(\n",
" STUDENT_MODEL_NAME, version=STUDENT_MODEL_VERSION\n",
")\n",
Expand Down Expand Up @@ -307,7 +329,7 @@
"metadata": {},
"outputs": [],
"source": [
"! mkdir -p data"
"!mkdir data"
]
},
{
Expand All @@ -319,44 +341,44 @@
"train_data_path = \"data/train_conjnli_512.jsonl\"\n",
"valid_data_path = \"data/valid_conjnli_256.jsonl\"\n",
"\n",
"for row in train:\n",
" data = {\"messages\": []}\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.\",\n",
" }\n",
" )\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: \"\n",
" + row[\"premise\"]\n",
" + \"\\nHypothesis: \"\n",
" + row[\"hypothesis\"],\n",
" }\n",
" )\n",
" with open(train_data_path, \"a\") as f:\n",
"with open(train_data_path, \"w+\") as f:\n",
" for row in train:\n",
" data = {\"messages\": []}\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.\",\n",
" }\n",
" )\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: \"\n",
" + row[\"premise\"]\n",
" + \"\\nHypothesis: \"\n",
" + row[\"hypothesis\"],\n",
" }\n",
" )\n",
" f.write(json.dumps(data) + \"\\n\")\n",
"\n",
"for row in val:\n",
" data = {\"messages\": []}\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.\",\n",
" }\n",
" )\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: \"\n",
" + row[\"premise\"]\n",
" + \"\\nHypothesis: \"\n",
" + row[\"hypothesis\"],\n",
" }\n",
" )\n",
" with open(valid_data_path, \"a\") as f:\n",
"with open(valid_data_path, \"w+\") as f:\n",
" for row in val:\n",
" data = {\"messages\": []}\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.\",\n",
" }\n",
" )\n",
" data[\"messages\"].append(\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: \"\n",
" + row[\"premise\"]\n",
" + \"\\nHypothesis: \"\n",
" + row[\"hypothesis\"],\n",
" }\n",
" )\n",
" f.write(json.dumps(data) + \"\\n\")"
]
},
Expand All @@ -375,7 +397,7 @@
"outputs": [],
"source": [
"train_data = None\n",
"train_data_name = \"nli_train_70-70\"\n",
"train_data_name = \"nli_train_70\"\n",
"\n",
"train_data = ml_client.data.create_or_update(\n",
" Data(\n",
Expand Down Expand Up @@ -427,7 +449,7 @@
"metadata": {},
"outputs": [],
"source": [
"ENABLE_CHAIN_OF_THOUGHT = \"true\""
"ENABLE_CHAIN_OF_THOUGHT = \"True\""
]
},
{
Expand Down Expand Up @@ -569,16 +591,114 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Consuming the distilled model\n",
"## Create a serverless endpoint to consume the model (optional)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Wait for the job to complete\n",
"ml_client.jobs.stream(ft_job.name)\n",
"registered_model_name = ml_client.jobs.get(ft_job.name).properties[\n",
" \"registered_ft_model_name\"\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create the model url for registered endpoint\n",
"rg_model_vs = ml_client.models.get(registered_model_name, label=\"latest\")._version\n",
"\n",
"rg_model_asset_id = (\n",
" \"azureml://locations/\"\n",
" f\"{ai_project.location}\"\n",
" \"/workspaces/\"\n",
" f\"{ai_project._workspace_id}\"\n",
" \"/models/\"\n",
" f\"{registered_model_name}\"\n",
" \"/versions/\"\n",
" f\"{rg_model_vs}\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create serverless endpoint - names must be unique, we will use suffix of the model\n",
"short_id = registered_model_name[-9:]\n",
"serverless_endpoint_name = \"my-endpoint-\" + short_id\n",
"\n",
"serverless_endpoint = ServerlessEndpoint(\n",
" name=serverless_endpoint_name,\n",
" model_id=rg_model_asset_id,\n",
")\n",
"\n",
"created_endpoint = ml_client.serverless_endpoints.begin_create_or_update(\n",
" serverless_endpoint\n",
").result()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample inference against the deployed endpoint (optional)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = created_endpoint.scoring_uri\n",
"key = ml_client.serverless_endpoints.get_keys(created_endpoint.name).primary_key\n",
"model = ChatCompletionsClient(\n",
" endpoint=url,\n",
" credential=AzureKeyCredential(key),\n",
")\n",
"\n",
"Once the above job completes, you should be able to deploy the model and use it for inferencing. To deploy this model, do the following:\n",
"response = model.complete(\n",
" messages=[\n",
" SystemMessage(\n",
" content=\"You are a helpful assistant. Your output should only be one of the five choices: 'A', 'B', 'C', 'D', or 'E'.\"\n",
" ),\n",
" UserMessage(\n",
" content=\"Answer the following multiple-choice question by selecting the correct option.\\n\\nQuestion: Can you name a good reason for attending school?\\nAnswer Choices:\\n(A) get smart\\n(B) boredom\\n(C) colds and flu\\n(D) taking tests\\n(E) spend time\"\n",
" ),\n",
" ],\n",
")\n",
"\n",
"* Go to AI Studio\n",
"* Navigate to the Fine-tuning tab on the left menu\n",
"* In the list of models you see, click on the model which got created from the distillation\n",
"* This should take you to the details page where you can see the model attributes and other details\n",
"* Click on the Deploy button on top of the page\n",
"* Follow the steps to deploy the model"
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cleanup endpoints created (optional)\n",
"\n",
"Endpoint deployments are chargeable and incurr costs on the subscription. Optionally clean up the endpoints after finishing experiments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_ = ml_client.serverless_endpoints.begin_delete(TEACHER_MODEL_ENDPOINT_NAME)\n",
"_ = ml_client.serverless_endpoints.begin_delete(serverless_endpoint_name)"
]
}
],
Expand Down
Loading

0 comments on commit b01a262

Please sign in to comment.