Skip to content

Commit

Permalink
add e2e UI notebook tests to DBPedia text classification notebook (mi…
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Aug 23, 2023
1 parent b06768f commit e5af558
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI e2e notebooks vision
name: CI e2e notebooks text vision

on:
push:
Expand All @@ -8,9 +8,11 @@ on:
paths:
- "raiwidgets/**"
- "responsibleai_vision/**"
- "responsibleai_text/**"
- ".github/workflows/CI-e2e-notebooks-vision.yml"
- "libs/e2e/src/lib/describer/modelAssessment/**"
- "libs/interpret-vision/**"
- "libs/interpret-text/**"
- "notebooks/**"

jobs:
Expand Down Expand Up @@ -43,7 +45,7 @@ jobs:
name: raiwidgets-js
path: raiwidgets/raiwidgets/widget

ci-e2e-notebook-vision:
ci-e2e-notebook-text-vision:
needs: ui-build

env:
Expand All @@ -56,7 +58,7 @@ jobs:
operatingSystem: [ubuntu-latest, windows-latest]
pythonVersion: [3.7, 3.8, 3.9, "3.10"]
flights: [""]
notebookGroup: ["nb_group_1"]
notebookGroup: ["vis_nb_group_1", "text_nb_group_1"]

runs-on: ${{ matrix.operatingSystem }}

Expand Down Expand Up @@ -104,13 +106,28 @@ jobs:
pip install -v -e .
working-directory: raiwidgets

- name: Install vision dependencies
- if: ${{ matrix.notebookGroup == 'vis_nb_group_1'}}
name: Install vision dependencies
shell: bash -l {0}
run: |
pip install -r requirements-dev.txt
pip install .
working-directory: responsibleai_vision

- if: ${{ matrix.notebookGroup == 'text_nb_group_1'}}
name: Install text dependencies
shell: bash -l {0}
run: |
pip install -r requirements-dev.txt
pip install .
working-directory: responsibleai_text

- if: ${{ matrix.notebookGroup == 'text_nb_group_1'}}
name: Setup spacy
shell: bash -l {0}
run: |
python -m spacy download en_core_web_sm
- name: Pip freeze
shell: bash -l {0}
run: |
Expand All @@ -125,14 +142,20 @@ jobs:
path: raiwidgets/installed-requirements-dev.txt

# keep list of notebooks in sync with scripts/e2e-widget.js, create new notebook group if necessary
- if: ${{ matrix.notebookGroup == 'nb_group_1'}}
- if: ${{ matrix.notebookGroup == 'vis_nb_group_1'}}
name: Run widget tests
shell: bash -l {0}
run: |
yarn e2e-widget -n "responsibleaidashboard-fridge-image-classification-model-debugging" -f ${{ matrix.flights }}
yarn e2e-widget -n "responsibleaidashboard-fridge-multilabel-image-classification-model-debugging" -f ${{ matrix.flights }}
yarn e2e-widget -n "responsibleaidashboard-fridge-object-detection-model-debugging" -f ${{ matrix.flights }}
- if: ${{ matrix.notebookGroup == 'text_nb_group_1'}}
name: Run widget tests
shell: bash -l {0}
run: |
yarn e2e-widget -n "responsibleaidashboard-DBPedia-text-classification-model-debugging" -f ${{ matrix.flights }}
- name: Upload e2e test screen shot
if: always()
uses: actions/upload-artifact@v3
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeDatasetExplorer,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape =
modelAssessmentDatasets.DBPediaTextClassificationModelDebugging;
describeDatasetExplorer(
datasetShape,
"DBPediaTextClassificationModelDebugging"
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeErrorAnalysis,
modelAssessmentDatasets
} from "@responsible-ai/e2e";

const datasetShape =
modelAssessmentDatasets.DBPediaTextClassificationModelDebugging;
describeErrorAnalysis(datasetShape, "DBPediaTextClassificationModelDebugging");
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape =
modelAssessmentDatasets.DBPediaTextClassificationModelDebugging;
describeModelOverview(datasetShape, "DBPediaTextClassificationModelDebugging");
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export interface IModelAssessmentData {
isObjectDetection?: boolean;
isMultiLabel?: boolean;
isImageClassification?: boolean;
isTextClassification?: boolean;
}

export interface IErrorAnalysisData {
Expand Down Expand Up @@ -204,5 +205,6 @@ export enum RAINotebookNames {
"OrangeJuiceForecastingDataBalanceExperience" = "responsibleaidashboard-orange-juice-forecasting.py",
"FridgeImageClassificationModelDebugging" = "responsibleaidashboard-fridge-image-classification-model-debugging.py",
"FridgeMultilabelModelDebugging" = "responsibleaidashboard-fridge-multilabel-image-classification-model-debugging.py",
"FridgeObjectDetectionModelDebugging" = "responsibleaidashboard-fridge-object-detection-model-debugging.py"
"FridgeObjectDetectionModelDebugging" = "responsibleaidashboard-fridge-object-detection-model-debugging.py",
"DBPediaTextClassificationModelDebugging" = "responsibleaidashboard-DBPedia-text-classification-model-debugging.py"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

export const DBPediaTextClassificationModelDebugging = {
causalAnalysisData: {
hasCausalAnalysisComponent: false
},
checkDupCohort: true,
cohortDefaultName: "All data",
dataBalanceData: {
aggregateBalanceMeasuresComputed: false,
distributionBalanceMeasuresComputed: false,
featureBalanceMeasuresComputed: false
},
errorAnalysisData: {
hasErrorAnalysisComponent: true
},
featureImportanceData: {
hasFeatureImportanceComponent: false
},
featureNames: ["text"],
isTextClassification: true,
modelOverviewData: {
featureCohortView: {
firstFeatureToSelect: "positive_words",
multiFeatureCohorts: 7,
secondFeatureToSelect: "negative_words",
singleFeatureCohorts: 3
},
hasModelOverviewComponent: true,
initialCohorts: [
{
metrics: {
accuracy: "0.6",
macroF1: "0.649",
macroPrecision: "0.625",
macroRecall: "0.675"
},
name: "All data",
sampleSize: "134"
}
],
newCohort: {
metrics: {
accuracy: "0.9",
macroF1: "0.9",
macroPrecision: "0.9",
macroRecall: "0.9"
},
name: "CohortCreateE2E-text-classification",
sampleSize: "5"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import _ from "lodash";
import { IModelAssessmentData } from "../IModelAssessmentData";

import { CensusClassificationModelDebugging } from "./CensusClassificationModelDebugging";
import { DBPediaTextClassificationModelDebugging } from "./DBPediaTextClassificationModelDebugging";
import { DiabetesDecisionMaking } from "./DiabetesDecisionMaking";
import { DiabetesRegressionModelDebugging } from "./DiabetesRegressionModelDebugging";
import { FridgeImageClassificationModelDebugging } from "./FridgeImageClassificationModelDebugging";
Expand All @@ -21,6 +22,7 @@ export const regExForNumbersWithBrackets = /^\((\d+)\)$/; // Ex: (60)

const modelAssessmentDatasets: { [name: string]: IModelAssessmentData } = {
CensusClassificationModelDebugging,
DBPediaTextClassificationModelDebugging,
DiabetesDecisionMaking,
DiabetesRegressionModelDebugging,
FridgeImageClassificationModelDebugging,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ import { ensureNewCohortsShowUpInCharts } from "./ensureNewCohortsShowUpInCharts

const testName = "Model Overview v2";

function getDashboardName(isVision: boolean, isText: boolean): string {
if (isVision) {
return "modelAssessmentVision";
} else if (isText) {
return "modelAssessmentText";
}
return "modelAssessment";
}

export function describeModelOverview(
datasetShape: IModelAssessmentData,
name?: keyof typeof modelAssessmentDatasetsIncludingFlights,
Expand All @@ -28,15 +37,15 @@ export function describeModelOverview(
datasetShape.isImageClassification
? true
: false;
const isText = datasetShape.isTextClassification ? true : false;
const isTabular = !isVision && !isText;
if (isNotebookTest) {
before(() => {
visit(name);
});
} else {
before(() => {
const dashboardName = isVision
? "modelAssessmentVision"
: "modelAssessment";
const dashboardName = getDashboardName(isVision, isText);
cy.visit(`#/${dashboardName}/${name}/light/english/Version-2`);
});
}
Expand All @@ -48,7 +57,7 @@ export function describeModelOverview(
datasetShape,
false,
isNotebookTest,
isVision
isTabular
);
});

Expand All @@ -68,7 +77,7 @@ export function describeModelOverview(
ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape,
1,
isVision
isTabular
);
});

Expand All @@ -81,16 +90,16 @@ export function describeModelOverview(
ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape,
2,
isVision
isTabular
);
});

it("should show new cohorts in charts", () => {
ensureNewCohortsShowUpInCharts(datasetShape, isNotebookTest, isVision);
ensureNewCohortsShowUpInCharts(datasetShape, isNotebookTest, isTabular);
});

it("should pivot between charts when clicking", () => {
if (!isVision) {
if (isTabular) {
ensureChartsPivot(datasetShape, isNotebookTest, true);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape: IModelAssessmentData,
includeNewCohort: boolean,
isNotebookTest: boolean,
isVision: boolean
isTabular: boolean
): void {
const data = datasetShape.modelOverviewData;
const initialCohorts = data?.initialCohorts;
Expand Down Expand Up @@ -104,7 +104,7 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
"not.exist"
);

if (!isVision) {
if (isTabular) {
if (isNotebookTest) {
cy.get(Locators.ModelOverviewHeatmapCells)
.should(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ import {
export function ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape: IModelAssessmentData,
selectedFeatures: number,
isVision: boolean
isTabular: boolean
): void {
cy.get(Locators.ModelOverviewFeatureSelection).should("exist");
cy.get(Locators.ModelOverviewFeatureConfigurationActionButton).should(
"exist"
);
cy.get(Locators.ModelOverviewDatasetCohortStatsTable).should("not.exist");

if (!isVision) {
if (isTabular) {
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should("exist"); // TODO: check!
cy.get(Locators.ModelOverviewDisaggregatedAnalysisTable).should("exist");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import { ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent } from
export function ensureNewCohortsShowUpInCharts(
datasetShape: IModelAssessmentData,
isNotebookTest: boolean,
isVision: boolean
isTabular: boolean
): void {
cy.get(Locators.ModelOverviewCohortViewDatasetCohortViewButton).click();
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
false,
isNotebookTest,
isVision
isTabular
);
createCohort(datasetShape.modelOverviewData?.newCohort?.name);
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
true,
isNotebookTest,
isVision
isTabular
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"id": "b0804d4d",
"metadata": {},
"source": [
"Next we load the DBPedia dataset from huggingface datasets"
"Next we load the DBPedia dataset from huggingface datasets. Note we use only 6 examples and 8 additional error instances here since it can take some time to compute explanations, especially on CPU. You can increase the NUM_TEST_SAMPLES to 100 or more to get a more interesting dashboard."
]
},
{
Expand All @@ -95,25 +95,24 @@
"metadata": {},
"outputs": [],
"source": [
"NUM_TEST_SAMPLES = 100\n",
"# Bump up the number of examples to 100 or greater to view more\n",
"# information, but it may take longer to compute\n",
"NUM_TEST_SAMPLES = 6\n",
"\n",
"def load_dataset(split):\n",
" dataset = datasets.load_dataset(\"DeveloperOats/DBPedia_Classes\", split=split)\n",
" return pd.DataFrame({\"text\": dataset[\"text\"], \"l1\": dataset[\"l1\"]})\n",
"\n",
"pd_data = load_dataset(\"train\")\n",
"pd_valid_data = load_dataset(\"test\")\n",
"\n",
"def rename_label_column(dataset):\n",
" dataset[\"label\"] = dataset[\"l1\"]\n",
" dataset = dataset.drop(columns=\"l1\")\n",
" return dataset\n",
"\n",
"pd_data = rename_label_column(pd_data)\n",
"pd_valid_data = rename_label_column(pd_valid_data)\n",
"\n",
"START_INDEX = 0\n",
"train_data = pd_data[NUM_TEST_SAMPLES:]\n",
"test_data = pd_valid_data[:NUM_TEST_SAMPLES]"
]
},
Expand Down Expand Up @@ -353,7 +352,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.8.17"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit e5af558

Please sign in to comment.