From 29bfb51675e60564ac617da5a8bd330b4108c0b2 Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Mon, 24 Jul 2023 10:32:44 -0700 Subject: [PATCH] Individual feature importance interpret QA (#2186) * add Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * cleanup Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * fix row change err Signed-off-by: vinutha karanth * address comments Signed-off-by: vinutha karanth --------- Signed-off-by: vinutha karanth --- .../lib/Interfaces/ExplanationInterfaces.ts | 1 + .../Interfaces/TextExplanationInterfaces.ts | 3 + .../TextExplanationDashboard/CommonUtils.ts | 8 + .../getTokenImportancesChartOptions.ts | 83 ++++- .../ITextExplanationViewSpec.ts | 3 +- .../TextExplanationView/SidePanelOfChart.tsx | 6 + .../TextExplanationView.styles.ts | 9 + .../TextExplanationView.tsx | 296 ++++++++---------- .../TextExplanationViewUtils.ts | 100 ++++++ .../TextInputOutputAreaWithLegend.tsx | 94 ++++++ .../TrueAndPredictedAnswerView.tsx | 46 +++ .../TextFeatureLegend.styles.ts | 12 +- .../TextFeatureLegend/TextFeatureLegend.tsx | 29 +- .../TextHighlighting.styles.ts | 72 +++-- .../TextHighlighting/TextHightlighting.tsx | 32 +- .../Interfaces/IChartProps.ts | 5 + .../Interfaces/IExplanationDashboardProps.ts | 3 + libs/localization/src/lib/en.json | 11 +- .../Controls/FeatureImportances.tsx | 1 - .../TextLocalImportancePlots.tsx | 18 +- 20 files changed, 595 insertions(+), 237 deletions(-) create mode 100644 libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts create mode 100644 libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx create mode 100644 libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx diff --git a/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts b/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts index 2546b38352..a7087ca029 100644 --- a/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts +++ b/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts @@ -57,6 +57,7 @@ export interface IPrecomputedExplanations { export interface ITextFeatureImportance { text: string[]; localExplanations: number[][]; + baseValues?: number[][]; } export interface IEBMGlobalExplanation { diff --git a/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts b/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts index 9923f9c6f1..965915c3da 100644 --- a/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts +++ b/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts @@ -6,4 +6,7 @@ export interface ITextExplanationDashboardData { localExplanations: number[][]; prediction: number[]; text: string[]; + baseValues?: number[][]; + predictedY?: number[] | number[][] | string[] | string | number; + trueY?: number[] | number[][] | string[] | string | number; } diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts index 03ca37e2de..8b59ba9c0a 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts @@ -78,6 +78,14 @@ export class Utils { return sortedList; } + public static addItem(value: number, radio: string | undefined): boolean { + return ( + radio === RadioKeys.All || + (radio === RadioKeys.Neg && value <= 0) || + (radio === RadioKeys.Pos && value >= 0) + ); + } + public static takeTopK(list: number[], k: number): number[] { /* * Returns a list after splicing and taking the top K diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts index bd9d2612a3..cc7f2630c1 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts @@ -2,17 +2,27 @@ // Licensed under the MIT License. import { ITheme } from "@fluentui/react"; -import { - IHighchartsConfig, - getPrimaryChartColor, - getPrimaryBackgroundChartColor -} from "@responsible-ai/core-ui"; +import { IHighchartsConfig } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; import { SeriesOptionsType } from "highcharts"; +import _ from "lodash"; import { Utils } from "../../CommonUtils"; import { IChartProps } from "../../Interfaces/IChartProps"; +function findNearestIndex( + array: number[], + target?: number +): number | undefined { + if (!target) { + return array.length; + } + const nearestElement = _.minBy(array, (element) => + Math.abs(element - target) + ); + return _.indexOf(array, nearestElement); +} + export function getTokenImportancesChartOptions( props: IChartProps, theme: ITheme @@ -20,6 +30,11 @@ export function getTokenImportancesChartOptions( const importances = props.localExplanations; const k = props.topK; const sortedList = Utils.sortedTopK(importances, k, props.radio); + + const outputFeatureImportanceLabel = `f ${ + props.text[props.selectedTokenIndex || 0] + } (inputs)`; + const baseValueLabel = "base value"; const [x, y, ylabel, tooltip]: [number[], number[], string[], string[]] = [ [], [], @@ -46,6 +61,36 @@ export function getTokenImportancesChartOptions( ylabel.push(props.text[idx]); tooltip.push(str); }); + + // add output feature importance + if (props.outputFeatureValue && props.baseValue) { + const outputFeatureValueIndex = findNearestIndex( + x, + props.outputFeatureValue + ); + const baseValueFeatureValueIndex = findNearestIndex(x, props.baseValue); + if (outputFeatureValueIndex && baseValueFeatureValueIndex) { + if (Utils.addItem(props.outputFeatureValue, props.radio)) { + addItem( + x, + props.outputFeatureValue, + ylabel, + outputFeatureImportanceLabel, + outputFeatureValueIndex + ); + } + if (Utils.addItem(props.baseValue, props.radio)) { + addItem( + x, + props.baseValue, + ylabel, + baseValueLabel, + baseValueFeatureValueIndex + ); + } + } + } + // Put most significant word at the top by reversing order tooltip.reverse(); ylabel.reverse(); @@ -54,11 +99,10 @@ export function getTokenImportancesChartOptions( const data: any[] = []; x.forEach((p, index) => { const temp = { - borderColor: getPrimaryChartColor(theme), color: (p || 0) >= 0 - ? getPrimaryChartColor(theme) - : getPrimaryBackgroundChartColor(theme), + ? theme.semanticColors.errorText + : theme.semanticColors.link, x: index, y: p }; @@ -68,6 +112,15 @@ export function getTokenImportancesChartOptions( const series: SeriesOptionsType[] = [ { data, + dataLabels: { + align: "center", + color: theme.semanticColors.bodyBackground, + enabled: true, + formatter(): string | number | undefined { + return this.x; // Display the Y-axis value inside the bar + }, + inside: true + }, name: "", showInLegend: false, type: "bar" @@ -80,11 +133,12 @@ export function getTokenImportancesChartOptions( }, plotOptions: { bar: { + minPointLength: 10, tooltip: { pointFormatter(): string { return `${tooltip[this.x || 0]}: ${this.y || 0}`; } - } + } // Set the minimum pixel width for bars } }, series, @@ -98,3 +152,14 @@ export function getTokenImportancesChartOptions( } }; } + +function addItem( + x: any[], + xValue: any, + yLabel: any[], + yLabelValue: any, + index: number +): void { + x.splice(index, 0, xValue); + yLabel.splice(index, 0, yLabelValue); +} diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts index 8255279c61..0fe5c3364c 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts @@ -13,12 +13,13 @@ export interface ITextExplanationViewState { maxK: number; topK: number; radio: string; - // qaRadio?: string; + qaRadio?: string; importances: number[]; singleTokenImportances: number[]; selectedToken: number; tokenIndexes: number[]; text: string[]; + outputFeatureImportances: number[][]; } export const options: IChoiceGroupOption[] = [ diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx index 4b5e412e24..8e0e453227 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx @@ -38,6 +38,9 @@ export interface ISidePanelOfChartProps { selectedWeightVector: WeightVectorOption; weightOptions: WeightVectorOption[]; weightLabels: any; + baseValue?: number; + outputFeatureValue?: number; + selectedTokenIndex?: number; changeRadioButton: ( _event?: React.FormEvent, item?: IChoiceGroupOption @@ -63,6 +66,9 @@ export class SidePanelOfChart extends React.PureComponent diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts index cb77508059..9f38b9d5fe 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts @@ -11,16 +11,25 @@ import { export interface ITextExplanationDashboardStyles { chartRight: IStyle; textHighlighting: IStyle; + predictedAnswer: IStyle; + boldText: IStyle; } export const textExplanationDashboardStyles: () => IProcessedStyleSet = () => { const theme = getTheme(); return mergeStyleSets({ + boldText: { + fontWeight: "bold" + }, chartRight: { maxWidth: "230px", minWidth: "230px" }, + predictedAnswer: { + fontWeight: "bold", + paddingBottom: "14px" + }, textHighlighting: { borderColor: theme.semanticColors.variantBorder, borderRadius: "1px", diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx index de8c28a28f..636862564e 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx @@ -2,24 +2,29 @@ // Licensed under the MIT License. import { IChoiceGroupOption, Stack, Text } from "@fluentui/react"; -import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; +import { WeightVectorOption } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; import React from "react"; -import { RadioKeys, Utils } from "../../CommonUtils"; +import { QAExplanationType, RadioKeys } from "../../CommonUtils"; import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps"; -import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend"; -import { TextHighlighting } from "../TextHighlighting/TextHightlighting"; import { ITextExplanationViewState, - MaxImportantWords, componentStackTokens } from "./ITextExplanationViewSpec"; import { SidePanelOfChart } from "./SidePanelOfChart"; -import { textExplanationDashboardStyles } from "./TextExplanationView.styles"; - -export class TextExplanationView extends React.PureComponent< +import { + calculateMaxKImportances, + calculateTopKImportances, + computeImportancesForAllTokens, + computeImportancesForWeightVector, + getOutputFeatureImportances +} from "./TextExplanationViewUtils"; +import { TextInputOutputAreaWithLegend } from "./TextInputOutputAreaWithLegend"; +import { TrueAndPredictedAnswerView } from "./TrueAndPredictedAnswerView"; + +export class TextExplanationView extends React.Component< ITextExplanationViewProps, ITextExplanationViewState > { @@ -31,23 +36,32 @@ export class TextExplanationView extends React.PureComponent< const weightVector = this.props.selectedWeightVector; const importances = this.props.isQA - ? this.computeImportancesForAllTokens( - this.props.dataSummary.localExplanations + ? computeImportancesForAllTokens( + this.props.dataSummary.localExplanations, + true ) - : this.computeImportancesForWeightVector( + : computeImportancesForWeightVector( this.props.dataSummary.localExplanations, weightVector ); - const maxK = this.calculateMaxKImportances(importances); - const topK = this.calculateTopKImportances(importances); + const maxK = calculateMaxKImportances(importances); + const topK = calculateTopKImportances(importances); this.state = { importances, maxK, - // qaRadio: QAExplanationType.Start, + outputFeatureImportances: getOutputFeatureImportances( + this.props.dataSummary.localExplanations, + this.props.dataSummary.baseValues + ), + qaRadio: QAExplanationType.Start, radio: RadioKeys.All, - selectedToken: 0, // default to the first token - singleTokenImportances: this.getImportanceForSingleToken(0), // get importance for first token + selectedToken: 0, + // default to the first token + singleTokenImportances: this.props.dataSummary.localExplanations[0].map( + (row) => row[0] + ), + // get importance for first token text: this.props.dataSummary.text, tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index), topK @@ -60,28 +74,19 @@ export class TextExplanationView extends React.PureComponent< this.props.dataSummary.localExplanations !== prevProps.dataSummary.localExplanations ) { - if (this.props.isQA) { - this.setState( - { - selectedToken: 0, - //update token dropdown - tokenIndexes: [...this.props.dataSummary.text].map( - (_, index) => index - ) - }, - () => { - this.updateTokenImportances(); - this.updateSingleTokenImportances(); - } - ); - } else { - this.updateImportances(this.props.selectedWeightVector); - } + this.updateState(); } } public render(): React.ReactNode { - const classNames = textExplanationDashboardStyles(); + const outputLocalExplanations = + this.state.qaRadio === QAExplanationType.Start + ? this.state.outputFeatureImportances[0] + : this.state.outputFeatureImportances[1]; + const inputLocalExplanations = this.props.isQA + ? this.state.singleTokenImportances + : this.state.importances; + const baseValue = this.props.isQA ? this.getBaseValue() : undefined; return ( @@ -93,9 +98,30 @@ export class TextExplanationView extends React.PureComponent< )} + + {this.props.isQA && ( + + )} + + + + - - - - - - - {this.props.isQA && ( - - - - )} - - - - - ); } - private onWeightVectorChange = (weightOption: WeightVectorOption): void => { - this.updateImportances(weightOption); - this.props.onWeightChange(weightOption); - }; - - private onSelectedTokenChange = (newIndex: number): void => { - this.setState({ selectedToken: newIndex }, () => { - this.updateSingleTokenImportances(); - }); - }; - - private updateImportances(weightOption: WeightVectorOption): void { - const importances = this.computeImportancesForWeightVector( - this.props.dataSummary.localExplanations, - weightOption - ); - - const topK = this.calculateTopKImportances(importances); - const maxK = this.calculateMaxKImportances(importances); + private updateState(): void { + const importances = this.props.isQA + ? this.getTokenImportances() + : this.getImportances(this.props.selectedWeightVector); + const [topK, maxK] = this.getTopKMaxK(importances); this.setState({ importances, maxK, + outputFeatureImportances: getOutputFeatureImportances( + this.props.dataSummary.localExplanations, + this.props.dataSummary.baseValues + ), + selectedToken: 0, + singleTokenImportances: this.getImportanceForSingleToken( + this.state.selectedToken + ), text: this.props.dataSummary.text, + tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index), topK }); } - // for QA - private updateTokenImportances(): void { - const importances = this.computeImportancesForAllTokens( - this.props.dataSummary.localExplanations - ); - const topK = this.calculateTopKImportances(importances); - const maxK = this.calculateMaxKImportances(importances); + private onWeightVectorChange = (weightOption: WeightVectorOption): void => { + const importances = this.getImportances(weightOption); + const [topK, maxK] = this.getTopKMaxK(importances); + this.setState({ importances, maxK, topK }); + this.props.onWeightChange(weightOption); + }; + + private onSelectedTokenChange = (newIndex: number): void => { + const singleTokenImportances = this.getImportanceForSingleToken(newIndex); this.setState({ - importances, - maxK, - text: this.props.dataSummary.text, - topK + selectedToken: newIndex, + singleTokenImportances }); - } + }; - private updateSingleTokenImportances(): void { - const singleTokenImportances = this.getImportanceForSingleToken( - this.state.selectedToken - ); - this.setState({ singleTokenImportances }); - } + private getSelectedWord = (): string => { + return this.props.dataSummary.text[this.state.selectedToken]; + }; - private calculateTopKImportances(importances: number[]): number { - return Math.min( - MaxImportantWords, - Math.ceil(Utils.countNonzeros(importances) / 2) - ); + private getTopKMaxK(importances: number[]): [number, number] { + const topK = calculateTopKImportances(importances); + const maxK = calculateMaxKImportances(importances); + return [topK, maxK]; } - private calculateMaxKImportances(importances: number[]): number { - return Math.min( - MaxImportantWords, - Math.ceil(Utils.countNonzeros(importances)) + private getImportances(weightOption: WeightVectorOption): number[] { + return computeImportancesForWeightVector( + this.props.dataSummary.localExplanations, + weightOption ); } - private computeImportancesForWeightVector( - importances: number[][], - weightVector: WeightVectorOption - ): number[] { - if (weightVector === WeightVectors.AbsAvg) { - // Sum the multidimensional array to one dimension across rows for each token - const numClasses = importances[0].length; - const sumImportances = importances.map((row) => - row.reduce((a, b): number => { - return (a + Math.abs(b)) / numClasses; - }, 0) - ); - return sumImportances; - } - return importances.map( - (perClassImportances) => perClassImportances[weightVector as number] + // for QA + private getTokenImportances(): number[] { + return computeImportancesForAllTokens( + this.props.dataSummary.localExplanations ); } - private computeImportancesForAllTokens(importances: number[][]): number[] { - /* - * sum the tokens importance - * TODO: add base values? - */ - - const sumImportances = importances[0].map((_, index) => - importances.reduce((sum, row) => sum + row[index], 0) + private getImportanceForSingleToken(index: number): number[] { + const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1; + return this.props.dataSummary.localExplanations[expIndex].map( + (row) => row[index] ); - - return sumImportances; } - private getImportanceForSingleToken(index: number): number[] { - return this.props.dataSummary.localExplanations.map((row) => row[index]); + private getBaseValue(): number { + if (this.props.dataSummary.baseValues) { + const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1; + return this.props.dataSummary.baseValues?.[expIndex][ + this.state.selectedToken + ]; + } + return 0; } private setTopK = (newNumber: number): void => { - /* - * Changes the state of K - */ this.setState({ topK: newNumber }); }; @@ -265,23 +231,23 @@ export class TextExplanationView extends React.PureComponent< _event?: React.FormEvent, item?: IChoiceGroupOption ): void => { - /* - * Changes the state of the radio button - */ - if (item?.key !== undefined) { + if (item?.key) { this.setState({ radio: item.key }); } }; - private switchQAPrediction = (): // _event?: React.FormEvent, - // _item?: IChoiceGroupOption - void => { - /* - * switch to the target predictions(starting or ending) - * TODO: add logic for switching explanation data - */ - // if (item?.key !== undefined) { - // this.setState({ qaRadio: item.key }); - // } + private switchQAPrediction = ( + _event?: React.FormEvent, + item?: IChoiceGroupOption + ): void => { + if (item?.key) { + const singleTokenImportances = this.getImportanceForSingleToken( + this.state.selectedToken + ); + this.setState({ + qaRadio: item.key, + singleTokenImportances + }); + } }; } diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts new file mode 100644 index 0000000000..313deaee99 --- /dev/null +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; + +import { QAExplanationType, Utils } from "../../CommonUtils"; + +import { MaxImportantWords } from "./ITextExplanationViewSpec"; + +export function getOutputFeatureImportances( + localExplanations: number[][], + baseValues?: number[][] +): number[][] { + const startSumOfFeatureImportances = getSumOfFeatureImportances( + localExplanations[0] + ); + const endSumOfFeatureImportances = getSumOfFeatureImportances( + localExplanations[1] + ); + const startOutputFeatureImportances = getOutputFeatureImportancesIntl( + startSumOfFeatureImportances, + baseValues?.[0] + ); + const endOutputFeatureImportances = getOutputFeatureImportancesIntl( + endSumOfFeatureImportances, + baseValues?.[1] + ); + return [ + startOutputFeatureImportances || [], + endOutputFeatureImportances || [] + ]; +} + +export function getSumOfFeatureImportances(importances: number[]): number[] { + return importances.map((_, index) => + importances.reduce((sum, row) => sum + row[index], 0) + ); +} + +export function getOutputFeatureImportancesIntl( + sumOfFeatureImportances: number[], + baseValues?: number[] +): number[] | undefined { + return baseValues?.map( + (bValue, index) => sumOfFeatureImportances[index] + bValue + ); +} + +export function calculateTopKImportances(importances: number[]): number { + return Math.min( + MaxImportantWords, + Math.ceil(Utils.countNonzeros(importances) / 2) + ); +} + +export function calculateMaxKImportances(importances: number[]): number { + return Math.min( + MaxImportantWords, + Math.ceil(Utils.countNonzeros(importances)) + ); +} + +export function computeImportancesForWeightVector( + importances: number[][], + weightVector: WeightVectorOption +): number[] { + if (weightVector === WeightVectors.AbsAvg) { + // Sum the multidimensional array to one dimension across rows for each token + const numClasses = importances[0].length; + const sumImportances = importances.map((row) => + row.reduce((a, b): number => { + return (a + Math.abs(b)) / numClasses; + }, 0) + ); + return sumImportances; + } + return importances.map( + (perClassImportances) => perClassImportances[weightVector as number] + ); +} + +export function computeImportancesForAllTokens( + importances: number[][], + isInitialState?: boolean, + qaRadio?: string +): number[] { + const startSumImportances = importances[0].map((_, index) => + importances.reduce((sum, row) => sum + row[index], 0) + ); + const endSumImportances = importances[1].map((_, index) => + importances.reduce((sum, row) => sum + row[index], 0) + ); + if (isInitialState) { + return startSumImportances; + } + + return qaRadio === QAExplanationType.Start + ? startSumImportances + : endSumImportances; +} diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx new file mode 100644 index 0000000000..579d56bc0a --- /dev/null +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { Stack, Text } from "@fluentui/react"; +import { localization } from "@responsible-ai/localization"; +import React from "react"; + +import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend"; +import { TextHighlighting } from "../TextHighlighting/TextHightlighting"; + +import { componentStackTokens } from "./ITextExplanationViewSpec"; +import { textExplanationDashboardStyles } from "./TextExplanationView.styles"; + +interface ITextInputOutputAreaWithLegendProps { + topK: number; + radio: string; + selectedToken: number; + text: string[]; + outputLocalExplanations: number[]; + inputLocalExplanations: number[]; + isQA?: boolean; + getSelectedWord: () => string; + onSelectedTokenChange: (newIndex: number) => void; +} + +export class TextInputOutputAreaWithLegend extends React.Component { + public render(): React.ReactNode { + const classNames = textExplanationDashboardStyles(); + + return ( + + {this.props.isQA && ( + + + + + {localization.InterpretText.View.outputs} + + + + + + + + )} + + + + + {localization.InterpretText.View.inputs} + + + + + + + + + + + + ); + } +} diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx new file mode 100644 index 0000000000..bc94fa8247 --- /dev/null +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { Stack, Text } from "@fluentui/react"; +import { localization } from "@responsible-ai/localization"; +import React from "react"; + +import { textExplanationDashboardStyles } from "./TextExplanationView.styles"; + +interface ITrueAndPredictedAnswerViewProps { + predictedY: string | number | number[] | string[] | number[][] | undefined; + trueY: string | number | number[] | string[] | number[][] | undefined; +} + +export class TrueAndPredictedAnswerView extends React.Component { + public render(): React.ReactNode { + const classNames = textExplanationDashboardStyles(); + + return ( + + + + + {localization.InterpretText.View.predictedAnswer}   + + + + + {this.props.predictedY} + + + + + + + {localization.InterpretText.View.trueAnswer}   + + + + {this.props.trueY} + + + + ); + } +} diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts index a72c4b0993..38fc87aecf 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts @@ -7,10 +7,6 @@ import { IProcessedStyleSet, getTheme } from "@fluentui/react"; -import { - getPrimaryBackgroundChartColor, - getPrimaryChartColor -} from "@responsible-ai/core-ui"; export interface ITextFeatureLegendStyles { legend: IStyle; @@ -26,12 +22,12 @@ export const textFeatureLegendStyles: () => IProcessedStyleSet { public render(): React.ReactNode { const classNames = textFeatureLegendStyles(); return ( @@ -51,6 +56,28 @@ export class TextFeatureLegend extends React.Component { + {this.props.isQA && ( + + + {localization.InterpretText.Legend.cls} + + + {localization.InterpretText.Legend.sep} + + + + + + {localization.InterpretText.Legend.selectedWord}   + + + + {this.props.selectedWord} + + + + + )} ); } diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts index bff576c338..c98ee8bb88 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts @@ -2,14 +2,13 @@ // Licensed under the MIT License. import { - IStyle, - mergeStyles, - mergeStyleSets, IProcessedStyleSet, IStackStyles, - getTheme + IStyle, + getTheme, + mergeStyleSets, + mergeStyles } from "@fluentui/react"; -import { getPrimaryChartColor } from "@responsible-ai/core-ui"; export const textStackStyles: IStackStyles = { root: { @@ -31,30 +30,41 @@ export interface ITextHighlightingStyles { boldunderline: IStyle; } -export const textHighlightingStyles: () => IProcessedStyleSet = - () => { - const theme = getTheme(); - const normal = { - color: theme.semanticColors.bodyText - }; - return mergeStyleSets({ - boldunderline: mergeStyles([ - normal, - { - color: getPrimaryChartColor(theme), - fontSize: theme.fonts.large.fontSize, - margin: "2px", - padding: 0, - textDecorationLine: "underline" - } - ]), - highlighted: mergeStyles([ - normal, - { - backgroundColor: getPrimaryChartColor(theme), - color: theme.semanticColors.bodyBackground - } - ]), - normal - }); +export const textHighlightingStyles: ( + isTextSelected: boolean +) => IProcessedStyleSet = (isTextSelected) => { + const theme = getTheme(); + const normal = { + color: theme.semanticColors.bodyText }; + const selectedTextStyle = isTextSelected + ? { + textDecorationColor: "black", + textDecorationLine: "underline", + textDecorationStyle: "solid", + textDecorationThickness: "4px" + } + : {}; + return mergeStyleSets({ + boldunderline: mergeStyles([ + normal, + { + backgroundColor: theme.semanticColors.link, + color: theme.semanticColors.bodyBackground, + fontSize: theme.fonts.large.fontSize, + margin: "2px", + padding: 0 + }, + selectedTextStyle + ]), + highlighted: mergeStyles([ + normal, + selectedTextStyle, + { + backgroundColor: theme.semanticColors.errorText, + color: theme.semanticColors.bodyBackground + } + ]), + normal: mergeStyles([normal, selectedTextStyle]) + }); +}; diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx index 2fbac58256..1101cfd0b4 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx @@ -2,7 +2,6 @@ // Licensed under the MIT License. import { - Label, Text, Stack, IStackTokens, @@ -25,12 +24,11 @@ const textStackTokens: IStackTokens = { padding: "s2" }; -export class TextHighlighting extends React.PureComponent { +export class TextHighlighting extends React.Component { /* * Presents the document in an accessible manner with text highlighting */ public render(): React.ReactNode { - const classNames = textHighlightingStyles(); const text = this.props.text; const importances = this.props.localExplanations; const k = this.props.topK; @@ -47,30 +45,22 @@ export class TextHighlighting extends React.PureComponent { styles={textStackStyles} > {text.map((word, wordIndex) => { + const isWordSelected = + (this.props.selectedTokenIndex && + wordIndex === this.props.selectedTokenIndex) || + false; + const classNames = textHighlightingStyles(isWordSelected); let styleType = classNames.normal; const score = importances[wordIndex]; - let isBold = false; if (sortedList.includes(wordIndex)) { if (score > 0) { styleType = classNames.highlighted; } else if (score < 0) { styleType = classNames.boldunderline; - isBold = true; } else { styleType = classNames.normal; } } - if (isBold) { - return ( - - ); - } return ( { key={wordIndex} className={styleType} title={score.toString()} + onClick={(): void => this.handleClick(wordIndex)} > {word} @@ -88,4 +79,13 @@ export class TextHighlighting extends React.PureComponent { ); } + + private readonly handleClick = (wordIndex: number): void => { + if (this.props.isInput) { + return; + } + if (this.props.onSelectedTokenChange) { + this.props.onSelectedTokenChange(wordIndex); + } + }; } diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts index fa2770820f..f16f1f6d9e 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts @@ -9,4 +9,9 @@ export interface IChartProps { localExplanations: number[]; topK?: number; radio?: string; + isInput?: boolean; + baseValue?: number; + outputFeatureValue?: number; + selectedTokenIndex?: number; + onSelectedTokenChange?: (newIndex: number) => void; } diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts index d82261257e..903a652cdf 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts @@ -19,5 +19,8 @@ export interface IDatasetSummary { text: string[]; classNames?: string[]; localExplanations: number[][]; + baseValues?: number[][]; prediction?: number[]; + predictedY?: number[] | number[][] | string[] | string | number; + trueY?: number[] | number[][] | string[] | string | number; } diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index 5e6824eedd..430bb44530 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -1374,12 +1374,19 @@ "label": "Label", "colon": ": ", "startingPosition": "STARTING POSITION", - "endingPosition": "ENDING POSITION" + "endingPosition": "ENDING POSITION", + "predictedAnswer": "Predicted answer: ", + "trueAnswer": "True answer: ", + "inputs": "Inputs", + "outputs": "Outputs" }, "Legend": { "featureLegend": "TEXT FEATURE LEGEND", "posFeatureImportance": "POSITIVE FEATURE IMPORTANCE", - "negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE" + "negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE", + "cls": "CLS: start of the sentence", + "sep": "SEP: end of the sentence", + "selectedWord": "Selected word: " }, "BarChart": { "featureImportance": "FEATURE IMPORTANCE" diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx index 7abd051ec1..903b20caf1 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx @@ -75,7 +75,6 @@ export class FeatureImportancesTab extends React.PureComponent< return React.Fragment; } const classNames = featureImportanceTabStyles(); - return ( { @@ -44,10 +47,13 @@ export class TextLocalImportancePlots extends React.Component