Skip to content

Commit

Permalink
Utils and helpers for sequence salience, most notably token grouping …
Browse files Browse the repository at this point in the history
…code.

PiperOrigin-RevId: 606346156
  • Loading branch information
iftenney authored and LIT team committed Feb 12, 2024
1 parent 27e6901 commit ab294bd
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 4 deletions.
8 changes: 8 additions & 0 deletions lit_nlp/client/elements/tooltip.css
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
* with tooltip positioning.
*/
--anchor-display-mode: inline-block;
--tooltip-position-left: unset;
--tooltip-position-right: unset;
--tooltip-position-top: unset;
--tooltip-position-bottom: unset;
}

/* Tooltip */
Expand All @@ -34,6 +38,10 @@
font-size: 12px;
font-weight: normal;
line-height: 16px;
left: var(--tooltip-position-left);
right: var(--tooltip-position-right);
top: var(--tooltip-position-top);
bottom: var(--tooltip-position-bottom);

display: -webkit-box;
-webkit-line-clamp: 6;
Expand Down
1 change: 1 addition & 0 deletions lit_nlp/client/elements/tooltip.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export class LitTooltip extends ReactiveElement {
'disabled': this.disabled,
});

// prettier-ignore
return html`<div class='lit-tooltip'>
<slot name="tooltip-anchor">
${this.content === '' ? '' : html`
Expand Down
53 changes: 53 additions & 0 deletions lit_nlp/client/lib/token_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* @fileoverview Utils for working with tokenized text.
*/

/**
* Evil underscore used by sentencepiece to replace spaces.
*/
export const SPM_SPACE_SENTINEL = '▁';

/**
* Clean SPM text to make it more human-readable.
*/
export function cleanSpmText(text: string): string {
return text.replaceAll(SPM_SPACE_SENTINEL, ' ');
}

/**
* Use a regex to match segment prefixes. The prefix and anything
* following it (until the next match) are treated as one segment.
*/
export function groupTokensByRegexPrefix(
tokens: string[],
matcher: RegExp,
): string[][] {
const text = tokens.join('');
const matches = [...text.matchAll(matcher)];

let textCharOffset = 0; // chars into text
let matchIdx = 0; // indices into matches
const groups: string[][] = [];
let acc: string[] = [];
for (let i = 0; i < tokens.length; i++) {
const token = tokens[i];
const nextMatch = matches[matchIdx];

// Look ahead to see if this token intrudes on a match.
// If so, start a new segment before pushing the token.
if (nextMatch !== undefined &&
textCharOffset + token.length > nextMatch.index!) {
// Don't push an empty group if the first token is part of a match.
if (acc.length > 0 || groups.length > 0) groups.push(acc);
acc = [];
matchIdx += 1;
}

// Push the token.
acc.push(token);
textCharOffset += token.length;
}
// Finally, push any open group.
if (acc.length > 0) groups.push(acc);
return groups;
}
88 changes: 88 additions & 0 deletions lit_nlp/client/lib/token_utils_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/**
* Testing for token_utils.ts
*/

import 'jasmine';

import * as tokenUtils from './token_utils';

describe('cleanSpmText test', () => {
it('cleans magic underscores from SPM output', () => {
const text = 'Summarize▁this▁sentence:\n\nOnce▁upon▁a▁time';
expect(tokenUtils.cleanSpmText(text))
.toEqual('Summarize this sentence:\n\nOnce upon a time');
});
});

describe('groupTokensByRegexPrefix test', () => {
[{
testcaseName: 'groups tokens by word',
tokens: ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'],
regex: /[\s]+/g,
expectedGroups: [['Sum', 'mar', 'ize'], ['▁this'], ['▁sent', 'ence', ':']],
},
{
testcaseName: 'groups tokens by word, handling newlines',
tokens: [
'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once',
'▁upon', '▁a', '▁time'
],
// Consecutive newlines should be their own segment.
// Start a new word on the first non-\n afterwards.
regex: /([\s]+)|(?<=\n)[^\n]/g,
expectedGroups: [
['Sum', 'mar', 'ize'], ['▁this'], ['▁sent', 'ence', ':'], ['\n', '\n'],
['Once'], ['▁upon'], ['▁a'], ['▁time']
],
},
{
testcaseName: 'groups tokens by sentence, simple version',
tokens: [
'Sent', 'ence', '▁one', '.', '▁Sent', 'ence', '▁two', '!', '▁Sent',
'ence', '▁three', '?'
],
regex: /(?<=[.?!])[\s]+/g,
expectedGroups: [
['Sent', 'ence', '▁one', '.'],
['▁Sent', 'ence', '▁two', '!'],
['▁Sent', 'ence', '▁three', '?'],
],
},
{
testcaseName: 'groups tokens by sentence, handling newlines',
tokens: [
'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once',
'▁upon', '▁a', '▁time'
],
// Sentence start is one of:
// - a run of consecutive \n as its own segment
// - any non-\n following \n
// - whitespace or magic underscore following punctuation [.?!]
regex: /(\n+)|((?<=\n)[^\n])|((?<=[.?!])([\s]+))/g,
expectedGroups: [
['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'],
['Once', '▁upon', '▁a', '▁time']
],
},
{
testcaseName: 'groups tokens by line',
tokens: [
'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once',
'▁upon', '▁a', '▁time'
],
// Line start is either:
// - a run of consecutive \n as its own segment
// - any non-\n following \n
regex: /(\n+)|([^\n]+)/g,
expectedGroups: [
['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'],
['Once', '▁upon', '▁a', '▁time']
],
},
].forEach(({testcaseName, tokens, regex, expectedGroups}) => {
it(testcaseName, () => {
const groups = tokenUtils.groupTokensByRegexPrefix(tokens, regex);
expect(groups).toEqual(expectedGroups);
});
});
});
21 changes: 21 additions & 0 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,27 @@ export function cumSumArray(array: number[]) {
return newArray;
}

