Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-3488: Support for writing a ColumnCorpus instance to files #3497

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
178 changes: 165 additions & 13 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,9 @@
import shutil
from collections import defaultdict
from pathlib import Path
from typing import (
Any,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)
from typing import Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union, cast

from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data import ConcatDataset, Dataset, Subset

import flair
from flair.data import (
Expand All @@ -28,7 +17,9 @@
MultiCorpus,
Relation,
Sentence,
Span,
Token,
_iter_dataset,
get_spans_from_bio,
)
from flair.datasets.base import find_train_dev_test_files
Expand Down Expand Up @@ -443,6 +434,167 @@ def __init__(
**corpusargs,
)

@staticmethod
def get_token_level_label_of_each_token(sentence: Sentence, label_type: str) -> List[str]:
"""Generates a label for each token in the sentence. This function requires that the labels corresponding to the label_type are token-level tokens.

Args:
sentence: a flair sentence to generate labels for
label_type: a string representing the type of the labels, e.g., "pos"
"""
list_of_labels = ["O" for _ in range(len(sentence.tokens))]
for label in sentence.get_labels(label_type):
label_token_index = label.data_point._internal_index
chelseagzr marked this conversation as resolved.
Show resolved Hide resolved
list_of_labels[label_token_index - 1] = label.value
return list_of_labels

@staticmethod
def get_span_level_label_of_each_token(sentence: Sentence, label_type: str) -> List[str]:
chelseagzr marked this conversation as resolved.
Show resolved Hide resolved
"""Generates a label for each token in the sentence in BIO format. This function requires that the labels corresponding to the label_type are span-level tokens.

Args:
sentence: a flair sentence to generate labels for
label_type: a string representing the type of the labels, e.g., "ner"
"""
list_of_labels = ["O" for _ in range(len(sentence.tokens))]
for label in sentence.get_labels(label_type):
tokens = label.data_point.tokens
start_token_index = tokens[0]._internal_index
list_of_labels[start_token_index - 1] = f"B-{label.value}"
for token in tokens[1:]:
token_index = token._internal_index
list_of_labels[token_index - 1] = f"I-{label.value}"
return list_of_labels

@staticmethod
def write_dataset_to_file(
chelseagzr marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dataset, file_path: Path, label_type_tuples: List[tuple], column_delimiter: str = "\t"
) -> None:
"""Writes a dataset to a file.

Following these two rules:
(1) the text and the label(s) of every token is represented in one line separated by column_delimiter
(2) every sentence is separated from the previous one by an empty line
"""
with open(file_path, mode="w") as output_file:
for sentence in _iter_dataset(dataset):
texts = [token.text for token in sentence.tokens]
texts_and_labels = [texts]
for label_type, level in label_type_tuples:
if level is Token:
texts_and_labels.append(ColumnCorpus.get_token_level_label_of_each_token(sentence, label_type))
elif level is Span:
texts_and_labels.append(ColumnCorpus.get_span_level_label_of_each_token(sentence, label_type))
else:
raise NotImplementedError(f"The level of {label_type} is neither token nor span.")

for text_and_labels_of_a_token in zip(*texts_and_labels):
output_file.write(column_delimiter.join(text_and_labels_of_a_token) + "\n")
output_file.write("\n")

@classmethod
def load_corpus_with_meta_data(cls, directory: Path) -> "ColumnCorpus":
"""Creates a ColumnCorpus instance from the directory generated by 'write_to_directory'."""
with open(directory / "meta_data.json") as file:
meta_data = json.load(file)

meta_data["column_format"] = {int(key): value for key, value in meta_data["column_format"].items()}

return cls(
data_folder=directory,
autofind_splits=True,
skip_first_line=False,
**meta_data,
)

def get_level_of_label(self, label_type: str) -> Union[Type[Token], Type[Span]]:
chelseagzr marked this conversation as resolved.
Show resolved Hide resolved
"""Gets level of label type by checking the first label in this corpus.

Raises:
RuntimeError: if there is no label of label_type
NotImplementedError: if level of label_type is other than Token or Span
"""
for dataset in [self.train, self.dev, self.test]:
if dataset:
for sentence in _iter_dataset(dataset):
for label in sentence.get_labels(label_type):
if isinstance(label.data_point, Token):
return Token
elif isinstance(label.data_point, Span):
return Span
else:
raise NotImplementedError(
f"The level of {label_type} is neither token nor span. Only token level labels and span level labels can be handled now."
)
raise RuntimeError(f"There is no label of type {label_type} in this corpus.")

