forked from microsoft/responsible-ai-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add RAITextInsights question answering notebook example and tests (mi…
- Loading branch information
1 parent
f22769f
commit 4721d35
Showing
5 changed files
with
318 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
294 changes: 294 additions & 0 deletions
294
...sponsibleaidashboard/text/responsibleaidashboard-question-answering-model-debugging.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a50648b4", | ||
"metadata": {}, | ||
"source": [ | ||
"# Assess predictions on Stanford Question Answering Dataset (SQuAD) with a huggingface question answering model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "72398086", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook demonstrates the use of the `responsibleai` API to assess a huggingface question answering model on the SQuAD dataset (see https://huggingface.co/datasets/squad for more information about the dataset). It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "739385c6", | ||
"metadata": {}, | ||
"source": [ | ||
"* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n", | ||
" * [Load Model and Data](#Load-Model-and-Data)\n", | ||
" * [Create Model and Data Insights](#Create-Model-and-Data-Insights)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1343e9b0", | ||
"metadata": {}, | ||
"source": [ | ||
"## Launch Responsible AI Toolbox" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ea121102", | ||
"metadata": {}, | ||
"source": [ | ||
"The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "40739025", | ||
"metadata": {}, | ||
"source": [ | ||
"### Load Model and Data\n", | ||
"*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ac9d0df6", | ||
"metadata": {}, | ||
"source": [ | ||
"First we import all necessary dependencies" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "75ef9e91", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import datasets\n", | ||
"import pandas as pd\n", | ||
"from transformers import pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f2132e2f", | ||
"metadata": {}, | ||
"source": [ | ||
"Next we load the SQuAD dataset from huggingface datasets" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3ae1bf09", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset = datasets.load_dataset(\"squad\", split=\"train\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a0eef443", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "42786ee8", | ||
"metadata": {}, | ||
"source": [ | ||
"Reformat the dataset to be a pandas dataframe with three columns: context, questions and answers" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "88cd5fed", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"questions = []\n", | ||
"context = []\n", | ||
"answers = []\n", | ||
"for row in dataset:\n", | ||
" context.append(row['context'])\n", | ||
" questions.append(row['question'])\n", | ||
" answers.append(row['answers']['text'][0])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "051d4017", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data = pd.DataFrame({'context': context, 'questions': questions, 'answers': answers})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e6f87e9c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8e694f20", | ||
"metadata": {}, | ||
"source": [ | ||
"Fetch a huggingface question answering model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d2236e3d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# load the question-answering model\n", | ||
"pmodel = pipeline('question-answering')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "04801887", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_data = data\n", | ||
"test_data = data[:5]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7eec0c5c", | ||
"metadata": {}, | ||
"source": [ | ||
"### Create Model and Data Insights" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1dd2b169", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from responsibleai_text import RAITextInsights, ModelTask\n", | ||
"from raiwidgets import ResponsibleAIDashboard" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c6d97b2c", | ||
"metadata": {}, | ||
"source": [ | ||
"To use Responsible AI Dashboard, initialize a RAITextInsights object upon which different components can be loaded.\n", | ||
"\n", | ||
"RAITextInsights accepts the model, the test dataset, the classes and the task type as its arguments." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "cc8021d3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rai_insights = RAITextInsights(pmodel, test_data,\n", | ||
" \"answers\",\n", | ||
" task_type=ModelTask.QUESTION_ANSWERING)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a0331491", | ||
"metadata": {}, | ||
"source": [ | ||
"Add the components of the toolbox for model assessment." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "526aca04", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rai_insights.error_analysis.add()\n", | ||
"rai_insights.explainer.add()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "210aaaf6", | ||
"metadata": {}, | ||
"source": [ | ||
"Once all the desired components have been loaded, compute insights on the test set." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "901b1863", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rai_insights.compute()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3f844646", | ||
"metadata": {}, | ||
"source": [ | ||
"Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b63dbfa2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ResponsibleAIDashboard(rai_insights)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters