Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the order of local tags #26

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion bsmetadata/metadata_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ class HtmlProcessor(MetadataProcessor):
def process_local(self, metadata_attrs: Dict[str, Any]) -> Optional[Tuple[str, str]]:
# We represent a html tag `T` by enclosing the corresponding text span with "<T>" and "</T>".
# Example: An <b>apple</b> is an edible fruit.
return f"<{metadata_attrs['value']}>", f"</{metadata_attrs['value']}>"
attributes = " ".join(
f'{attr}:"{value}"'
for attr, value in zip(metadata_attrs["value"]["attrs"]["attr"], metadata_attrs["value"]["attrs"]["value"])
)
if attributes:
attributes = " " + attributes
return f"<{metadata_attrs['value']['tag']}{attributes}>", f"</{metadata_attrs['value']['tag']}>"


class UrlProcessor(MetadataProcessor):
Expand Down
68 changes: 51 additions & 17 deletions bsmetadata/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple

from transformers import PreTrainedTokenizerFast
Expand All @@ -27,12 +28,10 @@ def add_metadata_and_chunk_examples(
examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: DataConfig
) -> Dict[str, List]:
"""Adds metadata to the provided input examples, encodes them and groups them in chunks of size `cfg.max_seq_len`.

Args:
examples: The examples to process, with required keys "text" and "metadata".
tokenizer: The pretrained tokenizer to use.
cfg: The config to use for adding metadata and chunking.

Returns:
A new (potentially larger) collection of examples with keys "input_ids", "attention_mask" and "metadata_mask", where:
- the input ids are a list of token ids corresponding to the input text with metadata;
Expand Down Expand Up @@ -100,11 +99,9 @@ def is_metadata(idx: int) -> bool:

def create_global_metadata_prefix(example: Dict[str, Any], cfg: DataConfig) -> str:
"""Creates a prefix containing all global metadata information (including URLs, timestamps, etc).

Args:
example: The example to create a global metadata prefix for.
cfg: The data config to use.

Returns:
A string containing the global metadata prefix.
"""
Expand All @@ -122,19 +119,25 @@ def create_global_metadata_prefix(example: Dict[str, Any], cfg: DataConfig) -> s
return cfg.metadata_sep.join(sorted_metadata) + cfg.global_metadata_sep if sorted_metadata else ""


@dataclass
class MetadataIdxStorage:
start_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list)))
end_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list)))
start_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list)))
end_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list)))


def add_local_metadata_to_text(example: Dict[str, Any], cfg: DataConfig) -> Tuple[str, List[bool]]:
"""Adds local metadata (such as HTML tags and entity names) to the given input text.

Args:
example: The example for which local metadata should be added.
cfg: The data config to use.

