From d1ce7660022204d312c2432ec557f689175c5dc9 Mon Sep 17 00:00:00 2001 From: Chelsea Gu Date: Wed, 10 Jul 2024 14:51:49 -0700 Subject: [PATCH] GH-3488: Add a function to write a ColumnCorpus instance to files --- flair/datasets/sequence_labeling.py | 158 +++++++++++++++++++++++++++- tests/test_datasets.py | 18 ++++ 2 files changed, 175 insertions(+), 1 deletion(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 9652070521..d3f0b4371f 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -19,7 +19,7 @@ cast, ) -from torch.utils.data import ConcatDataset, Dataset +from torch.utils.data import ConcatDataset, Dataset, Subset import flair from flair.data import ( @@ -28,6 +28,7 @@ MultiCorpus, Relation, Sentence, + Span, Token, get_spans_from_bio, ) @@ -443,6 +444,161 @@ def __init__( **corpusargs, ) + @staticmethod + def get_token_level_label_of_each_token(sentence: Sentence, label_type: str) -> List[str]: + """Generates a label for each token in the sentence. This function requires that the labels corresponding to the label_type are token-level tokens. + + Args: + sentence: a flair sentence to generate labels for + label_type: a string representing the type of the labels, e.g., "pos" + """ + list_of_labels = ["O" for _ in range(len(sentence.tokens))] + for label in sentence.get_labels(label_type): + label_token_index = label.data_point._internal_index + list_of_labels[label_token_index - 1] = label.value + return list_of_labels + + @staticmethod + def get_span_level_label_of_each_token(sentence: Sentence, label_type: str) -> List[str]: + """Generates a label for each token in the sentence in BIO format. This function requires that the labels corresponding to the label_type are span-level tokens. + + Args: + sentence: a flair sentence to generate labels for + label_type: a string representing the type of the labels, e.g., "ner" + """ + list_of_labels = ["O" for _ in range(len(sentence.tokens))] + for label in sentence.get_labels(label_type): + tokens = label.data_point.tokens + start_token_index = tokens[0]._internal_index + list_of_labels[start_token_index - 1] = f"B-{label.value}" + for token in tokens[1:]: + token_index = token._internal_index + list_of_labels[token_index - 1] = f"I-{label.value}" + return list_of_labels + + @staticmethod + def write_dataset_to_file( + dataset: Dataset, file_path: Path, label_type_tuples: List[tuple], column_delimiter: str = "\t" + ) -> None: + """Writes a dataset to a file. + + Following these two rules: + (1) the text and the label of every token is represented in one line separated by column_delimiter + (2) every sentence is separated from the previous one by an empty line + """ + with open(file_path, mode="w") as output_file: + for sentence in dataset: + texts = [token.text for token in sentence.tokens] + texts_and_labels = [texts] + for label_type, level in label_type_tuples: + if level is Token: + texts_and_labels.append(ColumnCorpus.get_token_level_label_of_each_token(sentence, label_type)) + elif level is Span: + texts_and_labels.append(ColumnCorpus.get_span_level_label_of_each_token(sentence, label_type)) + else: + raise NotImplementedError(f"The level of {label_type} is neither token nor span.") + + for text_and_labels_of_a_token in zip(*texts_and_labels): + output_file.write(column_delimiter.join(text_and_labels_of_a_token) + "\n") + output_file.write("\n") + + @classmethod + def load_corpus_with_meta_data(cls, directory: Path) -> "ColumnCorpus": + """Creates a ColumnCorpus instance from the directory generated by 'write_corpus_meta_data'.""" + with open(directory / "meta_data.json") as file: + meta_data = json.load(file) + + meta_data["column_format"] = {int(key): value for key, value in meta_data["column_format"].items()} + + return cls( + data_folder=directory, + autofind_splits=True, + skip_first_line=False, + **meta_data, + ) + + def get_level_of_label(self, label_type: str): + """Gets level of label type by checking the first label in this corpus.""" + for dataset in [self.train, self.dev, self.test]: + if dataset: + for sentence in dataset: + for label in sentence.get_labels(label_type): + if isinstance(label.data_point, Token): + return Token + elif isinstance(label.data_point, Span): + return Span + else: + raise NotImplementedError( + f"The level of {label_type} is neither token nor span. Only token level labels and span level labels can be handled now." + ) + raise RuntimeError(f"There is no label of type {label_type} in this corpus.") + + def write_corpus_meta_data(self, label_types: List[str], file_path: Path, column_delimiter: str) -> None: + """Writes meta data of this corpus to a json file. + + Note: + Currently, the whitespace_after attribute of each token will not be preserved. Only default_whitespace_after attribute of each dataset will be written to the file. + """ + meta_data = { + "name": self.name, + "sample_missing_splits": False, + "column_delimiter": column_delimiter, + } + + column_format = {0: "text"} + for label_type_index, label_type in enumerate(label_types): + column_format[label_type_index + 1] = label_type + meta_data["column_format"] = column_format + + nonempty_dataset = self.train or self.dev or self.test + MAX_DEPTH = 5 + for _ in range(MAX_DEPTH): + if type(nonempty_dataset) is ColumnDataset: + break + elif type(nonempty_dataset) is ConcatDataset: + nonempty_dataset = nonempty_dataset.datasets[0] + elif type(nonempty_dataset) is Subset: + nonempty_dataset = nonempty_dataset.dataset + else: + raise NotImplementedError("Unsupported type") + + if type(nonempty_dataset) is not ColumnDataset: + raise NotImplementedError("Unsupported type") + + meta_data["encoding"] = nonempty_dataset.encoding + meta_data["in_memory"] = nonempty_dataset.in_memory + meta_data["banned_sentences"] = nonempty_dataset.banned_sentences + meta_data["default_whitespace_after"] = nonempty_dataset.default_whitespace_after + + with open(file_path, mode="w") as output_file: + json.dump(meta_data, output_file) + + def write_to_directory(self, label_types: List[str], output_directory: Path, column_delimiter: str = "\t") -> None: + """Writes train, dev, test dataset (if exist) and the meta data of the corpus to a directory. + + Note: + Only labels corresponding to label_types will be written. + Only token level or span level sequence tagging labels are supported. + Currently, the whitespace_after attribute of each token will not be preserved in the written file. + """ + label_type_tuples = [(label_type, self.get_level_of_label(label_type)) for label_type in label_types] + + os.makedirs(output_directory, exist_ok=True) + if self.train: + ColumnCorpus.write_dataset_to_file( + self.train, output_directory / "train.conll", label_type_tuples, column_delimiter + ) + if self.dev: + ColumnCorpus.write_dataset_to_file( + self.dev, output_directory / "dev.conll", label_type_tuples, column_delimiter + ) + if self.test: + ColumnCorpus.write_dataset_to_file( + self.test, output_directory / "test.conll", label_type_tuples, column_delimiter + ) + + self.write_corpus_meta_data(label_types, output_directory / "meta_data.json", column_delimiter) + class ColumnDataset(FlairDataset): # special key for space after diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 52fec1c5ea..cdbeb86784 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -418,6 +418,24 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path): _assert_universal_dependencies_conllu_dataset(corpus.train) +def test_write_to_and_load_from_directory(tasks_base_path): + from pathlib import Path + + corpus = ColumnCorpus( + tasks_base_path / "column_with_whitespaces", + train_file="eng.train", + column_format={0: "text", 1: "ner"}, + column_delimiter=" ", + skip_first_line=False, + sample_missing_splits=False, + ) + directory = Path("resources/taggers/") + corpus.write_to_directory(["ner"], directory, column_delimiter="\t") + loaded_corpus = ColumnCorpus.load_corpus_with_meta_data(directory) + assert len(loaded_corpus.train) == len(corpus.train) + assert loaded_corpus.train[0].to_tagged_string() == corpus.train[0].to_tagged_string() + + def test_hipe_2022_corpus(tasks_base_path): # This test covers the complete HIPE 2022 dataset. # https://github.com/hipe-eval/HIPE-2022-data