Skip to content

Commit

Permalink
add RAITextInsights question answering notebook example and tests (mi…
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Aug 28, 2023
1 parent f22769f commit 4721d35
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI-notebook-text.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
- "raiwidgets/**"
- "responsibleai_text/**"
- ".github/workflows/CI-notebook-text.yml"
- "libs/e2e/src/lib/describer/modelAssessment/**"
- "libs/interpret-text/**"
- "notebooks/**"

jobs:
ci-notebook-text:
Expand Down
10 changes: 9 additions & 1 deletion .github/workflows/CI-responsibleai-text-vision-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,20 @@ jobs:
run: |
pip install -r ${{ matrix.packageDirectory }}/requirements-automl.txt
- name: Install package
- if: ${{ (matrix.packageDirectory == 'responsibleai_vision') }}
name: Install vision package
shell: bash -l {0}
run: |
pip install -v -e .
working-directory: ${{ matrix.packageDirectory }}

- if: ${{ (matrix.packageDirectory == 'responsibleai_text') }}
name: Install text package
shell: bash -l {0}
run: |
pip install -v -e .[qa]
working-directory: ${{ matrix.packageDirectory }}

- name: Run tests
shell: bash -l {0}
run: |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a50648b4",
"metadata": {},
"source": [
"# Assess predictions on Stanford Question Answering Dataset (SQuAD) with a huggingface question answering model"
]
},
{
"cell_type": "markdown",
"id": "72398086",
"metadata": {},
"source": [
"This notebook demonstrates the use of the `responsibleai` API to assess a huggingface question answering model on the SQuAD dataset (see https://huggingface.co/datasets/squad for more information about the dataset). It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model."
]
},
{
"cell_type": "markdown",
"id": "739385c6",
"metadata": {},
"source": [
"* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n",
" * [Load Model and Data](#Load-Model-and-Data)\n",
" * [Create Model and Data Insights](#Create-Model-and-Data-Insights)"
]
},
{
"cell_type": "markdown",
"id": "1343e9b0",
"metadata": {},
"source": [
"## Launch Responsible AI Toolbox"
]
},
{
"cell_type": "markdown",
"id": "ea121102",
"metadata": {},
"source": [
"The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed."
]
},
{
"cell_type": "markdown",
"id": "40739025",
"metadata": {},
"source": [
"### Load Model and Data\n",
"*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*"
]
},
{
"cell_type": "markdown",
"id": "ac9d0df6",
"metadata": {},
"source": [
"First we import all necessary dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75ef9e91",
"metadata": {},
"outputs": [],
"source": [
"import datasets\n",
"import pandas as pd\n",
"from transformers import pipeline"
]
},
{
"cell_type": "markdown",
"id": "f2132e2f",
"metadata": {},
"source": [
"Next we load the SQuAD dataset from huggingface datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ae1bf09",
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_dataset(\"squad\", split=\"train\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0eef443",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"id": "42786ee8",
"metadata": {},
"source": [
"Reformat the dataset to be a pandas dataframe with three columns: context, questions and answers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88cd5fed",
"metadata": {},
"outputs": [],
"source": [
"questions = []\n",
"context = []\n",
"answers = []\n",
"for row in dataset:\n",
" context.append(row['context'])\n",
" questions.append(row['question'])\n",
" answers.append(row['answers']['text'][0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "051d4017",
"metadata": {},
"outputs": [],
"source": [
"data = pd.DataFrame({'context': context, 'questions': questions, 'answers': answers})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6f87e9c",
"metadata": {},
"outputs": [],
"source": [
"data"
]
},
{
"cell_type": "markdown",
"id": "8e694f20",
"metadata": {},
"source": [
"Fetch a huggingface question answering model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2236e3d",
"metadata": {},
"outputs": [],
"source": [
"# load the question-answering model\n",
"pmodel = pipeline('question-answering')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04801887",
"metadata": {},
"outputs": [],
"source": [
"train_data = data\n",
"test_data = data[:5]"
]
},
{
"cell_type": "markdown",
"id": "7eec0c5c",
"metadata": {},
"source": [
"### Create Model and Data Insights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dd2b169",
"metadata": {},
"outputs": [],
"source": [
"from responsibleai_text import RAITextInsights, ModelTask\n",
"from raiwidgets import ResponsibleAIDashboard"
]
},
{
"cell_type": "markdown",
"id": "c6d97b2c",
"metadata": {},
"source": [
"To use Responsible AI Dashboard, initialize a RAITextInsights object upon which different components can be loaded.\n",
"\n",
"RAITextInsights accepts the model, the test dataset, the classes and the task type as its arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc8021d3",
"metadata": {},
"outputs": [],
"source": [
"rai_insights = RAITextInsights(pmodel, test_data,\n",
" \"answers\",\n",
" task_type=ModelTask.QUESTION_ANSWERING)"
]
},
{
"cell_type": "markdown",
"id": "a0331491",
"metadata": {},
"source": [
"Add the components of the toolbox for model assessment."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "526aca04",
"metadata": {},
"outputs": [],
"source": [
"rai_insights.error_analysis.add()\n",
"rai_insights.explainer.add()"
]
},
{
"cell_type": "markdown",
"id": "210aaaf6",
"metadata": {},
"source": [
"Once all the desired components have been loaded, compute insights on the test set."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "901b1863",
"metadata": {},
"outputs": [],
"source": [
"rai_insights.compute()"
]
},
{
"cell_type": "markdown",
"id": "3f844646",
"metadata": {},
"source": [
"Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b63dbfa2",
"metadata": {},
"outputs": [],
"source": [
"ResponsibleAIDashboard(rai_insights)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
10 changes: 10 additions & 0 deletions notebooks/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,13 @@ def test_responsibleaidashboard_blbooksgenre_text_model_debugging():

test_values = {}
assay_one_notebook(nb_path, nb_name, test_values)


@pytest.mark.text_notebooks
def test_responsibleaidashboard_question_answering_model_debugging():
nb_path = TEXT
nb_name = ("responsibleaidashboard-question-" +
"answering-model-debugging")

test_values = {}
assay_one_notebook(nb_path, nb_name, test_values)
3 changes: 2 additions & 1 deletion scripts/e2e-widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ const textFileNames = [
"responsibleaidashboard-blbooksgenre-binary-text-classification-model-debugging"
];
const ignoredFiles = [
"responsibleaidashboard-covid19-event-multilabel-text-classification-model-debugging"
"responsibleaidashboard-covid19-event-multilabel-text-classification-model-debugging",
"responsibleaidashboard-question-answering-model-debugging"
];
const fileNames = tabularFileNames
.concat(visionFileNames)
Expand Down

0 comments on commit 4721d35

Please sign in to comment.