Returns:
A tuple of two elements, where:
- the first element is the text with metadata;
- the second element is a boolean mask where `mask[i]` is set iff `text[i]` is some kind of metadata.
"""
metadata_start_texts, metadata_end_texts = defaultdict(list), defaultdict(list)
metadata_idx_storage = MetadataIdxStorage()

# Filter and sort all metadata so that they are processed in the requested order.
filtered_metadata = [md for md in example["metadata"] if md["type"] == "local" and md["key"] in cfg.metadata_list]
Expand All @@ -152,27 +155,58 @@ def add_local_metadata_to_text(example: Dict[str, Any], cfg: DataConfig) -> Tupl
char_start_idx = metadata.get("char_start_idx", -1)
char_end_idx = metadata.get("char_end_idx", -1)

metadata_start_texts[char_start_idx].insert(0, start_text)
metadata_end_texts[char_end_idx].append(end_text)
if char_start_idx == char_end_idx:
metadata_idx_storage.start_idx_tag_without_content[char_start_idx].insert(0, start_text)
metadata_idx_storage.end_idx_tag_without_content[char_end_idx].append(end_text)
else:
metadata_idx_storage.start_idx_tag_with_content[char_start_idx].insert(0, start_text)
metadata_idx_storage.end_idx_tag_with_content[char_end_idx].append(end_text)

# Build the final text with local metadata and the corresponding mask.
text_with_local_metadata = []
metadata_mask = []

def _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask):
for metadata_text in metadata_text_list:
text_with_local_metadata.append(metadata_text)
metadata_mask += [True] * len(metadata_text)

for idx, char in enumerate(example["text"]):
if idx in metadata_idx_storage.end_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.end_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_start_texts:
for start_text in metadata_start_texts[idx]:
text_with_local_metadata.append(start_text)
metadata_mask += [True] * len(start_text)
if idx in metadata_idx_storage.start_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

text_with_local_metadata.append(char)
metadata_mask += [False]

if idx + 1 in metadata_end_texts:
for end_text in metadata_end_texts[idx + 1]:
text_with_local_metadata.append(end_text)
metadata_mask += [True] * len(end_text)
idx += 1
if idx in metadata_idx_storage.end_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.end_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

return "".join(text_with_local_metadata), metadata_mask

Expand Down
84 changes: 83 additions & 1 deletion tests/test_metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import GPT2TokenizerFast

from bsmetadata.input_pipeline import DataConfig
from bsmetadata.metadata_processors import PROCESSORS, MetadataProcessor
from bsmetadata.metadata_processors import PROCESSORS, HtmlProcessor, MetadataProcessor
from bsmetadata.metadata_utils import (
add_local_metadata_to_text,
add_metadata_and_chunk_examples,
Expand Down Expand Up @@ -57,6 +57,76 @@ def setUp(self) -> None:
{"key": "url", "type": "global", "value": "callto:RickAndMorty/Year%202021/"},
],
},
{
"id": "0004",
"text": "useless text The Walking Dead (season 8)\n",
"metadata": [
{
"char_start_idx": 13,
"value": {
"tag": "h1",
"attrs": {"attr": [], "value": []},
},
"char_end_idx": 40,
"key": "html",
"type": "local",
},
{
"char_start_idx": 13,
"value": {
"tag": "div",
"attrs": {"attr": [], "value": []},
},
"char_end_idx": 13,
"key": "html",
"type": "local",
},
{
"char_start_idx": 0,
"value": {"tag": "a", "attrs": {"attr": [], "value": []}},
"char_end_idx": 13,
"key": "html",
"type": "local",
},
{
"char_start_idx": 13,
"value": {
"tag": "div",
"attrs": {"attr": [], "value": []},
},
"char_end_idx": 13,
"key": "html",
"type": "local",
},
{
"char_start_idx": 13,
"value": {
"tag": "a",
"attrs": {"attr": [], "value": []},
},
"char_end_idx": 13,
"key": "html",
"type": "local",
},
{
"char_start_idx": 13,
"value": {
"tag": "div",
"attrs": {"attr": [], "value": []},
},
"char_end_idx": 13,
"key": "html",
"type": "local",
},
{
"char_start_idx": 13,
"value": {"tag": "i", "attrs": {"attr": [], "value": []}},
"char_end_idx": 29,
"key": "html",
"type": "local",
},
],
},
]

def test_chunks(self):
Expand Down Expand Up @@ -133,6 +203,18 @@ def test_add_no_metadata_and_chunk_examples(self):
for example in mapped_ds:
self.assertTrue(all(not x for x in example["metadata_mask"]))

def test_add_html_tags(self):
cfg = DataConfig()
cfg.metadata_list = ["html"]
PROCESSORS["html"] = HtmlProcessor

text1, mask1 = add_local_metadata_to_text(self.examples[3], cfg)
target_text = (
"<a>useless text </a><div><a><div><div></div></div></a></div><h1><i>The Walking Dead</i> (season 8)</h1>\n"
)

self.assertEqual(text1, target_text)

def test_add_metadata_and_chunk_examples(self):
cfg = DataConfig()
cfg.metadata_list = ["url", "timestamp", "html", "entity"]
Expand Down