def write_corpus_meta_data(
chelseagzr marked this conversation as resolved.
Show resolved Hide resolved
self, label_types: List[str], file_path: Path, column_delimiter: str, max_depth=5
) -> None:
"""Writes meta data of this corpus to a json file.

Note:
Currently, the whitespace_after attribute of each token will not be preserved. Only default_whitespace_after attribute of each dataset will be written to the file.
"""
meta_data = {
"name": self.name,
"sample_missing_splits": False,
"column_delimiter": column_delimiter,
}

column_format = {0: "text"}
for label_type_index, label_type in enumerate(label_types):
column_format[label_type_index + 1] = label_type
meta_data["column_format"] = column_format

nonempty_dataset = self.train or self.dev or self.test
for _ in range(max_depth):
if type(nonempty_dataset) is ColumnDataset:
break
elif type(nonempty_dataset) is ConcatDataset:
nonempty_dataset = nonempty_dataset.datasets[0]
elif type(nonempty_dataset) is Subset:
nonempty_dataset = nonempty_dataset.dataset
else:
raise NotImplementedError("Unsupported type")
chelseagzr marked this conversation as resolved.
Show resolved Hide resolved

if type(nonempty_dataset) is not ColumnDataset:
raise NotImplementedError("Unsupported type")

meta_data["encoding"] = nonempty_dataset.encoding
meta_data["in_memory"] = nonempty_dataset.in_memory
meta_data["banned_sentences"] = nonempty_dataset.banned_sentences
meta_data["default_whitespace_after"] = nonempty_dataset.default_whitespace_after

with open(file_path, mode="w") as output_file:
json.dump(meta_data, output_file)

def write_to_directory(self, label_types: List[str], output_directory: Path, column_delimiter: str = "\t") -> None:
"""Writes train, dev, test dataset (if exist) and the meta data of the corpus to a directory.

Note:
Only labels corresponding to label_types will be written.
Only token level or span level sequence tagging labels are supported.
Currently, the whitespace_after attribute of each token will not be preserved in the written file.
"""
label_type_tuples = [(label_type, self.get_level_of_label(label_type)) for label_type in label_types]

os.makedirs(output_directory, exist_ok=True)
if self.train:
ColumnCorpus.write_dataset_to_file(
self.train, output_directory / "train.conll", label_type_tuples, column_delimiter
)
if self.dev:
ColumnCorpus.write_dataset_to_file(
self.dev, output_directory / "dev.conll", label_type_tuples, column_delimiter
)
if self.test:
ColumnCorpus.write_dataset_to_file(
self.test, output_directory / "test.conll", label_type_tuples, column_delimiter
)

self.write_corpus_meta_data(label_types, output_directory / "meta_data.json", column_delimiter)


class ColumnDataset(FlairDataset):
# special key for space after
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,24 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path):
_assert_universal_dependencies_conllu_dataset(corpus.train)


def test_write_to_and_load_from_directory(tasks_base_path):
from pathlib import Path

corpus = ColumnCorpus(
tasks_base_path / "column_with_whitespaces",
train_file="eng.train",
column_format={0: "text", 1: "ner"},
column_delimiter=" ",
skip_first_line=False,
sample_missing_splits=False,
)
directory = Path("resources/taggers/")
corpus.write_to_directory(["ner"], directory, column_delimiter="\t")
loaded_corpus = ColumnCorpus.load_corpus_with_meta_data(directory)
assert len(loaded_corpus.train) == len(corpus.train)
assert loaded_corpus.train[0].to_tagged_string() == corpus.train[0].to_tagged_string()


def test_hipe_2022_corpus(tasks_base_path):
# This test covers the complete HIPE 2022 dataset.
# https://github.com/hipe-eval/HIPE-2022-data
Expand Down
Loading