Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix special tokens #198

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions bsmetadata/hydra_configs/v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,22 @@ data_config:
#- generation_length_sentence
- generation_length_text
- entity_paragraph
local_metadata_special_tokens:
entity_paragraph: "entity"
metadata_sep: ' | '
metadata_key_value_sep: ': '
local_metadata_special_tokens:
entity_paragraph: "<special_entity>"
html: "<special_html>"
prefix_sep_tokens:
title: "<special_title>"
website_description: "<special_website_description>"
datasource: "<special_datasource>"
text_length: "<special_text_length>"
url: "<special_url>"
metadata_prefix_sep: '<prefix_sep>'
metadata_sep: ''
metadata_key_value_sep: ''
metadata_prefix_start_seq: ''
metadata_probability: 0.5
treat_local_metadata_as_regular_text: true
add_local_metadata_special_tokens_in_prefix: true
metadata_prefix_sep: ' |||'
metadata_prefix_start_seq: ''
max_seq_len: 1024
html_parser_config:
all_tags_rules:
Expand Down Expand Up @@ -76,7 +83,7 @@ data_config:
entity_paragraph: "<ENTITY_CHAIN>"
html: "<HTML>"
local_metadata_special_token_end:
entity_paragraph: " </ENTITY_CHAIN> "
entity_paragraph: "</ENTITY_CHAIN>"
html: "</HTML>"
local_metadata_special_token_state: true
html_overall_sample_rate: 0.25
Expand Down
24 changes: 18 additions & 6 deletions bsmetadata/metadata_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class MetadataConfig:
},
)
metadata_prefix_sep: str = field(
default=" |||",
default="<prefix_sep>",
metadata={
"help": "The character sequence that is used to separate all global metadata and/or local metadata "
"special tokens (if `add_local_metadata_special_tokens_in_prefix` is `True`) from the actual text."
Expand Down Expand Up @@ -351,7 +351,9 @@ class UrlProcessor(MetadataProcessor):
def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]:
# We represent a URL with unquoted format such that less confusion for a tokenizer.
# Example: "foo.bar/Year 2021/" instead of "foo.bar/Year%202021/".
return "".join([metadata_attrs["key"], self.cfg.metadata_key_value_sep, unquote_plus(metadata_attrs["value"])])
return "".join(
[self.cfg.prefix_sep_tokens["url"], self.cfg.metadata_key_value_sep, unquote_plus(metadata_attrs["value"])]
)


class TitleProcessor(MetadataProcessor):
Expand All @@ -360,15 +362,21 @@ class TitleProcessor(MetadataProcessor):
def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]:
# We represent a title by the title of the corresponding webpage content.
# Example: "My Thoughts On It » Dad, I want to be an inventor".
return "".join(["Title", self.cfg.metadata_key_value_sep, metadata_attrs["value"]])
return "".join([self.cfg.prefix_sep_tokens["title"], self.cfg.metadata_key_value_sep, metadata_attrs["value"]])


class WebsiteDescriptionProcessor(MetadataProcessor):
"""An example metadata processor for website descriptions."""

def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]:
# Example: "website_description: BBC is a news organization".
return "".join(["Website Description", self.cfg.metadata_key_value_sep, metadata_attrs["value"]])
return "".join(
[
self.cfg.prefix_sep_tokens["website_description"],
self.cfg.metadata_key_value_sep,
metadata_attrs["value"],
]
)


class DatasourceProcessor(MetadataProcessor):
Expand All @@ -378,7 +386,9 @@ def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]:
# We represent the DATASOURCE by using meaningful information of the URL.
# URL: http://www.example.de/2015/forum/article/21-new-project
# Example: example.de > forum > article > new project
return "".join(["Datasource", self.cfg.metadata_key_value_sep, metadata_attrs["value"]])
return "".join(
[self.cfg.prefix_sep_tokens["datasource"], self.cfg.metadata_key_value_sep, metadata_attrs["value"]]
)


class GenerationLengthProcessor(MetadataProcessor):
Expand All @@ -388,7 +398,9 @@ def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]:
# We represent the length of a text by the number of characters.
# Example: Length: 123

return "".join(["Text Length", self.cfg.metadata_key_value_sep, metadata_attrs["value"]])
return "".join(
[self.cfg.prefix_sep_tokens["text_length"], self.cfg.metadata_key_value_sep, metadata_attrs["value"]]
)


class BasicStartLocalProcessor(MetadataProcessor):
Expand Down
3 changes: 3 additions & 0 deletions bsmetadata/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def main(args: CFG) -> None:
)
)
)
new_tokens.append(args.data_config.metadata_config.metadata_prefix_sep)
new_tokens.extend(args.data_config.metadata_config.prefix_sep_tokens.values())
new_tokens.extend(args.data_config.metadata_config.local_metadata_special_tokens.values())
new_tokens = [
AddedToken(token, rstrip=False, lstrip=False, single_word=False, normalized=False) for token in new_tokens
]
Expand Down