/**
* Group elements of one list to match the partitions of another.
*
* Example:
* groupAlike([0, 1, 2, 3, 4, 5], [['a', 'b'], ['c'], ['d', 'e', 'f']])
*
* Should return: [[0, 1], [2], [3, 4, 5]]
*/
export function groupAlike<T>(items: T[], groups: unknown[][]): T[][] {
const offsets = [0, ...cumSumArray(groups.map(g => g.length))];
if (offsets.at(-1) !== items.length) {
throw new Error(`Total length of groups (${
offsets.at(-1)}) !== number of items (${items.length}).`);
}
const ret = [];
for (let i = 0; i < groups.length; i++) {
ret.push(items.slice(offsets[i], offsets[i + 1]));
}
return ret;
}

/**
* Python-style array comparison.
* Compare on first element, then second, and so on until a mismatch is found.
Expand Down
13 changes: 13 additions & 0 deletions lit_nlp/client/lib/utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,19 @@ describe('cumSumArray test', () => {
});
});

describe('groupAlike test', () => {
it('groups items', () => {
const result = utils.groupAlike(
[0, 1, 2, 3, 4, 5], [['a', 'b'], ['c'], ['d', 'e', 'f']]);
expect(result).toEqual([[0, 1], [2], [3, 4, 5]]);
});
it('raises an error if lengths do not match', () => {
expect(() => utils.groupAlike([0, 1, 2, 3, 4, 5], [['a', 'b'], ['c']]))
.toThrow(
new Error('Total length of groups (3) !== number of items (6).'));
});
});

describe('compareArrays test', () => {
it('Correctly tests normal comparison', () => {
// Shorter arrays.
Expand Down
15 changes: 13 additions & 2 deletions lit_nlp/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,20 @@ def unbatch_preds(
yield {key: value[i] for key, value in preds.items()}


def pad1d(arr: list[T], min_len: int, pad_val: T) -> list[T]:
def pad1d(
arr: list[T],
min_len: int,
pad_val: T,
pad_left: bool = False,
max_len: int | None = None,
) -> list[T]:
"""Pad a list to the target length."""
return arr + [pad_val] * max(0, min_len - len(arr))
if pad_left:
padded = [pad_val] * max(0, min_len - len(arr)) + arr
return padded[-max_len:] if max_len is not None else padded
else:
padded = arr + [pad_val] * max(0, min_len - len(arr))
return padded[:max_len] if max_len is not None else padded


def find_all_combinations(
Expand Down
48 changes: 46 additions & 2 deletions lit_nlp/lib/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,55 @@ def test_batch_inputs_raises(
pad_val="",
expected=["one", "two", "three", "", ""],
),
dict(
testcase_name="truncate_max_len",
inputs=[1, 2, 3, 4, 5],
min_len=3,
pad_val=0,
max_len=3,
expected=[1, 2, 3],
),
dict(
testcase_name="pad_left",
inputs=[1, 2, 3],
min_len=5,
pad_val=0,
pad_left=True,
expected=[0, 0, 1, 2, 3],
),
dict(
testcase_name="truncate_max_len_left",
inputs=[1, 2, 3, 4, 5],
min_len=3,
pad_val=0,
pad_left=True,
max_len=3,
expected=[3, 4, 5],
),
dict(
testcase_name="pad_left_with_strings",
inputs=["one", "two", "three"],
min_len=5,
pad_val="",
pad_left=True,
expected=["", "", "one", "two", "three"],
),
)
def test_pad1d(
self, inputs: list[T], min_len: T, pad_val: T, expected: list[T]
self,
inputs: list[T],
min_len: T,
pad_val: T,
expected: list[T],
pad_left: bool = False,
max_len: int | None = None,
):
self.assertEqual(utils.pad1d(inputs, min_len, pad_val), expected)
self.assertEqual(
utils.pad1d(
inputs, min_len, pad_val, pad_left=pad_left, max_len=max_len
),
expected,
)

@parameterized.named_parameters(
dict(
Expand Down

0 comments on commit ab294bd

Please sign in to comment.