From f82969d9d56a03dfd9438664abcf4f4b6b67d9a8 Mon Sep 17 00:00:00 2001 From: Elena Merdjanovska Date: Tue, 23 Jul 2024 22:59:06 +0200 Subject: [PATCH 001/333] add intial noisebench dataset class --- flair/datasets/sequence_labeling.py | 51 +++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 38ca75e94b..044f2a8609 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4973,6 +4973,57 @@ def _write_instances(cls, version, base_path, split, data): out_file.write("\n") +class NER_NOISEBENCH(ColumnCorpus): + label_url = "https://github.com/elenamer/NoiseBench/tree/8a32da1e06f2239afc95b3f9dc5274abc25cc46d/data/annotations" + def __init__( + self, + noise: str = None, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + **corpusargs, + ) -> None: + """Initialize the NoiseBench corpus. + + Args: + noise (string): Chooses the labelset for the data. + clean (default): Clean labels + crowd,crowdbest,expert,distant,weak,llm : Different kinds of noisy labelsets (details: ...) + base_path (Optional[Union[str, Path]]): Path to the data. + Default is None, meaning the corpus gets automatically downloaded and saved. + You can override this by passing a path to a directory containing the unprocessed files but typically this + should not be necessary. + in_memory (bool): If True the dataset is kept in memory achieving speedups in training. + **corpusargs: The arguments propagated to :meth:'flair.datasets.ColumnCorpus.__init__'. + """ + if noise not in ['clean', None, 'crowd','crowdbest','expert','distant','weak','llm']: + raise Exception( + "Please choose a valid version" + ) + + base_path = self._set_path(base_path) + + filename = 'clean' if noise in ['clean',None] else f'noise_{noise}' + + cached_path(f"{self.label_url}/{filename}.traindev", base_path) + cached_path(f"{self.label_url}/index.txt", base_path) + + super().__init__( + data_folder=base_path, + train_file=f"{filename}.train", + dev_file=f"{filename}.dev", + test_file=f"clean.test", # test set is always clean (without noise) + column_format={0: "text", 1: "ner"}, + in_memory=in_memory, + column_delimiter="\t", + document_separator = "-DOCSTART-", + **corpusargs, + ) + + @classmethod + def _set_path(cls, base_path) -> Path: + base_path = flair.cache_root / "datasets" / "noisebench" if not base_path else Path(base_path) + return base_path + class MASAKHA_POS(MultiCorpus): def __init__( self, From 2ede7190217ec28320fa537f5f2d615d1c12bcf4 Mon Sep 17 00:00:00 2001 From: Elena Merdjanovska Date: Wed, 24 Jul 2024 09:34:56 +0200 Subject: [PATCH 002/333] add downloading of cleanconll --- flair/datasets/sequence_labeling.py | 32 ++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 044f2a8609..0c89793463 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -5003,9 +5003,35 @@ def __init__( base_path = self._set_path(base_path) filename = 'clean' if noise in ['clean',None] else f'noise_{noise}' - - cached_path(f"{self.label_url}/{filename}.traindev", base_path) - cached_path(f"{self.label_url}/index.txt", base_path) + file_paths = [base_path / f'{filename}.train', base_path / f'{filename}.dev', base_path / 'clean.test'] + files_exist = [path.exists() for path in file_paths] + cleanconll_base_path = flair.cache_root / "datasets" / "cleanconll" + + if not all(files_exist): + cached_path(f"{self.label_url}/{filename}.traindev", base_path / 'annotations_only') + cached_path(f"{self.label_url}/index.txt", base_path / 'annotations_only') + + cleanconll_files_exist = [Path(f'{cleanconll_base_path}/cleanconll.{split}').exists() for split in ['train','dev','test']] + if not all(cleanconll_files_exist): + # download cleanconll + + clone = f"git clone https://github.com/flairNLP/CleanCoNLL.git {cleanconll_base_path}/CleanCoNLL" + os.system(clone) # Cloning + cwd = os.getcwd() + + os.chdir(f"{cleanconll_base_path}/CleanCoNLL") + chmod = f"chmod u+x create_cleanconll_from_conll03.sh" + os.system(chmod) + create = f"bash create_cleanconll_from_conll03.sh" + + os.system(create) + os.chdir(cwd) + + shutil.move(f'{cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.train', cleanconll_base_path) + shutil.move(f'{cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.dev', cleanconll_base_path) + shutil.move(f'{cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.test', cleanconll_base_path) + + # create dataset files from index and train/test splits super().__init__( data_folder=base_path, From a34d2b2e5c7c60c8efc7d6ea0cd28c46cb901a47 Mon Sep 17 00:00:00 2001 From: Elena Merdjanovska Date: Wed, 24 Jul 2024 09:35:11 +0200 Subject: [PATCH 003/333] update __init__ --- flair/datasets/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 2837e017c0..5a7ccf7e77 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -197,6 +197,7 @@ NER_ENGLISH_WIKIGOLD, NER_ENGLISH_WNUT_2020, NER_ESTONIAN_NOISY, + NER_NOISEBENCH, NER_FINNISH, NER_GERMAN_BIOFID, NER_GERMAN_EUROPARL, From 2684e93136fa7e61059aa9340cb23e5aa460fc0a Mon Sep 17 00:00:00 2001 From: Elena Merdjanovska Date: Wed, 24 Jul 2024 14:22:41 +0200 Subject: [PATCH 004/333] add processing of noisebench label sets --- flair/datasets/sequence_labeling.py | 144 ++++++++++++++++++++++++---- 1 file changed, 126 insertions(+), 18 deletions(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 0c89793463..9a30c05ee4 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4974,7 +4974,9 @@ def _write_instances(cls, version, base_path, split, data): class NER_NOISEBENCH(ColumnCorpus): - label_url = "https://github.com/elenamer/NoiseBench/tree/8a32da1e06f2239afc95b3f9dc5274abc25cc46d/data/annotations" + label_url = "https://raw.githubusercontent.com/elenamer/NoiseBench/main/data/annotations/" + SAVE_TRAINDEV_FILE = False + def __init__( self, noise: str = None, @@ -5000,26 +5002,26 @@ def __init__( "Please choose a valid version" ) - base_path = self._set_path(base_path) + self._set_path(base_path) filename = 'clean' if noise in ['clean',None] else f'noise_{noise}' - file_paths = [base_path / f'{filename}.train', base_path / f'{filename}.dev', base_path / 'clean.test'] + file_paths = [self.base_path / f'{filename}.train', self.base_path / f'{filename}.dev', self.base_path / 'clean.test'] files_exist = [path.exists() for path in file_paths] - cleanconll_base_path = flair.cache_root / "datasets" / "cleanconll" + self.cleanconll_base_path = flair.cache_root / "datasets" / "cleanconll" if not all(files_exist): - cached_path(f"{self.label_url}/{filename}.traindev", base_path / 'annotations_only') - cached_path(f"{self.label_url}/index.txt", base_path / 'annotations_only') + cached_path(f"{self.label_url}/{filename}.traindev", self.base_path / 'annotations_only') + cached_path(f"{self.label_url}/index.txt", self.base_path / 'annotations_only') - cleanconll_files_exist = [Path(f'{cleanconll_base_path}/cleanconll.{split}').exists() for split in ['train','dev','test']] + cleanconll_files_exist = [Path(f'{self.cleanconll_base_path}/cleanconll.{split}').exists() for split in ['train','dev','test']] if not all(cleanconll_files_exist): # download cleanconll - clone = f"git clone https://github.com/flairNLP/CleanCoNLL.git {cleanconll_base_path}/CleanCoNLL" + clone = f"git clone https://github.com/flairNLP/CleanCoNLL.git {self.cleanconll_base_path}/CleanCoNLL" os.system(clone) # Cloning cwd = os.getcwd() - os.chdir(f"{cleanconll_base_path}/CleanCoNLL") + os.chdir(f"{self.cleanconll_base_path}/CleanCoNLL") chmod = f"chmod u+x create_cleanconll_from_conll03.sh" os.system(chmod) create = f"bash create_cleanconll_from_conll03.sh" @@ -5027,29 +5029,135 @@ def __init__( os.system(create) os.chdir(cwd) - shutil.move(f'{cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.train', cleanconll_base_path) - shutil.move(f'{cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.dev', cleanconll_base_path) - shutil.move(f'{cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.test', cleanconll_base_path) + shutil.move(f'{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.train', self.cleanconll_base_path) + shutil.move(f'{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.dev', self.cleanconll_base_path) + shutil.move(f'{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.test', self.cleanconll_base_path) + + shutil.rmtree(self.cleanconll_base_path / 'CleanCoNLL') # create dataset files from index and train/test splits + self.generate_data_files(filename,) super().__init__( - data_folder=base_path, + data_folder=self.base_path, train_file=f"{filename}.train", dev_file=f"{filename}.dev", test_file=f"clean.test", # test set is always clean (without noise) column_format={0: "text", 1: "ner"}, in_memory=in_memory, column_delimiter="\t", - document_separator = "-DOCSTART-", + document_separator_token = "-DOCSTART-", **corpusargs, ) - @classmethod - def _set_path(cls, base_path) -> Path: - base_path = flair.cache_root / "datasets" / "noisebench" if not base_path else Path(base_path) - return base_path + def _set_path(self, base_path) -> Path: + self.base_path = flair.cache_root / "datasets" / "noisebench" if not base_path else Path(base_path) + + @staticmethod + def read_column_file(filename): + raw = open(filename, 'r', errors='replace') + raw = raw.readlines() + all_x = [] + point = [] + for line in raw: + if '\t' in line.strip(): + stripped_line = line.strip().split('\t') + else: + stripped_line = line.strip().split(' ') + point.append(stripped_line) + if line.strip() == '': + if len(point[:-1]) > 0: + all_x.append(point[:-1]) + point = [] + + if len(point) > 0: + all_x.append(point) + + all_x = all_x + return all_x + + @staticmethod + def save_to_column_file(filename, list): + with open(filename, "w") as f: + for sentence in list: + for token in sentence: + f.write('\t'.join(token)) + f.write('\n') + f.write('\n') + + def _create_train_dev_splits(self, filename, all_sentences = None, datestring = '1996-08-24'): + if not all_sentences: + all_sentences = self.read_column_file(filename) + + train_sentences = [] + dev_sentences = [] + for i, s in enumerate(all_sentences): + if 'DOCSTART' in s[0][0]: + assert i+3 < len(all_sentences) # last document is too short + + # news date is usually in 3rd or 4th sentence of each article + if datestring in all_sentences[i+2][-1][0] or datestring in all_sentences[i+3][-1][0]: + save_to_dev = True + else: + save_to_dev = False + + if save_to_dev: + dev_sentences.append(s) + else: + train_sentences.append(s) + + self.save_to_column_file(os.sep.join(filename.split(os.sep)[:-1])+os.sep+filename.split(os.sep)[-1].split('.')[0]+'.dev',dev_sentences) + self.save_to_column_file(os.sep.join(filename.split(os.sep)[:-1])+os.sep+filename.split(os.sep)[-1].split('.')[0]+'.train',train_sentences) + + + def _merge_tokens_labels(self, corpus, all_clean_sentences, token_indices): + # generate NoiseBench dataset variants, given CleanCoNLL, noisy label files and index file + + noisy_labels = self.read_column_file(os.path.join(self.base_path,'annotations_only',f'{corpus}.traindev')) + #print(noisy_labels) + #print(token_indices) + for index, sentence in zip(token_indices, noisy_labels): + + if index.strip() == 'docstart': + assert len(sentence) == 1 + sentence[0][0] = '-DOCSTART-' + continue + clean_sentence = all_clean_sentences[int(index.strip())] + + assert len(clean_sentence) == len(sentence) # this means indexing is wrong + + for token, label in zip(clean_sentence, sentence): + label[0] = token[0] # token[0] -> text, token[1] -> BIO label + if self.SAVE_TRAINDEV_FILE: + self.save_to_column_file(os.path.join(self.base_path,f'{corpus}.traindev'),noisy_labels) + return noisy_labels + + + def generate_data_files(self, filename): + + index_file = open(os.path.join(self.base_path,'annotations_only','index.txt')) + token_indices = index_file.readlines() + + all_clean_sentences = self.read_column_file(os.path.join(self.cleanconll_base_path,'cleanconll.train')) + + #os.makedirs(os.path.join('data','noisebench'), exist_ok=True) + + noisy_sentences = self._merge_tokens_labels(filename, all_clean_sentences, token_indices) + self._create_train_dev_splits(all_sentences=noisy_sentences,filename=os.path.join(self.base_path,f'{filename}.traindev')) + + # copy test set + all_clean_test_sentences = self.read_column_file(os.path.join(self.cleanconll_base_path,'cleanconll.test')) + + test_sentences = [] + for s in all_clean_test_sentences: + new_s = [] + for token in s: + new_s.append([token[0],token[4]]) + test_sentences.append(new_s) + + self.save_to_column_file(os.path.join(self.base_path,f'clean.test'),test_sentences) + class MASAKHA_POS(MultiCorpus): def __init__( self, From 0e8183d7a75276e6c9188d0d91f6211e961c54cb Mon Sep 17 00:00:00 2001 From: Elena Merdjanovska Date: Wed, 24 Jul 2024 14:57:35 +0200 Subject: [PATCH 005/333] fix formatting --- flair/datasets/__init__.py | 3 +- flair/datasets/sequence_labeling.py | 172 +++++++++++++++------------- 2 files changed, 97 insertions(+), 78 deletions(-) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 5a7ccf7e77..b38d1bd761 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -197,7 +197,6 @@ NER_ENGLISH_WIKIGOLD, NER_ENGLISH_WNUT_2020, NER_ESTONIAN_NOISY, - NER_NOISEBENCH, NER_FINNISH, NER_GERMAN_BIOFID, NER_GERMAN_EUROPARL, @@ -217,6 +216,7 @@ NER_MULTI_WIKINER, NER_MULTI_XTREME, NER_NERMUD, + NER_NOISEBENCH, NER_SWEDISH, NER_TURKU, NER_UKRAINIAN, @@ -495,6 +495,7 @@ "NER_GERMAN_MOBIE", "NER_GERMAN_POLITICS", "NER_HIPE_2022", + "NER_NOISEBENCH", "NER_HUNGARIAN", "NER_ICDAR_EUROPEANA", "NER_ICELANDIC", diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 9a30c05ee4..7cc06c8a0a 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4976,10 +4976,10 @@ def _write_instances(cls, version, base_path, split, data): class NER_NOISEBENCH(ColumnCorpus): label_url = "https://raw.githubusercontent.com/elenamer/NoiseBench/main/data/annotations/" SAVE_TRAINDEV_FILE = False - + def __init__( self, - noise: str = None, + noise: str = "clean", base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, **corpusargs, @@ -4997,78 +4997,90 @@ def __init__( in_memory (bool): If True the dataset is kept in memory achieving speedups in training. **corpusargs: The arguments propagated to :meth:'flair.datasets.ColumnCorpus.__init__'. """ - if noise not in ['clean', None, 'crowd','crowdbest','expert','distant','weak','llm']: - raise Exception( - "Please choose a valid version" - ) + if noise not in ["clean", "crowd", "crowdbest", "expert", "distant", "weak", "llm"]: + raise Exception("Please choose a valid version") self._set_path(base_path) - filename = 'clean' if noise in ['clean',None] else f'noise_{noise}' - file_paths = [self.base_path / f'{filename}.train', self.base_path / f'{filename}.dev', self.base_path / 'clean.test'] + filename = "clean" if noise == "clean" else f"noise_{noise}" + file_paths = [ + self.base_path / f"{filename}.train", + self.base_path / f"{filename}.dev", + self.base_path / "clean.test", + ] files_exist = [path.exists() for path in file_paths] self.cleanconll_base_path = flair.cache_root / "datasets" / "cleanconll" if not all(files_exist): - cached_path(f"{self.label_url}/{filename}.traindev", self.base_path / 'annotations_only') - cached_path(f"{self.label_url}/index.txt", self.base_path / 'annotations_only') - - cleanconll_files_exist = [Path(f'{self.cleanconll_base_path}/cleanconll.{split}').exists() for split in ['train','dev','test']] + cached_path(f"{self.label_url}/{filename}.traindev", self.base_path / "annotations_only") + cached_path(f"{self.label_url}/index.txt", self.base_path / "annotations_only") + + cleanconll_files_exist = [ + Path(f"{self.cleanconll_base_path}/cleanconll.{split}").exists() for split in ["train", "dev", "test"] + ] if not all(cleanconll_files_exist): # download cleanconll - clone = f"git clone https://github.com/flairNLP/CleanCoNLL.git {self.cleanconll_base_path}/CleanCoNLL" - os.system(clone) # Cloning + clone = f"git clone https://github.com/flairNLP/CleanCoNLL.git {self.cleanconll_base_path}/CleanCoNLL" + os.system(clone) # Cloning cwd = os.getcwd() os.chdir(f"{self.cleanconll_base_path}/CleanCoNLL") - chmod = f"chmod u+x create_cleanconll_from_conll03.sh" + chmod = "chmod u+x create_cleanconll_from_conll03.sh" os.system(chmod) - create = f"bash create_cleanconll_from_conll03.sh" - + create = "bash create_cleanconll_from_conll03.sh" + os.system(create) os.chdir(cwd) - - shutil.move(f'{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.train', self.cleanconll_base_path) - shutil.move(f'{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.dev', self.cleanconll_base_path) - shutil.move(f'{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.test', self.cleanconll_base_path) - - shutil.rmtree(self.cleanconll_base_path / 'CleanCoNLL') + + shutil.move( + f"{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.train", + self.cleanconll_base_path, + ) + shutil.move( + f"{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.dev", self.cleanconll_base_path + ) + shutil.move( + f"{self.cleanconll_base_path}/CleanCoNLL/data/cleanconll/cleanconll.test", self.cleanconll_base_path + ) + + shutil.rmtree(self.cleanconll_base_path / "CleanCoNLL") # create dataset files from index and train/test splits - self.generate_data_files(filename,) + self.generate_data_files( + filename, + ) super().__init__( data_folder=self.base_path, train_file=f"{filename}.train", dev_file=f"{filename}.dev", - test_file=f"clean.test", # test set is always clean (without noise) + test_file="clean.test", # test set is always clean (without noise) column_format={0: "text", 1: "ner"}, in_memory=in_memory, column_delimiter="\t", - document_separator_token = "-DOCSTART-", + document_separator_token="-DOCSTART-", **corpusargs, ) - def _set_path(self, base_path) -> Path: + def _set_path(self, base_path): self.base_path = flair.cache_root / "datasets" / "noisebench" if not base_path else Path(base_path) - - @staticmethod + + @staticmethod def read_column_file(filename): - raw = open(filename, 'r', errors='replace') - raw = raw.readlines() - all_x = [] - point = [] - for line in raw: - if '\t' in line.strip(): - stripped_line = line.strip().split('\t') - else: - stripped_line = line.strip().split(' ') - point.append(stripped_line) - if line.strip() == '': - if len(point[:-1]) > 0: - all_x.append(point[:-1]) - point = [] + with open(filename, errors="replace") as file: + lines = file.readlines() + all_x = [] + point = [] + for line in lines: + if "\t" in line.strip(): + stripped_line = line.strip().split("\t") if "\t" in line.strip() else line.strip().split(" ") + + point.append(stripped_line) + if line.strip() == "": + if len(point[:-1]) > 0: + all_x.append(point[:-1]) + point = [] if len(point) > 0: all_x.append(point) @@ -5081,22 +5093,22 @@ def save_to_column_file(filename, list): with open(filename, "w") as f: for sentence in list: for token in sentence: - f.write('\t'.join(token)) - f.write('\n') - f.write('\n') + f.write("\t".join(token)) + f.write("\n") + f.write("\n") - def _create_train_dev_splits(self, filename, all_sentences = None, datestring = '1996-08-24'): + def _create_train_dev_splits(self, filename, all_sentences=None, datestring="1996-08-24"): if not all_sentences: all_sentences = self.read_column_file(filename) - train_sentences = [] - dev_sentences = [] + train_sentences = [] + dev_sentences = [] for i, s in enumerate(all_sentences): - if 'DOCSTART' in s[0][0]: - assert i+3 < len(all_sentences) # last document is too short - + if "DOCSTART" in s[0][0]: + assert i + 3 < len(all_sentences) # last document is too short + # news date is usually in 3rd or 4th sentence of each article - if datestring in all_sentences[i+2][-1][0] or datestring in all_sentences[i+3][-1][0]: + if datestring in all_sentences[i + 2][-1][0] or datestring in all_sentences[i + 3][-1][0]: save_to_dev = True else: save_to_dev = False @@ -5106,57 +5118,63 @@ def _create_train_dev_splits(self, filename, all_sentences = None, datestring = else: train_sentences.append(s) - self.save_to_column_file(os.sep.join(filename.split(os.sep)[:-1])+os.sep+filename.split(os.sep)[-1].split('.')[0]+'.dev',dev_sentences) - self.save_to_column_file(os.sep.join(filename.split(os.sep)[:-1])+os.sep+filename.split(os.sep)[-1].split('.')[0]+'.train',train_sentences) - + self.save_to_column_file( + os.sep.join(filename.split(os.sep)[:-1]) + os.sep + filename.split(os.sep)[-1].split(".")[0] + ".dev", + dev_sentences, + ) + self.save_to_column_file( + os.sep.join(filename.split(os.sep)[:-1]) + os.sep + filename.split(os.sep)[-1].split(".")[0] + ".train", + train_sentences, + ) def _merge_tokens_labels(self, corpus, all_clean_sentences, token_indices): # generate NoiseBench dataset variants, given CleanCoNLL, noisy label files and index file - noisy_labels = self.read_column_file(os.path.join(self.base_path,'annotations_only',f'{corpus}.traindev')) - #print(noisy_labels) - #print(token_indices) + noisy_labels = self.read_column_file(os.path.join(self.base_path, "annotations_only", f"{corpus}.traindev")) + # print(noisy_labels) + # print(token_indices) for index, sentence in zip(token_indices, noisy_labels): - if index.strip() == 'docstart': + if index.strip() == "docstart": assert len(sentence) == 1 - sentence[0][0] = '-DOCSTART-' + sentence[0][0] = "-DOCSTART-" continue clean_sentence = all_clean_sentences[int(index.strip())] - assert len(clean_sentence) == len(sentence) # this means indexing is wrong + assert len(clean_sentence) == len(sentence) # this means indexing is wrong for token, label in zip(clean_sentence, sentence): - label[0] = token[0] # token[0] -> text, token[1] -> BIO label + label[0] = token[0] # token[0] -> text, token[1] -> BIO label if self.SAVE_TRAINDEV_FILE: - self.save_to_column_file(os.path.join(self.base_path,f'{corpus}.traindev'),noisy_labels) + self.save_to_column_file(os.path.join(self.base_path, f"{corpus}.traindev"), noisy_labels) return noisy_labels - def generate_data_files(self, filename): - index_file = open(os.path.join(self.base_path,'annotations_only','index.txt')) - token_indices = index_file.readlines() + with open(os.path.join(self.base_path, "annotations_only", "index.txt")) as index_file: + token_indices = index_file.readlines() - all_clean_sentences = self.read_column_file(os.path.join(self.cleanconll_base_path,'cleanconll.train')) + all_clean_sentences = self.read_column_file(os.path.join(self.cleanconll_base_path, "cleanconll.train")) - #os.makedirs(os.path.join('data','noisebench'), exist_ok=True) + # os.makedirs(os.path.join('data','noisebench'), exist_ok=True) - noisy_sentences = self._merge_tokens_labels(filename, all_clean_sentences, token_indices) - self._create_train_dev_splits(all_sentences=noisy_sentences,filename=os.path.join(self.base_path,f'{filename}.traindev')) - + noisy_sentences = self._merge_tokens_labels(filename, all_clean_sentences, token_indices) + self._create_train_dev_splits( + all_sentences=noisy_sentences, filename=os.path.join(self.base_path, f"{filename}.traindev") + ) + + # copy test set + all_clean_test_sentences = self.read_column_file(os.path.join(self.cleanconll_base_path, "cleanconll.test")) - # copy test set - all_clean_test_sentences = self.read_column_file(os.path.join(self.cleanconll_base_path,'cleanconll.test')) - test_sentences = [] for s in all_clean_test_sentences: new_s = [] for token in s: - new_s.append([token[0],token[4]]) + new_s.append([token[0], token[4]]) test_sentences.append(new_s) - self.save_to_column_file(os.path.join(self.base_path,f'clean.test'),test_sentences) + self.save_to_column_file(os.path.join(self.base_path, "clean.test"), test_sentences) + class MASAKHA_POS(MultiCorpus): def __init__( From 588279f00bdad2302630fb26a8bf3959dacb610d Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 6 Sep 2024 14:54:36 +0200 Subject: [PATCH 006/333] fix T5 tokenizer loading --- flair/embeddings/transformer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index f3492178f9..1e88787deb 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -26,6 +26,7 @@ LayoutLMv2FeatureExtractor, PretrainedConfig, PreTrainedTokenizer, + T5TokenizerFast, ) from transformers.tokenization_utils_base import LARGE_INTEGER from transformers.utils import PaddingStrategy @@ -444,7 +445,7 @@ def _tokenizer_from_bytes(cls, zip_data: BytesIO) -> PreTrainedTokenizer: zip_obj = zipfile.ZipFile(zip_data) with tempfile.TemporaryDirectory() as temp_dir: zip_obj.extractall(temp_dir) - return AutoTokenizer.from_pretrained(temp_dir, add_prefix_space=True) + return AutoTokenizer.from_pretrained(temp_dir) @classmethod def _feature_extractor_from_bytes(cls, zip_data: Optional[BytesIO]) -> Optional[FeatureExtractionMixin]: @@ -458,7 +459,13 @@ def _feature_extractor_from_bytes(cls, zip_data: Optional[BytesIO]) -> Optional[ def __tokenizer_bytes(self): with tempfile.TemporaryDirectory() as temp_dir: files = list(self.tokenizer.save_pretrained(temp_dir)) - if self.tokenizer.is_fast and self.tokenizer.slow_tokenizer_class: + if ( + self.tokenizer.is_fast + and self.tokenizer.slow_tokenizer_class + and not isinstance( + self.tokenizer, T5TokenizerFast + ) # do not remove slow files for T5, as it can only be created from slow tokenizer with prefix space + ): vocab_files = self.tokenizer.slow_tokenizer_class.vocab_files_names.values() files = [f for f in files if all(v not in f for v in vocab_files)] zip_data = BytesIO() From caea8bb58467f9b84b4025d21906c8af9b1fdb64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Sat, 19 Oct 2024 12:11:28 +0200 Subject: [PATCH 007/333] Initial commit. --- flair/embeddings/transformer.py | 27 ++++++++----------- .../test_transformer_document_embeddings.py | 19 +++++++++++++ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 1e88787deb..3205aeabc6 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -8,18 +8,18 @@ from abc import abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Type, Union import torch import transformers from packaging.version import Version from torch.jit import ScriptModule from transformers import ( - CONFIG_MAPPING, AutoConfig, AutoFeatureExtractor, AutoModel, AutoTokenizer, + CONFIG_MAPPING, FeatureExtractionMixin, LayoutLMTokenizer, LayoutLMTokenizerFast, @@ -32,13 +32,8 @@ from transformers.utils import PaddingStrategy import flair -from flair.data import Sentence, Token, log -from flair.embeddings.base import ( - DocumentEmbeddings, - Embeddings, - TokenEmbeddings, - register_embeddings, -) +from flair.data import log, Sentence, Token +from flair.embeddings.base import DocumentEmbeddings, Embeddings, register_embeddings, TokenEmbeddings SENTENCE_BOUNDARY_TAG: str = "[FLERT]" @@ -198,7 +193,7 @@ def fill_mean_token_embeddings( @torch.jit.script_if_tracing -def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): +def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: result = torch.zeros( sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype ) @@ -206,9 +201,11 @@ def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths for i in torch.arange(sentence_hidden_states.shape[0]): result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0) + return result + @torch.jit.script_if_tracing -def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): +def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: result = torch.zeros( sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype ) @@ -216,6 +213,8 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: for i in torch.arange(sentence_hidden_states.shape[0]): result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0) + return result + def _legacy_reconstruct_word_ids( embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]] @@ -1127,11 +1126,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if peft_config is not None: # add adapters for finetuning try: - from peft import ( - TaskType, - get_peft_model, - prepare_model_for_kbit_training, - ) + from peft import get_peft_model, prepare_model_for_kbit_training, TaskType except ImportError: log.error("You cannot use the PEFT finetuning without peft being installed") raise diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index 1a65a96fdb..f253d2d8d8 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -1,7 +1,10 @@ +import pytest + from flair.data import Dictionary from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier + from tests.embedding_test_utils import BaseEmbeddingsTest @@ -37,3 +40,19 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): # check that context_length and use_context_separator is the same for both assert model.embeddings.context_length == loaded_single_task.embeddings.context_length assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator + + +@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) +def test_cls_pooling(cls_pooling): + from flair.data import Sentence + from flair.embeddings import TransformerDocumentEmbeddings + + embeddings = TransformerDocumentEmbeddings( + model="xlm-roberta-base", + layers="-1", + cls_pooling=cls_pooling, + allow_long_sentences=True, + ) + sentence = Sentence("Today is a good day.") + embeddings.embed(sentence) + assert sentence.embedding is not None From 43fc96a5370fa0260c62018d0bd4f87579a2e987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Sat, 19 Oct 2024 12:16:52 +0200 Subject: [PATCH 008/333] Remove imports from the test function. --- tests/embeddings/test_transformer_document_embeddings.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index f253d2d8d8..bf202110b1 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -1,6 +1,6 @@ import pytest -from flair.data import Dictionary +from flair.data import Dictionary, Sentence from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier @@ -44,9 +44,6 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): @pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) def test_cls_pooling(cls_pooling): - from flair.data import Sentence - from flair.embeddings import TransformerDocumentEmbeddings - embeddings = TransformerDocumentEmbeddings( model="xlm-roberta-base", layers="-1", From fdb49952c61abe3f671b0a39d1eb2685a8aff9f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Sat, 19 Oct 2024 13:12:25 +0200 Subject: [PATCH 009/333] Refactor cls pooling into a function. --- flair/embeddings/transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 3205aeabc6..8a1e76cd72 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -191,6 +191,10 @@ def fill_mean_token_embeddings( return all_token_embeddings +@torch.jit.script_if_tracing +def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: + return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1] + @torch.jit.script_if_tracing def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: @@ -1436,9 +1440,7 @@ def forward( else: assert sub_token_lengths is not None if self.cls_pooling == "cls": - document_embeddings = sentence_hidden_states[ - torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1 - ] + document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths) elif self.cls_pooling == "mean": document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths) elif self.cls_pooling == "max": From 4202333d43536e3329cd44810f39b2b5763e6d24 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 25 Oct 2024 14:51:37 +0200 Subject: [PATCH 010/333] drop python 3.8 --- .github/workflows/ci.yml | 4 +- .github/workflows/publish-docs.yml | 2 +- CONTRIBUTING.md | 4 +- README.md | 2 +- docs/contributing/local_development.md | 4 +- docs/tutorial/intro.md | 2 +- flair/class_utils.py | 11 +- flair/data.py | 167 ++++++------- flair/datasets/base.py | 12 +- flair/datasets/biomedical.py | 230 +++++++++--------- flair/datasets/document_classification.py | 132 +++++----- flair/datasets/entity_linking.py | 109 +++++---- flair/datasets/ocr.py | 8 +- flair/datasets/relation_extraction.py | 30 +-- flair/datasets/sequence_labeling.py | 207 ++++++++-------- flair/datasets/text_image.py | 11 +- flair/datasets/text_text.py | 48 ++-- flair/datasets/treebanks.py | 10 +- flair/embeddings/base.py | 24 +- flair/embeddings/document.py | 62 ++--- flair/embeddings/image.py | 16 +- flair/embeddings/legacy.py | 22 +- flair/embeddings/token.py | 67 +++-- flair/embeddings/transformer.py | 70 +++--- flair/file_utils.py | 11 +- flair/inference_utils.py | 4 +- flair/models/entity_linker_model.py | 27 +- flair/models/entity_mention_linking.py | 95 ++++---- flair/models/language_model.py | 14 +- flair/models/lemmatizer_model.py | 36 +-- flair/models/multitask_model.py | 32 +-- flair/models/pairwise_classification_model.py | 5 +- flair/models/pairwise_regression_model.py | 29 +-- flair/models/prefixed_tagger.py | 36 +-- flair/models/regexp_tagger.py | 22 +- flair/models/relation_classifier_model.py | 67 +++-- flair/models/relation_extractor_model.py | 10 +- flair/models/sequence_tagger_model.py | 32 +-- flair/models/sequence_tagger_utils/viterbi.py | 12 +- flair/models/tars_model.py | 44 ++-- flair/models/text_classification_model.py | 6 +- flair/models/text_regression_model.py | 28 +-- flair/models/triple_classification_model.py | 5 +- flair/models/word_tagger_model.py | 8 +- flair/nn/decoder.py | 6 +- flair/nn/model.py | 70 +++--- flair/nn/multitask.py | 17 +- flair/samplers.py | 3 +- flair/splitter.py | 18 +- flair/tokenization.py | 34 +-- flair/trainers/language_model_trainer.py | 17 +- flair/trainers/plugins/base.py | 19 +- .../plugins/functional/anneal_on_plateau.py | 4 +- .../plugins/functional/checkpoints.py | 4 +- .../plugins/functional/linear_scheduler.py | 4 +- .../functional/reduce_transformer_vocab.py | 3 +- .../plugins/functional/weight_extractor.py | 4 +- flair/trainers/plugins/loggers/log_file.py | 4 +- flair/trainers/plugins/loggers/loss_file.py | 8 +- .../plugins/loggers/metric_history.py | 7 +- flair/trainers/plugins/loggers/tensorboard.py | 4 +- flair/trainers/plugins/loggers/wandb.py | 4 +- flair/trainers/plugins/metric_records.py | 5 +- flair/trainers/trainer.py | 32 +-- flair/training_utils.py | 18 +- flair/visual/ner_html.py | 4 +- flair/visual/training_curves.py | 8 +- flair/visual/tree_printer.py | 6 +- pyproject.toml | 4 +- requirements-dev.txt | 2 +- resources/docs/EXPERIMENTS.md | 12 +- resources/docs/HUNFLAIR2.md | 2 +- .../KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md | 2 +- .../KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md | 2 +- setup.py | 2 +- tests/embedding_test_utils.py | 16 +- ...test_document_transform_word_embeddings.py | 8 +- tests/embeddings/test_word_embeddings.py | 4 +- tests/model_test_utils.py | 10 +- tests/models/test_relation_classifier.py | 12 +- tests/test_datasets_biomedical.py | 4 +- tests/test_labels.py | 38 ++- tests/test_tokenize_sentence.py | 4 +- 83 files changed, 1084 insertions(+), 1118 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b3633d93e..5633ac850f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,11 +12,11 @@ jobs: FLAIR_CACHE_ROOT: ./cache/flair steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.9 id: setup-python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Install Torch cpu run: pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Install Flair dependencies diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 24a424adba..f752b1324c 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -10,7 +10,7 @@ jobs: name: Build the docs using Sphinx and push to gh-pages runs-on: ubuntu-latest env: - python-version: 3.8 + python-version: 3.9 steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b40ddfe77e..b44927f17a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,8 +24,8 @@ the code should hopefully be easy. ### Setup -Flair requires python-3.8 or higher. To make sure your code also runs on the oldest supported -python version, it is recommended to use python-3.8.x for flair development. +Flair requires python-3.9 or higher. To make sure your code also runs on the oldest supported +python version, it is recommended to use python-3.9.x for flair development. Create a python environment of your preference and run: ```bash diff --git a/README.md b/README.md index 92502c972b..fdf4130124 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ In your favorite virtual environment, simply do: pip install flair ``` -Flair requires Python 3.8+. +Flair requires Python 3.9+. ### Example 1: Tag Entities in Text diff --git a/docs/contributing/local_development.md b/docs/contributing/local_development.md index 87439439f2..9c7413703e 100644 --- a/docs/contributing/local_development.md +++ b/docs/contributing/local_development.md @@ -6,8 +6,8 @@ the code should hopefully be easy. ## Setup -Flair requires python-3.8 or higher. To make sure our code also runs on the oldest supported -python version, it is recommended to use python-3.8.x for flair development. +Flair requires python-3.9 or higher. To make sure our code also runs on the oldest supported +python version, it is recommended to use python-3.9.x for flair development. Create a python environment of your preference and run: ```bash diff --git a/docs/tutorial/intro.md b/docs/tutorial/intro.md index e652583f76..b8af9b5667 100644 --- a/docs/tutorial/intro.md +++ b/docs/tutorial/intro.md @@ -16,7 +16,7 @@ In your favorite virtual environment, simply do: pip install flair ``` -Flair requires Python 3.8+. +Flair requires Python 3.9+. ## Example 1: Tag Entities in Text diff --git a/flair/class_utils.py b/flair/class_utils.py index 9aa95cd1ee..7e01f4ff42 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -1,12 +1,13 @@ import importlib import inspect +from collections.abc import Iterable from types import ModuleType -from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Optional, TypeVar, Union, overload T = TypeVar("T") -def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: +def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]: for subclass in cls.__subclasses__(): yield from get_non_abstract_subclasses(subclass) if inspect.isabstract(subclass): @@ -14,7 +15,7 @@ def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: yield subclass -def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]: +def get_state_subclass_by_name(cls: type[T], cls_name: Optional[str]) -> type[T]: for sub_cls in get_non_abstract_subclasses(cls): if sub_cls.__name__ == cls_name: return sub_cls @@ -26,12 +27,12 @@ def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ... @overload -def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ... +def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> list[Any]: ... def lazy_import( group: str, module: str, first_symbol: Optional[str] = None, *symbols: str -) -> Union[List[Any], ModuleType]: +) -> Union[list[Any], ModuleType]: try: imported_module = importlib.import_module(module) except ImportError: diff --git a/flair/data.py b/flair/data.py index 69d85baf92..56622b249c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -4,10 +4,11 @@ import typing from abc import ABC, abstractmethod from collections import Counter, defaultdict +from collections.abc import Iterable from operator import itemgetter from os import PathLike from pathlib import Path -from typing import Any, DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast +from typing import Any, NamedTuple, Optional, Union, cast import torch from deprecated.sphinx import deprecated @@ -52,8 +53,8 @@ class Dictionary: def __init__(self, add_unk: bool = True) -> None: # init dictionaries - self.item2idx: Dict[bytes, int] = {} - self.idx2item: List[bytes] = [] + self.item2idx: dict[bytes, int] = {} + self.idx2item: list[bytes] = [] self.add_unk = add_unk self.multi_label = False self.span_labels = False @@ -101,7 +102,7 @@ def get_idx_for_item(self, item: str) -> int: ) raise IndexError - def get_idx_for_items(self, items: List[str]) -> List[int]: + def get_idx_for_items(self, items: list[str]) -> list[int]: """Returns the IDs for each item of the list of string, otherwise 0 if not found. Args: @@ -120,7 +121,7 @@ def get_idx_for_items(self, items: List[str]) -> List[int]: return [results] return list(results) - def get_items(self) -> List[str]: + def get_items(self) -> list[str]: items = [] for item in self.idx2item: items.append(item.decode("UTF-8")) @@ -151,7 +152,7 @@ def save(self, savefile: PathLike): mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx} pickle.dump(mappings, f) - def __setstate__(self, d: Dict) -> None: + def __setstate__(self, d: dict) -> None: self.__dict__ = d # set 'add_unk' if the dictionary was created with a version of Flair older than 0.9 if "add_unk" not in self.__dict__: @@ -281,9 +282,9 @@ class DataPoint: """ def __init__(self) -> None: - self.annotation_layers: Dict[str, List[Label]] = {} - self._embeddings: Dict[str, torch.Tensor] = {} - self._metadata: Dict[str, Any] = {} + self.annotation_layers: dict[str, list[Label]] = {} + self._embeddings: dict[str, torch.Tensor] = {} + self._metadata: dict[str, Any] = {} @property @abstractmethod @@ -293,7 +294,7 @@ def embedding(self) -> torch.Tensor: def set_embedding(self, name: str, vector: torch.Tensor): self._embeddings[name] = vector - def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor: + def get_embedding(self, names: Optional[list[str]] = None) -> torch.Tensor: # if one embedding name, directly return it if names and len(names) == 1: if names[0] in self._embeddings: @@ -308,7 +309,7 @@ def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor: else: return torch.tensor([], device=flair.device) - def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> List[torch.Tensor]: + def get_each_embedding(self, embedding_names: Optional[list[str]] = None) -> list[torch.Tensor]: embeddings = [] for embed_name in sorted(self._embeddings.keys()): if embedding_names and embed_name not in embedding_names: @@ -325,7 +326,7 @@ def to(self, device: str, pin_memory: bool = False) -> None: else: self._embeddings[name] = vector.to(device, non_blocking=True) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None) -> None: + def clear_embeddings(self, embedding_names: Optional[list[str]] = None) -> None: if embedding_names is None: self._embeddings = {} else: @@ -368,14 +369,14 @@ def get_label(self, label_type: Optional[str] = None, zero_tag_value: str = "O") return Label(self, zero_tag_value) return self.get_labels(label_type)[0] - def get_labels(self, typename: Optional[str] = None) -> List[Label]: + def get_labels(self, typename: Optional[str] = None) -> list[Label]: if typename is None: return self.labels return self.annotation_layers.get(typename, []) @property - def labels(self) -> List[Label]: + def labels(self) -> list[Label]: all_labels = [] for key in self.annotation_layers: all_labels.extend(self.annotation_layers[key]) @@ -447,8 +448,8 @@ def __init__( concept_id: str, concept_name: str, database_name: str, - additional_ids: Optional[List[str]] = None, - synonyms: Optional[List[str]] = None, + additional_ids: Optional[list[str]] = None, + synonyms: Optional[list[str]] = None, description: Optional[str] = None, ): """A Concept as part of a knowledgebase or ontology. @@ -483,7 +484,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "concept_id": self.concept_id, "concept_name": self.concept_name, @@ -550,8 +551,8 @@ def __init__( self._start_position = start_position - self._embeddings: Dict = {} - self.tags_proba_dist: Dict[str, List[Label]] = {} + self._embeddings: dict[str, torch.Tensor] = {} + self.tags_proba_dist: dict[str, list[Label]] = {} @property def idx(self) -> int: @@ -568,10 +569,10 @@ def text(self) -> str: def unlabeled_identifier(self) -> str: return f'Token[{self.idx - 1}]: "{self.text}"' - def add_tags_proba_dist(self, tag_type: str, tags: List[Label]) -> None: + def add_tags_proba_dist(self, tag_type: str, tags: list[Label]) -> None: self.tags_proba_dist[tag_type] = tags - def get_tags_proba_dist(self, tag_type: str) -> List[Label]: + def get_tags_proba_dist(self, tag_type: str) -> list[Label]: if tag_type in self.tags_proba_dist: return self.tags_proba_dist[tag_type] return [] @@ -617,7 +618,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): else: DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata) - def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: + def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]: return { "text": self.text, "start_pos": self.start_position, @@ -629,7 +630,7 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: class Span(_PartOfSentence): """This class represents one textual span consisting of Tokens.""" - def __new__(self, tokens: List[Token]): + def __new__(self, tokens: list[Token]): # check if the span already exists. If so, return it unlabeled_identifier = self._make_unlabeled_identifier(tokens) if unlabeled_identifier in tokens[0].sentence._known_spans: @@ -643,7 +644,7 @@ def __new__(self, tokens: List[Token]): tokens[0].sentence._known_spans[unlabeled_identifier] = span return span - def __init__(self, tokens: List[Token]) -> None: + def __init__(self, tokens: list[Token]) -> None: if not self.initialized: super().__init__(tokens[0].sentence) self.tokens = tokens @@ -662,7 +663,7 @@ def text(self) -> str: return "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip() @staticmethod - def _make_unlabeled_identifier(tokens: List[Token]): + def _make_unlabeled_identifier(tokens: list[Token]): text = "".join([t.text + t.whitespace_after * " " for t in tokens]).strip() return f'Span[{tokens[0].idx - 1}:{tokens[-1].idx}]: "{text}"' @@ -769,7 +770,7 @@ class Sentence(DataPoint): def __init__( self, - text: Union[str, List[str], List[Token]], + text: Union[str, list[str], list[Token]], use_tokenizer: Union[bool, Tokenizer] = True, language_code: Optional[str] = None, start_position: int = 0, @@ -790,10 +791,10 @@ def __init__( """ super().__init__() - self.tokens: List[Token] = [] + self.tokens: list[Token] = [] # private field for all known spans - self._known_spans: Dict[str, _PartOfSentence] = {} + self._known_spans: dict[str, _PartOfSentence] = {} self.language_code: Optional[str] = language_code @@ -818,7 +819,7 @@ def __init__( self._previous_sentence: Optional[Sentence] = None self._has_context: bool = False self._next_sentence: Optional[Sentence] = None - self._position_in_dataset: Optional[typing.Tuple[Dataset, int]] = None + self._position_in_dataset: Optional[tuple[Dataset, int]] = None # if text is passed, instantiate sentence with tokens (words) if isinstance(text, str): @@ -830,7 +831,7 @@ def __init__( self.tokens[-1].whitespace_after = 0 return else: - words = cast(List[str], text) + words = cast(list[str], text) text = " ".join(words) # determine token positions and whitespace_after flag @@ -861,15 +862,15 @@ def __init__( def unlabeled_identifier(self): return f'Sentence[{len(self)}]: "{self.text}"' - def get_relations(self, label_type: Optional[str] = None) -> List[Relation]: - relations: List[Relation] = [] + def get_relations(self, label_type: Optional[str] = None) -> list[Relation]: + relations: list[Relation] = [] for label in self.get_labels(label_type): if isinstance(label.data_point, Relation): relations.append(label.data_point) return relations - def get_spans(self, label_type: Optional[str] = None) -> List[Span]: - spans: List[Span] = [] + def get_spans(self, label_type: Optional[str] = None) -> list[Span]: + spans: list[Span] = [] for potential_span in self._known_spans.values(): if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)): spans.append(potential_span) @@ -922,16 +923,16 @@ def to(self, device: str, pin_memory: bool = False): for token in self: token.to(device, pin_memory) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[list[str]] = None): super().clear_embeddings(embedding_names) # clear token embeddings for token in self: token.clear_embeddings(embedding_names) - def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> List[Token]: + def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]: sentence = self - left_context: List[Token] = [] + left_context: list[Token] = [] while len(left_context) < context_length: sentence = sentence.previous_sentence() if sentence is None: @@ -943,9 +944,9 @@ def left_context(self, context_length: int, respect_document_boundaries: bool = left_context = sentence.tokens + left_context return left_context[-context_length:] - def right_context(self, context_length: int, respect_document_boundaries: bool = True) -> List[Token]: + def right_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]: sentence = self - right_context: List[Token] = [] + right_context: list[Token] = [] while len(right_context) < context_length: sentence = sentence.next_sentence() if sentence is None: @@ -1037,7 +1038,7 @@ def to_original_text(self) -> str: [t.text + t.whitespace_after * " " for t in self.tokens] ).strip() - def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: + def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]: return { "text": self.to_original_text(), "labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self], @@ -1180,7 +1181,7 @@ def copy_context_from_sentence(self, sentence: "Sentence") -> None: self._position_in_dataset = sentence._position_in_dataset @classmethod - def set_context_for_sentences(cls, sentences: List["Sentence"]) -> None: + def set_context_for_sentences(cls, sentences: list["Sentence"]) -> None: previous_sentence = None for sentence in sentences: if sentence.is_context_set(): @@ -1231,7 +1232,7 @@ def to(self, device: str, pin_memory: bool = False): self.first.to(device, pin_memory) self.second.to(device, pin_memory) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[list[str]] = None): self.first.clear_embeddings(embedding_names) self.second.clear_embeddings(embedding_names) if self.concatenated_data is not None: @@ -1276,7 +1277,7 @@ def to(self, device: str, pin_memory: bool = False): self.second.to(device, pin_memory) self.third.to(device, pin_memory) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[list[str]] = None): self.first.clear_embeddings(embedding_names) self.second.clear_embeddings(embedding_names) self.third.clear_embeddings(embedding_names) @@ -1313,7 +1314,7 @@ def __init__(self, data=None, imageURL=None): super().__init__() self.data = data - self._embeddings: Dict = {} + self._embeddings: dict[str, torch.Tensor] = {} self.imageURL = imageURL @property @@ -1497,17 +1498,17 @@ def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dict return vocab_dictionary - def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> List[str]: + def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> list[str]: tokens_and_frequencies = Counter(self._get_all_tokens()) - tokens: List[str] = [] + tokens: list[str] = [] for token, freq in tokens_and_frequencies.most_common(): if (min_freq != -1 and freq < min_freq) or (max_tokens != -1 and len(tokens) == max_tokens): break tokens.append(token) return tokens - def _get_all_tokens(self) -> List[str]: + def _get_all_tokens(self) -> list[str]: assert self.train tokens = [s.tokens for s in _iter_dataset(self.train)] tokens = [token for sublist in tokens for token in sublist] @@ -1544,13 +1545,8 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict: tags_to_count = Corpus._count_token_labels(sentences, tag_type) tokens_per_sentence = Corpus._get_tokens_per_sentence(sentences) - label_size_dict = {} - for label, c in classes_to_count.items(): - label_size_dict[label] = c - - tag_size_dict = {} - for tag, c in tags_to_count.items(): - tag_size_dict[tag] = c + label_size_dict = dict(classes_to_count) + tag_size_dict = dict(tags_to_count) return { "dataset": name, @@ -1566,20 +1562,20 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict: } @staticmethod - def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> List[int]: + def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> list[int]: return [len(x.tokens) for x in sentences] @staticmethod - def _count_sentence_labels(sentences: Iterable[Sentence]) -> DefaultDict[str, int]: - label_count: DefaultDict[str, int] = defaultdict(lambda: 0) + def _count_sentence_labels(sentences: Iterable[Sentence]) -> defaultdict[str, int]: + label_count: defaultdict[str, int] = defaultdict(lambda: 0) for sent in sentences: for label in sent.labels: label_count[label.value] += 1 return label_count @staticmethod - def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> DefaultDict[str, int]: - label_count: DefaultDict[str, int] = defaultdict(lambda: 0) + def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> defaultdict[str, int]: + label_count: defaultdict[str, int] = defaultdict(lambda: 0) for sent in sentences: for token in sent.tokens: if label_type in token.annotation_layers: @@ -1623,7 +1619,7 @@ def make_label_dictionary( sentence_label_type_counter: typing.Counter[str] = Counter() label_value_counter: typing.Counter[str] = Counter() - all_sentence_labels: List[str] = [] + all_sentence_labels: list[str] = [] # first, determine the datapoint type by going through dataset until first label is found datapoint_type = None @@ -1687,10 +1683,10 @@ def make_label_dictionary( def add_label_noise( self, label_type: str, - labels: List[str], + labels: list[str], noise_share: float = 0.2, split: str = "train", - noise_transition_matrix: Optional[Dict[str, List[float]]] = None, + noise_transition_matrix: Optional[dict[str, list[float]]] = None, ): """Generates uniform label noise distribution in the chosen dataset split. @@ -1817,12 +1813,12 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary: class MultiCorpus(Corpus): def __init__( self, - corpora: List[Corpus], - task_ids: Optional[List[str]] = None, + corpora: list[Corpus], + task_ids: Optional[list[str]] = None, name: str = "multicorpus", **corpusargs, ) -> None: - self.corpora: List[Corpus] = corpora + self.corpora: list[Corpus] = corpora ids = task_ids if task_ids else [f"Task_{i}" for i in range(len(corpora))] @@ -1871,8 +1867,8 @@ class ConcatFlairDataset(Dataset): datasets (sequence): List of datasets to be concatenated """ - datasets: List[Dataset] - cumulative_sizes: List[int] + datasets: list[Dataset] + cumulative_sizes: list[int] @staticmethod def cumsum(sequence): @@ -1907,36 +1903,13 @@ def __getitem__(self, idx: int) -> Sentence: return sentence @property - def cummulative_sizes(self) -> List[int]: + def cummulative_sizes(self) -> list[int]: return self.cumulative_sizes -def iob2(tags: List) -> bool: - """Converts the tags to the IOB2 format. - - Check that tags have a valid IOB format. - Tags in IOB1 format are converted to IOB2. - """ - for i, tag in enumerate(tags): - if tag.value == "O": - continue - split = tag.value.split("-") - if len(split) != 2 or split[0] not in ["I", "B"]: - return False - if split[0] == "B": - continue - elif i == 0 or tags[i - 1].value == "O": # conversion IOB1 to IOB2 - tags[i].value = "B" + tag.value[1:] - elif tags[i - 1].value[1:] == tag.value[1:]: - continue - else: # conversion IOB1 to IOB2 - tags[i].value = "B" + tag.value[1:] - return True - - def randomly_split_into_two_datasets( dataset: Dataset, length_of_first: int, random_seed: Optional[int] = None -) -> Tuple[Subset, Subset]: +) -> tuple[Subset, Subset]: """Shuffles a dataset and splits into two subsets. The length of the first is specified and the remaining samples go into the second subset. @@ -1959,17 +1932,17 @@ def randomly_split_into_two_datasets( def get_spans_from_bio( - bioes_tags: List[str], bioes_scores: Optional[List[float]] = None -) -> List[typing.Tuple[List[int], float, str]]: + bioes_tags: list[str], bioes_scores: Optional[list[float]] = None +) -> list[tuple[list[int], float, str]]: # add a dummy "O" to close final prediction bioes_tags.append("O") # return complex list found_spans = [] # internal variables - current_tag_weights: Dict[str, float] = {} + current_tag_weights: dict[str, float] = {} previous_tag = "O-" - current_span: List[int] = [] - current_span_scores: List[float] = [] + current_span: list[int] = [] + current_span_scores: list[float] = [] for idx, bioes_tag in enumerate(bioes_tags): # non-set tags are OUT tags if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_": diff --git a/flair/datasets/base.py b/flair/datasets/base.py index a38d0b1321..0737b4660d 100644 --- a/flair/datasets/base.py +++ b/flair/datasets/base.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from pathlib import Path -from typing import Generic, List, Optional, Union +from typing import Generic, Optional, Union import torch.utils.data.dataloader from deprecated.sphinx import deprecated @@ -41,7 +41,7 @@ def __init__( class FlairDatapointDataset(FlairDataset, Generic[DT]): """A simple Dataset object to wrap a List of Datapoints, for example Sentences.""" - def __init__(self, datapoints: Union[DT, List[DT]]) -> None: + def __init__(self, datapoints: Union[DT, list[DT]]) -> None: """Instantiate FlairDatapointDataset. Args: @@ -64,7 +64,7 @@ def __getitem__(self, index: int = 0) -> DT: class SentenceDataset(FlairDatapointDataset): @deprecated(version="0.11", reason="The 'SentenceDataset' class was renamed to 'FlairDatapointDataset'") - def __init__(self, sentences: Union[Sentence, List[Sentence]]) -> None: + def __init__(self, sentences: Union[Sentence, list[Sentence]]) -> None: super().__init__(sentences) @@ -73,7 +73,7 @@ class StringDataset(FlairDataset): def __init__( self, - texts: Union[str, List[str]], + texts: Union[str, list[str]], use_tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(), ) -> None: """Instantiate StringDataset. @@ -111,7 +111,7 @@ def __init__( database: str, collection: str, text_field: str, - categories_field: Optional[List[str]] = None, + categories_field: Optional[list[str]] = None, max_tokens_per_doc: int = -1, max_chars_per_doc: int = -1, tokenizer: Tokenizer = SegtokTokenizer(), @@ -195,7 +195,7 @@ def __init__( def _parse_document_to_sentence( self, text: str, - labels: List[str], + labels: list[str], tokenizer: Union[bool, Tokenizer], ): if self.max_chars_per_doc > 0: diff --git a/flair/datasets/biomedical.py b/flair/datasets/biomedical.py index 28f4aca98b..e99a71ccf7 100644 --- a/flair/datasets/biomedical.py +++ b/flair/datasets/biomedical.py @@ -7,6 +7,7 @@ import sys from abc import ABC, abstractmethod from collections import defaultdict, deque +from collections.abc import Iterable from copy import copy from operator import attrgetter from pathlib import Path @@ -18,7 +19,7 @@ StreamError, TarError, ) -from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import NamedTuple, Optional, Union from zipfile import BadZipFile, LargeZipFile import ftfy @@ -56,7 +57,7 @@ class Entity: text as well as the type of entity (e.g. Chemical, Gene, and so on). """ - def __init__(self, char_span: Tuple[int, int], entity_type: str) -> None: + def __init__(self, char_span: tuple[int, int], entity_type: str) -> None: assert char_span[0] < char_span[1] self.char_span = range(*char_span) self.type = entity_type @@ -98,9 +99,9 @@ class InternalBioNerDataset: def __init__( self, - documents: Dict[str, str], - entities_per_document: Dict[str, List[Entity]], - entity_types: List[str] = [], + documents: dict[str, str], + entities_per_document: dict[str, list[Entity]], + entity_types: list[str] = [], ): self.documents = documents self.entities_per_document = entities_per_document @@ -134,7 +135,7 @@ def merge_datasets(data_sets: Iterable[InternalBioNerDataset]): def filter_and_map_entities( - dataset: InternalBioNerDataset, entity_type_to_canonical: Dict[str, str] + dataset: InternalBioNerDataset, entity_type_to_canonical: dict[str, str] ) -> InternalBioNerDataset: mapped_entities_per_document = {} entity_types = list(entity_type_to_canonical.values()) @@ -223,7 +224,7 @@ def bioc_to_internal(bioc_file: Path): for document in Tqdm.tqdm(documents, desc="Converting to internal"): document_id = document.xpath("./id")[0].text - texts: List[str] = [] + texts: list[str] = [] entities = [] for passage in document.xpath("passage"): @@ -358,7 +359,7 @@ def __init__( """ self.sentence_splitter = sentence_splitter - def process_dataset(self, datasets: Dict[str, InternalBioNerDataset], out_dir: Path): + def process_dataset(self, datasets: dict[str, InternalBioNerDataset], out_dir: Path): if "train" in datasets: self.write_to_conll(datasets["train"], out_dir / (self.sentence_splitter.name + "_train.conll")) if "dev" in datasets: @@ -450,7 +451,7 @@ def to_internal(self, data_folder: Path) -> InternalBioNerDataset: @staticmethod @abstractmethod - def split_url() -> Union[str, List[str]]: + def split_url() -> Union[str, list[str]]: raise NotImplementedError def get_corpus_sentence_splitter(self) -> Optional[SentenceSplitter]: @@ -596,8 +597,8 @@ def download_dataset(cls, data_dir: Path) -> Path: @classmethod def parse_dataset(cls, original_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} tree = etree.parse(str(original_file)) sentence_elems = tree.xpath("//sentence") @@ -647,7 +648,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -726,14 +727,14 @@ def download_and_prepare_test(cls, data_folder: Path, sentence_tag: str) -> Inte @classmethod def read_file(cls, input_iob_file: Path, sentence_tag: str) -> InternalBioNerDataset: - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = defaultdict(list) + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = defaultdict(list) with open(str(input_iob_file), encoding="utf8") as file_reader: document_id: Optional[str] = None document_text: Optional[str] = None - entities: List[Entity] = [] + entities: list[Entity] = [] entity_type: Optional[str] = None entity_start = 0 @@ -818,7 +819,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -994,7 +995,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/cellfinder_cellline", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/cellfinder_species", @@ -1009,7 +1010,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1176,7 +1177,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1566,7 +1567,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(dataset, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1747,8 +1748,8 @@ def download_dataset(data_dir: Path): @classmethod def parse_dataset(cls, original_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} tree = etree.parse(str(original_file)) document_elems = tree.xpath("//document") @@ -1905,7 +1906,7 @@ def split_url() -> str: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return LINNEAUS.download_and_parse_dataset(data_dir) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1995,7 +1996,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2021,7 +2022,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2033,7 +2034,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/CDRDisease", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/CDRChem", @@ -2052,7 +2053,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2167,7 +2168,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2190,7 +2191,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2213,7 +2214,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2230,7 +2231,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/variome_gene", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/variome_disease", @@ -2247,7 +2248,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2343,7 +2344,7 @@ def parse_input_file(input_file: Path): with open(str(input_file), encoding="utf8") as file: document_id = "" document_text = "" - entities: List[Entity] = [] + entities: list[Entity] = [] c = 1 for line in file: @@ -2406,7 +2407,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, dev_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2455,13 +2456,13 @@ def download_corpus(self, data_folder: Path) -> Path: @staticmethod def parse_input_file(input_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} with open(str(input_file), encoding="iso-8859-1") as file: document_id = None document_text = "" - entities: List[Entity] = [] + entities: list[Entity] = [] entity_type = None entity_start = 0 @@ -2584,7 +2585,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2605,7 +2606,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2628,7 +2629,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/scai_chemicals", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/scai_disease", @@ -2641,7 +2642,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2763,7 +2764,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2863,7 +2864,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2945,7 +2946,7 @@ def download_dev_corpus(cls, data_dir) -> Path: @staticmethod def parse_input_file(text_file: Path, ann_file: Path) -> InternalBioNerDataset: documents = {} - entities_per_document: Dict[str, List[Entity]] = {} + entities_per_document: dict[str, list[Entity]] = {} document_title_length = {} @@ -3010,7 +3011,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, dev_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3071,8 +3072,8 @@ def download_corpus(cls, data_dir: Path) -> Path: @staticmethod def parse_corpus(text_dir: Path, gold_file: Path) -> InternalBioNerDataset: - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} text_files = [file for file in os.listdir(str(text_dir)) if not file.startswith(".")] @@ -3122,7 +3123,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return DECA.parse_corpus(text_dir, gold_file) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3221,7 +3222,7 @@ def parse_corpus(corpus_dir: Path, sentence_separator: str) -> InternalBioNerDat akt_pos += len(words[i]) + 1 sentences += [tmp_sentence] - pre_entities: List[List[Tuple[int, int, str]]] = [[] for _ in sentences] + pre_entities: list[list[tuple[int, int, str]]] = [[] for _ in sentences] for protein in protein_tree: for span in protein.get("span").split(","): start = word_to_id[span.split("..")[0]] @@ -3287,7 +3288,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3450,8 +3451,8 @@ def parse_dataset(data_dir: Path) -> InternalBioNerDataset: ] text_files = sorted(text_files) - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} for text_file in sorted(text_files): document_id = os.path.basename(text_file).split("_")[0] @@ -3590,7 +3591,7 @@ def parse_test_dataset(cls, data_folder: Path) -> InternalBioNerDataset: @staticmethod def parse_dataset(text_file: Path, ann_file: Path) -> InternalBioNerDataset: documents = {} - entities_per_document: Dict[str, List[Entity]] = {} + entities_per_document: dict[str, list[Entity]] = {} with open(str(text_file), encoding="utf8") as text_file_reader: for line in text_file_reader: @@ -3733,7 +3734,7 @@ def download_dev_corpus(cls, data_dir) -> Path: @staticmethod def parse_input_file(text_file: Path, ann_file: Path) -> InternalBioNerDataset: documents = {} - entities_per_document: Dict[str, List[Entity]] = {} + entities_per_document: dict[str, list[Entity]] = {} document_abstract_length = {} with open(str(text_file), encoding="utf8") as text_reader: @@ -3806,7 +3807,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: dataset = merge_datasets([train_data, dev_data]) return filter_and_map_entities(dataset, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3945,7 +3946,7 @@ def to_internal(self, data_dir: Path, annotator: int = 0) -> InternalBioNerDatas dataset = CHEBI.parse_dataset(corpus_dir, annotator=annotator) return filter_and_map_entities(dataset, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -4038,7 +4039,7 @@ def __init__( @staticmethod @abstractmethod - def download_corpus(data_folder: Path) -> Tuple[Path, Path, Path]: + def download_corpus(data_folder: Path) -> tuple[Path, Path, Path]: pass @staticmethod @@ -4083,7 +4084,7 @@ class BIONLP2013_PC(BioNLPCorpus): """ @staticmethod - def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]: + def download_corpus(download_folder: Path) -> tuple[Path, Path, Path]: train_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_training_data.tar.gz" dev_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_development_data.tar.gz" test_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_test_data.tar.gz" @@ -4125,7 +4126,7 @@ class BIONLP2013_CG(BioNLPCorpus): """ @staticmethod - def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]: + def download_corpus(download_folder: Path) -> tuple[Path, Path, Path]: url = "https://github.com/openbiocorpora/bionlp-st-2013-cg/archive/refs/heads/master.zip" cached_path(url, download_folder) @@ -4292,9 +4293,10 @@ def download_corpora(download_dir: Path): @staticmethod def convert_and_write(download_folder, data_folder, tag_type): data_folder.mkdir(parents=True, exist_ok=True) - with (download_folder / "train.tsv").open(encoding="utf8") as f_in, (data_folder / "train.conll").open( - "w", encoding="utf8" - ) as f_out: + with ( + (download_folder / "train.tsv").open(encoding="utf8") as f_in, + (data_folder / "train.conll").open("w", encoding="utf8") as f_out, + ): for line in f_in: if not line.strip(): f_out.write("\n") @@ -4305,9 +4307,10 @@ def convert_and_write(download_folder, data_folder, tag_type): tag = tag + "-" + tag_type f_out.write(f"{token} {tag}\n") - with (download_folder / "devel.tsv").open(encoding="utf8") as f_in, (data_folder / "dev.conll").open( - "w", encoding="utf8" - ) as f_out: + with ( + (download_folder / "devel.tsv").open(encoding="utf8") as f_in, + (data_folder / "dev.conll").open("w", encoding="utf8") as f_out, + ): for line in f_in: if not line.strip(): f_out.write("\n") @@ -4317,9 +4320,10 @@ def convert_and_write(download_folder, data_folder, tag_type): tag = tag + "-" + tag_type f_out.write(f"{token} {tag}\n") - with (download_folder / "test.tsv").open(encoding="utf8") as f_in, (data_folder / "test.conll").open( - "w", encoding="utf8" - ) as f_out: + with ( + (download_folder / "test.tsv").open(encoding="utf8") as f_in, + (data_folder / "test.conll").open("w", encoding="utf8") as f_out, + ): for line in f_in: if not line.strip(): f_out.write("\n") @@ -4638,7 +4642,7 @@ def download_corpus(cls, data_dir: Path) -> Path: @staticmethod def prepare_splits( data_dir: Path, corpus: InternalBioNerDataset - ) -> Tuple[InternalBioNerDataset, InternalBioNerDataset, InternalBioNerDataset]: + ) -> tuple[InternalBioNerDataset, InternalBioNerDataset, InternalBioNerDataset]: splits_dir = data_dir / "splits" os.makedirs(str(splits_dir), exist_ok=True) @@ -4734,7 +4738,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -4792,7 +4796,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -4896,7 +4900,7 @@ def parse_corpus(input_file: Path) -> InternalBioNerDataset: prev_sentence_id: Optional[str] = None document_text: Optional[str] = None - entities: List[Entity] = [] + entities: list[Entity] = [] offset: Optional[int] = None for line in azdz_reader: @@ -5014,7 +5018,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return corpus_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -5221,7 +5225,7 @@ def __init__( sample_missing_splits=True, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: """Return the mapping of entity type given in the dataset to canonical types. Note, if a entity type is not present in the map it is discarded. @@ -5279,8 +5283,8 @@ def build_corpus_directory_name(self, dataset_name: str) -> str: def to_internal_dataset(self, dataset, split: str) -> InternalBioNerDataset: """Converts a dataset given in hugging datasets format to our internal corpus representation.""" - id_to_text: Dict[str, str] = {} - id_to_entities: Dict[str, list] = {} + id_to_text: dict[str, str] = {} + id_to_entities: dict[str, list] = {} entity_type_set = set() for document in dataset[split]: document_id = document["document_id"] @@ -5331,10 +5335,10 @@ def to_internal_dataset(self, dataset, split: str) -> InternalBioNerDataset: def bin_search_passage( self, - passages: List[Tuple[str, List[Tuple[int, int]]]], + passages: list[tuple[str, list[tuple[int, int]]]], low: int, high: int, - entity: Dict, + entity: dict, ): """Helper methods to find the passage to a given entity mention (incl. offset). @@ -5381,7 +5385,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "Gene": GENE_TAG, "GENERIF": GENE_TAG, @@ -5414,7 +5418,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"GENE-N": GENE_TAG, "GENE-Y": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5441,7 +5445,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"CHEMICAL": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5452,7 +5456,7 @@ class HUNER_ALL_DRUGPROT(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="drugprot", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"GENE-N": GENE_TAG, "GENE-Y": GENE_TAG, "CHEMICAL": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5479,7 +5483,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"GeneOrGeneProduct": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5506,7 +5510,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"ChemicalEntity": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5533,7 +5537,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"DiseaseOrPhenotypicFeature": DISEASE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5560,7 +5564,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"OrganismTaxon": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5587,7 +5591,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"CellLine": CELL_LINE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5598,7 +5602,7 @@ class HUNER_ALL_BIORED(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="biored", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "GeneOrGeneProduct": GENE_TAG, "ChemicalEntity": CHEMICAL_TAG, @@ -5631,7 +5635,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5658,7 +5662,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"compound": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5669,7 +5673,7 @@ class HUNER_ALL_CPI(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="cpi", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"protein": GENE_TAG, "compound": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5696,7 +5700,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene_or_gene_product": GENE_TAG, "Complex": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5723,7 +5727,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Simple_chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5734,7 +5738,7 @@ class HUNER_ALL_BIONLP_ST_2013_PC(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="bionlp_st_2013_pc", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "Gene_or_gene_product": GENE_TAG, "Complex": GENE_TAG, @@ -5765,7 +5769,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5792,7 +5796,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5819,7 +5823,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5846,7 +5850,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5873,7 +5877,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Organism": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5884,7 +5888,7 @@ class HUNER_ALL_BIONLP_ST_2011_ID(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="bionlp_st_2011_id", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "Protein": GENE_TAG, "Chemical": CHEMICAL_TAG, @@ -5915,7 +5919,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5942,7 +5946,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5969,7 +5973,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Microorganism": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5996,7 +6000,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"gene": GENE_TAG, "protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6023,7 +6027,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6050,7 +6054,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"species": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6077,7 +6081,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: # TODO whether cell or cell line is the correct tag return {"cellline": CELL_LINE_TAG} @@ -6089,7 +6093,7 @@ class HUNER_ALL_BIOID(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="bioid", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: # TODO whether cell or cell line is the correct tag return { "gene": GENE_TAG, @@ -6123,7 +6127,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene": GENE_TAG, "FamilyName": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6155,7 +6159,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"progene_text": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6182,7 +6186,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6209,7 +6213,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6224,7 +6228,7 @@ def __init__(self, *args, **kwargs): **kwargs, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py index 0bbc471818..363c84e561 100644 --- a/flair/datasets/document_classification.py +++ b/flair/datasets/document_classification.py @@ -2,8 +2,9 @@ import json import logging import os +import tarfile from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import flair from flair.data import ( @@ -36,8 +37,8 @@ def __init__( filter_if_longer_than: int = -1, tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(), memory_mode: str = "partial", - label_name_map: Optional[Dict[str, str]] = None, - skip_labels: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + skip_labels: Optional[list[str]] = None, allow_examples_without_labels=False, sample_missing_splits: bool = True, encoding: str = "utf-8", @@ -131,8 +132,8 @@ def __init__( filter_if_longer_than: int = -1, tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(), memory_mode: str = "partial", - label_name_map: Optional[Dict[str, str]] = None, - skip_labels: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + skip_labels: Optional[list[str]] = None, allow_examples_without_labels=False, encoding: str = "utf-8", ) -> None: @@ -277,11 +278,7 @@ def _parse_line_to_sentence(self, line: str, label_prefix: str, tokenizer: Union return None def is_in_memory(self) -> bool: - if self.memory_mode == "disk": - return False - if self.memory_mode == "partial": - return False - return True + return self.memory_mode not in ["disk", "partial"] def __len__(self) -> int: return self.total_sentence_count @@ -309,7 +306,7 @@ class CSVClassificationCorpus(Corpus): def __init__( self, data_folder: Union[str, Path], - column_name_map: Dict[int, str], + column_name_map: dict[int, str], label_type: str, name: str = "csv_corpus", train_file=None, @@ -404,7 +401,7 @@ class CSVClassificationDataset(FlairDataset): def __init__( self, path_to_file: Union[str, Path], - column_name_map: Dict[int, str], + column_name_map: dict[int, str], label_type: str, max_tokens_per_doc: int = -1, max_chars_per_doc: int = -1, @@ -453,8 +450,8 @@ def __init__( self.total_sentence_count: int = 0 # most data sets have the token text in the first column, if not, pass 'text' as column - self.text_columns: List[int] = [] - self.pair_columns: List[int] = [] + self.text_columns: list[int] = [] + self.pair_columns: list[int] = [] for column in column_name_map: if column_name_map[column] == "text": self.text_columns.append(column) @@ -567,7 +564,7 @@ class AMAZON_REVIEWS(ClassificationCorpus): def __init__( self, split_max: int = 30000, - label_name_map: Dict[str, str] = { + label_name_map: dict[str, str] = { "1.0": "NEGATIVE", "2.0": "NEGATIVE", "3.0": "NEGATIVE", @@ -955,9 +952,10 @@ def __init__( original_filenames = original_filenames[:-1] if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open( - data_folder / new_filename, "w", encoding="utf-8" - ) as write_fp: + with ( + open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, + open(data_folder / new_filename, "w", encoding="utf-8") as write_fp, + ): csv_reader = csv.reader( open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True ) @@ -1048,9 +1046,10 @@ def __init__( label_list.append(labels[int(line) - 1]) # handle data file - with (data_path / "original" / "title_StackOverflow.txt").open(encoding="latin1") as open_fp, ( - data_folder / "train.txt" - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_path / "original" / "title_StackOverflow.txt").open(encoding="latin1") as open_fp, + (data_folder / "train.txt").open("w", encoding="utf-8") as write_fp, + ): for idx, line in enumerate(open_fp): line = line.rstrip() @@ -1104,9 +1103,10 @@ def __init__( os.makedirs(data_folder) # create train.txt file from CSV - with open(data_folder / "train.txt", "w") as train_file, open( - senteval_folder / "training.1600000.processed.noemoticon.csv", encoding="latin-1" - ) as csv_train: + with ( + open(data_folder / "train.txt", "w") as train_file, + open(senteval_folder / "training.1600000.processed.noemoticon.csv", encoding="latin-1") as csv_train, + ): csv_reader = csv.reader(csv_train) for row in csv_reader: @@ -1115,9 +1115,10 @@ def __init__( train_file.write(f"__label__{label} {text}\n") # create test.txt file from CSV - with (data_folder / "test.txt").open("w", encoding="utf-8") as train_file, ( - senteval_folder / "testdata.manual.2009.06.14.csv" - ).open(encoding="latin-1") as csv_train: + with ( + (data_folder / "test.txt").open("w", encoding="utf-8") as train_file, + (senteval_folder / "testdata.manual.2009.06.14.csv").open(encoding="latin-1") as csv_train, + ): csv_reader = csv.reader(csv_train) for row in csv_reader: @@ -1384,9 +1385,10 @@ def __init__( # create train dev and test files in fasttext format for new_filename, original_filename in zip(new_filenames, original_filenames): - with open(data_folder / new_filename, "a") as out_file, open( - data_folder / "raw" / original_filename - ) as in_file: + with ( + open(data_folder / new_filename, "a") as out_file, + open(data_folder / "raw" / original_filename) as in_file, + ): for line in in_file: fields = line.split("\t") label = "POSITIVE" if fields[1].rstrip() == "1" else "NEGATIVE" @@ -1437,9 +1439,10 @@ def __init__( # convert to FastText format for split in ["train", "dev", "test"]: - with (data_folder / f"{split}.txt").open("w", encoding="utf-8") as train_file, ( - data_folder / "raw" / f"stsa.fine.{split}" - ).open(encoding="latin1") as file: + with ( + (data_folder / f"{split}.txt").open("w", encoding="utf-8") as train_file, + (data_folder / "raw" / f"stsa.fine.{split}").open(encoding="latin1") as file, + ): for line in file: train_file.write(f"__label__{line[0]} {line[2:]}") @@ -1496,9 +1499,10 @@ def __init__( # create train and dev splits in fasttext format for split in ["train", "dev"]: - with open(data_folder / "CoLA" / (split + ".txt"), "a") as out_file, open( - data_folder / "CoLA" / "original" / (split + ".tsv") - ) as in_file: + with ( + open(data_folder / "CoLA" / (split + ".txt"), "a") as out_file, + open(data_folder / "CoLA" / "original" / (split + ".tsv")) as in_file, + ): for line in in_file: fields = line.rstrip().split("\t") label = int(fields[1]) @@ -1506,9 +1510,10 @@ def __init__( out_file.write(f"__label__{label_map[label]} {sentence}\n") # create eval_dataset file with no labels - with open(data_folder / "CoLA" / "eval_dataset.txt", "a") as out_file, open( - data_folder / "CoLA" / "original" / "test.tsv" - ) as in_file: + with ( + open(data_folder / "CoLA" / "eval_dataset.txt", "a") as out_file, + open(data_folder / "CoLA" / "original" / "test.tsv") as in_file, + ): for line in in_file: fields = line.rstrip().split("\t") sentence = fields[1] @@ -1702,9 +1707,10 @@ def __init__( data_path = flair.cache_root / "datasets" / dataset_name / "raw" # create correctly formated txt files for name in ["train", "test", "dev"]: - with (data_folder / (name + ".txt")).open("w", encoding="utf-8") as txt_file, ( - data_path / (name + ".tsv") - ).open(encoding="utf-8") as tsv_file: + with ( + (data_folder / (name + ".txt")).open("w", encoding="utf-8") as txt_file, + (data_path / (name + ".tsv")).open(encoding="utf-8") as tsv_file, + ): lines = tsv_file.readlines() for line in lines: row = line.split("\t") @@ -1764,9 +1770,10 @@ def __init__( if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, ( - data_folder / new_filename - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, + (data_folder / new_filename).open("w", encoding="utf-8") as write_fp, + ): for line in open_fp: line = line.rstrip() fields = line.split() @@ -1820,9 +1827,10 @@ def __init__( if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, ( - data_folder / new_filename - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, + (data_folder / new_filename).open("w", encoding="utf-8") as write_fp, + ): for line in open_fp: line = line.rstrip() fields = line.split() @@ -1887,21 +1895,20 @@ def __init__( if not (data_folder / "train.txt").is_file(): cached_path(url, original) - import tarfile - - tar = tarfile.open(original / "yahoo_answers_csv.tgz", "r:gz") - members = [] + with tarfile.open(original / "yahoo_answers_csv.tgz", "r:gz") as tar: + members = [] - for member in tar.getmembers(): - if "test.csv" in member.name or "train.csv" in member.name: - members.append(member) + for member in tar.getmembers(): + if "test.csv" in member.name or "train.csv" in member.name: + members.append(member) - tar.extractall(original, members=members) + tar.extractall(original, members=members) for name in ["train", "test"]: - with (original / "yahoo_answers_csv" / (name + ".csv")).open(encoding="utf-8") as file, ( - data_folder / (name + ".txt") - ).open("w", encoding="utf-8") as writer: + with ( + (original / "yahoo_answers_csv" / (name + ".csv")).open(encoding="utf-8") as file, + (data_folder / (name + ".txt")).open("w", encoding="utf-8") as writer, + ): reader = csv.reader(file) for row in reader: writer.write("__label__" + label_map[row[0]] + " " + row[1] + "\n") @@ -1963,9 +1970,10 @@ def __init__( if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with (data_folder / "original" / original_filename).open(encoding="utf-8") as open_fp, ( - data_folder / task_setting / new_filename - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_folder / "original" / original_filename).open(encoding="utf-8") as open_fp, + (data_folder / task_setting / new_filename).open("w", encoding="utf-8") as write_fp, + ): for line in open_fp: line = line.rstrip() fields = line.split("\t") diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 20f2caefdd..f74bad7092 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -4,8 +4,9 @@ import logging import os import re +from collections.abc import Iterable, Iterator from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Optional, Union +from typing import Any, Optional, Union import requests from bioc import biocxml, pubtator @@ -47,7 +48,7 @@ def __init__( self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} # one name can map to multiple concepts - self._text_to_index: Dict[str, List[str]] = {} + self._text_to_index: dict[str, list[str]] = {} for candidate in candidates: for text in [candidate.concept_name, *candidate.synonyms]: if text not in self._text_to_index: @@ -60,11 +61,11 @@ def database_name(self) -> str: return self._dataset_name @property - def text_to_index(self) -> Dict[str, List[str]]: + def text_to_index(self) -> dict[str, list[str]]: return self._text_to_index @property - def candidates(self) -> List[EntityCandidate]: + def candidates(self) -> list[EntityCandidate]: return list(self._idx_to_candidates.values()) def __getitem__(self, item: str) -> EntityCandidate: @@ -80,18 +81,18 @@ def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": # NOTE: EntityLinkingDictionary are lazy-loaded from a preprocessed file. # Use this class to load into memory all candidates class InMemoryEntityLinkingDictionary(EntityLinkingDictionary): - def __init__(self, candidates: List[EntityCandidate], dataset_name: str): + def __init__(self, candidates: list[EntityCandidate], dataset_name: str): self._dataset_name = dataset_name super().__init__(candidates, dataset_name=dataset_name) - def to_state(self) -> Dict[str, Any]: + def to_state(self) -> dict[str, Any]: return { "dataset_name": self._dataset_name, "candidates": [candidate.to_dict() for candidate in self._idx_to_candidates.values()], } @classmethod - def from_state(cls, state: Dict[str, Any]) -> "InMemoryEntityLinkingDictionary": + def from_state(cls, state: dict[str, Any]) -> "InMemoryEntityLinkingDictionary": return cls( dataset_name=state["dataset_name"], candidates=[EntityCandidate(**candidate) for candidate in state["candidates"]], @@ -488,7 +489,7 @@ def __init__( to point to a different folder but typically this should not be necessary. in_memory: bool If True, keeps dataset in memory giving speedups in training. - column_format: Dict[int, str] + column_format: dict[int, str] The column-format to specify which columns correspond to the text or label types. """ base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) @@ -776,9 +777,10 @@ def __init__( wiki_language + "_dev.tsv", ], ): - with open(doc_path, encoding="utf-8") as read, open( - data_folder / file_name, "w", encoding="utf-8" - ) as write: + with ( + open(doc_path, encoding="utf-8") as read, + open(data_folder / file_name, "w", encoding="utf-8") as write, + ): # ignore first line read.readline() line = read.readline() @@ -1208,9 +1210,10 @@ def __init__( if not parsed_dataset.exists(): original_file_path = cached_path(f"{tweeki_gold_el_path}", Path("datasets") / dataset_name) - with open(original_file_path, encoding="utf-8") as read, open( - parsed_dataset, "w", encoding="utf-8" - ) as write: + with ( + open(original_file_path, encoding="utf-8") as read, + open(parsed_dataset, "w", encoding="utf-8") as write, + ): line = read.readline() while line: if line.startswith("#"): @@ -1274,9 +1277,10 @@ def __init__( with open(data_folder / corpus_file_name, "w", encoding="utf-8") as txtout: # First parse the post titles - with open(data_folder / "posts.tsv", encoding="utf-8") as tsvin1, open( - data_folder / "gold_post_annotations.tsv", encoding="utf-8" - ) as tsvin2: + with ( + open(data_folder / "posts.tsv", encoding="utf-8") as tsvin1, + open(data_folder / "gold_post_annotations.tsv", encoding="utf-8") as tsvin2, + ): posts = csv.reader(tsvin1, delimiter="\t") self.post_annotations = csv.reader(tsvin2, delimiter="\t") self.curr_annot = next(self.post_annotations) @@ -1312,13 +1316,14 @@ def __init__( ) # Then parse the comments - with open(data_folder / "comments.tsv", encoding="utf-8") as tsvin3, open( - data_folder / "gold_comment_annotations.tsv", encoding="utf-8" - ) as tsvin4: + with ( + open(data_folder / "comments.tsv", encoding="utf-8") as tsvin3, + open(data_folder / "gold_comment_annotations.tsv", encoding="utf-8") as tsvin4, + ): self.comments = csv.reader(tsvin3, delimiter="\t") self.comment_annotations = csv.reader(tsvin4, delimiter="\t") self.curr_annot = next(self.comment_annotations) - self.curr_row: Optional[List[str]] = next(self.comments) + self.curr_row: Optional[list[str]] = next(self.comments) self.stop_iter = False # Iterate over the comments.tsv file, until the end is reached @@ -1545,7 +1550,7 @@ def make_line(word, begin_or_inside, attributes): return line - def split_span(word_fields: List[str], datasetname: str): + def split_span(word_fields: list[str], datasetname: str): """Function that splits a word if necessary, i.e. if it is a multiple-word-span. Parameters @@ -1646,12 +1651,12 @@ def determine_tsv_file(filename: str, data_folder: Path, cut_multisense: bool = class WSD_UFSAC(MultiCorpus): def __init__( self, - filenames: Union[str, List[str]] = ["masc", "semcor"], + filenames: Union[str, list[str]] = ["masc", "semcor"], base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, cut_multisense: bool = True, columns={0: "text", 3: "sense"}, - banned_sentences: Optional[List[str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits_in_multicorpus: Union[bool, str] = True, sample_missing_splits_in_each_corpus: Union[bool, str] = True, use_raganato_ALL_as_test_data: bool = False, @@ -1713,7 +1718,7 @@ def __init__( if isinstance(filenames, str): filenames = [filenames] - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] log.info("Transforming data into column format and creating corpora...") @@ -1784,8 +1789,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: bool = True, cut_multisense: bool = True, ) -> None: @@ -1847,8 +1852,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, cut_multisense: bool = True, use_raganato_ALL_as_test_data: bool = False, @@ -1922,8 +1927,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, use_raganato_ALL_as_test_data: bool = False, ) -> None: @@ -1994,8 +1999,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, cut_multisense: bool = True, use_raganato_ALL_as_test_data: bool = False, @@ -2070,8 +2075,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, cut_multisense: bool = True, use_raganato_ALL_as_test_data: bool = False, @@ -2147,8 +2152,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, use_raganato_ALL_as_test_data: bool = False, ) -> None: @@ -2230,7 +2235,7 @@ def __init__( self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el", - norm_keys: List[str] = ["db_name", "db_id"], + norm_keys: list[str] = ["db_name", "db_id"], **kwargs, ) -> None: self.label_type = label_type @@ -2250,14 +2255,14 @@ def __init__( ) @abc.abstractmethod - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: pass @abc.abstractmethod - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: pass - def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]: + def _dict_to_sentences(self, entry: dict[str, Any]) -> list[Sentence]: entities = [entity for entity in entry["entities"] if entity["normalized"]] tokenized_passages = [ @@ -2326,7 +2331,7 @@ def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]: sent_s[start_token_idx : end_token_idx + 1].add_label(self.label_type, mention_id) return passage_sentences - def _files_to_dataset(self, paths: Union[Path, List[Path]]) -> FlairDatapointDataset: + def _files_to_dataset(self, paths: Union[Path, list[Path]]) -> FlairDatapointDataset: if isinstance(paths, Path): paths = [paths] all_sentences = [] @@ -2347,7 +2352,7 @@ class BIGBIO_EL_NCBI_DISEASE(BigBioEntityLinkingCorpus): def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-diseases", **kwargs) -> None: super().__init__(base_path, label_type, **kwargs) - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: download_urls = { "train": ( "NCBItrainset_corpus.txt", @@ -2362,7 +2367,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat "https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBItestset_corpus.zip", ), } - results_files: Dict[str, Union[Path, List[Path]]] = {} + results_files: dict[str, Union[Path, list[Path]]] = {} for split, (filename, url) in download_urls.items(): result_path = data_folder / filename @@ -2376,7 +2381,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat return results_files - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: with open(filepath) as f: for doc in pubtator.iterparse(f): unified_example = { @@ -2449,7 +2454,7 @@ class BIGBIO_EL_BC5CDR_CHEMICAL(BigBioEntityLinkingCorpus): def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-chemical", **kwargs) -> None: super().__init__(base_path, label_type, **kwargs) - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: url = "https://huggingface.co/datasets/bigbio/bc5cdr/resolve/main/CDR_Data.zip" path = cached_path(url, data_folder) @@ -2458,7 +2463,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat unpack_file(path, data_folder) assert data_folder.exists() - results_files: Dict[str, Union[Path, List[Path]]] = { + results_files: dict[str, Union[Path, list[Path]]] = { "train": data_path / "CDR_TrainingSet.BioC.xml", "dev": data_path / "CDR_DevelopmentSet.BioC.xml", "test": data_path / "CDR_TestSet.BioC.xml", @@ -2497,7 +2502,7 @@ def _get_bioc_entity(self, span, db_id_key="MESH"): "normalized": normalized, } - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: reader = biocxml.BioCXMLDocumentReader(str(filepath)) for i, xdoc in enumerate(reader): @@ -2542,7 +2547,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str self._re_tax_id = re.compile(r"(?P\d+)\([tT]ax:(?P\d+)\)") super().__init__(base_path, label_type, norm_keys=["db_id"], **kwargs) - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: url = "https://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/tmTools/download/GNormPlus/GNormPlusCorpus.zip" path = cached_path(url, data_folder) @@ -2551,7 +2556,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat unpack_file(path, data_folder) assert data_folder.exists() - results_files: Dict[str, Union[Path, List[Path]]] = { + results_files: dict[str, Union[Path, list[Path]]] = { "train": [data_path / "BC2GNtrain.BioC.xml", data_path / "NLMIAT.BioC.xml"], "test": data_path / "BC2GNtest.BioC.xml", } @@ -2595,7 +2600,7 @@ def _parse_bioc_entity(self, span, db_id_key="NCBIGene", insert_tax_id=False): "normalized": normalized, } - def _adjust_entity_offsets(self, text: str, entities: List[Dict]): + def _adjust_entity_offsets(self, text: str, entities: list[dict]): for entity in entities: start, end = entity["offsets"][0] entity_mention = entity["text"][0] @@ -2605,7 +2610,7 @@ def _adjust_entity_offsets(self, text: str, entities: List[Dict]): elif text[start : end - 1] == entity_mention: entity["offsets"] = [(start, end - 1)] - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: with filepath.open("r") as f: collection = biocxml.load(f) diff --git a/flair/datasets/ocr.py b/flair/datasets/ocr.py index bf60b2b0d6..4a58e4e7d3 100644 --- a/flair/datasets/ocr.py +++ b/flair/datasets/ocr.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union import gdown.download_folder import PIL @@ -20,7 +20,7 @@ def __init__( encoding: str = "utf-8", load_images: bool = False, normalize_coords_to_thousands: bool = True, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, ) -> None: """Instantiates a Dataset from a OCR-Json format. @@ -132,7 +132,7 @@ def __init__( in_memory: bool = True, load_images: bool = False, normalize_coords_to_thousands: bool = True, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, **corpusargs, ) -> None: """Instantiates a Corpus from a OCR-Json format. @@ -205,7 +205,7 @@ def __init__( in_memory: bool = True, load_images: bool = False, normalize_coords_to_thousands: bool = True, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, **corpusargs, ) -> None: """Instantiates the SROIE corpus with perfect ocr boxes. diff --git a/flair/datasets/relation_extraction.py b/flair/datasets/relation_extraction.py index 30709a14c4..871811abc2 100644 --- a/flair/datasets/relation_extraction.py +++ b/flair/datasets/relation_extraction.py @@ -5,8 +5,9 @@ import os import re from collections import defaultdict +from collections.abc import Iterable from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import conllu import gdown @@ -279,7 +280,7 @@ def extract_and_convert_to_conllu(self, data_file, data_folder): token_list = self._tacred_example_to_token_list(example) target_file.write(token_list.serialize()) - def _tacred_example_to_token_list(self, example: Dict[str, Any]) -> conllu.TokenList: + def _tacred_example_to_token_list(self, example: dict[str, Any]) -> conllu.TokenList: id_ = example["id"] tokens = example["token"] ner = example["stanford_ner"] @@ -379,7 +380,7 @@ def _parse_incr(self, source_file) -> Iterable[conllu.TokenList]: } metadata_parsers = {"__fallback__": lambda k, v: tuple(k.split())} - lines: List[str] = [] + lines: list[str] = [] for index, line in enumerate(source_file): if index > 0 and line.startswith("#"): source_str = "".join(lines) @@ -416,9 +417,10 @@ def convert_to_conllu(self, source_data_folder: Path, data_folder): ] for source_filename, target_filename in zip(source_filenames, target_filenames): - with (source_data_folder / source_filename).open(encoding="utf-8") as source_file, ( - data_folder / target_filename - ).open("w", encoding="utf-8") as target_file: + with ( + (source_data_folder / source_filename).open(encoding="utf-8") as source_file, + (data_folder / target_filename).open("w", encoding="utf-8") as target_file, + ): # write CoNLL-U Plus header target_file.write("# global.columns = id form ner\n") @@ -426,7 +428,7 @@ def convert_to_conllu(self, source_data_folder: Path, data_folder): token_list = self._src_token_list_to_token_list(src_token_list) target_file.write(token_list.serialize()) - def _bio_tags_to_spans(self, tags: List[str]) -> List[Tuple[int, int]]: + def _bio_tags_to_spans(self, tags: list[str]) -> list[tuple[int, int]]: spans = [] span_start = 0 span_end = 0 @@ -590,7 +592,7 @@ def extract_and_convert_to_conllu(self, data_file, data_folder): ent2 = arg2.split(":")[1] pmid_to_relations[pmid].add((rel_type, ent1, ent2)) - tokenlists: List[conllu.TokenList] = [] + tokenlists: list[conllu.TokenList] = [] with zip_file.open( f"drugprot-gs-training-development/{split}/drugprot_{split}_abstracs.tsv" ) as abstracts_file: @@ -652,13 +654,13 @@ def has_overlap(self, a, b): def drugprot_document_to_tokenlists( self, pmid: str, - title_sentences: List[Sentence], - abstract_sentences: List[Sentence], + title_sentences: list[Sentence], + abstract_sentences: list[Sentence], abstract_offset: int, - entities: Dict[str, Tuple[str, int, int, str]], - relations: Set[Tuple[str, str, str]], - ) -> List[conllu.TokenList]: - tokenlists: List[conllu.TokenList] = [] + entities: dict[str, tuple[str, int, int, str]], + relations: set[tuple[str, str, str]], + ) -> list[conllu.TokenList]: + tokenlists: list[conllu.TokenList] = [] sentence_id = 1 for offset, sents in [ (0, title_sentences), diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 6699548498..55e50723d1 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4,17 +4,13 @@ import os import re import shutil +import tarfile from collections import defaultdict +from collections.abc import Iterable, Iterator from pathlib import Path from typing import ( Any, - DefaultDict, - Dict, - Iterable, - Iterator, - List, Optional, - Tuple, Union, cast, ) @@ -224,7 +220,7 @@ def __init__( self.label_type = label_type self.path_to_json_file = path_to_json_file - self.sentences: List[Sentence] = [] + self.sentences: list[Sentence] = [] with path_to_json_file.open(encoding=encoding) as jsonl_fp: for line in jsonl_fp: current_line = json.loads(line) @@ -238,7 +234,7 @@ def __init__( self.sentences.append(current_sentence) - def _add_labels_to_sentence(self, raw_text: str, sentence: Sentence, labels: List[List[Any]]): + def _add_labels_to_sentence(self, raw_text: str, sentence: Sentence, labels: list[list[Any]]): # Add tags for each annotated span for label in labels: self._add_label_to_sentence(raw_text, sentence, label[0], label[1], label[2]) @@ -288,7 +284,7 @@ def _add_label_to_sentence(self, text: str, sentence: Sentence, start: int, end: sentence[start_idx : end_idx + 1].add_label(self.label_type, label) - def _add_metadatas_to_sentence(self, sentence: Sentence, metadatas: List[Tuple[str, str]]): + def _add_metadatas_to_sentence(self, sentence: Sentence, metadatas: list[tuple[str, str]]): # Add metadatas for sentence for metadata in metadatas: self._add_metadata_to_sentence(sentence, metadata[0], metadata[1]) @@ -313,7 +309,7 @@ def __getitem__(self, index: int) -> Sentence: class MultiFileColumnCorpus(Corpus): def __init__( self, - column_format: Dict[int, str], + column_format: dict[int, str], train_files=None, test_files=None, dev_files=None, @@ -323,8 +319,8 @@ def __init__( document_separator_token: Optional[str] = None, skip_first_line: bool = False, in_memory: bool = True, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, default_whitespace_after: int = 1, **corpusargs, ) -> None: @@ -424,7 +420,7 @@ class ColumnCorpus(MultiFileColumnCorpus): def __init__( self, data_folder: Union[str, Path], - column_format: Dict[int, str], + column_format: dict[int, str], train_file=None, test_file=None, dev_file=None, @@ -475,15 +471,15 @@ class ColumnDataset(FlairDataset): def __init__( self, path_to_column_file: Union[str, Path], - column_name_map: Dict[int, str], + column_name_map: dict[int, str], column_delimiter: str = r"\s+", comment_symbol: Optional[str] = None, - banned_sentences: Optional[List[str]] = None, + banned_sentences: Optional[list[str]] = None, in_memory: bool = True, document_separator_token: Optional[str] = None, encoding: str = "utf-8", skip_first_line: bool = False, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, default_whitespace_after: int = 1, ) -> None: r"""Instantiates a column dataset. @@ -537,7 +533,7 @@ def __init__( # option 1: keep Sentence objects in memory if self.in_memory: - self.sentences: List[Sentence] = [] + self.sentences: list[Sentence] = [] # pointer to previous previous_sentence = None @@ -579,7 +575,7 @@ def __init__( # option 2: keep source data in memory if not self.in_memory: - self.sentences_raw: List[List[str]] = [] + self.sentences_raw: list[list[str]] = [] while True: # read lines for next sentence, but don't parse @@ -679,10 +675,10 @@ def _read_next_sentence(self, file): return lines def _convert_lines_to_sentence( - self, lines, word_level_tag_columns: Dict[int, str], span_level_tag_columns: Optional[Dict[int, str]] = None + self, lines, word_level_tag_columns: dict[int, str], span_level_tag_columns: Optional[dict[int, str]] = None ): token: Optional[Token] = None - tokens: List[Token] = [] + tokens: list[Token] = [] filtered_lines = [] comments = [] for line in lines: @@ -749,9 +745,9 @@ def _convert_lines_to_sentence( return sentence return None - def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: Optional[Token] = None) -> Token: + def _parse_token(self, line: str, column_name_map: dict[int, str], last_token: Optional[Token] = None) -> Token: # get fields from line - fields: List[str] = self.column_delimiter.split(line.rstrip()) + fields: list[str] = self.column_delimiter.split(line.rstrip()) field_count = len(fields) # get head_id if exists (only in dependency parses) head_id = int(fields[self.head_id_column]) if self.head_id_column else None @@ -855,7 +851,7 @@ def __init__( base_path: Optional[Union[str, Path]] = None, version: str = "v4", language: str = "english", - domain: Union[None, str, List[str], Dict[str, Union[None, str, List[str]]]] = None, + domain: Union[None, str, list[str], dict[str, Union[None, str, list[str]]]] = None, in_memory: bool = True, **corpusargs, ) -> None: @@ -893,7 +889,7 @@ def get_available_domains( version: str = "v4", language: str = "english", split: str = "train", - ) -> List[str]: + ) -> list[str]: processed_data_path = cls._ensure_data_processed(base_path=base_path, language=language, version=version) processed_split_path = processed_data_path / "splits" / version / language / split @@ -907,7 +903,7 @@ def _get_processed_file_paths( split: str = "train", version: str = "v4", language: str = "english", - domain: Optional[Union[str, List[str], Dict[str, Union[None, str, List[str]]]]] = None, + domain: Optional[Union[str, list[str], dict[str, Union[None, str, list[str]]]]] = None, ) -> Iterable[Path]: processed_split_path = processed_data_path / "splits" / version / language / split @@ -1009,8 +1005,8 @@ def _process_coref_span_annotations_for_word( cls, label: str, word_index: int, - clusters: DefaultDict[int, List[Tuple[int, int]]], - coref_stacks: DefaultDict[int, List[int]], + clusters: defaultdict[int, list[tuple[int, int]]], + coref_stacks: defaultdict[int, list[int]], ) -> None: """For a given coref label, add it to a currently open span(s), complete a span(s) or ignore it, if it is outside of all spans. @@ -1048,9 +1044,9 @@ def _process_coref_span_annotations_for_word( @classmethod def _process_span_annotations_for_word( cls, - annotations: List[str], - span_labels: List[List[str]], - current_span_labels: List[Optional[str]], + annotations: list[str], + span_labels: list[list[str]], + current_span_labels: list[Optional[str]], ) -> None: for annotation_index, annotation in enumerate(annotations): # strip all bracketing information to @@ -1076,33 +1072,33 @@ def _process_span_annotations_for_word( current_span_labels[annotation_index] = None @classmethod - def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict: + def _conll_rows_to_sentence(cls, conll_rows: list[str]) -> dict: document_id: str sentence_id: int # The words in the sentence. - sentence: List[str] = [] + sentence: list[str] = [] # The pos tags of the words in the sentence. - pos_tags: List[str] = [] + pos_tags: list[str] = [] # the pieces of the parse tree. - parse_pieces: List[Optional[str]] = [] + parse_pieces: list[Optional[str]] = [] # The lemmatised form of the words in the sentence which # have SRL or word sense information. - predicate_lemmas: List[Optional[str]] = [] + predicate_lemmas: list[Optional[str]] = [] # The FrameNet ID of the predicate. - predicate_framenet_ids: List[Optional[str]] = [] + predicate_framenet_ids: list[Optional[str]] = [] # The sense of the word, if available. - word_senses: List[Optional[float]] = [] + word_senses: list[Optional[float]] = [] # The current speaker, if available. - speakers: List[Optional[str]] = [] + speakers: list[Optional[str]] = [] - verbal_predicates: List[str] = [] - span_labels: List[List[str]] = [] - current_span_labels: List[Optional[str]] = [] + verbal_predicates: list[str] = [] + span_labels: list[list[str]] = [] + current_span_labels: list[Optional[str]] = [] # Cluster id -> List of (start_index, end_index) spans. - clusters: DefaultDict[int, List[Tuple[int, int]]] = defaultdict(list) + clusters: defaultdict[int, list[tuple[int, int]]] = defaultdict(list) # Cluster id -> List of start_indices which are open for this id. - coref_stacks: DefaultDict[int, List[int]] = defaultdict(list) + coref_stacks: defaultdict[int, list[int]] = defaultdict(list) for index, row in enumerate(conll_rows): conll_components = row.split() @@ -1178,7 +1174,7 @@ def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict: srl_frames = list(zip(verbal_predicates, span_labels[1:])) # this would not be reached if parse_pieces contained None, hence the cast - parse_tree = "".join(cast(List[str], parse_pieces)) if all(parse_pieces) else None + parse_tree = "".join(cast(list[str], parse_pieces)) if all(parse_pieces) else None coref_span_tuples = {(cluster_id, span) for cluster_id, span_list in clusters.items() for span in span_list} return { @@ -1197,7 +1193,7 @@ def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict: } @classmethod - def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List]: + def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[list[dict]]: """An iterator over CONLL formatted files which yields documents, regardless of the number of document annotations in a particular file. This is useful for conll data which has been preprocessed, such @@ -1206,7 +1202,7 @@ def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List """ with open(file_path, encoding="utf8") as open_file: conll_rows = [] - document: List = [] + document: list[dict] = [] for line in open_file: line = line.strip() if line != "" and not line.startswith("#"): @@ -1456,17 +1452,22 @@ def __init__( cached_path(f"{conll_2000_path}train.txt.gz", Path("datasets") / dataset_name) cached_path(f"{conll_2000_path}test.txt.gz", Path("datasets") / dataset_name) import gzip - import shutil - with gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in, open( - flair.cache_root / "datasets" / dataset_name / "train.txt", - "wb", - ) as f_out: + with ( + gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in, + open( + flair.cache_root / "datasets" / dataset_name / "train.txt", + "wb", + ) as f_out, + ): shutil.copyfileobj(f_in, f_out) - with gzip.open(flair.cache_root / "datasets" / dataset_name / "test.txt.gz", "rb") as f_in, open( - flair.cache_root / "datasets" / dataset_name / "test.txt", - "wb", - ) as f_out: + with ( + gzip.open(flair.cache_root / "datasets" / dataset_name / "test.txt.gz", "rb") as f_in, + open( + flair.cache_root / "datasets" / dataset_name / "test.txt", + "wb", + ) as f_out, + ): shutil.copyfileobj(f_in, f_out) super().__init__( @@ -1735,8 +1736,6 @@ def __init__( data_file = data_path / "named_ent_eu.train" if not data_file.is_file(): cached_path(f"{ner_basque_path}/eiec_v1.0.tgz", Path("datasets") / dataset_name) - import shutil - import tarfile with tarfile.open( flair.cache_root / "datasets" / dataset_name / "eiec_v1.0.tgz", @@ -2247,15 +2246,13 @@ def __init__( if not base_path: base_path = Path(flair.cache_root) / "datasets" data_folder = base_path / dataset_name - import tarfile if not os.path.isfile(data_folder / "webpages_ner.txt"): # # download zip tar_file = "https://cogcomp.seas.upenn.edu/Data/NERWebpagesColumns.tgz" webpages_ner_path = cached_path(tar_file, Path("datasets") / dataset_name) - tf = tarfile.open(webpages_ner_path) - tf.extractall(data_folder) - tf.close() + with tarfile.open(webpages_ner_path) as tf: + tf.extractall(data_folder) outputfile = os.path.abspath(data_folder) # merge the files in one as the zip is containing multiples files @@ -2538,7 +2535,7 @@ def _add_IOB_tags(self, data_file: Union[str, Path], encoding: str = "utf8", ner Specifies the ner-tagged column. The default is 1 (the second column). """ - def add_I_prefix(current_line: List[str], ner: int, tag: str): + def add_I_prefix(current_line: list[str], ner: int, tag: str): for i in range(len(current_line)): if i == 0: f.write(line_list[i]) @@ -2779,9 +2776,11 @@ def _create_datasets(self, data_file: Union[str, Path], data_folder: Path): train_len = round(num_lines * 0.8) test_len = round(num_lines * 0.1) - with (data_folder / "train.txt").open("w", encoding="utf-8") as train, (data_folder / "test.txt").open( - "w", encoding="utf-8" - ) as test, (data_folder / "dev.txt").open("w", encoding="utf-8") as dev: + with ( + (data_folder / "train.txt").open("w", encoding="utf-8") as train, + (data_folder / "test.txt").open("w", encoding="utf-8") as test, + (data_folder / "dev.txt").open("w", encoding="utf-8") as dev, + ): for k, line in enumerate(file.readlines(), start=1): if k <= train_len: train.write(line) @@ -2972,7 +2971,7 @@ def __prepare_jap_wikinews_corpus(file_in: Union[str, Path], file_out: Union[str class NER_MASAKHANE(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "luo", + languages: Union[str, list[str]] = "luo", version: str = "v2", base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, @@ -3056,7 +3055,7 @@ def __init__( if languages == ["all"]: languages = list(language_to_code.values()) - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] for language in languages: if language in language_to_code: language = language_to_code[language] @@ -3239,7 +3238,7 @@ def __init__( class NER_MULTI_WIKIANN(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "en", + languages: Union[str, list[str]] = "en", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -3251,7 +3250,7 @@ def __init__( Parameters ---------- - languages : Union[str, List[str]] + languages : Union[str, list[str]] Should be an abbreviation of a language ("en", "de",..) or a list of abbreviations. The datasets of all passed languages will be saved in one MultiCorpus. (Note that, even though listed on https://elisa-ie.github.io/wikiann/ some datasets are empty. @@ -3282,7 +3281,7 @@ def __init__( # this list is handed to the multicorpus # list that contains the columncopora - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] google_drive_path = "https://drive.google.com/uc?id=" # download data if necessary @@ -3294,8 +3293,6 @@ def __init__( # if language not downloaded yet, download it if not language_folder.exists(): if first: - import tarfile - import gdown first = False @@ -3310,10 +3307,8 @@ def __init__( # unzip log.info("Extracting data...") - tar = tarfile.open(str(language_folder / language) + ".tar.gz", "r:gz") - # tar.extractall(language_folder,members=[tar.getmember(file_name)]) - tar.extract(file_name, str(language_folder)) - tar.close() + with tarfile.open(str(language_folder / language) + ".tar.gz", "r:gz") as tar: + tar.extract(file_name, str(language_folder)) log.info("...done.") # transform data into required format @@ -3342,9 +3337,10 @@ def __init__( ) def _silver_standard_to_simple_ner_annotation(self, data_file: Union[str, Path]): - with open(data_file, encoding="utf-8") as f_read, open( - str(data_file) + "_new", "w+", encoding="utf-8" - ) as f_write: + with ( + open(data_file, encoding="utf-8") as f_read, + open(str(data_file) + "_new", "w+", encoding="utf-8") as f_write, + ): while True: line = f_read.readline() if line: @@ -3660,7 +3656,7 @@ def _google_drive_id_from_language_name(self, language): class NER_MULTI_XTREME(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "en", + languages: Union[str, list[str]] = "en", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -3672,7 +3668,7 @@ def __init__( Parameters ---------- - languages : Union[str, List[str]], optional + languages : Union[str, list[str]], optional Specify the languages you want to load. Provide an empty list or string to select all languages. base_path : Union[str, Path], optional Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this to point to a different folder but typically this should not be necessary. @@ -3743,7 +3739,7 @@ def __init__( # This list is handed to the multicorpus # list that contains the columncopora - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] hu_path = "https://nlp.informatik.hu-berlin.de/resources/datasets/panx_dataset" @@ -3765,12 +3761,10 @@ def __init__( # unzip log.info("Extracting data...") - import tarfile - tar = tarfile.open(str(temp_file), "r:gz") - for part in ["train", "test", "dev"]: - tar.extract(part, str(language_folder)) - tar.close() + with tarfile.open(str(temp_file), "r:gz") as tar: + for part in ["train", "test", "dev"]: + tar.extract(part, str(language_folder)) log.info("...done.") # transform data into required format @@ -3809,7 +3803,7 @@ def _xtreme_to_simple_ner_annotation(self, data_file: Union[str, Path]): class NER_MULTI_WIKINER(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "en", + languages: Union[str, list[str]] = "en", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -3828,7 +3822,7 @@ def __init__( data_folder = base_path / dataset_name - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] for language in languages: language_folder = data_folder / language @@ -3868,11 +3862,14 @@ def _download_wikiner(self, language_code: str, dataset_name: str): flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.bz2", "rb", ) - with bz_file as f, open( - flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.train", - "w", - encoding="utf-8", - ) as out: + with ( + bz_file as f, + open( + flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.train", + "w", + encoding="utf-8", + ) as out, + ): for lineb in f: line = lineb.decode("utf-8") words = line.split(" ") @@ -4740,7 +4737,7 @@ def __init__( class NER_NERMUD(MultiCorpus): def __init__( self, - domains: Union[str, List[str]] = "all", + domains: Union[str, list[str]] = "all", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -4779,7 +4776,7 @@ def __init__( data_folder = base_path / dataset_name - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] github_path = "https://raw.githubusercontent.com/dhfbk/KIND/main/evalita-2023" @@ -4923,7 +4920,7 @@ def _set_path(cls, base_path) -> Path: return base_path @classmethod - def _load_features(cls, base_path) -> List[List[str]]: + def _load_features(cls, base_path) -> list[list[str]]: print(base_path) unpack_file(cached_path(cls.data_url, base_path), base_path, "zip", False) with open(f"{base_path}/estner.cnll", encoding="utf-8") as in_file: @@ -4932,17 +4929,17 @@ def _load_features(cls, base_path) -> List[List[str]]: return features @classmethod - def _process_clean_labels(cls, features) -> List[List[str]]: + def _process_clean_labels(cls, features) -> list[list[str]]: preinstances = [[instance[0], instance[len(instance) - 1]] for instance in features] return preinstances @classmethod - def _rmv_clean_labels(cls, features) -> List[str]: + def _rmv_clean_labels(cls, features) -> list[str]: rdcd_features = [feature[:-1] for feature in features] return rdcd_features @classmethod - def _load_noisy_labels(cls, version, base_path) -> List[str]: + def _load_noisy_labels(cls, version, base_path) -> list[str]: file_name = f"NoisyNER_labelset{version}.labels" cached_path(f"{cls.label_url}/{file_name}", base_path) with open(f"{base_path}/{file_name}", encoding="utf-8") as in_file: @@ -4950,7 +4947,7 @@ def _load_noisy_labels(cls, version, base_path) -> List[str]: return labels @classmethod - def _process_noisy_labels(cls, rdcd_features, labels) -> List[List[str]]: + def _process_noisy_labels(cls, rdcd_features, labels) -> list[list[str]]: instances = [] label_idx = 0 for feature in rdcd_features: @@ -4965,7 +4962,7 @@ def _process_noisy_labels(cls, rdcd_features, labels) -> List[List[str]]: return instances @classmethod - def _delete_empty_labels(cls, version, preinstances) -> List[str]: + def _delete_empty_labels(cls, version, preinstances) -> list[str]: instances = [] if version == 0: for instance in preinstances: @@ -4978,7 +4975,7 @@ def _delete_empty_labels(cls, version, preinstances) -> List[str]: return instances @classmethod - def _split_data(cls, instances) -> Tuple[List[str], List[str], List[str]]: + def _split_data(cls, instances) -> tuple[list[str], list[str], list[str]]: train = instances[:185708] dev = instances[185708:208922] test = instances[208922:] @@ -4996,7 +4993,7 @@ def _write_instances(cls, version, base_path, split, data): class MASAKHA_POS(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "bam", + languages: Union[str, list[str]] = "bam", version: str = "v1", base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, @@ -5063,7 +5060,7 @@ def __init__( if languages == ["all"]: languages = supported_languages - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] for language in languages: if language not in supported_languages: log.error(f"Language '{language}' is not in list of supported languages!") diff --git a/flair/datasets/text_image.py b/flair/datasets/text_image.py index f7baf72be9..676b078d7f 100644 --- a/flair/datasets/text_image.py +++ b/flair/datasets/text_image.py @@ -3,7 +3,6 @@ import os import urllib from pathlib import Path -from typing import List import numpy as np import torch.utils.data.dataloader @@ -40,13 +39,13 @@ def __init__(self, **kwargs) -> None: feidegger_dataset: Dataset = FeideggerDataset(dataset_info, **kwargs) - train_indices = list(np.where(np.in1d(feidegger_dataset.split, list(range(8))))[0]) # type: ignore[attr-defined] + train_indices = list(np.where(np.isin(feidegger_dataset.split, list(range(8))))[0]) # type: ignore[attr-defined] train = torch.utils.data.dataset.Subset(feidegger_dataset, train_indices) - dev_indices = list(np.where(np.in1d(feidegger_dataset.split, [8]))[0]) # type: ignore[attr-defined] + dev_indices = list(np.where(np.isin(feidegger_dataset.split, [8]))[0]) # type: ignore[attr-defined] dev = torch.utils.data.dataset.Subset(feidegger_dataset, dev_indices) - test_indices = list(np.where(np.in1d(feidegger_dataset.split, [9]))[0]) # type: ignore[attr-defined] + test_indices = list(np.where(np.isin(feidegger_dataset.split, [9]))[0]) # type: ignore[attr-defined] test = torch.utils.data.dataset.Subset(feidegger_dataset, test_indices) super().__init__(train, dev, test, name="feidegger") @@ -56,8 +55,8 @@ class FeideggerDataset(FlairDataset): def __init__(self, dataset_info, **kwargs) -> None: super().__init__() - self.data_points: List[DataPair] = [] - self.split: List[int] = [] + self.data_points: list[DataPair] = [] + self.split: list[int] = [] def identity(x): return x diff --git a/flair/datasets/text_text.py b/flair/datasets/text_text.py index 0bf0e91020..58a40d62c9 100644 --- a/flair/datasets/text_text.py +++ b/flair/datasets/text_text.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import flair from flair.data import ( @@ -144,14 +144,15 @@ def __init__( self.total_sentence_count: int = 0 if self.in_memory: - self.bi_sentences: List[DataPair] = [] + self.bi_sentences: list[DataPair] = [] else: - self.source_lines: List[str] = [] - self.target_lines: List[str] = [] + self.source_lines: list[str] = [] + self.target_lines: list[str] = [] - with open(str(path_to_source), encoding="utf-8") as source_file, open( - str(path_to_target), encoding="utf-8" - ) as target_file: + with ( + open(str(path_to_source), encoding="utf-8") as source_file, + open(str(path_to_target), encoding="utf-8") as target_file, + ): source_line = source_file.readline() target_line = target_file.readline() @@ -204,7 +205,7 @@ class DataPairCorpus(Corpus): def __init__( self, data_folder: Union[str, Path], - columns: List[int] = [0, 1, 2], + columns: list[int] = [0, 1, 2], train_file=None, test_file=None, dev_file=None, @@ -318,7 +319,7 @@ class DataPairDataset(FlairDataset): def __init__( self, path_to_data: Union[str, Path], - columns: List[int] = [0, 1, 2], + columns: list[int] = [0, 1, 2], max_tokens_per_doc=-1, max_chars_per_doc=-1, use_tokenizer=True, @@ -368,11 +369,11 @@ def __init__( self.total_data_count: int = 0 if self.in_memory: - self.data_pairs: List[DataPair] = [] + self.data_pairs: list[DataPair] = [] else: - self.first_elements: List[str] = [] - self.second_elements: List[str] = [] - self.labels: List[Optional[str]] = [] + self.first_elements: list[str] = [] + self.second_elements: list[str] = [] + self.labels: list[Optional[str]] = [] with open(str(path_to_data), encoding=encoding) as source_file: source_line = source_file.readline() @@ -448,7 +449,7 @@ class DataTripleCorpus(Corpus): def __init__( self, data_folder: Union[str, Path], - columns: List[int] = [0, 1, 2, 3], + columns: list[int] = [0, 1, 2, 3], train_file=None, test_file=None, dev_file=None, @@ -563,7 +564,7 @@ class DataTripleDataset(FlairDataset): def __init__( self, path_to_data: Union[str, Path], - columns: List[int] = [0, 1, 2, 3], + columns: list[int] = [0, 1, 2, 3], max_tokens_per_doc=-1, max_chars_per_doc=-1, use_tokenizer=True, @@ -614,12 +615,12 @@ def __init__( self.total_data_count: int = 0 if self.in_memory: - self.data_triples: List[DataTriple] = [] + self.data_triples: list[DataTriple] = [] else: - self.first_elements: List[str] = [] - self.second_elements: List[str] = [] - self.third_elements: List[str] = [] - self.labels: List[Optional[str]] = [] + self.first_elements: list[str] = [] + self.second_elements: list[str] = [] + self.third_elements: list[str] = [] + self.labels: list[Optional[str]] = [] with open(str(path_to_data), encoding=encoding) as source_file: source_line = source_file.readline() @@ -828,9 +829,10 @@ def __init__( str(data_folder / "MNLI" / temp_file), ) - with open(data_folder / "MNLI" / dev_filename, "a", encoding="utf-8") as out_file, open( - data_folder / "MNLI" / temp_file, encoding="utf-8" - ) as in_file: + with ( + open(data_folder / "MNLI" / dev_filename, "a", encoding="utf-8") as out_file, + open(data_folder / "MNLI" / temp_file, encoding="utf-8") as in_file, + ): for line in in_file: fields = line.split("\t") reordered_columns = "\t".join(fields[column_id] for column_id in range(11)) diff --git a/flair/datasets/treebanks.py b/flair/datasets/treebanks.py index ed0f0135cd..21ae327691 100644 --- a/flair/datasets/treebanks.py +++ b/flair/datasets/treebanks.py @@ -1,7 +1,7 @@ import logging import re from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import flair from flair.data import Corpus, FlairDataset, Sentence, Token @@ -82,7 +82,7 @@ def __init__( with open(str(self.path_to_conll_file), encoding="utf-8") as file: # option 1: read only sentence boundaries as offset positions if not self.in_memory: - self.indices: List[int] = [] + self.indices: list[int] = [] line = file.readline() position = 0 @@ -97,7 +97,7 @@ def __init__( # option 2: keep everything in memory if self.in_memory: - self.sentences: List[Sentence] = [] + self.sentences: list[Sentence] = [] while True: sentence = self._read_next_sentence(file) @@ -129,7 +129,7 @@ def __getitem__(self, index: int = 0) -> Sentence: def _read_next_sentence(self, file) -> Optional[Sentence]: line = file.readline() - tokens: List[Token] = [] + tokens: list[Token] = [] # current token ID token_idx = 0 @@ -143,7 +143,7 @@ def _read_next_sentence(self, file) -> Optional[Sentence]: newline_reached = False while line: line = line.strip() - fields: List[str] = re.split("\t+", line) + fields: list[str] = re.split("\t+", line) # end of sentence if line == "": diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 154f2600be..294b41fac8 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -1,7 +1,8 @@ import inspect import logging from abc import abstractmethod -from typing import Any, Dict, Generic, List, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Generic, Union import torch from torch.nn import Parameter, ParameterList @@ -37,7 +38,7 @@ def embedding_length(self) -> int: def embedding_type(self) -> str: raise NotImplementedError - def embed(self, data_points: Union[DT, List[DT]]) -> List[DT]: + def embed(self, data_points: Union[DT, list[DT]]) -> list[DT]: """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings are non-static. @@ -55,10 +56,10 @@ def _everything_embedded(self, data_points: Sequence[DT]) -> bool: return all(self.name in data_point._embeddings for data_point in data_points) @abstractmethod - def _add_embeddings_internal(self, sentences: List[DT]): + def _add_embeddings_internal(self, sentences: list[DT]): """Private method for adding embeddings to all words in a list of sentences.""" - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of @@ -67,9 +68,6 @@ def get_names(self) -> List[str]: """ return [self.name] - def get_named_embeddings_dict(self) -> Dict: - return {self.name: self} - @staticmethod def get_instance_parameters(locals: dict) -> dict: class_definition = locals.get("__class__") @@ -84,14 +82,14 @@ def get_instance_parameters(locals: dict) -> dict: return instance_parameters @classmethod - def from_params(cls, params: Dict[str, Any]) -> "Embeddings": + def from_params(cls, params: dict[str, Any]) -> "Embeddings": raise NotImplementedError - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: raise NotImplementedError @classmethod - def load_embedding(cls, params: Dict[str, Any]): + def load_embedding(cls, params: dict[str, Any]): state_dict = params.pop("state_dict", None) embedding = cls.from_params(params) @@ -155,7 +153,7 @@ def __init__(self, mixture_size: int, trainable: bool = False) -> None: requires_grad=trainable, ) - def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: + def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor: """Forward pass of scalar mix. Computes a weighted average of the ``tensors``. The input tensors an be any shape @@ -203,7 +201,7 @@ def _everything_embedded(self, data_points: Sequence[Sentence]) -> bool: return True -EMBEDDING_CLASSES: Dict[str, Type[Embeddings]] = {} +EMBEDDING_CLASSES: dict[str, type[Embeddings]] = {} def register_embeddings(*args): @@ -225,7 +223,7 @@ def _register(cls): return _register -def load_embeddings(params: Dict[str, Any]) -> Embeddings: +def load_embeddings(params: dict[str, Any]) -> Embeddings: cls_name = params.pop("__cls__") cls = EMBEDDING_CLASSES[cls_name] return cls.load_embedding(params) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index c1e73442e6..28867d889a 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast import torch from sklearn.feature_extraction.text import TfidfVectorizer @@ -67,7 +67,7 @@ def create_from_state(cls, **state): class DocumentPoolEmbeddings(DocumentEmbeddings): def __init__( self, - embeddings: Union[TokenEmbeddings, List[TokenEmbeddings]], + embeddings: Union[TokenEmbeddings, list[TokenEmbeddings]], fine_tune_mode: str = "none", pooling: str = "mean", ) -> None: @@ -114,7 +114,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates only if embeddings are non-static. @@ -146,18 +146,18 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentence.set_embedding(self.name, pooled_embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass def extra_repr(self): return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}" @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentPoolEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentPoolEmbeddings": embeddings = cast(StackedEmbeddings, load_embeddings(params.pop("embeddings"))).embeddings return cls(embeddings=embeddings, **params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "pooling": self.pooling, "fine_tune_mode": self.fine_tune_mode, @@ -169,7 +169,7 @@ def to_params(self) -> Dict[str, Any]: class DocumentTFIDFEmbeddings(DocumentEmbeddings): def __init__( self, - train_dataset: List[Sentence], + train_dataset: list[Sentence], vectorizer: Optional[TfidfVectorizer] = None, **vectorizer_params, ) -> None: @@ -203,7 +203,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to every sentence in the given list of sentences.""" # if only one sentence is passed, convert to list of sentence if isinstance(sentences, Sentence): @@ -215,14 +215,14 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): for sentence_id, sentence in enumerate(sentences): sentence.set_embedding(self.name, tfidf_vectors[sentence_id]) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentTFIDFEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentTFIDFEmbeddings": return cls(train_dataset=[], vectorizer=params["vectorizer"]) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "vectorizer": self.vectorizer, } @@ -232,7 +232,7 @@ def to_params(self) -> Dict[str, Any]: class DocumentRNNEmbeddings(DocumentEmbeddings): def __init__( self, - embeddings: List[TokenEmbeddings], + embeddings: list[TokenEmbeddings], hidden_size=128, rnn_layers=1, reproject_words: bool = True, @@ -317,7 +317,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static. @@ -332,7 +332,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # embed words in the sentence self.embeddings.embed(sentences) - lengths: List[int] = [len(sentence.tokens) for sentence in sentences] + lengths: list[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) pre_allocated_zero_tensor = torch.zeros( @@ -341,7 +341,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): device=flair.device, ) - all_embs: List[torch.Tensor] = [] + all_embs: list[torch.Tensor] = [] for sentence in sentences: all_embs += [emb for token in sentence for emb in token.get_each_embedding()] nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) @@ -436,7 +436,7 @@ def to_params(self): return model_state @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentRNNEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentRNNEmbeddings": stacked_embeddings = load_embeddings(params["embeddings"]) assert isinstance(stacked_embeddings, StackedEmbeddings) return cls( @@ -484,7 +484,7 @@ def __setstate__(self, d): @register_embeddings class DocumentLMEmbeddings(DocumentEmbeddings): - def __init__(self, flair_embeddings: List[FlairEmbeddings]) -> None: + def __init__(self, flair_embeddings: list[FlairEmbeddings]) -> None: super().__init__() self.embeddings = flair_embeddings @@ -503,7 +503,7 @@ def __init__(self, flair_embeddings: List[FlairEmbeddings]) -> None: def embedding_length(self) -> int: return self._embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): for embedding in self.embeddings: embedding.embed(sentences) @@ -520,17 +520,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): return sentences - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: if "__names" not in self.__dict__: self.__names = [name for embedding in self.embeddings for name in embedding.get_names()] return self.__names - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return {"flair_embeddings": [embedding.save_embeddings(False) for embedding in self.embeddings]} @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentLMEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentLMEmbeddings": return cls([cast(FlairEmbeddings, load_embeddings(embedding)) for embedding in params["flair_embeddings"]]) @@ -566,7 +566,7 @@ def __init__( self.static_embeddings = True self.eval() - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: sentence_batches = [ sentences[i * self.batch_size : (i + 1) * self.batch_size] for i in range((len(sentences) + self.batch_size - 1) // self.batch_size) @@ -577,7 +577,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: return sentences - def _add_embeddings_to_sentences(self, sentences: List[Sentence]): + def _add_embeddings_to_sentences(self, sentences: list[Sentence]): # convert to plain strings, embedded in a list for the encode function sentences_plain_text = [sentence.to_plain_string() for sentence in sentences] @@ -591,10 +591,10 @@ def embedding_length(self) -> int: return self.model.get_sentence_embedding_dimension() @classmethod - def from_params(cls, params: Dict[str, Any]) -> "SentenceTransformerDocumentEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "SentenceTransformerDocumentEmbeddings": return cls(**params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "model": self.model_name, "batch_size": self.batch_size, @@ -605,7 +605,7 @@ def to_params(self) -> Dict[str, Any]: class DocumentCNNEmbeddings(DocumentEmbeddings): def __init__( self, - embeddings: List[TokenEmbeddings], + embeddings: list[TokenEmbeddings], kernels=((100, 3), (100, 4), (100, 5)), reproject_words: bool = True, reproject_words_dimension: Optional[int] = None, @@ -673,7 +673,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static. @@ -689,7 +689,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # embed words in the sentence self.embeddings.embed(sentences) - lengths: List[int] = [len(sentence.tokens) for sentence in sentences] + lengths: list[int] = [len(sentence.tokens) for sentence in sentences] padding_length: int = max(max(lengths), self.min_sequence_length) pre_allocated_zero_tensor = torch.zeros( @@ -698,7 +698,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): device=flair.device, ) - all_embs: List[torch.Tensor] = [] + all_embs: list[torch.Tensor] = [] for sentence in sentences: all_embs += [emb for token in sentence for emb in token.get_each_embedding()] nb_padding_tokens = padding_length - len(sentence) @@ -757,11 +757,11 @@ def _apply(self, fn): child_module._apply(fn) @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentCNNEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentCNNEmbeddings": embeddings = cast(StackedEmbeddings, load_embeddings(params.pop("embeddings"))).embeddings return cls(embeddings=embeddings, **params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "embeddings": self.embeddings.save_embeddings(False), "kernels": self.kernels, diff --git a/flair/embeddings/image.py b/flair/embeddings/image.py index df6d1fadd9..5d79a04390 100644 --- a/flair/embeddings/image.py +++ b/flair/embeddings/image.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch import torch.nn.functional as F @@ -29,12 +29,12 @@ class ImageEmbeddings(Embeddings[Image]): def embedding_type(self) -> str: return "image-level" - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: # legacy pickle-like saving for image embeddings, as implementation details are not obvious return self.__getstate__() @classmethod - def from_params(cls, params: Dict[str, Any]) -> "Embeddings": + def from_params(cls, params: dict[str, Any]) -> "Embeddings": # legacy pickle-like loading for image embeddings, as implementation details are not obvious embedding = cls.__new__(cls) embedding.__setstate__(params) @@ -53,7 +53,7 @@ def __init__(self, transforms) -> None: self.static_embeddings = True super().__init__() - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): for image in images: image_data = self.PIL.Image.open(image.imageURL) image_data.load() @@ -77,7 +77,7 @@ def __init__(self, url2tensor_dict, name) -> None: self.static_embeddings = True super().__init__() - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): for image in images: if image.imageURL in self.url2tensor_dict: image.set_embedding(self.name, self.url2tensor_dict[image.imageURL]) @@ -137,7 +137,7 @@ def __init__(self, name, pretrained=True, transforms=None) -> None: else: raise Exception(f"Image embeddings {name} not available.") - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): image_tensor = torch.stack([self.transforms(image.data) for image in images]) image_embeddings = self.features(image_tensor) image_embeddings = ( @@ -163,7 +163,7 @@ def __init__(self, feats_in, convnet_parms, posnet_parms, transformer_parms) -> adaptive_pool_func_map = {"max": AdaptiveMaxPool2d, "avg": AdaptiveAvgPool2d} - convnet_arch: List[Any] = [] if convnet_parms["dropout"][0] <= 0 else [Dropout2d(convnet_parms["dropout"][0])] + convnet_arch: list[Any] = [] if convnet_parms["dropout"][0] <= 0 else [Dropout2d(convnet_parms["dropout"][0])] convnet_arch.extend( [ Conv2d( @@ -266,7 +266,7 @@ def forward(self, x): return x - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): image_tensor = torch.stack([image.data for image in images]) image_embeddings = self.forward(image_tensor) for image_id, image in enumerate(images): diff --git a/flair/embeddings/legacy.py b/flair/embeddings/legacy.py index b2658e2d2f..4b3d2a9517 100644 --- a/flair/embeddings/legacy.py +++ b/flair/embeddings/legacy.py @@ -1,7 +1,7 @@ import logging import re from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import torch from deprecated.sphinx import deprecated @@ -110,12 +110,12 @@ def use_layers_top(self, x): def use_layers_average(self, x): return torch.mean(torch.stack(x), 0) - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: # ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn if not getattr(self, "embedding_mode_fn", None): self.embedding_mode_fn = self.use_layers_all - sentence_words: List[List[str]] = [] + sentence_words: list[list[str]] = [] for sentence in sentences: sentence_words.append([token.text for token in sentence]) @@ -394,7 +394,7 @@ def __getstate__(self): def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: # if cache is used, try setting embeddings from cache first if "cache" in self.__dict__ and self.cache is not None: # try populating embeddings from cache @@ -463,7 +463,7 @@ class DocumentMeanEmbeddings(DocumentEmbeddings): version="0.3.1", reason="The functionality of this class is moved to 'DocumentPoolEmbeddings'", ) - def __init__(self, token_embeddings: List[TokenEmbeddings]) -> None: + def __init__(self, token_embeddings: list[TokenEmbeddings]) -> None: """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -478,7 +478,7 @@ def __init__(self, token_embeddings: List[TokenEmbeddings]) -> None: def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates only if embeddings are non-static. """ @@ -506,7 +506,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentence.set_embedding(self.name, mean_embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass @@ -517,7 +517,7 @@ class DocumentLSTMEmbeddings(DocumentEmbeddings): ) def __init__( self, - embeddings: List[TokenEmbeddings], + embeddings: list[TokenEmbeddings], hidden_size=128, rnn_layers=1, reproject_words: bool = True, @@ -587,7 +587,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static. """ @@ -604,7 +604,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): longest_token_sequence_in_batch: int = len(sentences[0]) all_sentence_tensors = [] - lengths: List[int] = [] + lengths: list[int] = [] # go through each sentence in batch for _i, sentence in enumerate(sentences): @@ -669,5 +669,5 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentence = sentences[sentence_no] sentence.set_embedding(self.name, embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index b068305800..700eaf4c45 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -4,7 +4,7 @@ import tempfile from collections import Counter from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -64,7 +64,7 @@ def create_from_state(cls, **state): class StackedEmbeddings(TokenEmbeddings): """A stack of embeddings, used if you need to combine several different embedding types.""" - def __init__(self, embeddings: List[TokenEmbeddings], overwrite_names: bool = True) -> None: + def __init__(self, embeddings: list[TokenEmbeddings], overwrite_names: bool = True) -> None: """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -88,7 +88,7 @@ def __init__(self, embeddings: List[TokenEmbeddings], overwrite_names: bool = Tr self.__embedding_length += embedding.embedding_length self.eval() - def embed(self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True): + def embed(self, sentences: Union[Sentence, list[Sentence]], static_embeddings: bool = True): # if only one sentence is passed, convert to list of sentence if type(sentences) is Sentence: sentences = [sentences] @@ -104,7 +104,7 @@ def embedding_type(self) -> str: def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: for embedding in self.embeddings: embedding._add_embeddings_internal(sentences) @@ -113,7 +113,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: def __str__(self) -> str: return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]' - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of this embedding. But in some cases, the @@ -126,13 +126,6 @@ def get_names(self) -> List[str]: return self.__names - def get_named_embeddings_dict(self) -> Dict: - named_embeddings_dict = {} - for embedding in self.embeddings: - named_embeddings_dict.update(embedding.get_named_embeddings_dict()) - - return named_embeddings_dict - @classmethod def from_params(cls, params): embeddings = [load_embeddings(p) for p in params["embeddings"]] @@ -154,7 +147,7 @@ def __init__( force_cpu: bool = True, stable: bool = False, no_header: bool = False, - vocab: Optional[Dict[str, int]] = None, + vocab: Optional[dict[str, int]] = None, embedding_length: Optional[int] = None, name: Optional[str] = None, ) -> None: @@ -334,10 +327,10 @@ def get_cached_token_index(self, word: str) -> int: else: return len(self.vocab) # token - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: tokens = [token for sentence in sentences for token in sentence.tokens] - word_indices: List[int] = [] + word_indices: list[int] = [] for token in tokens: word = token.text if self.field is None else token.get_label(self.field).value word_indices.append(self.get_cached_token_index(word)) @@ -386,7 +379,7 @@ def __getattribute__(self, item): return None return super().__getattribute__(item) - def __setstate__(self, state: Dict[str, Any]): + def __setstate__(self, state: dict[str, Any]): state.pop("get_cached_vec", None) state.setdefault("embeddings", state["name"]) state.setdefault("force_cpu", True) @@ -416,10 +409,10 @@ def __setstate__(self, state: Dict[str, Any]): super().__setstate__(state) @classmethod - def from_params(cls, params: Dict[str, Any]) -> "WordEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "WordEmbeddings": return cls(embeddings=None, **params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "vocab": self.vocab, "stable": self.stable, @@ -487,7 +480,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): for sentence in sentences: tokens_char_indices = [] @@ -544,10 +537,10 @@ def __str__(self) -> str: return self.name @classmethod - def from_params(cls, params: Dict[str, Any]) -> "CharacterEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "CharacterEmbeddings": return cls(**params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "path_to_char_dict": self.char_dictionary, "char_embedding_dim": self.char_embedding_dim, @@ -793,7 +786,7 @@ def train(self, mode=True): def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: # gradients are enable if fine-tuning is enabled gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad() @@ -885,7 +878,7 @@ def from_params(cls, params): lm = LanguageModel(**model_params) return cls(lm, **params) - def __setstate__(self, d: Dict[str, Any]): + def __setstate__(self, d: dict[str, Any]): # make compatible with old models d.setdefault("fine_tune", False) d.setdefault("chars_per_chunk", 512) @@ -920,8 +913,8 @@ def __init__( self.name = self.context_embeddings.name + "-context" # these fields are for the embedding memory - self.word_embeddings: Dict[str, torch.Tensor] = {} - self.word_count: Dict[str, int] = {} + self.word_embeddings: dict[str, torch.Tensor] = {} + self.word_count: dict[str, int] = {} # whether to add only capitalized words to memory (faster runtime and lower memory consumption) self.only_capitalized = only_capitalized @@ -940,7 +933,7 @@ def train(self, mode=True): self.word_embeddings = {} self.word_count = {} - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: self.context_embeddings.embed(sentences) # if we keep a pooling, it needs to be updated continuously @@ -989,10 +982,10 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: def embedding_length(self) -> int: return self.__embedding_length - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: return [self.name, self.context_embeddings.name] - def __setstate__(self, d: Dict[str, Any]): + def __setstate__(self, d: dict[str, Any]): super().__setstate__(d) if flair.device.type != "cpu": @@ -1073,7 +1066,7 @@ def get_cached_vec(self, word: str) -> torch.Tensor: word_embedding = torch.tensor(word_embedding.tolist(), device=flair.device, dtype=torch.float) return word_embedding - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: for sentence in sentences: for token in sentence.tokens: word = token.text if self.field is None else token.get_label(self.field).value @@ -1152,7 +1145,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: tokens = [t for sentence in sentences for t in sentence.tokens] if self.field == "text": @@ -1240,7 +1233,7 @@ def num_embeddings(self) -> int: def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): def get_idx_for_item(text): hash_function = hashlib.new(self.__hash_method) hash_function.update(bytes(str(text), "utf-8")) @@ -1282,7 +1275,7 @@ def __init__( self.name: str = "muse-crosslingual" self.static_embeddings = True self.__embedding_length: int = 300 - self.language_embeddings: Dict[str, Any] = {} + self.language_embeddings: dict[str, Any] = {} (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") self.kv = KeyedVectors super().__init__() @@ -1304,7 +1297,7 @@ def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor: word_embedding = torch.tensor(word_embedding, device=flair.device, dtype=torch.float) return word_embedding - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: for _i, sentence in enumerate(sentences): language_code = sentence.get_language_code() supported = [ @@ -1465,10 +1458,10 @@ def _preprocess(self, text: str) -> str: def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: tokens = [token for sentence in sentences for token in sentence.tokens] - word_indices: List[List[int]] = [] + word_indices: list[list[int]] = [] for token in tokens: word = token.text if self.field is None else token.get_label(self.field).value @@ -1601,13 +1594,13 @@ def __init__(self, embeddings: str, model: str = "skip", size: int = 100) -> Non else: embeddings_path = embeddings - log.info("Reading embeddings from %s" % embeddings_path) + log.info("Reading embeddings from %s", embeddings_path) super().__init__( embeddings=str(extract_single_zip_file(embeddings_path, cache_dir=cache_dir)), name="NILC-" + embeddings ) @classmethod - def from_params(cls, params: Dict[str, Any]) -> "WordEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "WordEmbeddings": # no need to recreate as NILCEmbeddings return WordEmbeddings(embeddings=None, **params) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 1e88787deb..d09ed33699 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -8,7 +8,7 @@ from abc import abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, Literal, Optional, Union, cast import torch import transformers @@ -44,7 +44,7 @@ @torch.jit.script_if_tracing -def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tensor: +def pad_sequence_embeddings(all_hidden_states: list[torch.Tensor]) -> torch.Tensor: embedding_length = all_hidden_states[0].shape[1] longest_token_sequence_in_batch = 0 for hidden_states in all_hidden_states: @@ -218,13 +218,13 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: def _legacy_reconstruct_word_ids( - embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]] -) -> List[List[Optional[int]]]: + embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]] +) -> list[list[Optional[int]]]: word_ids_list = [] max_len = 0 for tokens in flair_tokens: token_texts = embedding.tokenizer.tokenize(" ".join(tokens), is_split_into_words=True) - token_ids = cast(List[int], embedding.tokenizer.convert_tokens_to_ids(token_texts)) + token_ids = cast(list[int], embedding.tokenizer.convert_tokens_to_ids(token_texts)) expanded_token_ids = embedding.tokenizer.build_inputs_with_special_tokens(token_ids) j = 0 for _i, token_id in enumerate(token_ids): @@ -264,10 +264,10 @@ def _get_processed_token_text(tokenizer, token: str) -> str: return token_text.strip() -def _reconstruct_word_ids_from_subtokens(embedding, tokens: List[str], subtokens: List[str]): +def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens: list[str]): word_iterator = iter(enumerate(_get_processed_token_text(embedding.tokenizer, token) for token in tokens)) token_id, token_text = next(word_iterator) - word_ids: List[Optional[int]] = [] + word_ids: list[Optional[int]] = [] reconstructed_token = "" subtoken_count = 0 processed_first_token = False @@ -504,10 +504,10 @@ def embedding_type(self) -> str: return "word-level" if self.token_embedding else "sentence-level" @abstractmethod - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: return self(**tensors) - def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.device] = None): + def prepare_tensors(self, sentences: list[Sentence], device: Optional[torch.device] = None): if device is None: device = flair.device flair_tokens, offsets, lengths = self.__gather_flair_tokens(sentences) @@ -535,13 +535,13 @@ def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.devi def __build_transformer_model_inputs( self, - sentences: List[Sentence], - offsets: List[int], - sentence_lengths: List[int], - flair_tokens: List[List[Token]], + sentences: list[Sentence], + offsets: list[int], + sentence_lengths: list[int], + flair_tokens: list[list[Token]], device: torch.device, ): - tokenizer_kwargs: Dict[str, Any] = {} + tokenizer_kwargs: dict[str, Any] = {} if self.tokenizer_needs_ocr_boxes: tokenizer_kwargs["boxes"] = [[t.get_metadata("bbox") for t in tokens] for tokens in flair_tokens] else: @@ -662,7 +662,7 @@ def __build_transformer_model_inputs( return model_kwargs - def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[Token]], List[int], List[int]]: + def __gather_flair_tokens(self, sentences: list[Sentence]) -> tuple[list[list[Token]], list[int], list[int]]: offsets = [] lengths = [] if self.context_length > 0: @@ -686,7 +686,7 @@ def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[To lengths.append(len(sentence)) return sentence_tokens, offsets, lengths - def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: + def _expand_sentence_with_context(self, sentence) -> tuple[list[Token], int]: # fields to store left and right context left_context = [] right_context = [] @@ -722,7 +722,7 @@ def __extract_token_embeddings(self, sentence_embeddings, sentences): for token_embedding, token in zip(token_embeddings, sentence): token.set_embedding(self.name, token_embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): tensors = self.prepare_tensors(sentences, device=self.force_device) gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() with gradient_context: @@ -739,7 +739,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): @register_embeddings class TransformerOnnxEmbeddings(TransformerBaseEmbeddings): - def __init__(self, onnx_model: str, providers: List = [], session_options: Optional[Dict] = None, **kwargs) -> None: + def __init__(self, onnx_model: str, providers: list = [], session_options: Optional[dict] = None, **kwargs) -> None: # onnx prepares numpy arrays, no mather if it runs on gpu or cpu, the input is on cpu first. super().__init__(**kwargs, force_device=torch.device("cpu")) self.onnx_model = onnx_model @@ -756,7 +756,7 @@ def to_params(self): return params @classmethod - def from_params(cls, params: Dict[str, Any]) -> "TransformerOnnxEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "TransformerOnnxEmbeddings": params["tokenizer"] = cls._tokenizer_from_bytes(params.pop("tokenizer_data")) params["feature_extractor"] = cls._feature_extractor_from_bytes(params.pop("feature_extractor_data", None)) return cls(**params) @@ -812,7 +812,7 @@ def quantize_model(self, quantize_model_path, use_external_data_format: bool = F self.onnx_model = quantize_model_path self.create_session() - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: input_array = {k: v.numpy() for k, v in tensors.items()} embeddings = self.session.run([], input_array) @@ -854,9 +854,9 @@ def export_from_embedding( cls, path: Union[str, Path], embedding: "TransformerEmbeddings", - example_sentences: List[Sentence], + example_sentences: list[Sentence], opset_version: int = 14, - providers: Optional[List] = None, + providers: Optional[list] = None, session_options: Optional[dict] = None, ): path = str(path) @@ -903,7 +903,7 @@ def export_from_embedding( @register_embeddings class TransformerJitEmbeddings(TransformerBaseEmbeddings): - def __init__(self, jit_model: Union[bytes, ScriptModule], param_names: List[str], **kwargs) -> None: + def __init__(self, jit_model: Union[bytes, ScriptModule], param_names: list[str], **kwargs) -> None: super().__init__(**kwargs) if isinstance(jit_model, bytes): buffer = BytesIO(jit_model) @@ -925,12 +925,12 @@ def to_params(self): return state @classmethod - def from_params(cls, params: Dict[str, Any]) -> "Embeddings": + def from_params(cls, params: dict[str, Any]) -> "Embeddings": params["tokenizer"] = cls._tokenizer_from_bytes(params.pop("tokenizer_data")) params["feature_extractor"] = cls._feature_extractor_from_bytes(params.pop("feature_extractor_data", None)) return cls(**params) - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: parameters = [] for param in self.param_names: parameters.append(tensors[param]) @@ -945,13 +945,13 @@ def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: raise ValueError("either 'token_embedding' or 'document_embedding' needs to be set.") @classmethod - def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbeddings", param_names: List[str]): + def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbeddings", param_names: list[str]): return cls(jit_model=module, param_names=param_names, **embedding.to_args()) @classmethod def parameter_to_list( - cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence] - ) -> Tuple[List[str], List[torch.Tensor]]: + cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: list[Sentence] + ) -> tuple[list[str], list[torch.Tensor]]: tensors = embedding.prepare_tensors(sentences) param_names = list(inspect.signature(wrapper.forward).parameters.keys()) params = [] @@ -998,7 +998,7 @@ def __init__( @register_embeddings class TransformerEmbeddings(TransformerBaseEmbeddings): - onnx_cls: Type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings + onnx_cls: type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings def __init__( self, @@ -1021,11 +1021,11 @@ def __init__( force_max_length: bool = False, needs_manual_ocr: Optional[bool] = None, use_context_separator: bool = True, - transformers_tokenizer_kwargs: Dict[str, Any] = {}, - transformers_config_kwargs: Dict[str, Any] = {}, - transformers_model_kwargs: Dict[str, Any] = {}, + transformers_tokenizer_kwargs: dict[str, Any] = {}, + transformers_config_kwargs: dict[str, Any] = {}, + transformers_model_kwargs: dict[str, Any] = {}, peft_config=None, - peft_gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = {}, + peft_gradient_checkpointing_kwargs: Optional[dict[str, Any]] = {}, **kwargs, ) -> None: """Instantiate transformers embeddings. @@ -1503,11 +1503,11 @@ def forward( result["token_embeddings"] = all_token_embeddings return result - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: return self.forward(**tensors) def export_onnx( - self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs + self, path: Union[str, Path], example_sentences: list[Sentence], **kwargs ) -> TransformerOnnxEmbeddings: """Export TransformerEmbeddings to OnnxFormat. diff --git a/flair/file_utils.py b/flair/file_utils.py index 7a0118822b..518d69e809 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -12,8 +12,9 @@ import typing import warnings import zipfile +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence, Tuple, Union, cast +from typing import Optional, Union, cast from urllib.parse import urlparse import boto3 @@ -28,10 +29,10 @@ logger = logging.getLogger("flair") -url_proxies: Optional[typing.Dict[str, str]] = None +url_proxies: Optional[dict[str, str]] = None -def set_proxies(proxies: typing.Dict[str, str]) -> None: +def set_proxies(proxies: dict[str, str]) -> None: r"""Allows for data downloaded from urls to be forwarded to a proxy. see https://requests.readthedocs.io/en/latest/user/advanced/#proxies @@ -74,7 +75,7 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str: return decoded -def filename_to_url(filename: str) -> Tuple[str, Optional[str]]: +def filename_to_url(filename: str) -> tuple[str, Optional[str]]: """Recovers the the url from the encoded filename. Returns it and the ETag (which may be ``None``) @@ -374,7 +375,7 @@ def create_cache(self, *args, **kwargs): return decorator -def load_torch_state(model_file: str) -> typing.Dict[str, typing.Any]: +def load_torch_state(model_file: str) -> dict[str, typing.Any]: with warnings.catch_warnings(): warnings.filterwarnings("ignore") # load_big_file is a workaround byhttps://github.com/highway11git diff --git a/flair/inference_utils.py b/flair/inference_utils.py index 0310671534..c811bf39a1 100644 --- a/flair/inference_utils.py +++ b/flair/inference_utils.py @@ -126,7 +126,7 @@ def create_stores(model, backend="sqlite"): Also deletes the original vectors to save memory. """ for embedding in WordEmbeddingsStore._word_embeddings(model): - if type(embedding) == WordEmbeddings: + if isinstance(embedding, WordEmbeddings): WordEmbeddingsStore(embedding, backend) del embedding.precomputed_word_embeddings @@ -135,7 +135,7 @@ def load_stores(model, backend="sqlite"): """Loads the db versions of all word embeddings in the model.""" embeds = WordEmbeddingsStore._word_embeddings(model) for i, embedding in enumerate(embeds): - if type(embedding) == WordEmbeddings: + if isinstance(embedding, WordEmbeddings): embeds[i] = WordEmbeddingsStore(embedding, backend) @staticmethod diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 9f516a703c..0f1c916bfb 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -2,7 +2,7 @@ import re from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Union, cast +from typing import Any, Callable, Optional, Union, cast from unicodedata import category import torch @@ -19,9 +19,9 @@ class CandidateGenerator: """Given a string, the CandidateGenerator returns possible target classes as candidates.""" - def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = True) -> None: + def __init__(self, candidates: Union[str, dict[str, list[str]]], backoff: bool = True) -> None: # internal candidate lists of generator - self.mention_to_candidates_map: Dict = {} + self.mention_to_candidates_map: dict[str, list[str]] = {} # load Zelda candidates if so passed if isinstance(candidates, str) and candidates.lower() == "zelda": @@ -39,16 +39,15 @@ def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = self.mention_to_candidates_map = candidate_lists - elif isinstance(candidates, Dict): + elif isinstance(candidates, dict): self.mention_to_candidates_map = candidates else: raise ValueError(f"'{candidates}' could not be loaded.") - self.mention_to_candidates_map = cast(Dict[str, List[str]], self.mention_to_candidates_map) # if lower casing is enabled, create candidate lists of lower cased versions self.backoff = backoff if self.backoff: # create a new dictionary for lower cased mentions - lowercased_mention_to_candidates_map: Dict = {} + lowercased_mention_to_candidates_map: dict[str, list[str]] = {} # go through each mention and its candidates for mention, candidates_list in self.mention_to_candidates_map.items(): @@ -56,8 +55,8 @@ def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = # check if backoff mention already seen. If so, add candidates. Else, create new entry. if backoff_mention in lowercased_mention_to_candidates_map: current_candidates = lowercased_mention_to_candidates_map[backoff_mention] - lowercased_mention_to_candidates_map[backoff_mention] = set(current_candidates).union( - candidates_list + lowercased_mention_to_candidates_map[backoff_mention] = list( + set(current_candidates).union(candidates_list) ) else: lowercased_mention_to_candidates_map[backoff_mention] = candidates_list @@ -72,7 +71,7 @@ def _make_backoff_string(self, mention: str) -> str: backoff_mention = re.sub(" +", " ", backoff_mention) return backoff_mention - def get_candidates(self, mention: str) -> Set[str]: + def get_candidates(self, mention: str) -> set[str]: """Given a mention, this method returns a set of candidate classes.""" if self.backoff: mention = self._make_backoff_string(mention) @@ -125,7 +124,7 @@ def __init__( self._label_type = label_type self._span_label_type = span_label_type - cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = { + cases: dict[str, Callable[[Span, list[str]], torch.Tensor]] = { "average": self.emb_mean, "first": self.emb_first, "last": self.emb_last, @@ -155,7 +154,7 @@ def emb_firstAndLast(self, span: Span, embedding_names): def emb_mean(self, span, embedding_names): return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Span]: if self._span_label_type is not None: spans = sentence.get_spans(self._span_label_type) # only use span label type if there are predictions, otherwise search for output label type (training labels) @@ -223,7 +222,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): def label_type(self): return self._label_type - def _mask_scores(self, scores: torch.Tensor, data_points: List[Span]): + def _mask_scores(self, scores: torch.Tensor, data_points: list[Span]): if not self.candidates: return scores @@ -242,9 +241,7 @@ def _mask_scores(self, scores: torch.Tensor, data_points: List[Span]): return masked_scores @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SpanClassifier": - from typing import cast - + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "SpanClassifier": return cast("SpanClassifier", super().load(model_path=model_path)) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index cecf2d9e57..5a1382dd60 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -4,9 +4,10 @@ import re import string from abc import ABC, abstractmethod +from collections.abc import Sequence from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast +from typing import Any, Optional, Union, cast import numpy as np import torch @@ -89,7 +90,7 @@ "chemical": "ctd-chemicals", } -BIOMEDICAL_DICTIONARIES: Dict[str, Type] = { +BIOMEDICAL_DICTIONARIES: dict[str, type] = { "ctd-diseases": CTD_DISEASES_DICTIONARY, "ctd-chemicals": CTD_CHEMICALS_DICTIONARY, "ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY, @@ -151,7 +152,7 @@ def load_dictionary( class EntityPreprocessor(ABC): """A pre-processor used to transform / clean both entity mentions and entity names.""" - def initialize(self, sentences: List[Sentence]) -> None: + def initialize(self, sentences: list[Sentence]) -> None: """Initializes the pre-processor for a batch of sentences. This may be necessary for more sophisticated transformations. @@ -187,14 +188,14 @@ def process_entity_name(self, entity_name: str) -> str: """ @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": + def _from_state(cls, state_dict: dict[str, Any]) -> "EntityPreprocessor": if inspect.isabstract(cls): cls_name = state_dict.pop("__cls__", None) return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict) else: return cls(**state_dict) - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return {"__cls__": self.__class__.__name__} @@ -237,7 +238,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "lowercase": self.lowercase, @@ -270,9 +271,9 @@ def __init__( self.ab3p = pyab3p.Ab3p() self.preprocessor = preprocessor - self.abbreviation_dict: Dict[str, Dict[str, str]] = {} + self.abbreviation_dict: dict[str, dict[str, str]] = {} - def initialize(self, sentences: List[Sentence]) -> None: + def initialize(self, sentences: list[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = None) -> str: @@ -303,7 +304,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name - def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: + def _build_abbreviation_dict(self, sentences: list[flair.data.Sentence]) -> dict[str, dict[str, str]]: """Processes the given sentences with the Ab3P tool. The function returns a (nested) dictionary containing the abbreviations found for each sentence, e.g.: @@ -321,7 +322,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict Returns: abbreviation_dict: abbreviations and their resolution detected in each input sentence """ - abbreviation_dict: Dict[str, Dict[str, str]] = {} + abbreviation_dict: dict[str, dict[str, str]] = {} for sentence in sentences: sentence_text = sentence.to_original_text() @@ -331,14 +332,14 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict return abbreviation_dict - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "preprocessor": None if self.preprocessor is None else self.preprocessor._get_state(), } @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": + def _from_state(cls, state_dict: dict[str, Any]) -> "EntityPreprocessor": return cls( preprocessor=( None @@ -364,7 +365,7 @@ def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[Enti """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: + def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. Args: @@ -376,14 +377,14 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, """ @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + def _from_state(cls, state_dict: dict[str, Any]) -> "CandidateSearchIndex": if inspect.isabstract(cls): cls_name = state_dict.pop("__cls__", None) return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict) else: return cls(**state_dict) - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return {"__cls__": self.__class__.__name__} @@ -396,7 +397,7 @@ def __init__(self): Args: name_to_id_index: internal state, should only be set when loading an initialized index. """ - self.name_to_id_index: Dict[str, str] = {} + self.name_to_id_index: dict[str, str] = {} def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None: def p(text: str) -> str: @@ -407,8 +408,8 @@ def p(text: str) -> str: for synonym in candidate.synonyms: self.name_to_id_index[p(synonym)] = candidate.concept_id - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: - results: List[List[Tuple[str, float]]] = [] + def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]: + results: list[list[tuple[str, float]]] = [] for mention in entity_mentions: dict_entry = self.name_to_id_index.get(mention) if dict_entry is None: @@ -419,12 +420,12 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + def _from_state(cls, state_dict: dict[str, Any]) -> "CandidateSearchIndex": index = cls() index.name_to_id_index = state_dict["name_to_id_index"] return index - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "name_to_id_index": self.name_to_id_index, @@ -436,7 +437,7 @@ class SemanticCandidateSearchIndex(CandidateSearchIndex): def __init__( self, - embeddings: Dict[str, DocumentEmbeddings], + embeddings: dict[str, DocumentEmbeddings], hybrid_search: bool, similarity_metric: SimilarityMetric = SimilarityMetric.INNER_PRODUCT, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, @@ -460,8 +461,8 @@ def __init__( self.show_progress = show_progress self.batch_size = batch_size - self.ids: List[str] = [] - self._precomputed_embeddings: Dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])} + self.ids: list[str] = [] + self._precomputed_embeddings: dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])} @classmethod def bi_encoder( @@ -479,7 +480,7 @@ def bi_encoder( if model_name_or_path in PRETRAINED_MODELS: similarity_metric = PRETRAINED_MODEL_TO_SIMILARITY_METRIC[model_name_or_path] - embeddings: Dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)} + embeddings: dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)} if hybrid_search: if dictionary is None: @@ -515,7 +516,7 @@ def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[Enti def p(text: str) -> str: return preprocessor.process_entity_name(text) if preprocessor is not None else text - texts: List[str] = [] + texts: list[str] = [] self.ids = [] for candidate in dictionary.candidates: texts.append(p(candidate.concept_name)) @@ -564,8 +565,8 @@ def p(text: str) -> str: sent.clear_embeddings() self._precomputed_embeddings["sparse"] = np.stack(sparse_embs, axis=0) - def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: - query_embeddings: Dict[str, List] = {"dense": []} + def embed(self, entity_mentions: list[str]) -> dict[str, np.ndarray]: + query_embeddings: dict[str, list[np.ndarray]] = {"dense": []} inputs = [Sentence(name) for name in entity_mentions] @@ -600,7 +601,7 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: return {k: np.stack(v, axis=0) for k, v in query_embeddings.items()} - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: + def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. Args: @@ -634,10 +635,10 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchIndex": + def _from_state(cls, state_dict: dict[str, Any]) -> "SemanticCandidateSearchIndex": index = cls( embeddings=cast( - Dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()} + dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()} ), similarity_metric=SimilarityMetric(state_dict["similarity_metric"]), sparse_weight=state_dict["sparse_weight"], @@ -649,7 +650,7 @@ def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchInde index._precomputed_embeddings = state_dict["precomputed_embeddings"] return index - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "embeddings": {k: emb.save_embeddings() for k, emb in self.embeddings.items()}, @@ -670,7 +671,7 @@ def __init__( self, candidate_generator: CandidateSearchIndex, preprocessor: EntityPreprocessor, - entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]], + entity_label_types: Union[str, Sequence[str], dict[str, set[str]]], label_type: str, dictionary: EntityLinkingDictionary, batch_size: int = 1024, @@ -698,8 +699,8 @@ def __init__( super().__init__() def get_entity_label_types( - self, entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]] - ) -> Dict[str, Set[str]]: + self, entity_label_types: Union[str, Sequence[str], dict[str, set[str]]] + ) -> dict[str, set[str]]: """Find out what NER labels to extract from sentence. Args: @@ -709,9 +710,9 @@ def get_entity_label_types( To use all labels from 'ner', pass 'ner' """ if isinstance(entity_label_types, str): - entity_label_types = cast(Dict[str, Set[str]], {entity_label_types: {}}) + entity_label_types = cast(dict[str, set[str]], {entity_label_types: {}}) elif isinstance(entity_label_types, Sequence): - entity_label_types = cast(Dict[str, Set[str]], {label: {} for label in entity_label_types}) + entity_label_types = cast(dict[str, set[str]], {label: {} for label in entity_label_types}) entity_label_types = { label: {normalize_entity_type(e) for e in entity_types} @@ -728,9 +729,9 @@ def label_type(self): def dictionary(self) -> EntityLinkingDictionary: return self._dictionary - def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict[str, Set[str]]) -> List[Label]: + def extract_entities_mentions(self, sentence: Sentence, entity_label_types: dict[str, set[str]]) -> list[Label]: """Extract tagged mentions from sentences.""" - entities_mentions: List[Label] = [] + entities_mentions: list[Label] = [] # NOTE: This is a hacky workaround for the fact that # the `label_type`s in `Classifier.load('hunflair)` are @@ -762,10 +763,10 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], top_k: int = 1, pred_label_type: Optional[str] = None, - entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Set[str]]]] = None, + entity_label_types: Optional[Union[str, Sequence[str], dict[str, set[str]]]] = None, batch_size: Optional[int] = None, ) -> None: """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. @@ -859,7 +860,7 @@ def _fetch_model(model_name: str) -> str: return hf_download(model_name) @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs) -> "EntityMentionLinker": candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"]) preprocessor = EntityPreprocessor._from_state(state["entity_preprocessor"]) entity_label_types = state["entity_label_types"] @@ -961,7 +962,7 @@ def __get_model_path_and_entity_type( model_name_or_path: str, entity_type: Optional[str] = None, hybrid_search: bool = False, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Try to figure out what model the user wants.""" if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( @@ -1039,24 +1040,24 @@ def __get_dictionary_path( return dictionary_name_or_path - def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]: raise NotImplementedError("The EntityLinker cannot be trained") @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLinker": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "EntityMentionLinker": from typing import cast return cast("EntityMentionLinker", super().load(model_path=model_path)) def evaluate( self, - data_points: Union[List[Sentence], Dataset], + data_points: Union[list[Sentence], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: str = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("accuracy", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("accuracy", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, k: int = 1, diff --git a/flair/models/language_model.py b/flair/models/language_model.py index ed417f2434..d85db2fb93 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -1,6 +1,6 @@ import math from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import logsumexp, nn @@ -111,7 +111,7 @@ def init_hidden(self, bsz): def get_representation( self, - strings: List[str], + strings: list[str], start_marker: str, end_marker: str, chars_per_chunk: int = 512, @@ -119,7 +119,7 @@ def get_representation( len_longest_str: int = len(max(strings, key=len)) # pad strings with whitespaces to longest sentence - padded_strings: List[str] = [] + padded_strings: list[str] = [] for string in strings: if not self.is_forward_lm: @@ -141,11 +141,11 @@ def get_representation( padding_char_index = self.dictionary.get_idx_for_item(" ") - batches: List[torch.Tensor] = [] + batches: list[torch.Tensor] = [] # push each chunk through the RNN language model for chunk in chunks: len_longest_chunk: int = len(max(chunk, key=len)) - sequences_as_char_indices: List[List[int]] = [] + sequences_as_char_indices: list[list[int]] = [] for string in chunk: char_indices = self.dictionary.get_idx_for_items(list(string)) char_indices += [padding_char_index] * (len_longest_chunk - len(string)) @@ -176,7 +176,7 @@ def get_output(self, text: str): def repackage_hidden(self, h): """Wraps hidden states in new Variables, to detach them from their history.""" - if type(h) == torch.Tensor: + if isinstance(h, torch.Tensor): return h.clone().detach() else: return tuple(self.repackage_hidden(v) for v in h) @@ -296,7 +296,7 @@ def generate_text( number_of_characters: int = 1000, temperature: float = 1.0, break_on_suffix=None, - ) -> Tuple[str, float]: + ) -> tuple[str, float]: if prefix == "": prefix = "\n" diff --git a/flair/models/lemmatizer_model.py b/flair/models/lemmatizer_model.py index 6f0854d4b5..55fa34698c 100644 --- a/flair/models/lemmatizer_model.py +++ b/flair/models/lemmatizer_model.py @@ -1,6 +1,6 @@ import logging from math import inf -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -159,7 +159,7 @@ def label_type(self): def words_to_char_indices( self, - tokens: List[str], + tokens: list[str], end_symbol=True, start_symbol=False, padding_in_front=False, @@ -202,7 +202,7 @@ def words_to_char_indices( return tensor - def forward_pass(self, sentences: Union[List[Sentence], Sentence]): + def forward_pass(self, sentences: Union[list[Sentence], Sentence]): if isinstance(sentences, Sentence): sentences = [sentences] @@ -247,7 +247,7 @@ def decode(self, decoder_input_indices, initial_hidden_states, all_encoder_outpu output_vectors = self.character_decoder(output) return output_vectors, hidden - def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[Optional[torch.Tensor], ...]: + def _prepare_tensors(self, sentences: list[Sentence]) -> tuple[Optional[torch.Tensor], ...]: # get all tokens tokens = [token for sentence in sentences for token in sentence] @@ -290,7 +290,7 @@ def forward( encoder_input_indices: Optional[torch.Tensor], lengths: Optional[torch.Tensor], token_embedding_hidden: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: # variable to store initial hidden states for decoder initial_hidden_for_decoder = [] @@ -340,7 +340,7 @@ def forward( return initial_hidden, all_encoder_outputs - def encode(self, sentences: List[Sentence]): + def encode(self, sentences: list[Sentence]): tensors = self._prepare_tensors(sentences) return self.forward(*tensors) @@ -396,14 +396,14 @@ def _calculate_loss(self, scores, labels): return self.loss(scores_in_correct_format, target), len(labels) - def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]: scores, labels = self.forward_pass(sentences) return self._calculate_loss(scores, labels) def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size: int = 16, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -474,7 +474,7 @@ def predict( # option 1: greedy decoding if self.beam_size == 1: # predictions - predicted: List[List[Union[int, float]]] = [[] for _ in range(number_tokens)] + predicted: list[list[Union[int, float]]] = [[] for _ in range(number_tokens)] for _decode_step in range(max_length): # decode next character @@ -525,7 +525,7 @@ def predict( # keep track of how many hypothesis were completed for each token n_completed = [0 for _ in range(number_tokens)] # cpu - final_candidates: List[List[Tuple[torch.Tensor, float]]] = [[] for _ in range(number_tokens)] # cpu + final_candidates: list[list[tuple[torch.Tensor, float]]] = [[] for _ in range(number_tokens)] # cpu # if all_encoder_outputs returned, expand them to beam size (otherwise keep this as None) batched_encoding_output = ( @@ -552,24 +552,24 @@ def predict( # check if an end symbol has been predicted and, in that case, set hypothesis aside end_symbols = (index_candidates == self.end_index).nonzero(as_tuple=False) - for tuple in end_symbols: + for row in end_symbols: # if the sequence is already ended, do not record as candidate - if sequences[tuple[0], -1].item() == self.end_index: + if sequences[row[0], -1].item() == self.end_index: continue # index of token in in list tokens_in_batch - token_number = torch.div(tuple[0], self.beam_size, rounding_mode="trunc") + token_number = torch.div(row[0], self.beam_size, rounding_mode="trunc") # print(token_number) - seq = sequences[tuple[0], :] # hypothesis sequence + seq = sequences[row[0], :] # hypothesis sequence # hypothesis score - score = (scores[tuple[0]] + log_probabilities[tuple[0], tuple[1]]) / (len(seq) + 1) + score = (scores[row[0]] + log_probabilities[row[0], row[1]]) / (len(seq) + 1) final_candidates[token_number].append((seq, score.item())) # TODO: remove token if number of completed hypothesis exceeds given value n_completed[token_number] += 1 # set score of corresponding entry to -inf so it will not be expanded - log_probabilities[tuple[0], tuple[1]] = -inf + log_probabilities[row[0], row[1]] = -inf # get leading_indices for next expansion # find highest scoring hypothesis among beam_size*beam_size possible ones for each token @@ -594,8 +594,8 @@ def predict( # a list of length beam_size*batch_size # where the first three inidices belong to the first token, the next three to the second token, # and so on - beam_numbers: List[int] = [] - seq_numbers: List[int] = [] + beam_numbers: list[int] = [] + seq_numbers: list[int] = [] for i, row in enumerate(indices_per_token): beam_numbers.extend(i * self.beam_size + index.item() // self.beam_size for index in row) diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py index 733751eff7..414eb46197 100644 --- a/flair/models/multitask_model.py +++ b/flair/models/multitask_model.py @@ -2,7 +2,7 @@ import random import typing from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch @@ -27,9 +27,9 @@ class MultitaskModel(flair.nn.Classifier): def __init__( self, - models: List[flair.nn.Classifier], - task_ids: Optional[List[str]] = None, - loss_factors: Optional[List[float]] = None, + models: list[flair.nn.Classifier], + task_ids: Optional[list[str]] = None, + loss_factors: Optional[list[float]] = None, use_all_tasks: bool = False, ) -> None: """Instantiates the MultiTaskModel. @@ -42,10 +42,10 @@ def __init__( """ super().__init__() - task_ids_internal: List[str] = task_ids if task_ids else [f"Task_{i}" for i in range(len(models))] + task_ids_internal: list[str] = task_ids if task_ids else [f"Task_{i}" for i in range(len(models))] - self.tasks: Dict[str, flair.nn.Classifier] = {} - self.loss_factors: Dict[str, float] = {} + self.tasks: dict[str, flair.nn.Classifier] = {} + self.loss_factors: dict[str, float] = {} self.use_all_tasks = use_all_tasks if not loss_factors: @@ -63,10 +63,10 @@ def __init__( def forward(self, *args) -> torch.Tensor: raise NotImplementedError("`forward` is not used for multitask learning") - def _prepare_tensors(self, data_points: List[DT]) -> Tuple[torch.Tensor, ...]: + def _prepare_tensors(self, data_points: list[DT]) -> tuple[torch.Tensor, ...]: raise NotImplementedError("`_prepare_tensors` is not used for multitask learning") - def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]: """Calls the respective forward loss of each model and sums them weighted by their loss factors. Args: @@ -92,7 +92,9 @@ def predict( task.predict(sentences, **predictargs) @staticmethod - def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_tasks: bool = False) -> Dict: + def split_batch_to_task_ids( + sentences: Union[list[Sentence], Sentence], all_tasks: bool = False + ) -> dict[str, list[int]]: """Splits a batch of sentences to its respective model. If single sentence is assigned to several tasks (i.e. same corpus but different tasks), then the model @@ -104,7 +106,7 @@ def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_task Returns: Key-value pairs as (task_id, list of sentences ids in batch) """ - batch_to_task_mapping: Dict[str, List[int]] = {} + batch_to_task_mapping: dict[str, list[int]] = {} for sentence_id, sentence in enumerate(sentences): if all_tasks: multitask_ids = sentence.get_labels("multitask_id") @@ -122,7 +124,7 @@ def evaluate( # type: ignore[override] data_points, gold_label_type: str, out_path: Optional[Union[str, Path]] = None, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), evaluate_all: bool = True, **evalargs, ) -> Result: @@ -161,7 +163,7 @@ def evaluate( # type: ignore[override] loss = torch.tensor(0.0, device=flair.device) main_score = 0.0 all_detailed_results = "" - all_classification_report: Dict[str, Dict[str, Any]] = {} + all_classification_report: dict[str, dict[str, Any]] = {} for task_id, split in batch_split.items(): result = self.tasks[task_id].evaluate( @@ -203,7 +205,7 @@ def evaluate( # type: ignore[override] def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for model in self.tasks.values(): yield from model.get_used_tokens(corpus, context_length, respect_document_boundaries) @@ -272,7 +274,7 @@ def _fetch_model(model_name) -> str: return model_name @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "MultitaskModel": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "MultitaskModel": from typing import cast return cast("MultitaskModel", super().load(model_path=model_path)) diff --git a/flair/models/pairwise_classification_model.py b/flair/models/pairwise_classification_model.py index 262fd08cb5..6308573edc 100644 --- a/flair/models/pairwise_classification_model.py +++ b/flair/models/pairwise_classification_model.py @@ -1,5 +1,4 @@ import typing -from typing import List import torch @@ -69,7 +68,7 @@ def __init__( def label_type(self): return self._label_type - def _get_data_points_from_sentence(self, sentence: TextPair) -> List[TextPair]: + def _get_data_points_from_sentence(self, sentence: TextPair) -> list[TextPair]: return [sentence] def _get_embedding_for_data_point(self, prediction_data_point: TextPair) -> torch.Tensor: @@ -119,7 +118,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index c3f34e0f69..9a1c2704be 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -1,5 +1,6 @@ +from collections.abc import Iterable from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -90,7 +91,7 @@ def label_type(self): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> Iterable[List[str]]: + ) -> Iterable[list[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] @@ -99,14 +100,14 @@ def get_used_tokens( yield [t.text for t in sentence_pair.second.left_context(context_length, respect_document_boundaries)] yield [t.text for t in sentence_pair.second.right_context(context_length, respect_document_boundaries)] - def forward_loss(self, pairs: List[TextPair]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, pairs: list[TextPair]) -> tuple[torch.Tensor, int]: loss, num = self._forward_loss_and_scores(pairs=pairs, return_num=True, return_scores=False) assert isinstance(loss, torch.Tensor) assert isinstance(num, int) return loss, num - def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, return_scores=True) -> Tuple: + def _forward_loss_and_scores(self, pairs: list[TextPair], return_num=True, return_scores=True) -> tuple: # make a forward pass to produce embedded data points and labels pairs = [pair for pair in pairs if self._filter_data_point(pair)] @@ -128,7 +129,7 @@ def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, retur # calculate the loss loss, num = self._calculate_loss(scores, target_tensor) - return_value: Tuple[Any, ...] = (loss,) + return_value: tuple[Any, ...] = (loss,) if return_num: return_value += (num,) @@ -138,10 +139,10 @@ def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, retur return return_value - def _calculate_loss(self, scores: torch.Tensor, target_tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + def _calculate_loss(self, scores: torch.Tensor, target_tensor: torch.Tensor) -> tuple[torch.Tensor, int]: return self.loss_function(scores, target_tensor), target_tensor.size(0) - def _prepare_target_tensor(self, pairs: List[TextPair]): + def _prepare_target_tensor(self, pairs: list[TextPair]): target_values = [ torch.tensor([float(label.value) for label in pair.get_labels(self.label_name)], dtype=torch.float) for pair in pairs @@ -152,7 +153,7 @@ def _prepare_target_tensor(self, pairs: List[TextPair]): def _filter_data_point(self, pair: TextPair) -> bool: return len(pair) > 0 - def _encode_data_points(self, data_points: List[TextPair]) -> torch.Tensor: + def _encode_data_points(self, data_points: list[TextPair]) -> torch.Tensor: # get a tensor of data points data_point_tensor = torch.stack([self._get_embedding_for_data_point(data_point) for data_point in data_points]) @@ -203,7 +204,7 @@ def _get_state_dict(self): return model_state @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): """Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict). Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary. @@ -227,12 +228,12 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): def predict( self, - pairs: Union[TextPair, List[TextPair]], + pairs: Union[TextPair, list[TextPair]], mini_batch_size: int = 32, verbose: bool = False, label_name: Optional[str] = None, embedding_storage_mode="none", - ) -> List[TextPair]: + ) -> list[TextPair]: if label_name is None: label_name = self.label_name if self.label_name is not None else "label" @@ -278,13 +279,13 @@ def predict( def evaluate( self, - data_points: Union[List[TextPair], Dataset], + data_points: Union[list[TextPair], Dataset], gold_label_type: str, out_path: Union[str, Path, None] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("correlation", "pearson"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, diff --git a/flair/models/prefixed_tagger.py b/flair/models/prefixed_tagger.py index 05d8fa8c34..b001653bdc 100644 --- a/flair/models/prefixed_tagger.py +++ b/flair/models/prefixed_tagger.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import torch from torch.utils.data import Dataset @@ -26,7 +26,7 @@ class SentenceAugmentationStrategy(ABC): @abstractmethod def augment_sentence( - self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None + self, sentence: Sentence, annotation_layers: Optional[Union[str, list[str]]] = None ) -> PrefixedSentence: """Augments the given sentence text with additional instructions for working / predicting the task on the given annotations. @@ -64,7 +64,7 @@ def _init_strategy_with_state_dict(cls, state, **kwargs): """Initializes the strategy from the given state.""" def augment_dataset( - self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None + self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, list[str]]] = None ) -> FlairDatapointDataset[PrefixedSentence]: """Transforms a dataset into a dataset containing augmented sentences specific to the `PrefixedSequenceTagger`. @@ -78,14 +78,14 @@ def augment_dataset( Returns: A dataset of augmented sentences specific to the `PrefixedSequenceTagger` """ data_loader: DataLoader = DataLoader(dataset, batch_size=1) - original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)] + original_sentences: list[Sentence] = [batch[0] for batch in iter(data_loader)] augmented_sentences = [self.augment_sentence(sentence, annotation_layers) for sentence in original_sentences] return FlairDatapointDataset(augmented_sentences) def augment_corpus( - self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None + self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, list[str]]] = None ) -> Corpus[PrefixedSentence]: """Transforms a corpus into a corpus containing augmented sentences specific to the `PrefixedSequenceTagger`. @@ -120,7 +120,7 @@ class EntityTypeTaskPromptAugmentationStrategy(SentenceAugmentationStrategy): "[Tag gene and disease] Mutations in the TP53 tumour suppressor gene are found in ~50% of human cancers" """ - def __init__(self, entity_types: List[str]): + def __init__(self, entity_types: list[str]): if len(entity_types) <= 0: raise AssertionError @@ -128,7 +128,7 @@ def __init__(self, entity_types: List[str]): self.task_prompt = self._build_tag_prompt_prefix(entity_types) def augment_sentence( - self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None + self, sentence: Sentence, annotation_layers: Optional[Union[str, list[str]]] = None ) -> PrefixedSentence: # Prepend the task description prompt to the sentence text augmented_sentence = PrefixedSentence( @@ -182,7 +182,7 @@ def apply_predictions( ] orig_span.add_label(target_annotation_layer, label.value, label.score) - def _build_tag_prompt_prefix(self, entity_types: List[str]) -> List[str]: + def _build_tag_prompt_prefix(self, entity_types: list[str]) -> list[str]: if len(self.entity_types) == 1: prompt = f"[ Tag {entity_types[0]} ]" else: @@ -219,29 +219,29 @@ def _init_model_with_state_dict(cls, state, **kwargs): return super()._init_model_with_state_dict(state, augmentation_strategy=strategy, **kwargs) @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "PrefixedSequenceTagger": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "PrefixedSequenceTagger": from typing import cast return cast("PrefixedSequenceTagger", super().load(model_path=model_path)) - def forward_loss(self, sentences: Union[List[Sentence], List[PrefixedSentence]]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: Union[list[Sentence], list[PrefixedSentence]]) -> tuple[torch.Tensor, int]: # If all sentences are not augmented -> augment them if all(isinstance(sentence, Sentence) for sentence in sentences): # mypy does not infer the type of "sentences" restricted by the if statement - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) sentences = self.augment_sentences(sentences=sentences, annotation_layers=self.tag_type) elif not all(isinstance(sentence, PrefixedSentence) for sentence in sentences): raise ValueError("All passed sentences must be either uniformly augmented or not.") # mypy does not infer the type of "sentences" restricted by code above - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) return super().forward_loss(sentences) def predict( self, - sentences: Union[List[Sentence], Sentence, List[PrefixedSentence], PrefixedSentence], + sentences: Union[list[Sentence], Sentence, list[PrefixedSentence], PrefixedSentence], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -260,7 +260,7 @@ def predict( # If all sentences are already augmented (i.e. compatible with this class), just forward the sentences if all(isinstance(sentence, PrefixedSentence) for sentence in sentences): # mypy does not infer the type of "sentences" restricted by the if statement - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) return super().predict( sentences, @@ -280,12 +280,12 @@ def predict( for sentence in sentences: sentence.remove_labels(prediction_label_type) - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) # Augment sentences - copy all annotation of the given tag type augmented_sentences = self.augment_sentences(sentences, self.tag_type) - mypy_safe_augmented_sentences = cast(List[Sentence], augmented_sentences) + mypy_safe_augmented_sentences = cast(list[Sentence], augmented_sentences) # Predict on augmented sentence and store it in an internal annotation layer / label loss_and_count = super().predict( @@ -312,8 +312,8 @@ def predict( return loss_and_count def augment_sentences( - self, sentences: Union[Sentence, List[Sentence]], annotation_layers: Optional[Union[str, List[str]]] = None - ) -> List[PrefixedSentence]: + self, sentences: Union[Sentence, list[Sentence]], annotation_layers: Optional[Union[str, list[str]]] = None + ) -> list[PrefixedSentence]: if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset): sentences = [sentences] diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index 35c244d960..e41981c899 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -1,7 +1,7 @@ import re import typing from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Union +from typing import Union from flair.data import Sentence, Span, Token @@ -15,8 +15,8 @@ class TokenCollection: """ sentence: Sentence - __tokens_start_pos: List[int] = field(init=False, default_factory=list) - __tokens_end_pos: List[int] = field(init=False, default_factory=list) + __tokens_start_pos: list[int] = field(init=False, default_factory=list) + __tokens_end_pos: list[int] = field(init=False, default_factory=list) def __post_init__(self): for token in self.tokens: @@ -24,10 +24,10 @@ def __post_init__(self): self.__tokens_end_pos.append(token.end_position) @property - def tokens(self) -> List[Token]: + def tokens(self) -> list[Token]: return list(self.sentence) - def get_token_span(self, span: Tuple[int, int]) -> Span: + def get_token_span(self, span: tuple[int, int]) -> Span: """Find a span by the token character positions. Given an interval specified with start and end pos as tuple, this function returns a Span object @@ -45,7 +45,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span: class RegexpTagger: - def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]) -> None: + def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> None: r"""This tagger is capable of tagging sentence objects with given regexp -> label mappings. I.e: The tuple (r'(["\'])(?:(?=(\\?))\2.)*?\1', 'QUOTE') maps every match of the regexp to @@ -58,14 +58,14 @@ def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]) -> No Args: mapping: A list of tuples or a single tuple representing a mapping as regexp -> label """ - self._regexp_mapping: Dict[str, typing.Pattern] = {} + self._regexp_mapping: dict[str, typing.Pattern] = {} self.register_labels(mapping=mapping) @property def registered_labels(self): return self._regexp_mapping - def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]): + def register_labels(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]): """Register a regexp -> label mapping. Args: @@ -81,7 +81,7 @@ def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]] f"Couldn't compile regexp '{regexp}' for label '{label}'. Aborted with error: '{err.msg}'" ) - def remove_labels(self, labels: Union[List[str], str]): + def remove_labels(self, labels: Union[list[str], str]): """Remove a registered regexp -> label mapping given by label. Args: @@ -101,7 +101,7 @@ def _listify(element: object) -> list: else: return element - def predict(self, sentences: Union[List[Sentence], Sentence]) -> List[Sentence]: + def predict(self, sentences: Union[list[Sentence], Sentence]) -> list[Sentence]: """Predict the given sentences according to the registered mappings.""" if not isinstance(sentences, list): sentences = [sentences] @@ -122,7 +122,7 @@ def _label(self, sentence: Sentence): for label, pattern in self._regexp_mapping.items(): for match in pattern.finditer(sentence.to_original_text()): - span: Tuple[int, int] = match.span() + span: tuple[int, int] = match.span() try: token_span = collection.get_token_span(span) except ValueError: diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 53ccabac36..9c6c69577f 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -2,17 +2,12 @@ import logging import typing from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence from pathlib import Path from typing import ( Any, - Dict, - Iterator, - List, NamedTuple, Optional, - Sequence, - Set, - Tuple, Union, cast, ) @@ -50,7 +45,7 @@ class EncodedSentence(Sentence): class EncodingStrategy(ABC): """The encoding of the head and tail entities in a sentence with a relation annotation.""" - special_tokens: Set[str] = set() + special_tokens: set[str] = set() def __init__(self, add_special_tokens: bool = False) -> None: self.add_special_tokens = add_special_tokens @@ -84,7 +79,7 @@ class EntityMask(EncodingStrategy): - "Larry Page and [TAIL] founded [HEAD]" -> Relation(head='Google', tail='Sergey Brin'). """ - special_tokens: Set[str] = {"[HEAD]", "[TAIL]"} + special_tokens: set[str] = {"[HEAD]", "[TAIL]"} def encode_head(self, head_span: Span, label: Label) -> str: return "[HEAD]" @@ -126,7 +121,7 @@ class EntityMarker(EncodingStrategy): -> Relation(head='Google', tail='Sergey Brin'). """ - special_tokens: Set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} + special_tokens: set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) @@ -254,8 +249,8 @@ def __init__( embeddings: DocumentEmbeddings, label_dictionary: Dictionary, label_type: str, - entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]], - entity_pair_labels: Optional[Set[Tuple[str, str]]] = None, + entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]], + entity_pair_labels: Optional[set[tuple[str, str]]] = None, entity_threshold: Optional[float] = None, cross_augmentation: bool = True, encoding_strategy: EncodingStrategy = TypedEntityMarker(), @@ -298,7 +293,7 @@ def __init__( ) if isinstance(entity_label_types, str): - self.entity_label_types: Dict[str, Optional[Set[str]]] = {entity_label_types: None} + self.entity_label_types: dict[str, Optional[set[str]]] = {entity_label_types: None} elif isinstance(entity_label_types, Sequence): self.entity_label_types = {entity_label_type: None for entity_label_type in entity_label_types} else: @@ -316,7 +311,7 @@ def __init__( and self.encoding_strategy.special_tokens and isinstance(self.embeddings, TransformerDocumentEmbeddings) ): - special_tokens: List[str] = list(self.encoding_strategy.special_tokens) + special_tokens: list[str] = list(self.encoding_strategy.special_tokens) tokenizer = self.embeddings.tokenizer tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) self.embeddings.model.resize_token_embeddings(len(tokenizer)) @@ -355,7 +350,7 @@ def _valid_entities(self, sentence: Sentence) -> Iterator[_Entity]: def _entity_pair_permutations( self, sentence: Sentence, - ) -> Iterator[Tuple[_Entity, _Entity, Optional[str]]]: + ) -> Iterator[tuple[_Entity, _Entity, Optional[str]]]: """Yields all valid entity pair permutations (relation candidates). If the passed sentence contains relation annotations, @@ -370,10 +365,10 @@ def _entity_pair_permutations( Yields: Tuples of (HEAD, TAIL, gold_label): The head and tail `_Entity`s` have span references to the passed sentence. """ - valid_entities: List[_Entity] = list(self._valid_entities(sentence)) + valid_entities: list[_Entity] = list(self._valid_entities(sentence)) # Use a dictionary to find gold relation annotations for a given entity pair - relation_to_gold_label: Dict[str, str] = { + relation_to_gold_label: dict[str, str] = { relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value for relation in sentence.get_relations(self.label_type) } @@ -420,13 +415,13 @@ def _encode_sentence( assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence." # Pre-compute non-leading head and tail tokens for entity masking - non_leading_head_tokens: List[Token] = head.span.tokens[1:] - non_leading_tail_tokens: List[Token] = tail.span.tokens[1:] + non_leading_head_tokens: list[Token] = head.span.tokens[1:] + non_leading_tail_tokens: list[Token] = tail.span.tokens[1:] # We can not use the plaintext of the head/tail span in the sentence as the mask/marker # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. - encoded_sentence_tokens: List[str] = [] + encoded_sentence_tokens: list[str] = [] for token in original_sentence: if token is head.span[0]: encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) @@ -456,7 +451,7 @@ def _encode_sentence( def _encode_sentence_for_inference( self, sentence: Sentence, - ) -> Iterator[Tuple[EncodedSentence, Relation]]: + ) -> Iterator[tuple[EncodedSentence, Relation]]: """Create Encoded Sentences and Relation pairs for Inference. Yields encoded sentences annotated with their gold relation and @@ -505,7 +500,7 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS yield masked_sentence - def transform_sentence(self, sentences: Union[Sentence, List[Sentence]]) -> List[EncodedSentence]: + def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]: """Transforms sentences into encoded sentences specific to the `RelationClassifier`. For more information on the internal sentence transformation procedure, @@ -541,7 +536,7 @@ def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset Returns: A dataset of encoded sentences specific to the `RelationClassifier` """ data_loader: DataLoader = DataLoader(dataset, batch_size=1) - original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)] + original_sentences: list[Sentence] = [batch[0] for batch in iter(data_loader)] return FlairDatapointDataset(self.transform_sentence(original_sentences)) def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]: @@ -568,10 +563,10 @@ def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]: ) def _get_embedding_for_data_point(self, prediction_data_point: EncodedSentence) -> torch.Tensor: - embedding_names: List[str] = self.embeddings.get_names() + embedding_names: list[str] = self.embeddings.get_names() return prediction_data_point.get_embedding(embedding_names) - def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> List[EncodedSentence]: + def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> list[EncodedSentence]: """Returns the encoded sentences to which labels are added. To encode sentences, use the `transform` function of the `RelationClassifier`. @@ -597,14 +592,14 @@ def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> List[Enco def predict( self, - sentences: Union[List[Sentence], List[EncodedSentence], Sentence, EncodedSentence], + sentences: Union[list[Sentence], list[EncodedSentence], Sentence, EncodedSentence], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> Optional[Tuple[torch.Tensor, int]]: + ) -> Optional[tuple[torch.Tensor, int]]: """Predicts the class labels for the given sentence(s). Standard `Sentence` objects and `EncodedSentences` specific to the `RelationClassifier` are allowed as input. @@ -626,14 +621,14 @@ def predict( if not isinstance(sentences, list): sentences = [sentences] - loss: Optional[Tuple[torch.Tensor, int]] - encoded_sentences: List[EncodedSentence] + loss: Optional[tuple[torch.Tensor, int]] + encoded_sentences: list[EncodedSentence] if all(isinstance(sentence, EncodedSentence) for sentence in sentences): # Deal with the case where all sentences are encoded sentences # mypy does not infer the type of "sentences" restricted by the if statement - encoded_sentences = cast(List[EncodedSentence], sentences) + encoded_sentences = cast(list[EncodedSentence], sentences) loss = super().predict( encoded_sentences, mini_batch_size=mini_batch_size, @@ -646,8 +641,8 @@ def predict( elif all(not isinstance(sentence, EncodedSentence) for sentence in sentences): # Deal with the case where all sentences are standard (non-encoded) sentences - Sentence.set_context_for_sentences(cast(List[Sentence], sentences)) - sentences_with_relation_reference: List[Tuple[EncodedSentence, Relation]] = list( + Sentence.set_context_for_sentences(cast(list[Sentence], sentences)) + sentences_with_relation_reference: list[tuple[EncodedSentence, Relation]] = list( itertools.chain.from_iterable(self._encode_sentence_for_inference(sentence) for sentence in sentences) ) @@ -672,8 +667,8 @@ def predict( return loss if return_loss else None - def _get_state_dict(self) -> Dict[str, Any]: - model_state: Dict[str, Any] = { + def _get_state_dict(self) -> dict[str, Any]: + model_state: dict[str, Any] = { **super()._get_state_dict(), "embeddings": self.embeddings.save_embeddings(use_state_dict=False), "label_dictionary": self.label_dictionary, @@ -689,7 +684,7 @@ def _get_state_dict(self) -> Dict[str, Any]: return model_state @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): return super()._init_model_with_state_dict( state, embeddings=state["embeddings"], @@ -719,7 +714,7 @@ def allow_unk_tag(self) -> bool: def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: yield from super().get_used_tokens(corpus, context_length, respect_document_boundaries) for sentence in _iter_dataset(corpus.get_all_sentences()): for span in sentence.get_spans(self.label_type): @@ -727,7 +722,7 @@ def get_used_tokens( yield self.encoding_strategy.encode_tail(span, span.get_label(self.label_type)).split(" ") @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "RelationClassifier": from typing import cast return cast("RelationClassifier", super().load(model_path=model_path)) diff --git a/flair/models/relation_extractor_model.py b/flair/models/relation_extractor_model.py index 795e8a517f..0c56abf5bd 100644 --- a/flair/models/relation_extractor_model.py +++ b/flair/models/relation_extractor_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import torch @@ -18,7 +18,7 @@ def __init__( embeddings: flair.embeddings.TokenEmbeddings, label_type: str, entity_label_type: str, - entity_pair_filters: Optional[List[Tuple[str, str]]] = None, + entity_pair_filters: Optional[list[tuple[str, str]]] = None, pooling_operation: str = "first_last", train_on_gold_pairs_only: bool = False, **classifierargs, @@ -56,13 +56,13 @@ def __init__( # whether to use gold entity pairs, and whether to filter entity pairs by type if entity_pair_filters is not None: - self.entity_pair_filters: Optional[Set[Tuple[str, str]]] = set(entity_pair_filters) + self.entity_pair_filters: Optional[set[tuple[str, str]]] = set(entity_pair_filters) else: self.entity_pair_filters = None self.to(flair.device) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Relation]: entity_pairs = [] entity_spans = sentence.get_spans(self.entity_label_type) @@ -172,7 +172,7 @@ def _fetch_model(model_name) -> str: return model_name @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationExtractor": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "RelationExtractor": from typing import cast return cast("RelationExtractor", super().load(model_path=model_path)) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 16f20a0ddf..2c0fc00ebb 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -1,7 +1,7 @@ import logging import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import torch import torch.nn @@ -40,7 +40,7 @@ def __init__( word_dropout: float = 0.05, locked_dropout: float = 0.5, train_initial_hidden_state: bool = False, - loss_weights: Optional[Dict[str, float]] = None, + loss_weights: Optional[dict[str, float]] = None, init_from_state_dict: bool = False, allow_unk_predictions: bool = False, ) -> None: @@ -204,7 +204,7 @@ def __init__( def label_type(self): return self.tag_type - def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor: + def _init_loss_weights(self, loss_weights: dict[str, float]) -> torch.Tensor: """Initializes the loss weights based on given dictionary. Args: @@ -267,7 +267,7 @@ def RNN( return RNN - def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]: # if there are no sentences, there is no loss if len(sentences) == 0: return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0 @@ -281,7 +281,7 @@ def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: # calculate loss given scores and labels return self._calculate_loss(scores, gold_labels) - def _prepare_tensors(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, torch.LongTensor]: + def _prepare_tensors(self, data_points: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, torch.LongTensor]: sentences = [data_points] if not isinstance(data_points, list) else data_points self.embeddings.embed(sentences) @@ -331,15 +331,15 @@ def forward(self, sentence_tensor: torch.Tensor, lengths: torch.LongTensor): return scores - def _calculate_loss(self, scores: torch.Tensor, labels: torch.LongTensor) -> Tuple[torch.Tensor, int]: + def _calculate_loss(self, scores: torch.Tensor, labels: torch.LongTensor) -> tuple[torch.Tensor, int]: if labels.size(0) == 0: return torch.tensor(0.0, requires_grad=True, device=flair.device), 1 return self.loss_function(scores, labels), len(labels) - def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torch.LongTensor, torch.Tensor]: + def _make_padded_tensor_for_batch(self, sentences: list[Sentence]) -> tuple[torch.LongTensor, torch.Tensor]: names = self.embeddings.get_names() - lengths: List[int] = [len(sentence.tokens) for sentence in sentences] + lengths: list[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * longest_token_sequence_in_batch, @@ -382,7 +382,7 @@ def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor): return scores - def _get_gold_labels(self, sentences: List[Sentence]) -> List[str]: + def _get_gold_labels(self, sentences: list[Sentence]) -> list[str]: """Extracts gold labels from each sentence. Args: @@ -419,7 +419,7 @@ def _get_gold_labels(self, sentences: List[Sentence]) -> List[str]: return labels - def _prepare_label_tensor(self, sentences: List[Sentence]): + def _prepare_label_tensor(self, sentences: list[Sentence]): gold_labels = self._get_gold_labels(sentences) labels = torch.tensor( [self.label_dictionary.get_idx_for_item(label) for label in gold_labels], @@ -430,7 +430,7 @@ def _prepare_label_tensor(self, sentences: List[Sentence]): def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -462,7 +462,7 @@ def predict( if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset): sentences = [sentences] - Sentence.set_context_for_sentences(cast(List[Sentence], sentences)) + Sentence.set_context_for_sentences(cast(list[Sentence], sentences)) # filter empty sentences sentences = [sentence for sentence in sentences if len(sentence) > 0] @@ -542,7 +542,7 @@ def predict( return overall_loss, label_count return None - def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool): + def _standard_inference(self, features: torch.Tensor, batch: list[Sentence], probabilities_for_all_classes: bool): """Softmax over emission scores from forward propagation. Args: @@ -573,7 +573,7 @@ def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], pro return predictions, all_tags - def _all_scores_for_token(self, sentences: List[Sentence], score_tensor: torch.Tensor, lengths: List[int]): + def _all_scores_for_token(self, sentences: list[Sentence], score_tensor: torch.Tensor, lengths: list[int]): """Returns all scores for each tag in tag dictionary.""" scores = score_tensor.numpy() tokens = [token for sentence in sentences for token in sentence] @@ -861,7 +861,7 @@ def push_to_hub( return repo_url @staticmethod - def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]: filtered_sentences = [sentence for sentence in sentences if sentence.tokens] if len(sentences) != len(filtered_sentences): log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") @@ -919,7 +919,7 @@ def _print_predictions(self, batch, gold_label_type): return lines @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SequenceTagger": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "SequenceTagger": from typing import cast return cast("SequenceTagger", super().load(model_path=model_path)) diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py index ed84d1a6b7..5c87f49f0c 100644 --- a/flair/models/sequence_tagger_utils/viterbi.py +++ b/flair/models/sequence_tagger_utils/viterbi.py @@ -1,5 +1,3 @@ -from typing import Tuple - import numpy as np import torch import torch.nn @@ -7,7 +5,7 @@ from torch.nn.utils.rnn import pack_padded_sequence import flair -from flair.data import Dictionary, Label, List, Sentence +from flair.data import Dictionary, Label, Sentence START_TAG: str = "" STOP_TAG: str = "" @@ -141,8 +139,8 @@ def __init__(self, tag_dictionary: Dictionary) -> None: self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG) def decode( - self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence] - ) -> Tuple[List, List]: + self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: list[Sentence] + ) -> tuple[list[list[tuple[str, float]]], list[list[list[Label]]]]: """Decoding function returning the most likely sequence of tags. Args: @@ -211,7 +209,7 @@ def decode( scores = softmax(scores_upto_t, dim=2) confidences = torch.max(scores, dim=2) - tags = [] + tags: list[list[tuple[str, float]]] = [] for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths): tags.append( [ @@ -230,7 +228,7 @@ def _all_scores_for_token( score_tensor: torch.Tensor, tag_sequences: torch.Tensor, lengths: torch.IntTensor, - sentences: List[Sentence], + sentences: list[Sentence], ): """Returns all scores for each tag in tag dictionary.""" scores = score_tensor.numpy() diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index a7a41bdb5d..4f5cb85731 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -3,7 +3,7 @@ from abc import ABC from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -30,14 +30,14 @@ class FewshotClassifier(flair.nn.Classifier[Sentence], ABC): def __init__(self) -> None: self._current_task = None - self._task_specific_attributes: Dict[str, Dict[str, Any]] = {} + self._task_specific_attributes: dict[str, dict[str, Any]] = {} self.label_nearest_map = None self.tars_model: flair.nn.Classifier[Sentence] self.separator: str super().__init__() - def forward_loss(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, data_points: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]: if not isinstance(data_points, list): data_points = [data_points] @@ -54,7 +54,7 @@ def tars_embeddings(self): def _get_tars_formatted_sentence(self, label, sentence): raise NotImplementedError - def _get_tars_formatted_sentences(self, sentences: List[Sentence]): + def _get_tars_formatted_sentences(self, sentences: list[Sentence]): label_text_pairs = [] all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] for sentence in sentences: @@ -173,7 +173,7 @@ def is_current_task_multi_label(self): def add_and_switch_to_new_task( self, task_name: str, - label_dictionary: Union[List, Set, Dictionary, str], + label_dictionary: Union[list, set, Dictionary, str], label_type: str, multi_label: bool = True, force_switch: bool = False, @@ -219,7 +219,7 @@ def add_and_switch_to_new_task( self.switch_to_task(task_name) - def list_existing_tasks(self) -> Set[str]: + def list_existing_tasks(self) -> set[str]: """Lists existing tasks in the loaded TARS model on the console.""" return set(self._task_specific_attributes.keys()) @@ -246,7 +246,7 @@ def _drop_task(self, task_name): log.warning("No task exists with the name `%s`.", task_name) @staticmethod - def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]: filtered_sentences = [sentence for sentence in sentences if sentence.tokens] if len(sentences) != len(filtered_sentences): log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") @@ -258,8 +258,8 @@ def label_type(self): def predict_zero_shot( self, - sentences: Union[List[Sentence], Sentence], - candidate_label_set: Union[List[str], Set[str], str], + sentences: Union[list[Sentence], Sentence], + candidate_label_set: Union[list[str], set[str], str], multi_label: bool = True, ): """Make zero shot predictions from the TARS model. @@ -307,14 +307,14 @@ def predict_zero_shot( def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: yield from super().get_used_tokens(corpus, context_length, respect_document_boundaries) for label in self.get_current_label_dictionary().idx2item: yield [label.decode("utf-8")] yield [self.separator] @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "FewshotClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "FewshotClassifier": from typing import cast return cast("FewshotClassifier", super().load(model_path=model_path)) @@ -472,7 +472,7 @@ def tars_embeddings(self): def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size=32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -532,12 +532,12 @@ def predict( if not batch: continue - tars_sentences: List[Sentence] = [] - all_labels_to_sentence: List[Dict[str, Sentence]] = [] + tars_sentences: list[Sentence] = [] + all_labels_to_sentence: list[dict[str, Sentence]] = [] for sentence in batch: # always remove tags first sentence.remove_labels(label_name) - labels_to_sentence: Dict[str, Sentence] = {} + labels_to_sentence: dict[str, Sentence] = {} for label in all_labels: tars_sentence = self._get_tars_formatted_sentence(label, sentence) tars_sentences.append(tars_sentence) @@ -570,7 +570,7 @@ def predict( if most_probable_first: import operator - already_set_indices: List[int] = [] + already_set_indices: list[int] = [] sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) sorted_x.reverse() @@ -648,7 +648,7 @@ def _print_predictions(self, batch, gold_label_type): return lines @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TARSTagger": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TARSTagger": from typing import cast return cast("TARSTagger", super().load(model_path=model_path)) @@ -832,7 +832,7 @@ def tars_embeddings(self): def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size=32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -907,12 +907,12 @@ def predict( if not batch: continue - tars_sentences: List[Sentence] = [] - all_labels_to_sentence: List[Dict[str, Sentence]] = [] + tars_sentences: list[Sentence] = [] + all_labels_to_sentence: list[dict[str, Sentence]] = [] for sentence in batch: # always remove tags first sentence.remove_labels(label_name) - labels_to_sentence: Dict[str, Sentence] = {} + labels_to_sentence: dict[str, Sentence] = {} for label in all_labels: tars_sentence = self._get_tars_formatted_sentence(label, sentence) tars_sentences.append(tars_sentence) @@ -972,7 +972,7 @@ def predict( return None @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TARSClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TARSClassifier": from typing import cast return cast("TARSClassifier", super().load(model_path=model_path)) diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py index 1b330a0da3..7f4e00d2c4 100644 --- a/flair/models/text_classification_model.py +++ b/flair/models/text_classification_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Union import torch @@ -56,7 +56,7 @@ def _get_embedding_for_data_point(self, prediction_data_point: Sentence) -> torc embedding_names = self.embeddings.get_names() return prediction_data_point.get_embedding(embedding_names) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Sentence]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Sentence]: return [sentence] def _get_state_dict(self): @@ -133,7 +133,7 @@ def label_type(self): return self._label_type @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TextClassifier": from typing import cast return cast("TextClassifier", super().load(model_path=model_path)) diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index 894ce3087e..d1ad98d4e0 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -1,7 +1,7 @@ import logging import typing from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -43,7 +43,7 @@ def __init__( def label_type(self): return self.label_name - def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[torch.Tensor]: + def _prepare_tensors(self, sentences: list[Sentence]) -> tuple[torch.Tensor]: self.document_embeddings.embed(sentences) embedding_names = self.document_embeddings.get_names() text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences] @@ -55,14 +55,14 @@ def forward(self, *args: torch.Tensor) -> torch.Tensor: label_scores = self.decoder(text_embedding_tensor) return label_scores - def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]: labels = self._labels_to_tensor(sentences) text_embedding_tensor = self._prepare_tensors(sentences) scores = self.forward(*text_embedding_tensor) return self.loss_function(scores.squeeze(1), labels), len(sentences) - def _labels_to_tensor(self, sentences: List[Sentence]): + def _labels_to_tensor(self, sentences: list[Sentence]): indices = [ torch.tensor([float(label.value) for label in sentence.get_labels(self.label_name)], dtype=torch.float) for sentence in sentences @@ -74,12 +74,12 @@ def _labels_to_tensor(self, sentences: List[Sentence]): def predict( self, - sentences: Union[Sentence, List[Sentence]], + sentences: Union[Sentence, list[Sentence]], mini_batch_size: int = 32, verbose: bool = False, label_name: Optional[str] = None, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> List[Sentence]: + ) -> list[Sentence]: if label_name is None: label_name = self.label_name if self.label_name is not None else "label" @@ -123,7 +123,7 @@ def predict( return sentences - def forward_labels_and_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_labels_and_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, torch.Tensor]: labels = self._labels_to_tensor(sentences) text_embedding_tensor = self._prepare_tensors(sentences) scores = self.forward(*text_embedding_tensor) @@ -132,13 +132,13 @@ def forward_labels_and_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tens def evaluate( self, - data_points: Union[List[Sentence], Dataset], + data_points: Union[list[Sentence], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, @@ -154,7 +154,7 @@ def evaluate( metric = MetricRegression("Evaluation") - lines: List[str] = [] + lines: list[str] = [] total_count = 0 for batch in data_loader: if isinstance(batch, Sentence): @@ -227,21 +227,21 @@ def _init_model_with_state_dict(cls, state, **kwargs): ) @staticmethod - def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]: filtered_sentences = [sentence for sentence in sentences if sentence.tokens] if len(sentences) != len(filtered_sentences): log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") return filtered_sentences @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextRegressor": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TextRegressor": from typing import cast return cast("TextRegressor", super().load(model_path=model_path)) def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence] yield [t.text for t in sentence.left_context(context_length, respect_document_boundaries)] diff --git a/flair/models/triple_classification_model.py b/flair/models/triple_classification_model.py index 1c1337f9b0..9f1a57a23e 100644 --- a/flair/models/triple_classification_model.py +++ b/flair/models/triple_classification_model.py @@ -1,5 +1,4 @@ import typing -from typing import List import torch @@ -69,7 +68,7 @@ def __init__( def label_type(self): return self._label_type - def _get_data_points_from_sentence(self, sentence: TextTriple) -> List[TextTriple]: + def _get_data_points_from_sentence(self, sentence: TextTriple) -> list[TextTriple]: return [sentence] def _get_embedding_for_data_point(self, prediction_data_point: TextTriple) -> torch.Tensor: @@ -121,7 +120,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] diff --git a/flair/models/word_tagger_model.py b/flair/models/word_tagger_model.py index 2d32a54b06..5040a63728 100644 --- a/flair/models/word_tagger_model.py +++ b/flair/models/word_tagger_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Union import torch from deprecated.sphinx import deprecated @@ -99,7 +99,7 @@ def _get_embedding_for_data_point(self, prediction_data_point: Token) -> torch.T names = self.embeddings.get_names() return prediction_data_point.get_embedding(names) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Token]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Token]: # special handling during training if this is a span prediction problem if self.training and self.span_prediction_problem: for token in sentence.tokens: @@ -125,7 +125,7 @@ def _post_process_batch_after_prediction(self, batch, label_name): for sentence in batch: # internal variables previous_tag = "O-" - current_span: List[Token] = [] + current_span: list[Token] = [] for token in sentence: bioes_tag = token.get_label(label_name).value @@ -222,7 +222,7 @@ def _print_predictions(self, batch, gold_label_type): return lines @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TokenClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TokenClassifier": from typing import cast return cast("TokenClassifier", super().load(model_path=model_path)) diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py index 65f802148a..48cdbf39b0 100644 --- a/flair/nn/decoder.py +++ b/flair/nn/decoder.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional import torch @@ -151,11 +151,11 @@ class LabelVerbalizerDecoder(torch.nn.Module): def __init__(self, label_embedding: Embeddings, label_dictionary: Dictionary): super().__init__() self.label_embedding = label_embedding - self.verbalized_labels: List[Sentence] = self.verbalize_labels(label_dictionary) + self.verbalized_labels: list[Sentence] = self.verbalize_labels(label_dictionary) self.to(flair.device) @staticmethod - def verbalize_labels(label_dictionary: Dictionary) -> List[Sentence]: + def verbalize_labels(label_dictionary: Dictionary) -> list[Sentence]: """Takes a label dictionary and returns a list of sentences with verbalized labels. Args: diff --git a/flair/nn/model.py b/flair/nn/model.py index eeb5b7c84a..bf13baf2f1 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import torch.nn from torch import Tensor @@ -31,7 +31,7 @@ class Model(torch.nn.Module, typing.Generic[DT], ABC): Every new type of model must implement these methods. """ - model_card: Optional[Dict[str, Any]] = None + model_card: Optional[dict[str, Any]] = None @property @abstractmethod @@ -40,7 +40,7 @@ def label_type(self) -> str: raise NotImplementedError @abstractmethod - def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]: """Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training. @@ -50,13 +50,13 @@ def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: @abstractmethod def evaluate( self, - data_points: Union[List[DT], Dataset], + data_points: Union[list[DT], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, @@ -84,7 +84,7 @@ def evaluate( exclude_labels = exclude_labels if exclude_labels is not None else [] raise NotImplementedError - def _get_state_dict(self) -> Dict: + def _get_state_dict(self) -> dict: """Returns the state dictionary for this model.""" # Always include the name of the Model class for which the state dict holds state_dict = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__} @@ -92,7 +92,7 @@ def _get_state_dict(self) -> Dict: return state_dict @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): """Initialize the model from a state dictionary.""" if "embeddings" in kwargs: embeddings = kwargs.pop("embeddings") @@ -128,7 +128,7 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: torch.save(model_state, str(model_file), pickle_protocol=4) @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": """Loads the model from the given file. Args: @@ -238,7 +238,7 @@ class ReduceTransformerVocabMixin(ABC): @abstractmethod def get_used_tokens( self, corpus: Corpus, context_lenth: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: pass @@ -251,13 +251,13 @@ class Classifier(Model[DT], typing.Generic[DT], ReduceTransformerVocabMixin, ABC def evaluate( self, - data_points: Union[List[DT], Dataset], + data_points: Union[list[DT], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, @@ -281,10 +281,10 @@ def evaluate( average_over = 0 # variables for printing - lines: List[str] = [] + lines: list[str] = [] # variables for computing scores - all_spans: Set[str] = set() + all_spans: set[str] = set() all_true_values = {} all_predicted_values = {} @@ -476,7 +476,7 @@ def evaluate( ) # Create and populate score object for logging with all evaluation values, plus the loss - scores: Dict[Union[Tuple[str, ...], str], Any] = {} + scores: dict[Union[tuple[str, ...], str], Any] = {} for avg_type in ("micro avg", "macro avg"): for metric_type in ("f1-score", "precision", "recall"): @@ -514,7 +514,7 @@ def evaluate( @abstractmethod def predict( self, - sentences: Union[List[DT], DT], + sentences: Union[list[DT], DT], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -537,7 +537,7 @@ def predict( """ raise NotImplementedError - def _print_predictions(self, batch: List[DT], gold_label_type: str) -> List[str]: + def _print_predictions(self, batch: list[DT], gold_label_type: str) -> list[str]: lines = [] for datapoint in batch: # check if there is a label mismatch @@ -557,14 +557,14 @@ def _print_predictions(self, batch: List[DT], gold_label_type: str) -> List[str] def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence] yield [t.text for t in sentence.left_context(context_length, respect_document_boundaries)] yield [t.text for t in sentence.right_context(context_length, respect_document_boundaries)] @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Classifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Classifier": from typing import cast return cast("Classifier", super().load(model_path=model_path)) @@ -589,7 +589,7 @@ def __init__( word_dropout: float = 0.0, multi_label: bool = False, multi_label_threshold: float = 0.5, - loss_weights: Optional[Dict[str, float]] = None, + loss_weights: Optional[dict[str, float]] = None, decoder: Optional[torch.nn.Module] = None, inverse_model: bool = False, train_on_gold_pairs_only: bool = False, @@ -663,21 +663,21 @@ def _get_embedding_for_data_point(self, prediction_data_point: DT2) -> torch.Ten raise NotImplementedError @abstractmethod - def _get_data_points_from_sentence(self, sentence: DT) -> List[DT2]: + def _get_data_points_from_sentence(self, sentence: DT) -> list[DT2]: """Returns the data_points to which labels are added. The results should be of any type that inherits from DataPoint (Sentence, Span, Token, ... objects). """ raise NotImplementedError - def _get_data_points_for_batch(self, sentences: List[DT]) -> List[DT2]: + def _get_data_points_for_batch(self, sentences: list[DT]) -> list[DT2]: """Returns the data_points to which labels are added. The results should be of any type that inherits from DataPoint (Sentence, Span, Token, ... objects). """ return [data_point for sentence in sentences for data_point in self._get_data_points_from_sentence(sentence)] - def _get_label_of_datapoint(self, data_point: DT2) -> List[str]: + def _get_label_of_datapoint(self, data_point: DT2) -> list[str]: """Extracts the labels from the data points. Each data point might return a list of strings, representing multiple labels. @@ -701,7 +701,7 @@ def multi_label_threshold(self, x): # setter method else: self._multi_label_threshold = {"default": x} - def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tensor: + def _prepare_label_tensor(self, prediction_data_points: list[DT2]) -> torch.Tensor: labels = [self._get_label_of_datapoint(dp) for dp in prediction_data_points] if self.multi_label: return torch.tensor( @@ -726,7 +726,7 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens device=flair.device, ) - def _encode_data_points(self, sentences: List[DT], data_points: List[DT2]) -> Tensor: + def _encode_data_points(self, sentences: list[DT], data_points: list[DT2]) -> Tensor: # embed sentences if self.should_embed_sentence: self.embeddings.embed(sentences) @@ -747,7 +747,7 @@ def _mask_scores(self, scores: Tensor, data_points) -> Tensor: """Classes that inherit from DefaultClassifier may optionally mask scores.""" return scores - def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # make a forward pass to produce embedded data points and labels sentences = [sentence for sentence in sentences if self._filter_data_point(sentence)] @@ -773,10 +773,10 @@ def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]: # calculate the loss return self._calculate_loss(scores, label_tensor) - def _calculate_loss(self, scores: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, int]: + def _calculate_loss(self, scores: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, int]: return self.loss_function(scores, labels), labels.size(0) - def _sort_data(self, data_points: List[DT]) -> List[DT]: + def _sort_data(self, data_points: list[DT]) -> list[DT]: if len(data_points) == 0: return [] @@ -784,16 +784,16 @@ def _sort_data(self, data_points: List[DT]) -> List[DT]: return data_points # filter empty sentences - sentences = [sentence for sentence in typing.cast(List[Sentence], data_points) if len(sentence) > 0] + sentences = [sentence for sentence in typing.cast(list[Sentence], data_points) if len(sentence) > 0] # reverse sort all sequences by their length reordered_sentences = sorted(sentences, key=len, reverse=True) - return typing.cast(List[DT], reordered_sentences) + return typing.cast(list[DT], reordered_sentences) def predict( self, - sentences: Union[List[DT], DT], + sentences: Union[list[DT], DT], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -824,7 +824,7 @@ def predict( sentences = [sentences] if isinstance(sentences[0], Sentence): - Sentence.set_context_for_sentences(typing.cast(List[Sentence], sentences)) + Sentence.set_context_for_sentences(typing.cast(list[Sentence], sentences)) reordered_sentences = self._sort_data(sentences) @@ -832,7 +832,7 @@ def predict( return sentences if len(reordered_sentences) > mini_batch_size: - batches: Union[DataLoader, List[List[DT]]] = DataLoader( + batches: Union[DataLoader, list[list[DT]]] = DataLoader( dataset=FlairDatapointDataset(reordered_sentences), batch_size=mini_batch_size, ) @@ -990,7 +990,7 @@ def _get_state_dict(self): return state @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "DefaultClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "DefaultClassifier": from typing import cast return cast("DefaultClassifier", super().load(model_path=model_path)) diff --git a/flair/nn/multitask.py b/flair/nn/multitask.py index 6fa2f20c02..42c5665141 100644 --- a/flair/nn/multitask.py +++ b/flair/nn/multitask.py @@ -1,4 +1,5 @@ -from typing import Iterable, Tuple, Union +from collections.abc import Iterable +from typing import Union from flair.data import Corpus, MultiCorpus from flair.models import MultitaskModel @@ -6,18 +7,18 @@ def make_multitask_model_and_corpus( - mapping: Iterable[Union[Tuple[Classifier, Corpus], Tuple[Classifier, Corpus, float]]] -) -> Tuple[Model, Corpus]: + mapping: Iterable[Union[tuple[Classifier, Corpus], tuple[Classifier, Corpus, float]]] +) -> tuple[Model, Corpus]: models = [] corpora = [] loss_factors = [] ids = [] - for task_id, map in enumerate(mapping): - models.append(map[0]) - corpora.append(map[1]) - if len(map) == 3: - loss_factors.append(map[2]) + for task_id, _map in enumerate(mapping): + models.append(_map[0]) + corpora.append(_map[1]) + if len(_map) == 3: + loss_factors.append(_map[2]) else: loss_factors.append(1.0) diff --git a/flair/samplers.py b/flair/samplers.py index 135dfb3310..53ad40c4c5 100644 --- a/flair/samplers.py +++ b/flair/samplers.py @@ -1,7 +1,6 @@ import logging import random from collections import defaultdict -from typing import Dict import torch from torch.utils.data.sampler import Sampler @@ -36,7 +35,7 @@ def set_dataset(self, data_source): self.indices = list(range(len(data_source))) # first determine the distribution of classes in the dataset - label_count: Dict[str, int] = defaultdict(int) + label_count: dict[str, int] = defaultdict(int) for sentence in data_source: for label in sentence.labels: label_count[label.value] += 1 diff --git a/flair/splitter.py b/flair/splitter.py index 9f7e502c87..2b6c90cd7f 100644 --- a/flair/splitter.py +++ b/flair/splitter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from segtok.segmenter import split_multi @@ -25,7 +25,7 @@ class SentenceSplitter(ABC): the sentence splitter's configuration. """ - def split(self, text: str, link_sentences: Optional[bool] = True) -> List[Sentence]: + def split(self, text: str, link_sentences: Optional[bool] = True) -> list[Sentence]: sentences = self._perform_split(text) if not link_sentences: return sentences @@ -34,7 +34,7 @@ def split(self, text: str, link_sentences: Optional[bool] = True) -> List[Senten return sentences @abstractmethod - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: raise NotImplementedError @property @@ -62,11 +62,11 @@ def __init__(self, tokenizer: Tokenizer = SegtokTokenizer()) -> None: super().__init__() self._tokenizer = tokenizer - def _perform_split(self, text: str) -> List[Sentence]: - plain_sentences: List[str] = split_multi(text) + def _perform_split(self, text: str) -> list[Sentence]: + plain_sentences: list[str] = split_multi(text) sentence_offset = 0 - sentences: List[Sentence] = [] + sentences: list[Sentence] = [] for sentence in plain_sentences: try: sentence_offset = text.index(sentence, sentence_offset) @@ -133,7 +133,7 @@ def __init__(self, model: Union[Any, str], tokenizer: Optional[Tokenizer] = None else: self._tokenizer = tokenizer - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: document = self.model(text) sentences = [ @@ -192,7 +192,7 @@ def __init__(self, tag: str, tokenizer: Tokenizer = SegtokTokenizer()) -> None: self._tokenizer = tokenizer self.tag = tag - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: plain_sentences = text.split(self.tag) sentences = [] @@ -252,7 +252,7 @@ def __init__(self, tokenizer: Tokenizer = SegtokTokenizer()) -> None: super().__init__() self._tokenizer = tokenizer - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: return [Sentence(text=text, use_tokenizer=self._tokenizer, start_position=0)] @property diff --git a/flair/tokenization.py b/flair/tokenization.py index 185e944d3b..b377c419e6 100644 --- a/flair/tokenization.py +++ b/flair/tokenization.py @@ -1,7 +1,7 @@ import logging import sys from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable from segtok.segmenter import split_single from segtok.tokenizer import split_contractions, word_tokenizer @@ -20,7 +20,7 @@ class Tokenizer(ABC): """ @abstractmethod - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: raise NotImplementedError @property @@ -57,11 +57,11 @@ def __init__(self, model) -> None: "spacy model or the name of the model to load." ) - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: from spacy.tokens.doc import Doc doc: Doc = self.model.make_doc(text) - words: List[str] = [] + words: list[str] = [] for word in doc: if len(word.text.strip()) == 0: continue @@ -82,12 +82,12 @@ class SegtokTokenizer(Tokenizer): def __init__(self) -> None: super().__init__() - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: return SegtokTokenizer.run_tokenize(text) @staticmethod - def run_tokenize(text: str) -> List[str]: - words: List[str] = [] + def run_tokenize(text: str) -> list[str]: + words: list[str] = [] sentences = split_single(text) for sentence in sentences: @@ -105,12 +105,12 @@ class SpaceTokenizer(Tokenizer): def __init__(self) -> None: super().__init__() - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: return SpaceTokenizer.run_tokenize(text) @staticmethod - def run_tokenize(text: str) -> List[str]: - tokens: List[str] = [] + def run_tokenize(text: str) -> list[str]: + tokens: list[str] = [] word = "" index = -1 for index, char in enumerate(text): @@ -166,8 +166,8 @@ def __init__(self, tokenizer: str, sudachi_mode: str = "A") -> None: self.sentence_tokenizer = konoha.SentenceTokenizer() self.word_tokenizer = konoha.WordTokenizer(tokenizer, mode=sudachi_mode) - def tokenize(self, text: str) -> List[str]: - words: List[str] = [] + def tokenize(self, text: str) -> list[str]: + words: list[str] = [] sentences = self.sentence_tokenizer.tokenize(text) for sentence in sentences: @@ -184,11 +184,11 @@ def name(self) -> str: class TokenizerWrapper(Tokenizer): """Helper class to wrap tokenizer functions to the class-based tokenizer interface.""" - def __init__(self, tokenizer_func: Callable[[str], List[str]]) -> None: + def __init__(self, tokenizer_func: Callable[[str], list[str]]) -> None: super().__init__() self.tokenizer_func = tokenizer_func - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: return self.tokenizer_func(text) @property @@ -225,7 +225,7 @@ def __init__(self) -> None: " Note that the scispacy version and the version of the model must match to work properly!" ) - def combined_rule_prefixes() -> List[str]: + def combined_rule_prefixes() -> list[str]: """Helper function that returns the prefix pattern for the tokenizer. It is a helper function to accommodate spacy tests that only test prefixes. @@ -270,9 +270,9 @@ def combined_rule_prefixes() -> List[str]: self.model.tokenizer.prefix_search = prefix_re.search self.model.tokenizer.infix_finditer = infix_re.finditer - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: sentence = self.model(text) - words: List[str] = [] + words: list[str] = [] for word in sentence: words.append(word.text) return words diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index eb374ed75d..341cead776 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -3,8 +3,9 @@ import math import random import time +from collections.abc import Iterable from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Optional, Union import torch from torch import cuda @@ -155,16 +156,16 @@ def __init__( self, model: LanguageModel, corpus: TextCorpus, - optimizer: Type[Optimizer] = SGD, + optimizer: type[Optimizer] = SGD, test_mode: bool = False, epoch: int = 0, split: int = 0, loss: float = 10000, - optimizer_state: Optional[Dict[str, Any]] = None, - scaler_state: Optional[Dict[str, Any]] = None, + optimizer_state: Optional[dict[str, Any]] = None, + scaler_state: Optional[dict[str, Any]] = None, ) -> None: self.model: LanguageModel = model - self.optimizer: Type[Optimizer] = optimizer + self.optimizer: type[Optimizer] = optimizer self.corpus: TextCorpus = corpus self.test_mode: bool = test_mode @@ -362,7 +363,7 @@ def train( ) with open(loss_txt, "a") as myfile: - myfile.write("%s\n" % summary) + myfile.write(f"{summary}\n") log.info(summary) log.info("-" * 89) @@ -386,7 +387,7 @@ def train( summary = f"TEST: valid loss {test_loss:5.4f} | valid ppl {math.exp(test_loss):8.4f}" with open(loss_txt, "a") as myfile: - myfile.write("%s\n" % summary) + myfile.write(f"{summary}\n") log.info(summary) log.info("-" * 89) @@ -440,7 +441,7 @@ def _repackage_hidden(h): def load_checkpoint( checkpoint_file: Union[str, Path], corpus: TextCorpus, - optimizer: Type[Optimizer] = SGD, + optimizer: type[Optimizer] = SGD, ): if isinstance(checkpoint_file, str): checkpoint_file = Path(checkpoint_file) diff --git a/flair/trainers/plugins/base.py b/flair/trainers/plugins/base.py index 958d57b785..663a78d6d3 100644 --- a/flair/trainers/plugins/base.py +++ b/flair/trainers/plugins/base.py @@ -1,19 +1,14 @@ import logging from collections import defaultdict +from collections.abc import Iterator, Sequence from inspect import isclass, signature from itertools import count from queue import Queue from typing import ( Any, Callable, - Dict, - Iterator, - List, NewType, Optional, - Sequence, - Set, - Type, Union, cast, ) @@ -21,7 +16,7 @@ log = logging.getLogger("flair") -PluginArgument = Union["BasePlugin", Type["BasePlugin"]] +PluginArgument = Union["BasePlugin", type["BasePlugin"]] HookHandleId = NewType("HookHandleId", int) EventIdenifier = str @@ -34,7 +29,7 @@ class TrainingInterrupt(Exception): class Pluggable: """Dispatches events which attached plugins can react to.""" - valid_events: Optional[Set[EventIdenifier]] = None + valid_events: Optional[set[EventIdenifier]] = None def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None: """Initialize a `Pluggable`. @@ -42,11 +37,11 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None: Args: plugins: Plugins which should be attached to this `Pluggable`. """ - self._hook_handles: Dict[EventIdenifier, Dict[HookHandleId, HookHandle]] = defaultdict(dict) + self._hook_handles: dict[EventIdenifier, dict[HookHandleId, HookHandle]] = defaultdict(dict) self._hook_handle_id_counter = count() - self._plugins: List[BasePlugin] = [] + self._plugins: list[BasePlugin] = [] # This flag tracks, whether an event is currently being processed (otherwise it is added to the queue) self._processing_events = False @@ -181,7 +176,7 @@ class BasePlugin: def __init__(self) -> None: """Initialize the base plugin.""" - self._hook_handles: List[HookHandle] = [] + self._hook_handles: list[HookHandle] = [] self._pluggable: Optional[Pluggable] = None def attach_to(self, pluggable: Pluggable): @@ -260,7 +255,7 @@ def pluggable(self) -> Optional[Pluggable]: def __str__(self) -> str: return self.__class__.__name__ - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {"__cls__": f"{self.__module__}.{self.__class__.__name__}"} diff --git a/flair/trainers/plugins/functional/anneal_on_plateau.py b/flair/trainers/plugins/functional/anneal_on_plateau.py index d62b21fba1..ccd330bf0f 100644 --- a/flair/trainers/plugins/functional/anneal_on_plateau.py +++ b/flair/trainers/plugins/functional/anneal_on_plateau.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin, TrainingInterrupt from flair.trainers.plugins.metric_records import MetricRecord @@ -108,7 +108,7 @@ def __str__(self) -> str: f"min_learning_rate: '{self.min_learning_rate}'" ) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index 75ecb9bd98..1936177835 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin @@ -29,7 +29,7 @@ def after_training_epoch(self, epoch, **kw): model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py index 1000be6dd7..2258844129 100644 --- a/flair/trainers/plugins/functional/linear_scheduler.py +++ b/flair/trainers/plugins/functional/linear_scheduler.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict +from typing import Any from flair.optim import LinearSchedulerWithWarmup from flair.trainers.plugins.base import TrainerPlugin @@ -62,7 +62,7 @@ def after_training_batch(self, optimizer_was_run: bool, **kwargs): def __str__(self) -> str: return f"LinearScheduler | warmup_fraction: '{self.warmup_fraction}'" - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "warmup_fraction": self.warmup_fraction, diff --git a/flair/trainers/plugins/functional/reduce_transformer_vocab.py b/flair/trainers/plugins/functional/reduce_transformer_vocab.py index 162c667f88..86759c2fe2 100644 --- a/flair/trainers/plugins/functional/reduce_transformer_vocab.py +++ b/flair/trainers/plugins/functional/reduce_transformer_vocab.py @@ -1,6 +1,5 @@ import logging from pathlib import Path -from typing import List from transformer_smaller_training_vocab import reduce_train_vocab @@ -57,7 +56,7 @@ def save_model_at_the_end(self, **kw): self.model.save(self.base_path / "final-model.pt", checkpoint=self.save_optimizer_state) -def get_transformer_embeddings(model: Model) -> List[TransformerEmbeddings]: +def get_transformer_embeddings(model: Model) -> list[TransformerEmbeddings]: embeddings = model.tars_embeddings if isinstance(model, FewshotClassifier) else getattr(model, "embeddings", None) if embeddings is None: diff --git a/flair/trainers/plugins/functional/weight_extractor.py b/flair/trainers/plugins/functional/weight_extractor.py index ef5afe081e..4ba7c07621 100644 --- a/flair/trainers/plugins/functional/weight_extractor.py +++ b/flair/trainers/plugins/functional/weight_extractor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin from flair.training_utils import WeightExtractor @@ -21,7 +21,7 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw): if (iteration + 1) % modulo == 0: self.weight_extractor.extract_weights(self.model.state_dict(), iteration) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/loggers/log_file.py b/flair/trainers/plugins/loggers/log_file.py index a9b7453a09..21a8c54632 100644 --- a/flair/trainers/plugins/loggers/log_file.py +++ b/flair/trainers/plugins/loggers/log_file.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin from flair.training_utils import add_file_handler @@ -21,5 +21,5 @@ def close_file_handler(self, **kw): self.log_handler.close() log.removeHandler(self.log_handler) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {**super().get_state(), "base_path": str(self.base_path)} diff --git a/flair/trainers/plugins/loggers/loss_file.py b/flair/trainers/plugins/loggers/loss_file.py index 29c42fc930..b53a23a956 100644 --- a/flair/trainers/plugins/loggers/loss_file.py +++ b/flair/trainers/plugins/loggers/loss_file.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from flair.trainers.plugins.base import TrainerPlugin from flair.trainers.plugins.metric_records import MetricName @@ -10,7 +10,7 @@ class LossFilePlugin(TrainerPlugin): """Plugin that manages the loss.tsv file output.""" def __init__( - self, base_path, epoch: int, metrics_to_collect: Optional[Dict[Union[Tuple, str], str]] = None + self, base_path, epoch: int, metrics_to_collect: Optional[dict[Union[tuple, str], str]] = None ) -> None: super().__init__() @@ -56,9 +56,9 @@ def __init__( self.headers[metric_name] = f"{prefix.upper()}_{header}" # initialize the first log line - self.current_row: Optional[Dict[MetricName, str]] = None + self.current_row: Optional[dict[MetricName, str]] = None - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/loggers/metric_history.py b/flair/trainers/plugins/loggers/metric_history.py index 8d7c946e8d..a22cf1b0e7 100644 --- a/flair/trainers/plugins/loggers/metric_history.py +++ b/flair/trainers/plugins/loggers/metric_history.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Mapping +from collections.abc import Mapping +from typing import Any from flair.trainers.plugins.base import TrainerPlugin @@ -17,7 +18,7 @@ class MetricHistoryPlugin(TrainerPlugin): def __init__(self, metrics_to_collect: Mapping = default_metrics_to_collect) -> None: super().__init__() - self.metric_history: Dict[str, list] = {} + self.metric_history: dict[str, list] = {} self.metrics_to_collect: Mapping = metrics_to_collect for target in self.metrics_to_collect.values(): self.metric_history[target] = [] @@ -33,7 +34,7 @@ def after_training(self, **kw): """Returns metric history.""" self.trainer.return_values.update(self.metric_history) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "metrics_to_collect": dict(self.metrics_to_collect), diff --git a/flair/trainers/plugins/loggers/tensorboard.py b/flair/trainers/plugins/loggers/tensorboard.py index 59bba9f2e9..a7af50a521 100644 --- a/flair/trainers/plugins/loggers/tensorboard.py +++ b/flair/trainers/plugins/loggers/tensorboard.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin from flair.training_utils import log_line @@ -59,7 +59,7 @@ def _training_finally(self, **kw): assert self.writer is not None self.writer.close() - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "log_dir": str(self.log_dir) if self.log_dir is not None else None, diff --git a/flair/trainers/plugins/loggers/wandb.py b/flair/trainers/plugins/loggers/wandb.py index 8608fcdbd9..0f8dc89f73 100644 --- a/flair/trainers/plugins/loggers/wandb.py +++ b/flair/trainers/plugins/loggers/wandb.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin @@ -72,7 +72,7 @@ def metric_recorded(self, record): def _training_finally(self, **kw): self.writer.close() - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "emit_alerts": self.emit_alerts, diff --git a/flair/trainers/plugins/metric_records.py b/flair/trainers/plugins/metric_records.py index 034c021854..548b54fccd 100644 --- a/flair/trainers/plugins/metric_records.py +++ b/flair/trainers/plugins/metric_records.py @@ -1,14 +1,15 @@ import time +from collections.abc import Iterable, Iterator from dataclasses import dataclass from enum import Enum -from typing import Any, Iterable, Iterator, Optional, Tuple, Union +from typing import Any, Optional, Union RecordType = Enum("RecordType", ["scalar", "image", "histogram", "string", "scalar_list"]) class MetricName: def __init__(self, name) -> None: - self.parts: Tuple[str, ...] + self.parts: tuple[str, ...] if isinstance(name, str): self.parts = tuple(name.split("/")) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 03e6edc083..2f32b54c01 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -7,7 +7,7 @@ import warnings from inspect import signature from pathlib import Path -from typing import List, Optional, Tuple, Type, Union +from typing import Optional, Union import torch from torch.optim.sgd import SGD @@ -128,7 +128,7 @@ def train( base_path, anneal_factor: float = 0.5, patience: int = 3, - min_learning_rate: Union[float, List[float]] = 0.0001, + min_learning_rate: Union[float, list[float]] = 0.0001, initial_extra_patience: int = 0, anneal_with_restarts: bool = False, learning_rate: float = 0.1, @@ -137,17 +137,17 @@ def train( eval_batch_size: int = 64, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 100, - optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, + optimizer: type[torch.optim.Optimizer] = torch.optim.SGD, train_with_dev: bool = False, train_with_test: bool = False, reduce_transformer_vocab: bool = False, # evaluation and monitoring - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, - exclude_labels: Optional[List[str]] = None, + exclude_labels: Optional[list[str]] = None, # sampling and shuffling sampler=None, shuffle: bool = True, @@ -164,7 +164,7 @@ def train( create_loss_file: bool = True, write_weights: bool = False, # plugins - plugins: Optional[List[TrainerPlugin]] = None, + plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, **kwargs, ): @@ -211,17 +211,17 @@ def fine_tune( eval_batch_size: int = 16, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 10, - optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, + optimizer: type[torch.optim.Optimizer] = torch.optim.AdamW, train_with_dev: bool = False, train_with_test: bool = False, reduce_transformer_vocab: bool = False, # evaluation and monitoring - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = True, gold_label_dictionary_for_eval: Optional[Dictionary] = None, - exclude_labels: Optional[List[str]] = None, + exclude_labels: Optional[list[str]] = None, # sampling and shuffling sampler=None, shuffle: bool = True, @@ -240,7 +240,7 @@ def fine_tune( # amp use_amp: bool = False, # plugins - plugins: Optional[List[TrainerPlugin]] = None, + plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, **kwargs, ): @@ -304,20 +304,20 @@ def train_custom( eval_batch_size: int = 64, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 100, - optimizer: Type[torch.optim.Optimizer] = SGD, + optimizer: type[torch.optim.Optimizer] = SGD, train_with_dev: bool = False, train_with_test: bool = False, max_grad_norm: Optional[float] = 5.0, reduce_transformer_vocab: bool = False, # evaluation and monitoring - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, - exclude_labels: Optional[List[str]] = None, + exclude_labels: Optional[list[str]] = None, # sampling and shuffling - sampler: Optional[FlairSampler] = None, + sampler: Optional[Union[FlairSampler, type[FlairSampler]]] = None, shuffle: bool = True, shuffle_first_epoch: bool = True, # evaluation and monitoring @@ -334,7 +334,7 @@ def train_custom( # amp use_amp: bool = False, # plugins - plugins: Optional[List[TrainerPlugin]] = None, + plugins: Optional[list[TrainerPlugin]] = None, **kwargs, ) -> dict: """Trains any class that implements the flair.nn.Model interface. @@ -475,7 +475,7 @@ def train_custom( # initialize sampler if provided if sampler is not None: # init with default values if only class is provided - if inspect.isclass(sampler): + if isinstance(sampler, type): sampler = sampler() # set dataset to sample from sampler.set_dataset(train_data) diff --git a/flair/training_utils.py b/flair/training_utils.py index 0b4ef91cbf..9b38ec1ddb 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -5,7 +5,7 @@ from functools import reduce from math import inf from pathlib import Path -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union from scipy.stats import pearsonr, spearmanr from sklearn.metrics import mean_absolute_error, mean_squared_error @@ -25,7 +25,7 @@ def __init__( main_score: float, detailed_results: str, classification_report: Optional[dict] = None, - scores: Optional[Dict] = None, + scores: Optional[dict] = None, ) -> None: classification_report = classification_report if classification_report is not None else {} assert scores is not None and "loss" in scores, "No loss provided." @@ -47,8 +47,8 @@ class MetricRegression: def __init__(self, name) -> None: self.name = name - self.true: List[float] = [] - self.pred: List[float] = [] + self.true: list[float] = [] + self.pred: list[float] = [] def mean_squared_error(self): return mean_squared_error(self.true, self.pred) @@ -98,7 +98,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> if isinstance(directory, str): directory = Path(directory) self.weights_file = init_output_file(directory, "weights.txt") - self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list)) + self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights def extract_weights(self, state_dict, iteration): @@ -338,7 +338,7 @@ def init_output_file(base_path: Union[str, Path], file_name: str) -> Path: return file -def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionary) -> List[List[int]]: +def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionary) -> list[list[int]]: """Convert list of labels to a one hot list. Args: @@ -365,9 +365,9 @@ def add_file_handler(log, output_file): def store_embeddings( - data_points: Union[List[DT], Dataset], + data_points: Union[list[DT], Dataset], storage_mode: EmbeddingStorageMode, - dynamic_embeddings: Optional[List[str]] = None, + dynamic_embeddings: Optional[list[str]] = None, ): if isinstance(data_points, Dataset): data_points = list(_iter_dataset(data_points)) @@ -391,7 +391,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List]: +def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: diff --git a/flair/visual/ner_html.py b/flair/visual/ner_html.py index c71e108379..5b691a9e60 100644 --- a/flair/visual/ner_html.py +++ b/flair/visual/ner_html.py @@ -1,5 +1,5 @@ import html -from typing import List, Union +from typing import Union from flair.data import Sentence @@ -41,7 +41,7 @@ def split_to_spans(s: Sentence, label_name="ner"): def render_ner_html( - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], title: str = "Flair", colors={ "PER": "#F7FF53", diff --git a/flair/visual/training_curves.py b/flair/visual/training_curves.py index 1fd856b669..32947c3348 100644 --- a/flair/visual/training_curves.py +++ b/flair/visual/training_curves.py @@ -3,7 +3,7 @@ import math from collections import defaultdict from pathlib import Path -from typing import Dict, List, Union +from typing import Union import matplotlib.pyplot as plt import numpy as np @@ -27,7 +27,7 @@ class Plotter: def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") -> dict: file_name = Path(file_name) - training_curves: Dict[str, Dict[str, List[float]]] = { + training_curves: dict[str, dict[str, list[float]]] = { "train": {"loss": [], "score": []}, "test": {"loss": [], "score": []}, "dev": {"loss": [], "score": []}, @@ -70,7 +70,7 @@ def _extract_weight_data(file_name: Union[str, Path]) -> dict: if isinstance(file_name, str): file_name = Path(file_name) - weights: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) + weights: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) with open(file_name) as f: tsvin = csv.reader(f, delimiter="\t") @@ -151,7 +151,7 @@ def plot_weights(self, file_name: Union[str, Path]): log.info(f"Weights plots are saved in {path}") # to let user know the path of the save plots plt.close(fig) - def plot_training_curves(self, file_name: Union[str, Path], plot_values: List[str] = ["loss", "F1"]): + def plot_training_curves(self, file_name: Union[str, Path], plot_values: list[str] = ["loss", "F1"]): file_name = Path(file_name) fig = plt.figure(figsize=(15, 10)) diff --git a/flair/visual/tree_printer.py b/flair/visual/tree_printer.py index fc461d9f81..9753a37a09 100644 --- a/flair/visual/tree_printer.py +++ b/flair/visual/tree_printer.py @@ -1,5 +1,3 @@ -from typing import List - from pptree import print_tree from flair.data import Sentence, Token @@ -9,7 +7,7 @@ class NodeToken: def __init__(self, token: Token, tag_type: str) -> None: self.token: Token = token self.tag_type: str = tag_type - self.children: List[NodeToken] = [] + self.children: list[NodeToken] = [] def set_haed(self, parent): parent.children.append(self) @@ -19,7 +17,7 @@ def __str__(self) -> str: def tree_printer(sentence: Sentence, tag_type: str): - tree: List[NodeToken] = [NodeToken(token, tag_type) for token in sentence] + tree: list[NodeToken] = [NodeToken(token, tag_type) for token in sentence] for x in tree: if x.token.head_id != 0: head_token = x.token.get_head() diff --git a/pyproject.toml b/pyproject.toml index 78d1692a09..9f4c5a7535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 120 -target-version = ['py37'] +target-version = ['py39'] exclude = ''' ( /( @@ -49,7 +49,7 @@ ignore_errors = true [tool.ruff] line-length = 120 -target-version = "py38" +target-version = "py39" [tool.ruff.lint] #select = ["ALL"] # Uncommit to autofix all the things diff --git a/requirements-dev.txt b/requirements-dev.txt index 61d45acf8c..3b8fbde79c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ pytest-black-ng==0.4.* pytest-github-actions-annotate-failures>=0.1.8 pytest-mypy>=0.10.3 pytest-ruff==0.3.* -ruff==0.3.* +ruff==0.7.* types-dataclasses>=0.6.6 types-Deprecated>=1.2.9.2 types-requests>=2.28.11.17 diff --git a/resources/docs/EXPERIMENTS.md b/resources/docs/EXPERIMENTS.md index 69f1a5fbf2..c6bbe72a1c 100644 --- a/resources/docs/EXPERIMENTS.md +++ b/resources/docs/EXPERIMENTS.md @@ -55,7 +55,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ # GloVe embeddings WordEmbeddings('glove'), @@ -124,7 +124,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('de'), PooledFlairEmbeddings('german-forward'), PooledFlairEmbeddings('german-backward'), @@ -225,7 +225,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('crawl'), WordEmbeddings('twitter'), FlairEmbeddings('news-forward'), @@ -292,7 +292,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('crawl'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward'), @@ -361,7 +361,7 @@ tag_type = 'pos' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('extvec'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward'), @@ -416,7 +416,7 @@ tag_type = 'np' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('extvec'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward'), diff --git a/resources/docs/HUNFLAIR2.md b/resources/docs/HUNFLAIR2.md index 6f2c1474b3..032b4fe075 100644 --- a/resources/docs/HUNFLAIR2.md +++ b/resources/docs/HUNFLAIR2.md @@ -14,7 +14,7 @@ NER tools on unseen corpora. ## Quick Start #### Requirements and Installation -*HunFlair2* is based on Flair 0.14+ and Python 3.8+. If you do not have Python 3.8, install it first. +*HunFlair2* is based on Flair 0.14+ and Python 3.9+. If you do not have Python 3.9, install it first. Then, in your favorite virtual environment, simply do: ``` pip install flair diff --git a/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md b/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md index cd839acc15..382600d5be 100644 --- a/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md +++ b/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md @@ -313,7 +313,7 @@ label_type = 'ner' # 3. 말뭉치에서 레이블 사전 만들기 label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False) # 4. 임베딩 초기화하기 -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('glove') ] embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) diff --git a/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md b/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md index e821dc2d17..ec3affe947 100644 --- a/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md +++ b/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md @@ -95,7 +95,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_label_dictionary(label_type=tag_type, add_unk=False) print(tag_dictionary.idx2item) # 4. 임베딩 초기화하기 -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('glove'), ] embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) diff --git a/setup.py b/setup.py index 3c1cc06018..0573896c19 100644 --- a/setup.py +++ b/setup.py @@ -20,5 +20,5 @@ "word-embeddings": ["gensim>=4.2.0", "bpemb>=0.3.5"], }, include_package_data=True, - python_requires=">=3.8", + python_requires=">=3.9", ) diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py index 554ef32777..c1a0b1a791 100644 --- a/tests/embedding_test_utils.py +++ b/tests/embedding_test_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import pytest import torch @@ -9,15 +9,15 @@ class BaseEmbeddingsTest: - embedding_cls: Type[Embeddings[Sentence]] + embedding_cls: type[Embeddings[Sentence]] is_token_embedding: bool is_document_embedding: bool - default_args: Dict[str, Any] - valid_args: List[Dict[str, Any]] = [] - invalid_args: List[Dict[str, Any]] = [] - invalid_names: List[str] = [] + default_args: dict[str, Any] + valid_args: list[dict[str, Any]] = [] + invalid_args: list[dict[str, Any]] = [] + invalid_names: list[str] = [] name_field: Optional[str] = None - weired_texts: List[str] = [ + weired_texts: list[str] = [ "Hybrid mesons , qq ̄ states with an admixture", "typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .", "🤟 🤟 🤟 hüllo", @@ -33,7 +33,7 @@ def create_embedding_from_name(self, name: str): kwargs.pop(self.name_field) return self.embedding_cls(name, **kwargs) # type: ignore[call-arg] - def create_embedding_with_args(self, args: Dict[str, Any]): + def create_embedding_with_args(self, args: dict[str, Any]): kwargs = dict(self.default_args) for k, v in args.items(): kwargs[k] = v diff --git a/tests/embeddings/test_document_transform_word_embeddings.py b/tests/embeddings/test_document_transform_word_embeddings.py index 6a06372723..73567ffbeb 100644 --- a/tests/embeddings/test_document_transform_word_embeddings.py +++ b/tests/embeddings/test_document_transform_word_embeddings.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from flair.embeddings import ( DocumentCNNEmbeddings, @@ -19,7 +19,7 @@ class BaseDocumentsViaWordEmbeddingsTest(BaseEmbeddingsTest): is_document_embedding = True is_token_embedding = False - base_embeddings: List[TokenEmbeddings] = [word, flair_embedding] + base_embeddings: list[TokenEmbeddings] = [word, flair_embedding] def create_embedding_from_name(self, name: str): """Overwrite this method if it is more complex to load an embedding by name.""" @@ -28,7 +28,7 @@ def create_embedding_from_name(self, name: str): kwargs.pop(self.name_field) return self.embedding_cls(name, **kwargs) # type: ignore[call-arg] - def create_embedding_with_args(self, args: Dict[str, Any]): + def create_embedding_with_args(self, args: dict[str, Any]): kwargs = dict(self.default_args) for k, v in args.items(): kwargs[k] = v @@ -63,4 +63,4 @@ class TestDocumentCNNEmbeddings(BaseDocumentsViaWordEmbeddingsTest): class TestDocumentLMEmbeddings(BaseDocumentsViaWordEmbeddingsTest): embedding_cls = DocumentLMEmbeddings base_embeddings = [flair_embedding, flair_embedding_back] - default_args: Dict[str, Any] = {} + default_args: dict[str, Any] = {} diff --git a/tests/embeddings/test_word_embeddings.py b/tests/embeddings/test_word_embeddings.py index 34d0b3b9f7..87f56fec4f 100644 --- a/tests/embeddings/test_word_embeddings.py +++ b/tests/embeddings/test_word_embeddings.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from flair.embeddings import MuseCrosslingualEmbeddings, NILCEmbeddings, WordEmbeddings from tests.embedding_test_utils import BaseEmbeddingsTest @@ -18,7 +18,7 @@ class TestMuseCrosslingualEmbeddings(BaseEmbeddingsTest): embedding_cls = MuseCrosslingualEmbeddings is_token_embedding = True is_document_embedding = False - default_args: Dict[str, Any] = {} + default_args: dict[str, Any] = {} class TestNILCEmbeddings(BaseEmbeddingsTest): diff --git a/tests/model_test_utils.py b/tests/model_test_utils.py index 10aab0831f..b5afd81bfe 100644 --- a/tests/model_test_utils.py +++ b/tests/model_test_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import pytest @@ -11,13 +11,13 @@ class BaseModelTest: - model_cls: Type[Model] + model_cls: type[Model] pretrained_model: Optional[str] = None empty_sentence = Sentence(" ") train_label_type: str - multiclass_prediction_labels: List[str] - model_args: Dict[str, Any] = {} - training_args: Dict[str, Any] = {} + multiclass_prediction_labels: list[str] + model_args: dict[str, Any] = {} + training_args: dict[str, Any] = {} finetune_instead_of_train: bool = False @pytest.fixture() diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index c0ca34bce5..da4de52bfc 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional import pytest from torch.utils.data import Dataset @@ -20,7 +20,7 @@ ) from tests.model_test_utils import BaseModelTest -encoding_strategies: Dict[EncodingStrategy, List[Tuple[str, str]]] = { +encoding_strategies: dict[EncodingStrategy, list[tuple[str, str]]] = { EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)], TypedEntityMask(): [ ("[HEAD-ORG]", "[TAIL-PER]"), @@ -140,7 +140,7 @@ def train_test_sentence(self): return sentence def assert_training_example(self, predicted_training_example): - relations: List[Relation] = predicted_training_example.get_relations("relation") + relations: list[Relation] = predicted_training_example.get_relations("relation") assert len(relations) == 2 # Intel ----founded_by---> Gordon Moore @@ -164,7 +164,7 @@ def assert_training_example(self, predicted_training_example): @staticmethod def check_transformation_correctness( split: Optional[Dataset], - ground_truth: Set[Tuple[str, Tuple[str, ...]]], + ground_truth: set[tuple[str, tuple[str, ...]]], ) -> None: # Ground truth is a set of tuples of (, ) assert split is not None @@ -190,7 +190,7 @@ def test_transform_corpus( embeddings: TransformerDocumentEmbeddings, cross_augmentation: bool, encoding_strategy: EncodingStrategy, - encoded_entity_pairs: List[Tuple[str, str]], + encoded_entity_pairs: list[tuple[str, str]], ) -> None: label_dictionary = corpus.make_label_dictionary("relation") model: RelationClassifier = self.build_model( @@ -200,7 +200,7 @@ def test_transform_corpus( # Check sentence masking and relation label annotation on # training, validation and test dataset (in this test the splits are the same) - ground_truth: Set[Tuple[str, Tuple[str, ...]]] = { + ground_truth: set[tuple[str, tuple[str, ...]]] = { # Entity pair permutations of: "Larry Page and Sergey Brin founded Google ." (f"{encoded_entity_pairs[0][1]} and Sergey Brin founded {encoded_entity_pairs[0][0]} .", ("founded_by",)), (f"Larry Page and {encoded_entity_pairs[1][1]} founded {encoded_entity_pairs[1][0]} .", ("founded_by",)), diff --git a/tests/test_datasets_biomedical.py b/tests/test_datasets_biomedical.py index 0264b08394..c15674eb6b 100644 --- a/tests/test_datasets_biomedical.py +++ b/tests/test_datasets_biomedical.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import List, Optional +from typing import Optional from flair.datasets.biomedical import ( CoNLLWriter, @@ -84,7 +84,7 @@ def test_conll_writer_one_token_multiple_entities2(): def assert_conll_writer_output( dataset: InternalBioNerDataset, - expected_output: List[str], + expected_output: list[str], sentence_splitter: Optional[SentenceSplitter] = None, ): fd, outfile_path = tempfile.mkstemp() diff --git a/tests/test_labels.py b/tests/test_labels.py index 210a215889..099484162c 100644 --- a/tests/test_labels.py +++ b/tests/test_labels.py @@ -1,5 +1,3 @@ -from typing import List - from flair.data import Label, Relation, Sentence, Span @@ -14,7 +12,7 @@ def test_token_tags(): sentence[0].add_label("pos", "pronoun") # check if there are three POS labels with correct text and values - labels: List[Label] = sentence.get_labels("pos") + labels: list[Label] = sentence.get_labels("pos") assert len(labels) == 3 assert labels[0].data_point.text == "I" assert labels[0].value == "pronoun" @@ -24,7 +22,7 @@ def test_token_tags(): assert labels[2].value == "proper noun" # check if there are is one SENTIMENT label with correct text and values - labels: List[Label] = sentence.get_labels("sentiment") + labels: list[Label] = sentence.get_labels("sentiment") assert len(labels) == 1 assert labels[0].data_point.text == "love" assert labels[0].value == "positive" @@ -45,7 +43,7 @@ def test_token_tags(): # remove the pos label from the last word sentence[2].remove_labels("pos") # there should be 2 POS labels left - labels: List[Label] = sentence.get_labels("pos") + labels: list[Label] = sentence.get_labels("pos") assert len(labels) == 2 assert len(sentence[0].get_labels("pos")) == 1 assert len(sentence[1].get_labels("pos")) == 1 @@ -72,7 +70,7 @@ def test_span_tags(): sentence[7:8].add_label("ner", "City") # check if there are three labels with correct text and values - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 3 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" assert labels[0].value == "Organization" @@ -82,7 +80,7 @@ def test_span_tags(): assert labels[2].value == "City" # check if there are two spans with correct text and values - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 2 assert spans[0].text == "Humboldt Universität zu Berlin" assert len(spans[0].get_labels("ner")) == 2 @@ -92,12 +90,12 @@ def test_span_tags(): # now delete the NER tags of "Humboldt-Universität zu Berlin" sentence[0:4].remove_labels("ner") # should be only one NER label left - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 1 assert labels[0].data_point.text == "Berlin" assert labels[0].value == "City" # and only one NER span - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 1 assert spans[0].text == "Berlin" assert spans[0].get_label("ner").value == "City" @@ -111,7 +109,7 @@ def test_different_span_tags(): sentence[7:8].add_label("ner", "City") # check if there are three labels with correct text and values - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 2 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" assert labels[0].value == "Organization" @@ -119,7 +117,7 @@ def test_different_span_tags(): assert labels[1].value == "City" # check if there are two spans with correct text and values - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 2 assert spans[0].text == "Humboldt Universität zu Berlin" assert spans[0].get_label("ner").value == "Organization" @@ -131,22 +129,22 @@ def test_different_span_tags(): # now delete the NER tags of "Humboldt-Universität zu Berlin" sentence[0:4].remove_labels("ner") # should be only one NER label left - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 1 assert labels[0].data_point.text == "Berlin" assert labels[0].value == "City" # and only one NER span - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 1 assert spans[0].text == "Berlin" assert spans[0].get_label("ner").value == "City" # but there is also one orgtype span and label - labels: List[Label] = sentence.get_labels("orgtype") + labels: list[Label] = sentence.get_labels("orgtype") assert len(labels) == 1 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" assert labels[0].value == "University" # and only one NER span - spans: List[Span] = sentence.get_spans("orgtype") + spans: list[Span] = sentence.get_spans("orgtype") assert len(spans) == 1 assert spans[0].text == "Humboldt Universität zu Berlin" assert spans[0].get_label("orgtype").value == "University" @@ -154,7 +152,7 @@ def test_different_span_tags(): # let's add the NER tag back sentence[0:4].add_label("ner", "Organization") # check if there are three labels with correct text and values - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") print(labels) assert len(labels) == 2 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" @@ -163,7 +161,7 @@ def test_different_span_tags(): assert labels[1].value == "City" # check if there are two spans with correct text and values - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 2 assert spans[0].text == "Humboldt Universität zu Berlin" assert spans[0].get_label("ner").value == "Organization" @@ -194,17 +192,17 @@ def test_relation_tags(): Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition") # there should be two relation labels - labels: List[Label] = sentence.get_labels("rel") + labels: list[Label] = sentence.get_labels("rel") assert len(labels) == 2 assert labels[0].value == "located in" assert labels[1].value == "university of" # there should be one syntactic labels - labels: List[Label] = sentence.get_labels("syntactic") + labels: list[Label] = sentence.get_labels("syntactic") assert len(labels) == 1 # there should be two relations, one with two and one with one label - relations: List[Relation] = sentence.get_relations("rel") + relations: list[Relation] = sentence.get_relations("rel") assert len(relations) == 2 assert len(relations[0].labels) == 1 assert len(relations[1].labels) == 2 diff --git a/tests/test_tokenize_sentence.py b/tests/test_tokenize_sentence.py index fd049b642e..7fd03ac6ba 100644 --- a/tests/test_tokenize_sentence.py +++ b/tests/test_tokenize_sentence.py @@ -1,5 +1,3 @@ -from typing import List - import pytest import flair @@ -492,5 +490,5 @@ def test_line_separator_is_ignored(): assert Sentence(with_separator).to_original_text() == Sentence(without_separator).to_original_text() -def no_op_tokenizer(text: str) -> List[str]: +def no_op_tokenizer(text: str) -> list[str]: return [text] From a30fd78bfbcdd2eb8d203629dd7b827e700bdbd6 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 25 Oct 2024 15:44:04 +0200 Subject: [PATCH 011/333] fix unrequred type ignore statements --- flair/embeddings/document.py | 2 +- flair/embeddings/token.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 28867d889a..8f66a198ed 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -371,7 +371,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]): sentence_tensor = self.word_reprojection_map(sentence_tensor) # push through RNN - packed = pack_padded_sequence(sentence_tensor, lengths, enforce_sorted=False, batch_first=True) # type: ignore[arg-type] + packed = pack_padded_sequence(sentence_tensor, lengths, enforce_sorted=False, batch_first=True) rnn_out, hidden = self.rnn(packed) outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 700eaf4c45..3d95c8ee0b 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -513,7 +513,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]): character_embeddings = self.char_embedding(chars).transpose(0, 1) - packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length) # type: ignore[arg-type] + packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length) lstm_out, self.hidden = self.char_rnn(packed) From 783ebc7c5db43f2d0e3ca0d6473556da34a4e70e Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 25 Oct 2024 19:00:09 +0200 Subject: [PATCH 012/333] fix transformers switching attention implementation --- flair/embeddings/transformer.py | 5 +++++ pyproject.toml | 1 + 2 files changed, 6 insertions(+) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index d09ed33699..245d528e5d 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1353,6 +1353,11 @@ def from_params(cls, params): def to_params(self): config_dict = self.model.config.to_dict() + + # do not switch the attention implementation upon reload. + config_dict["attn_implementation"] = self.model.config._attn_implementation + del config_dict["_attn_implementation_autoset"] + super_params = super().to_params() # those parameters are only from the super class and will be recreated in the constructor. diff --git a/pyproject.toml b/pyproject.toml index 9f4c5a7535..9711794abb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ filterwarnings = [ 'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub "ignore:`torch.cuda.amp.GradScaler", # GradScaler changes in torch 2.3.0 but we want to be backwards compatible. "ignore:`clean_up_tokenization_spaces` was not set", # Default behavior changes in transformers v4.45, raising irrelevant FutureWarning for serialized models. + "ignore:1Torch was not compiled with flash attention", # You might want to install flash attention, but you don't have to. ] markers = [ "integration", From d932baf83cad1081d8eb8f2ee9715d947a6c466a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Fri, 25 Oct 2024 23:05:48 +0200 Subject: [PATCH 013/333] Fix formatting errors. --- flair/embeddings/transformer.py | 11 ++++++----- .../test_transformer_document_embeddings.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 8a1e76cd72..8ba17b1fec 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -8,18 +8,18 @@ from abc import abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast import torch import transformers from packaging.version import Version from torch.jit import ScriptModule from transformers import ( + CONFIG_MAPPING, AutoConfig, AutoFeatureExtractor, AutoModel, AutoTokenizer, - CONFIG_MAPPING, FeatureExtractionMixin, LayoutLMTokenizer, LayoutLMTokenizerFast, @@ -32,8 +32,8 @@ from transformers.utils import PaddingStrategy import flair -from flair.data import log, Sentence, Token -from flair.embeddings.base import DocumentEmbeddings, Embeddings, register_embeddings, TokenEmbeddings +from flair.data import Sentence, Token, log +from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings SENTENCE_BOUNDARY_TAG: str = "[FLERT]" @@ -191,6 +191,7 @@ def fill_mean_token_embeddings( return all_token_embeddings + @torch.jit.script_if_tracing def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1] @@ -1130,7 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if peft_config is not None: # add adapters for finetuning try: - from peft import get_peft_model, prepare_model_for_kbit_training, TaskType + from peft import TaskType, get_peft_model, prepare_model_for_kbit_training except ImportError: log.error("You cannot use the PEFT finetuning without peft being installed") raise diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index bf202110b1..7402b2b467 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -4,7 +4,6 @@ from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier - from tests.embedding_test_utils import BaseEmbeddingsTest From 3c3620061c4a3578a536bf9c25ea78c745e2f19b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Tue, 29 Oct 2024 21:01:30 +0100 Subject: [PATCH 014/333] Use torch.allclose for comparing tensors in BaseEmbeddingsTest. --- tests/embedding_test_utils.py | 28 +++++++++++++++---- .../test_transformer_document_embeddings.py | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py index 554ef32777..933dafad73 100644 --- a/tests/embedding_test_utils.py +++ b/tests/embedding_test_utils.py @@ -19,11 +19,11 @@ class BaseEmbeddingsTest: name_field: Optional[str] = None weired_texts: List[str] = [ "Hybrid mesons , qq ̄ states with an admixture", - "typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .", + "typical proportionalities of \u223c 1nmV \u2212 1 [ 3,4 ] .", "🤟 🤟 🤟 hüllo", "🤟hallo 🤟 🤟 🤟 🤟", "🤟", - "\uF8F9", + "\uf8f9", ] def create_embedding_from_name(self, name: str): @@ -150,9 +150,17 @@ def test_embeddings_stay_the_same_after_saving_and_loading(self, args): if self.is_token_embedding: for token_old, token_new in zip(sentence_old, sentence_new): - assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all() + assert torch.allclose( + token_old.get_embedding(names_old), + token_new.get_embedding(names_new), + atol=1e-6, + ) if self.is_document_embedding: - assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all() + assert torch.allclose( + sentence_old.get_embedding(names_old), + sentence_new.get_embedding(names_new), + atol=1e-6, + ) def test_default_embeddings_stay_the_same_after_saving_and_loading(self): embeddings = self.create_embedding_with_args(self.default_args) @@ -176,9 +184,17 @@ def test_default_embeddings_stay_the_same_after_saving_and_loading(self): if self.is_token_embedding: for token_old, token_new in zip(sentence_old, sentence_new): - assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all() + assert torch.allclose( + token_old.get_embedding(names_old), + token_new.get_embedding(names_new), + atol=1e-6, + ) if self.is_document_embedding: - assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all() + assert torch.allclose( + sentence_old.get_embedding(names_old), + sentence_new.get_embedding(names_new), + atol=1e-6, + ) def test_embeddings_load_in_eval_mode(self): embeddings = self.create_embedding_with_args(self.default_args) diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index 7402b2b467..f0f6389b7d 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -44,7 +44,7 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): @pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) def test_cls_pooling(cls_pooling): embeddings = TransformerDocumentEmbeddings( - model="xlm-roberta-base", + model="distilbert-base-uncased", layers="-1", cls_pooling=cls_pooling, allow_long_sentences=True, From e71797c4f77f71519de72f304a6615c712b39f9c Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Mon, 23 Sep 2024 22:54:48 -0600 Subject: [PATCH 015/333] GH-3429: Add experimental multi GPU support --- flair/__init__.py | 6 + flair/distributed_utils.py | 49 +++++++++ flair/nn/model.py | 3 + flair/trainers/trainer.py | 220 ++++++++++++++++++++----------------- 4 files changed, 176 insertions(+), 102 deletions(-) create mode 100644 flair/distributed_utils.py diff --git a/flair/__init__.py b/flair/__init__.py index 341f630e43..55bd6d00cf 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -33,6 +33,12 @@ else: device = torch.device("cpu") +distributed = False +"""Experimental flag to indicate multiple GPUs are in use. + +Set by `launch_distributed` -- do not set manually. +""" + # global variable: version __version__ = "0.14.0" """The current version of the flair library installed.""" diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py new file mode 100644 index 0000000000..596570b7ce --- /dev/null +++ b/flair/distributed_utils.py @@ -0,0 +1,49 @@ +import logging +import os + +import torch +import torch.multiprocessing as mp +from torch.distributed import destroy_process_group, init_process_group + +import flair + +log = logging.getLogger("flair") + + +def launch_distributed(fp, *args): + """Executes the function fp(*args) on multiple GPUs (all local GPUs)""" + world_size = torch.cuda.device_count() + log.info(f"Launching {world_size} distributed processes") + mp.spawn(entrypoint, args=(world_size, fp, *args), nprocs=world_size) + + +def entrypoint(rank, world_size, fp, *args): + ddp_setup(rank, world_size) + fp(*args) + destroy_process_group() + + +def ddp_setup(rank: int, world_size: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + init_process_group(backend="nccl", rank=rank, world_size=world_size) + flair.distributed = True + flair.device = torch.device(rank) + torch.cuda.set_device(flair.device) + + +def is_main_process() -> bool: + """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu""" + if flair.distributed: + return flair.device.index == 0 + else: + return True + + +class DistributedModel(torch.nn.parallel.DistributedDataParallel): + """DistributedDataParallel, but redirects access to methods and attributes to the original Model""" + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) diff --git a/flair/nn/model.py b/flair/nn/model.py index bf13baf2f1..1710e32c76 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -17,6 +17,7 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset +from flair.distributed_utils import is_main_process from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -118,6 +119,8 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: model_file: the model file checkpoint: currently unused. """ + if not is_main_process(): + return model_state = self._get_state_dict() # write out a "model card" if one is set diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 2f32b54c01..05c376ec20 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -11,12 +11,14 @@ import torch from torch.optim.sgd import SGD +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import ConcatDataset import flair import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader +from flair.distributed_utils import DistributedModel, is_main_process from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -420,6 +422,8 @@ def train_custom( base_path=base_path, ).attach_to(self) + if flair.distributed: + self.model = DistributedModel(self.model, device_ids=[flair.device.index]) # === END BLOCK: ACTIVATE PLUGINS === # # derive parameters the function was called with (or defaults) @@ -509,36 +513,37 @@ def train_custom( else "model from best epoch (best-model.pt)" ) - log_line(log) - log.info(f'Model: "{self.model}"') - log_line(log) - log.info(f"{self.corpus}") - log_line(log) - log.info(f"Train: {len(train_data)} sentences") - log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") - log_line(log) - log.info("Training Params:") - log.info( - f' - learning_rate: "{learning_rate}" ' - f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' - ) - log.info(f' - mini_batch_size: "{mini_batch_size}"') - log.info(f' - max_epochs: "{max_epochs}"') - log.info(f' - shuffle: "{shuffle}"') - log_line(log) - log.info("Plugins:") - for plugin in plugins: - log.info(" - " + str(plugin)) - log_line(log) - log.info(f"Final evaluation on {final_eval_info}") - log.info(f' - metric: "{main_evaluation_metric}"') - log_line(log) - log.info("Computation:") - log.info(f" - compute on device: {flair.device}") - log.info(f" - embedding storage: {embeddings_storage_mode}") - log_line(log) - log.info(f'Model training base path: "{base_path}"') - log_line(log) + if is_main_process(): + log_line(log) + log.info(f'Model: "{self.model}"') + log_line(log) + log.info(f"{self.corpus}") + log_line(log) + log.info(f"Train: {len(train_data)} sentences") + log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") + log_line(log) + log.info("Training Params:") + log.info( + f' - learning_rate: "{learning_rate}" ' + f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' + ) + log.info(f' - mini_batch_size: "{mini_batch_size}"') + log.info(f' - max_epochs: "{max_epochs}"') + log.info(f' - shuffle: "{shuffle}"') + log_line(log) + log.info("Plugins:") + for plugin in plugins: + log.info(" - " + str(plugin)) + log_line(log) + log.info(f"Final evaluation on {final_eval_info}") + log.info(f' - metric: "{main_evaluation_metric}"') + log_line(log) + log.info("Computation:") + log.info(f" - compute on device: {flair.device}") + log.info(f" - embedding storage: {embeddings_storage_mode}") + log_line(log) + log.info(f'Model training base path: "{base_path}"') + log_line(log) # At any point you can hit Ctrl + C to break out of training early. try: @@ -560,12 +565,21 @@ def train_custom( if not shuffle_first_epoch and epoch == 1: shuffle_data_this_epoch = False - batch_loader = DataLoader( - train_data, - batch_size=mini_batch_size, - shuffle=shuffle_data_this_epoch, - sampler=sampler, - ) + if flair.distributed: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=False, + sampler=DistributedSampler(train_data, shuffle=shuffle_data_this_epoch), + ) + batch_loader.sampler.set_epoch(epoch) + else: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=shuffle_data_this_epoch, + sampler=sampler, + ) self.model.train() @@ -682,49 +696,50 @@ def train_custom( # Determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False - validation_scores: tuple - - for evaluation_split, evaluation_split_data in evaluation_splits.items(): - eval_result = self.model.evaluate( - evaluation_split_data, - out_path=base_path / f"{evaluation_split}.tsv", - mini_batch_size=eval_batch_size, - exclude_labels=exclude_labels, - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - embedding_storage_mode=embeddings_storage_mode, - gold_label_type=self.model.label_type, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - ) + validation_scores = () + + if is_main_process(): + for evaluation_split, evaluation_split_data in evaluation_splits.items(): + eval_result = self.model.evaluate( + evaluation_split_data, + out_path=base_path / f"{evaluation_split}.tsv", + mini_batch_size=eval_batch_size, + exclude_labels=exclude_labels, + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + embedding_storage_mode=embeddings_storage_mode, + gold_label_type=self.model.label_type, + gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, + ) - # log results - log.info( - f"{evaluation_split.upper()} : loss {eval_result.loss}" - f" - {main_evaluation_metric[1]}" - f" ({main_evaluation_metric[0]})" - f" {round(eval_result.main_score, 4)}" - ) + # log results + log.info( + f"{evaluation_split.upper()} : loss {eval_result.loss}" + f" - {main_evaluation_metric[1]}" + f" ({main_evaluation_metric[0]})" + f" {round(eval_result.main_score, 4)}" + ) - # depending on memory mode, embeddings are moved to CPU, GPU or deleted - store_embeddings(evaluation_split_data, embeddings_storage_mode) + # depending on memory mode, embeddings are moved to CPU, GPU or deleted + store_embeddings(evaluation_split_data, embeddings_storage_mode) - self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) + self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) - # use DEV split to determine if this is the best model so far - if determine_best_epoch_using_dev_score and evaluation_split == "dev": - validation_scores = eval_result.main_score, eval_result.loss + # use DEV split to determine if this is the best model so far + if determine_best_epoch_using_dev_score and evaluation_split == "dev": + validation_scores = eval_result.main_score, eval_result.loss - if eval_result.main_score > best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = eval_result.main_score + if eval_result.main_score > best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = eval_result.main_score - # if not using DEV score, determine best model using train loss - if not determine_best_epoch_using_dev_score: - validation_scores = (train_loss,) + # if not using DEV score, determine best model using train loss + if not determine_best_epoch_using_dev_score: + validation_scores = (train_loss,) - if epoch_train_loss < best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = train_loss + if epoch_train_loss < best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = train_loss # - LossFilePlugin -> somehow prints all relevant metrics # - AnnealPlugin -> scheduler step @@ -776,41 +791,42 @@ def train_custom( self.dispatch("_training_finally") # test best model if test data is present - if self.corpus.test and not train_with_test: - log_line(log) + if is_main_process(): + if self.corpus.test and not train_with_test: + log_line(log) - self.model.eval() + self.model.eval() - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - else: - log.info("Testing using last state of model ...") - - test_results = self.model.evaluate( - self.corpus.test, - gold_label_type=self.model.label_type, - mini_batch_size=eval_batch_size, - out_path=base_path / "test.tsv", - embedding_storage_mode="none", - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - return_loss=False, - ) + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + else: + log.info("Testing using last state of model ...") + + test_results = self.model.evaluate( + self.corpus.test, + gold_label_type=self.model.label_type, + mini_batch_size=eval_batch_size, + out_path=base_path / "test.tsv", + embedding_storage_mode="none", + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + exclude_labels=exclude_labels, + return_loss=False, + ) - log.info(test_results.detailed_results) - log_line(log) + log.info(test_results.detailed_results) + log_line(log) - # get and return the final test score of best model - self.return_values["test_score"] = test_results.main_score + # get and return the final test score of best model + self.return_values["test_score"] = test_results.main_score - else: - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - self.return_values["test_score"] = 0 - log.info("Test data not provided setting final score to 0") + else: + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + self.return_values["test_score"] = 0 + log.info("Test data not provided setting final score to 0") # MetricHistoryPlugin -> stores the loss history in return_values self.dispatch("after_training") From f777cc6dac88fc6db3c277756257ef175ebdd08f Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Mon, 30 Sep 2024 16:45:25 -0600 Subject: [PATCH 016/333] GH-3429: Move process spawning inside `.train`; WIP sync gradients --- flair/__init__.py | 6 - flair/distributed_utils.py | 58 ++++--- flair/nn/model.py | 7 +- flair/trainers/trainer.py | 328 +++++++++++++++++++------------------ 4 files changed, 212 insertions(+), 187 deletions(-) diff --git a/flair/__init__.py b/flair/__init__.py index 55bd6d00cf..341f630e43 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -33,12 +33,6 @@ else: device = torch.device("cpu") -distributed = False -"""Experimental flag to indicate multiple GPUs are in use. - -Set by `launch_distributed` -- do not set manually. -""" - # global variable: version __version__ = "0.14.0" """The current version of the flair library installed.""" diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index 596570b7ce..39ec8c4ce0 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -1,49 +1,63 @@ import logging import os +from multiprocessing.connection import Connection +from typing import Callable +import numpy as np import torch import torch.multiprocessing as mp from torch.distributed import destroy_process_group, init_process_group import flair +from flair.class_utils import T log = logging.getLogger("flair") -def launch_distributed(fp, *args): - """Executes the function fp(*args) on multiple GPUs (all local GPUs)""" - world_size = torch.cuda.device_count() - log.info(f"Launching {world_size} distributed processes") - mp.spawn(entrypoint, args=(world_size, fp, *args), nprocs=world_size) - +def launch_distributed(fn, *args, **kwargs): + """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). -def entrypoint(rank, world_size, fp, *args): - ddp_setup(rank, world_size) - fp(*args) + Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process + """ + world_size = torch.cuda.device_count() + log.info(f"Launching {world_size} processes") + parent_conn, child_conn = mp.Pipe() + mp.spawn(_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) + return_value = parent_conn.recv() + return return_value + + +def _entrypoint(rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict) -> None: + """Lifecycle of a process -- setup, run, cleanup.""" + log.info(f"Started process on rank={rank}") + _ddp_setup(rank, world_size) + return_value = fn(*args, **kwargs) + if is_main_process(): + child_conn.send(return_value) destroy_process_group() -def ddp_setup(rank: int, world_size: int) -> None: +def _ddp_setup(rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" - init_process_group(backend="nccl", rank=rank, world_size=world_size) - flair.distributed = True flair.device = torch.device(rank) torch.cuda.set_device(flair.device) + init_process_group(backend="nccl", rank=rank, world_size=world_size) def is_main_process() -> bool: - """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu""" - if flair.distributed: - return flair.device.index == 0 + """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 else: return True -class DistributedModel(torch.nn.parallel.DistributedDataParallel): - """DistributedDataParallel, but redirects access to methods and attributes to the original Model""" - def __getattr__(self, name): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.module, name) +def aggregate_if_distributed(value: T, aggregation_fn: Callable = np.mean) -> T: + """Gathers value from each process and returns the aggregated value according to the supplied function.""" + if torch.distributed.is_initialized(): + gathered_values = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_values, value) + return aggregation_fn(gathered_values) + else: + return value diff --git a/flair/nn/model.py b/flair/nn/model.py index 1710e32c76..c903c4000a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -17,7 +17,6 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset -from flair.distributed_utils import is_main_process from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -48,6 +47,10 @@ def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]: """ raise NotImplementedError + def forward(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + """Wraps forward_loss to maintain compatibility with hooks.""" + return self.forward_loss(data_points) + @abstractmethod def evaluate( self, @@ -119,8 +122,6 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: model_file: the model file checkpoint: currently unused. """ - if not is_main_process(): - return model_state = self._get_state_dict() # write out a "model card" if one is set diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 05c376ec20..ae5e30c712 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -7,9 +7,12 @@ import warnings from inspect import signature from pathlib import Path -from typing import Optional, Union +from queue import Queue +from typing import Optional, Tuple, Type, Union +import numpy as np import torch +from torch.nn.parallel import DistributedDataParallel from torch.optim.sgd import SGD from torch.utils.data import DistributedSampler from torch.utils.data.dataset import ConcatDataset @@ -18,7 +21,7 @@ import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader -from flair.distributed_utils import DistributedModel, is_main_process +from flair.distributed_utils import aggregate_if_distributed, is_main_process, launch_distributed from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -165,6 +168,8 @@ def train( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, + # scaling + multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -200,7 +205,12 @@ def train( "kwargs", ]: local_variables.pop(var) - return self.train_custom(**local_variables, **kwargs) + + if multi_gpu: + self._event_queue = None # Each process will make its own queue rather than share + return launch_distributed(self.train_custom, **local_variables, **kwargs) + else: + return self.train_custom(**local_variables, **kwargs) def fine_tune( self, @@ -239,8 +249,9 @@ def fine_tune( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # scaling use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -254,47 +265,21 @@ def fine_tune( if attach_default_scheduler: plugins.append(LinearSchedulerPlugin(warmup_fraction=warmup_fraction)) - return self.train_custom( - base_path=base_path, - # training parameters - learning_rate=learning_rate, - decoder_learning_rate=decoder_learning_rate, - mini_batch_size=mini_batch_size, - eval_batch_size=eval_batch_size, - mini_batch_chunk_size=mini_batch_chunk_size, - max_epochs=max_epochs, - optimizer=optimizer, - train_with_dev=train_with_dev, - train_with_test=train_with_test, - reduce_transformer_vocab=reduce_transformer_vocab, - # evaluation and monitoring - main_evaluation_metric=main_evaluation_metric, - monitor_test=monitor_test, - monitor_train_sample=monitor_train_sample, - use_final_model_for_eval=use_final_model_for_eval, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - # sampling and shuffling - sampler=sampler, - shuffle=shuffle, - shuffle_first_epoch=shuffle_first_epoch, - # evaluation and monitoring - embeddings_storage_mode=embeddings_storage_mode, - epoch=epoch, - # when and what to save - save_final_model=save_final_model, - save_optimizer_state=save_optimizer_state, - save_model_each_k_epochs=save_model_each_k_epochs, - # logging parameters - create_file_logs=create_file_logs, - create_loss_file=create_loss_file, - write_weights=write_weights, - # amp - use_amp=use_amp, - # plugins - plugins=plugins, - **kwargs, - ) + # call self.train_custom with all parameters (minus the ones specific to the LinearSchedulerPlugin) + local_variables = locals() + for var in [ + "self", + "warmup_fraction", + "attach_default_scheduler", + "kwargs", + ]: + local_variables.pop(var) + + if multi_gpu: + self._event_queue = None + return launch_distributed(self.train_custom, **local_variables, **kwargs) + else: + return self.train_custom(**local_variables, **kwargs) def train_custom( self, @@ -333,8 +318,9 @@ def train_custom( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # scaling use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, **kwargs, @@ -377,6 +363,7 @@ def train_custom( create_file_logs: If True, logging output is written to a file create_loss_file: If True, a loss file logging output is created use_amp: If True, uses the torch automatic mixed precision + multi_gpu: If True, uses all available GPUs write_weights: If True, write weights to weights.txt on each batch logging event. plugins: Any additional plugins you want to pass to the trainer **kwargs: Additional arguments, for instance for the optimizer @@ -422,8 +409,11 @@ def train_custom( base_path=base_path, ).attach_to(self) - if flair.distributed: - self.model = DistributedModel(self.model, device_ids=[flair.device.index]) + if multi_gpu: + self.model.to(flair.device) + self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index]) + self._event_queue = Queue() # Each process uses its own queue rather than share + log.disabled = not is_main_process() # Disable logging in distributed mode for all but the main process # === END BLOCK: ACTIVATE PLUGINS === # # derive parameters the function was called with (or defaults) @@ -512,38 +502,38 @@ def train_custom( if use_final_model_for_eval else "model from best epoch (best-model.pt)" ) - - if is_main_process(): - log_line(log) - log.info(f'Model: "{self.model}"') - log_line(log) - log.info(f"{self.corpus}") - log_line(log) - log.info(f"Train: {len(train_data)} sentences") - log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") - log_line(log) - log.info("Training Params:") - log.info( - f' - learning_rate: "{learning_rate}" ' - f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' - ) - log.info(f' - mini_batch_size: "{mini_batch_size}"') - log.info(f' - max_epochs: "{max_epochs}"') - log.info(f' - shuffle: "{shuffle}"') - log_line(log) - log.info("Plugins:") - for plugin in plugins: - log.info(" - " + str(plugin)) - log_line(log) - log.info(f"Final evaluation on {final_eval_info}") - log.info(f' - metric: "{main_evaluation_metric}"') - log_line(log) - log.info("Computation:") - log.info(f" - compute on device: {flair.device}") - log.info(f" - embedding storage: {embeddings_storage_mode}") - log_line(log) - log.info(f'Model training base path: "{base_path}"') - log_line(log) + computation_device_info = f"{torch.cuda.device_count()} GPUs" if multi_gpu else flair.device + + log_line(log) + log.info(f'Model: "{self.model}"') + log_line(log) + log.info(f"{self.corpus}") + log_line(log) + log.info(f"Train: {len(train_data)} sentences") + log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") + log_line(log) + log.info("Training Params:") + log.info( + f' - learning_rate: "{learning_rate}" ' + f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' + ) + log.info(f' - mini_batch_size: "{mini_batch_size}"') + log.info(f' - max_epochs: "{max_epochs}"') + log.info(f' - shuffle: "{shuffle}"') + log_line(log) + log.info("Plugins:") + for plugin in plugins: + log.info(" - " + str(plugin)) + log_line(log) + log.info(f"Final evaluation on {final_eval_info}") + log.info(f' - metric: "{main_evaluation_metric}"') + log_line(log) + log.info("Computation:") + log.info(f" - compute on device: {computation_device_info}") + log.info(f" - embedding storage: {embeddings_storage_mode}") + log_line(log) + log.info(f'Model training base path: "{base_path}"') + log_line(log) # At any point you can hit Ctrl + C to break out of training early. try: @@ -565,14 +555,14 @@ def train_custom( if not shuffle_first_epoch and epoch == 1: shuffle_data_this_epoch = False - if flair.distributed: + if multi_gpu: batch_loader = DataLoader( train_data, batch_size=mini_batch_size, shuffle=False, sampler=DistributedSampler(train_data, shuffle=shuffle_data_this_epoch), ) - batch_loader.sampler.set_epoch(epoch) + batch_loader.sampler.set_epoch(epoch - 1) else: batch_loader = DataLoader( train_data, @@ -617,7 +607,10 @@ def train_custom( for batch_step in batch_steps: # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): - loss, datapoint_count = self.model.forward_loss(batch_step) + if multi_gpu: + loss, datapoint_count = self.ddp_model(batch_step) + else: + loss, datapoint_count = self.model.forward_loss(batch_step) batch_train_samples += datapoint_count batch_train_loss += loss.item() @@ -663,8 +656,11 @@ def train_custom( if epoch_train_samples > 0 else epoch_train_samples / (batch_no + 1) ) + intermittent_loss = aggregate_if_distributed(intermittent_loss) current_time = time.time() + samples_per_second = epoch_train_samples / (current_time - epoch_start_time) + samples_per_second = aggregate_if_distributed(samples_per_second, np.sum) lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count) log.info( @@ -672,7 +668,7 @@ def train_custom( f" - iter {batch_no + 1}/{len(batch_loader)}" f" - loss {intermittent_loss:.8f}" f" - time (sec): {(current_time - epoch_start_time):.2f}" - f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}" + f" - samples/sec: {samples_per_second:.2f}" f"{lr_info}{momentum_info}" ) @@ -681,6 +677,7 @@ def train_custom( self.dispatch("after_training_batch", **batch_kw) train_loss = epoch_train_loss / epoch_train_samples + train_loss = aggregate_if_distributed(train_loss) self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch)) total_train_samples += epoch_train_samples @@ -696,50 +693,49 @@ def train_custom( # Determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False - validation_scores = () - - if is_main_process(): - for evaluation_split, evaluation_split_data in evaluation_splits.items(): - eval_result = self.model.evaluate( - evaluation_split_data, - out_path=base_path / f"{evaluation_split}.tsv", - mini_batch_size=eval_batch_size, - exclude_labels=exclude_labels, - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - embedding_storage_mode=embeddings_storage_mode, - gold_label_type=self.model.label_type, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - ) + validation_scores: tuple = () + + for evaluation_split, evaluation_split_data in evaluation_splits.items(): + eval_result = self.model.evaluate( + evaluation_split_data, + out_path=base_path / f"{evaluation_split}.tsv", + mini_batch_size=eval_batch_size, + exclude_labels=exclude_labels, + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + embedding_storage_mode=embeddings_storage_mode, + gold_label_type=self.model.label_type, + gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, + ) - # log results - log.info( - f"{evaluation_split.upper()} : loss {eval_result.loss}" - f" - {main_evaluation_metric[1]}" - f" ({main_evaluation_metric[0]})" - f" {round(eval_result.main_score, 4)}" - ) + # log results + log.info( + f"{evaluation_split.upper()} : loss {eval_result.loss}" + f" - {main_evaluation_metric[1]}" + f" ({main_evaluation_metric[0]})" + f" {round(eval_result.main_score, 4)}" + ) - # depending on memory mode, embeddings are moved to CPU, GPU or deleted - store_embeddings(evaluation_split_data, embeddings_storage_mode) + # depending on memory mode, embeddings are moved to CPU, GPU or deleted + store_embeddings(evaluation_split_data, embeddings_storage_mode) - self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) + self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) - # use DEV split to determine if this is the best model so far - if determine_best_epoch_using_dev_score and evaluation_split == "dev": - validation_scores = eval_result.main_score, eval_result.loss + # use DEV split to determine if this is the best model so far + if determine_best_epoch_using_dev_score and evaluation_split == "dev": + validation_scores = eval_result.main_score, eval_result.loss - if eval_result.main_score > best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = eval_result.main_score + if eval_result.main_score > best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = eval_result.main_score - # if not using DEV score, determine best model using train loss - if not determine_best_epoch_using_dev_score: - validation_scores = (train_loss,) + # if not using DEV score, determine best model using train loss + if not determine_best_epoch_using_dev_score: + validation_scores = (train_loss,) - if epoch_train_loss < best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = train_loss + if train_loss < best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = train_loss # - LossFilePlugin -> somehow prints all relevant metrics # - AnnealPlugin -> scheduler step @@ -752,14 +748,14 @@ def train_custom( if save_best_model and current_epoch_has_best_model_so_far: log.info("saving best model") - self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state) # - SWAPlugin -> restores SGD weights from SWA self.dispatch("after_training_loop") # if we do not use dev data for model selection, save final model if save_final_model: - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) @@ -769,7 +765,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except TrainingInterrupt as exc: @@ -780,7 +776,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except Exception: @@ -791,42 +787,41 @@ def train_custom( self.dispatch("_training_finally") # test best model if test data is present - if is_main_process(): - if self.corpus.test and not train_with_test: - log_line(log) + if self.corpus.test and not train_with_test: + log_line(log) - self.model.eval() + self.model.eval() - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - else: - log.info("Testing using last state of model ...") - - test_results = self.model.evaluate( - self.corpus.test, - gold_label_type=self.model.label_type, - mini_batch_size=eval_batch_size, - out_path=base_path / "test.tsv", - embedding_storage_mode="none", - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - return_loss=False, - ) + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self._load_model(base_path / "best-model.pt") + else: + log.info("Testing using last state of model ...") + + test_results = self.model.evaluate( + self.corpus.test, + gold_label_type=self.model.label_type, + mini_batch_size=eval_batch_size, + out_path=base_path / "test.tsv", + embedding_storage_mode="none", + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + exclude_labels=exclude_labels, + return_loss=False, + ) - log.info(test_results.detailed_results) - log_line(log) + log.info(test_results.detailed_results) + log_line(log) - # get and return the final test score of best model - self.return_values["test_score"] = test_results.main_score + # get and return the final test score of best model + self.return_values["test_score"] = test_results.main_score - else: - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - self.return_values["test_score"] = 0 - log.info("Test data not provided setting final score to 0") + else: + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self._load_model(base_path / "best-model.pt") + self.return_values["test_score"] = 0 + log.info("Test data not provided setting final score to 0") # MetricHistoryPlugin -> stores the loss history in return_values self.dispatch("after_training") @@ -840,7 +835,9 @@ def train_custom( def _get_current_lr_and_momentum(self, batch_count): current_learning_rate = [group["lr"] for group in self.optimizer.param_groups] + current_learning_rate = [aggregate_if_distributed(m) for m in current_learning_rate] momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups] + momentum = [aggregate_if_distributed(m) for m in momentum] lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate]) momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum]) self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count)) @@ -921,3 +918,22 @@ def _initialize_model_card(self, **training_parameters): def _record(self, metric): self.dispatch("metric_recorded", metric) + + def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: + """Saves the current model. Safe to call from a distributed context. + + Args: + model_file: the model file + checkpoint: currently unused. + """ + if is_main_process(): + self.model.save(model_file, checkpoint) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete + + def _load_model(self, model_file: Union[str, Path]) -> None: + """Loads the model from the given file into the current state. Safe to call from a distributed context.""" + self.model.load_state_dict(self.model.load(model_file).state_dict()) + if torch.distributed.is_initialized(): + self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index]) From c91809379402f2c359e1aff22b51adfb7539d111 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Fri, 25 Oct 2024 02:32:13 -0700 Subject: [PATCH 017/333] GH-3429: Redirect forward_loss; User calls launch; Handle plugins --- flair/distributed_utils.py | 31 +++-- flair/nn/model.py | 7 +- flair/trainers/plugins/base.py | 9 ++ .../plugins/functional/checkpoints.py | 4 + .../plugins/functional/linear_scheduler.py | 5 +- .../functional/reduce_transformer_vocab.py | 4 + .../plugins/functional/weight_extractor.py | 4 + .../plugins/loggers/clearml_logger.py | 4 + flair/trainers/plugins/loggers/log_file.py | 4 + flair/trainers/plugins/loggers/loss_file.py | 4 + .../plugins/loggers/metric_history.py | 4 + flair/trainers/plugins/loggers/tensorboard.py | 4 + flair/trainers/plugins/loggers/wandb.py | 4 + flair/trainers/trainer.py | 114 ++++++++++++------ 14 files changed, 146 insertions(+), 56 deletions(-) diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index 39ec8c4ce0..d391251b31 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -9,7 +9,6 @@ from torch.distributed import destroy_process_group, init_process_group import flair -from flair.class_utils import T log = logging.getLogger("flair") @@ -17,24 +16,30 @@ def launch_distributed(fn, *args, **kwargs): """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). + If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune. + Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process """ world_size = torch.cuda.device_count() log.info(f"Launching {world_size} processes") parent_conn, child_conn = mp.Pipe() - mp.spawn(_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) + mp.spawn(_process_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) return_value = parent_conn.recv() return return_value -def _entrypoint(rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict) -> None: - """Lifecycle of a process -- setup, run, cleanup.""" +def _process_entrypoint( + rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict +) -> None: + """Lifecycle of a distributed process -- setup, run, cleanup.""" log.info(f"Started process on rank={rank}") - _ddp_setup(rank, world_size) - return_value = fn(*args, **kwargs) - if is_main_process(): - child_conn.send(return_value) - destroy_process_group() + try: + _ddp_setup(rank, world_size) + return_value = fn(*args, **kwargs) + if is_main_process(): + child_conn.send(return_value) + finally: + destroy_process_group() def _ddp_setup(rank: int, world_size: int) -> None: @@ -53,11 +58,11 @@ def is_main_process() -> bool: return True -def aggregate_if_distributed(value: T, aggregation_fn: Callable = np.mean) -> T: - """Gathers value from each process and returns the aggregated value according to the supplied function.""" +def aggregate(value, aggregation_fn=np.mean): + """Gather `value` from all processes and send to `aggregation_fn` to get a single return value.""" if torch.distributed.is_initialized(): gathered_values = [None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(gathered_values, value) - return aggregation_fn(gathered_values) else: - return value + gathered_values = [value] + return aggregation_fn(gathered_values) diff --git a/flair/nn/model.py b/flair/nn/model.py index c903c4000a..f670c969a0 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -17,6 +17,7 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset +from flair.distributed_utils import is_main_process from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -47,10 +48,6 @@ def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]: """ raise NotImplementedError - def forward(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: - """Wraps forward_loss to maintain compatibility with hooks.""" - return self.forward_loss(data_points) - @abstractmethod def evaluate( self, @@ -295,7 +292,7 @@ def evaluate( loader = DataLoader(data_points, batch_size=mini_batch_size) sentence_id = 0 - for batch in Tqdm.tqdm(loader): + for batch in Tqdm.tqdm(loader, disable=not is_main_process()): # remove any previously predicted labels for datapoint in batch: datapoint.remove_labels("predicted") diff --git a/flair/trainers/plugins/base.py b/flair/trainers/plugins/base.py index 663a78d6d3..dcf7240a83 100644 --- a/flair/trainers/plugins/base.py +++ b/flair/trainers/plugins/base.py @@ -13,6 +13,8 @@ cast, ) +from flair.distributed_utils import is_main_process + log = logging.getLogger("flair") @@ -184,6 +186,8 @@ def attach_to(self, pluggable: Pluggable): assert self._pluggable is None assert len(self._hook_handles) == 0 + if not is_main_process() and not self.attach_to_all_processes: + return self._pluggable = pluggable pluggable.append_plugin(self) @@ -252,6 +256,11 @@ def decorator_func(func: Callable): def pluggable(self) -> Optional[Pluggable]: return self._pluggable + @property + def attach_to_all_processes(self) -> bool: + """If set, the plugin will be attached to all processes when distributed, not just the main process.""" + return True + def __str__(self) -> str: return self.__class__.__name__ diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index 1936177835..cf1a21468a 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -29,6 +29,10 @@ def after_training_epoch(self, epoch, **kw): model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py index 2258844129..fdf4752bde 100644 --- a/flair/trainers/plugins/functional/linear_scheduler.py +++ b/flair/trainers/plugins/functional/linear_scheduler.py @@ -1,6 +1,8 @@ import logging from typing import Any +import torch.distributed + from flair.optim import LinearSchedulerWithWarmup from flair.trainers.plugins.base import TrainerPlugin @@ -34,7 +36,8 @@ def after_setup( ): """Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers.""" # calculate warmup steps - steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size + num_processes = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size / num_processes num_train_steps = int(steps_per_epoch * max_epochs) num_warmup_steps = int(num_train_steps * self.warmup_fraction) diff --git a/flair/trainers/plugins/functional/reduce_transformer_vocab.py b/flair/trainers/plugins/functional/reduce_transformer_vocab.py index 86759c2fe2..eed6d7f1b2 100644 --- a/flair/trainers/plugins/functional/reduce_transformer_vocab.py +++ b/flair/trainers/plugins/functional/reduce_transformer_vocab.py @@ -55,6 +55,10 @@ def save_model_at_the_end(self, **kw): elif (self.base_path / "final-model.pt").exists(): self.model.save(self.base_path / "final-model.pt", checkpoint=self.save_optimizer_state) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_transformer_embeddings(model: Model) -> list[TransformerEmbeddings]: embeddings = model.tars_embeddings if isinstance(model, FewshotClassifier) else getattr(model, "embeddings", None) diff --git a/flair/trainers/plugins/functional/weight_extractor.py b/flair/trainers/plugins/functional/weight_extractor.py index 4ba7c07621..5c9bd4c4ac 100644 --- a/flair/trainers/plugins/functional/weight_extractor.py +++ b/flair/trainers/plugins/functional/weight_extractor.py @@ -21,6 +21,10 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw): if (iteration + 1) % modulo == 0: self.weight_extractor.extract_weights(self.model.state_dict(), iteration) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/loggers/clearml_logger.py b/flair/trainers/plugins/loggers/clearml_logger.py index 891b9f9244..18228d2db6 100644 --- a/flair/trainers/plugins/loggers/clearml_logger.py +++ b/flair/trainers/plugins/loggers/clearml_logger.py @@ -40,3 +40,7 @@ def metric_recorded(self, record: MetricRecord) -> None: self.logger.report_text(record.value, print_console=False) elif record.is_histogram: self.logger.report_histogram(record_name, record_name, record.value, record.global_step) + + @property + def attach_to_all_processes(self) -> bool: + return False diff --git a/flair/trainers/plugins/loggers/log_file.py b/flair/trainers/plugins/loggers/log_file.py index 21a8c54632..9bf22a284d 100644 --- a/flair/trainers/plugins/loggers/log_file.py +++ b/flair/trainers/plugins/loggers/log_file.py @@ -21,5 +21,9 @@ def close_file_handler(self, **kw): self.log_handler.close() log.removeHandler(self.log_handler) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return {**super().get_state(), "base_path": str(self.base_path)} diff --git a/flair/trainers/plugins/loggers/loss_file.py b/flair/trainers/plugins/loggers/loss_file.py index b53a23a956..bfe938b72d 100644 --- a/flair/trainers/plugins/loggers/loss_file.py +++ b/flair/trainers/plugins/loggers/loss_file.py @@ -113,3 +113,7 @@ def after_evaluation(self, epoch, **kw): f.write("\t".join([str(self.current_row[col]) for col in self.headers]) + "\n") self.current_row = {} + + @property + def attach_to_all_processes(self) -> bool: + return False diff --git a/flair/trainers/plugins/loggers/metric_history.py b/flair/trainers/plugins/loggers/metric_history.py index a22cf1b0e7..426e055186 100644 --- a/flair/trainers/plugins/loggers/metric_history.py +++ b/flair/trainers/plugins/loggers/metric_history.py @@ -34,6 +34,10 @@ def after_training(self, **kw): """Returns metric history.""" self.trainer.return_values.update(self.metric_history) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/loggers/tensorboard.py b/flair/trainers/plugins/loggers/tensorboard.py index a7af50a521..bf2dfcc29d 100644 --- a/flair/trainers/plugins/loggers/tensorboard.py +++ b/flair/trainers/plugins/loggers/tensorboard.py @@ -59,6 +59,10 @@ def _training_finally(self, **kw): assert self.writer is not None self.writer.close() + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/loggers/wandb.py b/flair/trainers/plugins/loggers/wandb.py index 0f8dc89f73..410d0377ba 100644 --- a/flair/trainers/plugins/loggers/wandb.py +++ b/flair/trainers/plugins/loggers/wandb.py @@ -72,6 +72,10 @@ def metric_recorded(self, record): def _training_finally(self, **kw): self.writer.close() + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index ae5e30c712..36c4f14295 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -21,7 +21,7 @@ import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader -from flair.distributed_utils import aggregate_if_distributed, is_main_process, launch_distributed +from flair.distributed_utils import aggregate, is_main_process from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -205,12 +205,7 @@ def train( "kwargs", ]: local_variables.pop(var) - - if multi_gpu: - self._event_queue = None # Each process will make its own queue rather than share - return launch_distributed(self.train_custom, **local_variables, **kwargs) - else: - return self.train_custom(**local_variables, **kwargs) + return self.train_custom(**local_variables, **kwargs) def fine_tune( self, @@ -265,21 +260,48 @@ def fine_tune( if attach_default_scheduler: plugins.append(LinearSchedulerPlugin(warmup_fraction=warmup_fraction)) - # call self.train_custom with all parameters (minus the ones specific to the LinearSchedulerPlugin) - local_variables = locals() - for var in [ - "self", - "warmup_fraction", - "attach_default_scheduler", - "kwargs", - ]: - local_variables.pop(var) - - if multi_gpu: - self._event_queue = None - return launch_distributed(self.train_custom, **local_variables, **kwargs) - else: - return self.train_custom(**local_variables, **kwargs) + return self.train_custom( + base_path=base_path, + # training parameters + learning_rate=learning_rate, + decoder_learning_rate=decoder_learning_rate, + mini_batch_size=mini_batch_size, + eval_batch_size=eval_batch_size, + mini_batch_chunk_size=mini_batch_chunk_size, + max_epochs=max_epochs, + optimizer=optimizer, + train_with_dev=train_with_dev, + train_with_test=train_with_test, + reduce_transformer_vocab=reduce_transformer_vocab, + # evaluation and monitoring + main_evaluation_metric=main_evaluation_metric, + monitor_test=monitor_test, + monitor_train_sample=monitor_train_sample, + use_final_model_for_eval=use_final_model_for_eval, + gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, + exclude_labels=exclude_labels, + # sampling and shuffling + sampler=sampler, + shuffle=shuffle, + shuffle_first_epoch=shuffle_first_epoch, + # evaluation and monitoring + embeddings_storage_mode=embeddings_storage_mode, + epoch=epoch, + # when and what to save + save_final_model=save_final_model, + save_optimizer_state=save_optimizer_state, + save_model_each_k_epochs=save_model_each_k_epochs, + # logging parameters + create_file_logs=create_file_logs, + create_loss_file=create_loss_file, + write_weights=write_weights, + # scaling + use_amp=use_amp, + multi_gpu=multi_gpu, + # plugins + plugins=plugins, + **kwargs, + ) def train_custom( self, @@ -363,7 +385,7 @@ def train_custom( create_file_logs: If True, logging output is written to a file create_loss_file: If True, a loss file logging output is created use_amp: If True, uses the torch automatic mixed precision - multi_gpu: If True, uses all available GPUs + multi_gpu: If True, distributes training across local GPUs write_weights: If True, write weights to weights.txt on each batch logging event. plugins: Any additional plugins you want to pass to the trainer **kwargs: Additional arguments, for instance for the optimizer @@ -409,11 +431,6 @@ def train_custom( base_path=base_path, ).attach_to(self) - if multi_gpu: - self.model.to(flair.device) - self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index]) - self._event_queue = Queue() # Each process uses its own queue rather than share - log.disabled = not is_main_process() # Disable logging in distributed mode for all but the main process # === END BLOCK: ACTIVATE PLUGINS === # # derive parameters the function was called with (or defaults) @@ -475,6 +492,16 @@ def train_custom( sampler.set_dataset(train_data) shuffle = False + # configure special behavior to use multiple GPUs + if multi_gpu: + if not torch.distributed.is_initialized(): + raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()") + self.ddp_model = DistributedDataParallel( + self.model, device_ids=[flair.device.index], find_unused_parameters=True + ) + log.disabled = not is_main_process() # Only print logs once + original_forward = self.model.forward + # this field stores the names of all dynamic embeddings in the model (determined after first forward pass) dynamic_embeddings = None @@ -502,7 +529,9 @@ def train_custom( if use_final_model_for_eval else "model from best epoch (best-model.pt)" ) - computation_device_info = f"{torch.cuda.device_count()} GPUs" if multi_gpu else flair.device + computation_device_info = aggregate( + flair.device, lambda devices: ", ".join([str(device) for device in devices]) + ) log_line(log) log.info(f'Model: "{self.model}"') @@ -556,13 +585,16 @@ def train_custom( shuffle_data_this_epoch = False if multi_gpu: + distributed_sampler: DistributedSampler = DistributedSampler( + train_data, shuffle=shuffle_data_this_epoch + ) + distributed_sampler.set_epoch(epoch - 1) batch_loader = DataLoader( train_data, batch_size=mini_batch_size, shuffle=False, - sampler=DistributedSampler(train_data, shuffle=shuffle_data_this_epoch), + sampler=distributed_sampler, ) - batch_loader.sampler.set_epoch(epoch - 1) else: batch_loader = DataLoader( train_data, @@ -608,6 +640,14 @@ def train_custom( # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): if multi_gpu: + # We need to __call__ ddp_model() because this triggers hooks that sync gradients. + # But that calls forward rather than forward_loss. So we patch forward to redirect + # to forward_loss. Then undo the patch in case forward_loss itself calls forward. + def wrapped_forward_loss(*args, **kwargs2): + self.model.forward = original_forward + return self.model.forward_loss(*args, **kwargs2) + + self.model.forward = wrapped_forward_loss loss, datapoint_count = self.ddp_model(batch_step) else: loss, datapoint_count = self.model.forward_loss(batch_step) @@ -656,11 +696,11 @@ def train_custom( if epoch_train_samples > 0 else epoch_train_samples / (batch_no + 1) ) - intermittent_loss = aggregate_if_distributed(intermittent_loss) + intermittent_loss = aggregate(intermittent_loss) current_time = time.time() samples_per_second = epoch_train_samples / (current_time - epoch_start_time) - samples_per_second = aggregate_if_distributed(samples_per_second, np.sum) + samples_per_second = aggregate(samples_per_second, np.sum) lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count) log.info( @@ -677,7 +717,7 @@ def train_custom( self.dispatch("after_training_batch", **batch_kw) train_loss = epoch_train_loss / epoch_train_samples - train_loss = aggregate_if_distributed(train_loss) + train_loss = aggregate(train_loss) self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch)) total_train_samples += epoch_train_samples @@ -835,9 +875,7 @@ def train_custom( def _get_current_lr_and_momentum(self, batch_count): current_learning_rate = [group["lr"] for group in self.optimizer.param_groups] - current_learning_rate = [aggregate_if_distributed(m) for m in current_learning_rate] momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups] - momentum = [aggregate_if_distributed(m) for m in momentum] lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate]) momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum]) self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count)) @@ -936,4 +974,6 @@ def _load_model(self, model_file: Union[str, Path]) -> None: """Loads the model from the given file into the current state. Safe to call from a distributed context.""" self.model.load_state_dict(self.model.load(model_file).state_dict()) if torch.distributed.is_initialized(): - self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index]) + self.ddp_model = DistributedDataParallel( + self.model, device_ids=[flair.device.index], find_unused_parameters=True + ) From 0cbeb9fb786e27bffbca034131d4f09bf120c646 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Mon, 28 Oct 2024 18:11:00 -0700 Subject: [PATCH 018/333] fix transformers switching attention implementation --- flair/embeddings/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 245d528e5d..c43258a193 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1356,7 +1356,7 @@ def to_params(self): # do not switch the attention implementation upon reload. config_dict["attn_implementation"] = self.model.config._attn_implementation - del config_dict["_attn_implementation_autoset"] + config_dict.pop("_attn_implementation_autoset", None) super_params = super().to_params() From a590cd6bb02c57f6a700179390650a2ca4329447 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Mon, 28 Oct 2024 18:11:08 -0700 Subject: [PATCH 019/333] GH-3429: Rename parameter group --- flair/trainers/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 36c4f14295..86a0ce3527 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -168,7 +168,7 @@ def train( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # scaling + # acceleration multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, @@ -244,7 +244,7 @@ def fine_tune( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # scaling + # acceleration use_amp: bool = False, multi_gpu: bool = False, # plugins @@ -295,7 +295,7 @@ def fine_tune( create_file_logs=create_file_logs, create_loss_file=create_loss_file, write_weights=write_weights, - # scaling + # acceleration use_amp=use_amp, multi_gpu=multi_gpu, # plugins @@ -340,7 +340,7 @@ def train_custom( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # scaling + # acceleration use_amp: bool = False, multi_gpu: bool = False, # plugins From ffb8e297a70f77f57405dffaa5c61cdb79f6ab39 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Fri, 8 Nov 2024 02:45:23 -0800 Subject: [PATCH 020/333] GH-3429: Add example to docs; Guard against corpus being different on each process --- examples/README.md | 5 +-- examples/multi_gpu/README.md | 32 +++++++++++++++++ examples/multi_gpu/__init__.py | 0 examples/multi_gpu/run_multi_gpu.py | 54 +++++++++++++++++++++++++++++ flair/data.py | 5 +++ flair/distributed_utils.py | 19 ++++++++++ flair/nn/model.py | 6 +++- flair/trainers/trainer.py | 30 ++++------------ tests/test_sentence.py | 10 ++++-- 9 files changed, 133 insertions(+), 28 deletions(-) create mode 100644 examples/multi_gpu/README.md create mode 100644 examples/multi_gpu/__init__.py create mode 100644 examples/multi_gpu/run_multi_gpu.py diff --git a/examples/README.md b/examples/README.md index 53e11ffe21..1221fbf49f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,6 +4,7 @@ This folder contains actively maintained examples of use of Flair, organized alo ## Table of Tasks -| Task | Documentation -| ----------------------------- | ------------- +| Task | Documentation +| ------------------------------ | ------------- | Named Entity Recognition (NER) | [Here](ner/) +| Multi GPU | [Here](multi_gpu/) diff --git a/examples/multi_gpu/README.md b/examples/multi_gpu/README.md new file mode 100644 index 0000000000..4c12b9d0bd --- /dev/null +++ b/examples/multi_gpu/README.md @@ -0,0 +1,32 @@ +# Multi GPU + +Training can be distributed across multiple GPUs on a local machine when using +[`ModelTrainer`](#flair.trainers.trainer.ModelTrainer). + +## Example + +See the script `run_multi_gpu.py` and its comments. + +## Tutorial + +There are 2 changes that are always required, as well as a few things to consider + +Always Required: +1) Pass the argument `multi_gpu=True` to your [`.train()`](#flair.trainers.trainer.ModelTrainer.train) or `.fine_tune()` +2) Wrap your code in [`launch_distributed`](#flair.distributed_utils.launch_distributed), e.g. + `launch_distributed(main, *args)`. This spawns multiple processes, each driving a GPU + +Other considerations: +- The corpus and other preprocessing must be the same on all processes. For example, if corpus initialization involves + anything random, you should either + - Set the random seed before initializing the corpus (e.g. [`flair.set_seed(42)`) OR + - Initialize the corpus before calling `launch_distributed` and pass the corpus as an argument so it's serialized to + all processes +- The effective batch size will be larger by a factor of num_gpus + - Each GPU will now process `mini_batch_size` examples before the optimizer steps, resulting in fewer total steps + taken relative to training with a single device. To obtain comparable results between single/multi gpu, + both mathematically, and in terms of wall time, consider the method in the example script. +- Large batch sizes may be necessary to see faster runs, otherwise the communication overhead may dominate + +Only the parameter updates in the training process will be distributed across multiple GPUs. Evaluation and prediction +are still done on a single device. diff --git a/examples/multi_gpu/__init__.py b/examples/multi_gpu/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/multi_gpu/run_multi_gpu.py b/examples/multi_gpu/run_multi_gpu.py new file mode 100644 index 0000000000..bba529b1b2 --- /dev/null +++ b/examples/multi_gpu/run_multi_gpu.py @@ -0,0 +1,54 @@ +import torch + +import flair +from flair.datasets import IMDB +from flair.distributed_utils import launch_distributed +from flair.embeddings import TransformerDocumentEmbeddings +from flair.models import TextClassifier +from flair.trainers import ModelTrainer + + +def main(multi_gpu): + # Note: Multi-GPU can affect corpus loading + # This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to + # ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using + # the same seed on all processes. + flair.set_seed(1336) + + corpus = IMDB() + corpus.downsample(0.01) + label_type = "sentiment" + label_dictionary = corpus.make_label_dictionary(label_type) + + embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") + model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary) + + # Note: Multi-GPU can affect choice of batch size. + # In order to compare batch updates fairly between single and multi-GPU training, we should: + # 1) Step the optimizer after the same number of examples to achieve com + # 2) Process the same number of examples in each forward pass + mini_batch_chunk_size = 32 # Make this as large as possible without running out of GPU-memory to pack device + num_devices_when_distributing = max(torch.cuda.device_count(), 1) + mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_devices_when_distributing + # e.g. Suppose your machine has 2 GPUs. If multi_gpu=False, the first gpu will process 32 examples, then the + # first gpu will process another 32 examples, then the optimizer will step. If multi_gpu=True, each gpu will + # process 32 examples at the same time, then the optimizer will step. + + trainer = ModelTrainer(model, corpus) + trainer.train( + "resources/taggers/multi-gpu", + multi_gpu=multi_gpu, # Required for multi-gpu + max_epochs=2, + mini_batch_chunk_size=mini_batch_chunk_size, + mini_batch_size=mini_batch_size, + ) + + +if __name__ == "__main__": + """Minimal example demonstrating how to train a model on multiple GPUs.""" + multi_gpu = True + + if multi_gpu: + launch_distributed(main, multi_gpu) # Required for multi-gpu + else: + main(multi_gpu) diff --git a/flair/data.py b/flair/data.py index 56622b249c..ce97a4ce91 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1072,6 +1072,11 @@ def __len__(self) -> int: def __repr__(self) -> str: return self.__str__() + def __eq__(self, o: object) -> bool: + if not isinstance(o, Sentence): + return False + return self.to_dict() == o.to_dict() + @property def start_position(self) -> int: return self._start_position diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index d391251b31..f3296eda58 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -1,5 +1,6 @@ import logging import os +import random from multiprocessing.connection import Connection from typing import Callable @@ -7,8 +8,10 @@ import torch import torch.multiprocessing as mp from torch.distributed import destroy_process_group, init_process_group +from torch.utils.data import Dataset import flair +from flair.data import Corpus, _len_dataset log = logging.getLogger("flair") @@ -66,3 +69,19 @@ def aggregate(value, aggregation_fn=np.mean): else: gathered_values = [value] return aggregation_fn(gathered_values) + + +def validate_corpus_same_across_processes(corpus: Corpus) -> None: + for dataset in [corpus.train, corpus.dev, corpus.test]: + if dataset is not None: + validate_dataset_same_across_processes(dataset) + + +def validate_dataset_same_across_processes(dataset: Dataset, sample_size: int = 10) -> None: + """Sanity checks a few examples to catch datasets that are obviously different, but not exhaustive to save time.""" + random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset))) + for i in random_indices: + example = dataset[i] + examples = aggregate(example, list) + if not all(example == examples[0] for example in examples): + raise ValueError("Dataset must be the same on each process") diff --git a/flair/nn/model.py b/flair/nn/model.py index f670c969a0..40761b5f52 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -126,7 +126,11 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: model_state["model_card"] = self.model_card # save model - torch.save(model_state, str(model_file), pickle_protocol=4) + if is_main_process(): + torch.save(model_state, str(model_file), pickle_protocol=4) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete @classmethod def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 86a0ce3527..25c33491a1 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -21,7 +21,7 @@ import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader -from flair.distributed_utils import aggregate, is_main_process +from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_across_processes from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -496,6 +496,8 @@ def train_custom( if multi_gpu: if not torch.distributed.is_initialized(): raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()") + # Guard against each process initializing corpus differently due to e.g. different random seeds + validate_corpus_same_across_processes(self.corpus) self.ddp_model = DistributedDataParallel( self.model, device_ids=[flair.device.index], find_unused_parameters=True ) @@ -788,14 +790,14 @@ def wrapped_forward_loss(*args, **kwargs2): if save_best_model and current_epoch_has_best_model_so_far: log.info("saving best model") - self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state) + self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) # - SWAPlugin -> restores SGD weights from SWA self.dispatch("after_training_loop") # if we do not use dev data for model selection, save final model if save_final_model: - self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) @@ -805,7 +807,7 @@ def wrapped_forward_loss(*args, **kwargs2): if save_final_model: log.info("Saving model ...") - self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except TrainingInterrupt as exc: @@ -816,7 +818,7 @@ def wrapped_forward_loss(*args, **kwargs2): if save_final_model: log.info("Saving model ...") - self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except Exception: @@ -957,23 +959,5 @@ def _initialize_model_card(self, **training_parameters): def _record(self, metric): self.dispatch("metric_recorded", metric) - def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: - """Saves the current model. Safe to call from a distributed context. - - Args: - model_file: the model file - checkpoint: currently unused. - """ - if is_main_process(): - self.model.save(model_file, checkpoint) - - if torch.distributed.is_initialized(): - torch.distributed.barrier() # Prevent any process from loading a model until writing is complete - def _load_model(self, model_file: Union[str, Path]) -> None: - """Loads the model from the given file into the current state. Safe to call from a distributed context.""" self.model.load_state_dict(self.model.load(model_file).state_dict()) - if torch.distributed.is_initialized(): - self.ddp_model = DistributedDataParallel( - self.model, device_ids=[flair.device.index], find_unused_parameters=True - ) diff --git a/tests/test_sentence.py b/tests/test_sentence.py index 3e31422648..90f7b2f204 100644 --- a/tests/test_sentence.py +++ b/tests/test_sentence.py @@ -13,9 +13,15 @@ def test_sentence_context(): def test_equality(): assert Sentence("Guten Tag!") != Sentence("Good day!") assert Sentence("Guten Tag!", use_tokenizer=True) != Sentence("Guten Tag!", use_tokenizer=False) + sentence1 = Sentence("This sentence will be labeled") + sentence1[1].set_label("ner", "B-subject") + sentence2 = Sentence("This sentence will be labeled") + sentence2[1].set_label("ner", "B-object") + assert sentence1 != sentence2 - # TODO: is this desirable? Or should two sentences with same text be considered same objects? - assert Sentence("Guten Tag!") != Sentence("Guten Tag!") + assert Sentence("Guten Tag!") == Sentence("Guten Tag!") + sentence2[1].set_label("ner", "B-subject") + assert sentence1 == sentence2 def test_token_labeling(): From 9a82b0eb9cbd6ed65996f23e7d3f710b53b0b197 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Thu, 14 Nov 2024 11:58:34 -0700 Subject: [PATCH 021/333] GH-3429: Remove serializability requirement --- examples/multi_gpu/run_multi_gpu.py | 6 +++--- flair/data.py | 5 ----- flair/distributed_utils.py | 11 ++++++----- flair/trainers/trainer.py | 7 +++---- tests/test_sentence.py | 10 ++-------- 5 files changed, 14 insertions(+), 25 deletions(-) diff --git a/examples/multi_gpu/run_multi_gpu.py b/examples/multi_gpu/run_multi_gpu.py index bba529b1b2..f7d059111b 100644 --- a/examples/multi_gpu/run_multi_gpu.py +++ b/examples/multi_gpu/run_multi_gpu.py @@ -13,10 +13,10 @@ def main(multi_gpu): # This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to # ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using # the same seed on all processes. - flair.set_seed(1336) + flair.set_seed(42) corpus = IMDB() - corpus.downsample(0.01) + corpus.downsample(0.1) label_type = "sentiment" label_dictionary = corpus.make_label_dictionary(label_type) @@ -35,7 +35,7 @@ def main(multi_gpu): # process 32 examples at the same time, then the optimizer will step. trainer = ModelTrainer(model, corpus) - trainer.train( + trainer.fine_tune( "resources/taggers/multi-gpu", multi_gpu=multi_gpu, # Required for multi-gpu max_epochs=2, diff --git a/flair/data.py b/flair/data.py index ce97a4ce91..56622b249c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1072,11 +1072,6 @@ def __len__(self) -> int: def __repr__(self) -> str: return self.__str__() - def __eq__(self, o: object) -> bool: - if not isinstance(o, Sentence): - return False - return self.to_dict() == o.to_dict() - @property def start_position(self) -> int: return self._start_position diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index f3296eda58..e774084009 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -71,17 +71,18 @@ def aggregate(value, aggregation_fn=np.mean): return aggregation_fn(gathered_values) -def validate_corpus_same_across_processes(corpus: Corpus) -> None: +def validate_corpus_same_each_process(corpus: Corpus) -> None: + """Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two + reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable""" for dataset in [corpus.train, corpus.dev, corpus.test]: if dataset is not None: - validate_dataset_same_across_processes(dataset) + _validate_dataset_same_each_process(dataset) -def validate_dataset_same_across_processes(dataset: Dataset, sample_size: int = 10) -> None: - """Sanity checks a few examples to catch datasets that are obviously different, but not exhaustive to save time.""" +def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None: random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset))) for i in random_indices: - example = dataset[i] + example = str(dataset[i]) examples = aggregate(example, list) if not all(example == examples[0] for example in examples): raise ValueError("Dataset must be the same on each process") diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 25c33491a1..37c63f5d33 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -7,8 +7,7 @@ import warnings from inspect import signature from pathlib import Path -from queue import Queue -from typing import Optional, Tuple, Type, Union +from typing import Optional, Union import numpy as np import torch @@ -21,7 +20,7 @@ import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader -from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_across_processes +from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_each_process from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -497,7 +496,7 @@ def train_custom( if not torch.distributed.is_initialized(): raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()") # Guard against each process initializing corpus differently due to e.g. different random seeds - validate_corpus_same_across_processes(self.corpus) + validate_corpus_same_each_process(self.corpus) self.ddp_model = DistributedDataParallel( self.model, device_ids=[flair.device.index], find_unused_parameters=True ) diff --git a/tests/test_sentence.py b/tests/test_sentence.py index 90f7b2f204..3e31422648 100644 --- a/tests/test_sentence.py +++ b/tests/test_sentence.py @@ -13,15 +13,9 @@ def test_sentence_context(): def test_equality(): assert Sentence("Guten Tag!") != Sentence("Good day!") assert Sentence("Guten Tag!", use_tokenizer=True) != Sentence("Guten Tag!", use_tokenizer=False) - sentence1 = Sentence("This sentence will be labeled") - sentence1[1].set_label("ner", "B-subject") - sentence2 = Sentence("This sentence will be labeled") - sentence2[1].set_label("ner", "B-object") - assert sentence1 != sentence2 - assert Sentence("Guten Tag!") == Sentence("Guten Tag!") - sentence2[1].set_label("ner", "B-subject") - assert sentence1 == sentence2 + # TODO: is this desirable? Or should two sentences with same text be considered same objects? + assert Sentence("Guten Tag!") != Sentence("Guten Tag!") def test_token_labeling(): From 952e9fa445464bcd70a654008b3ced5136376da9 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Thu, 21 Nov 2024 16:38:59 -0700 Subject: [PATCH 022/333] GH-3429: Fix checkpointing --- flair/nn/model.py | 6 +----- flair/trainers/plugins/functional/checkpoints.py | 4 ++++ flair/trainers/trainer.py | 14 ++++++++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 40761b5f52..f670c969a0 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -126,11 +126,7 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: model_state["model_card"] = self.model_card # save model - if is_main_process(): - torch.save(model_state, str(model_file), pickle_protocol=4) - - if torch.distributed.is_initialized(): - torch.distributed.barrier() # Prevent any process from loading a model until writing is complete + torch.save(model_state, str(model_file), pickle_protocol=4) @classmethod def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index cf1a21468a..a8179edbc5 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -1,6 +1,8 @@ import logging from typing import Any +import torch + from flair.trainers.plugins.base import TrainerPlugin log = logging.getLogger("flair") @@ -28,6 +30,8 @@ def after_training_epoch(self, epoch, **kw): ) model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state) + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete @property def attach_to_all_processes(self) -> bool: diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 37c63f5d33..cf3bb80eb6 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -789,14 +789,14 @@ def wrapped_forward_loss(*args, **kwargs2): if save_best_model and current_epoch_has_best_model_so_far: log.info("saving best model") - self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state) # - SWAPlugin -> restores SGD weights from SWA self.dispatch("after_training_loop") # if we do not use dev data for model selection, save final model if save_final_model: - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) @@ -806,7 +806,7 @@ def wrapped_forward_loss(*args, **kwargs2): if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except TrainingInterrupt as exc: @@ -817,7 +817,7 @@ def wrapped_forward_loss(*args, **kwargs2): if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except Exception: @@ -960,3 +960,9 @@ def _record(self, metric): def _load_model(self, model_file: Union[str, Path]) -> None: self.model.load_state_dict(self.model.load(model_file).state_dict()) + + def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: + if is_main_process(): + self.model.save(model_file, checkpoint) + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete From 83238458c44333a97e751925289cc4c94a21b575 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Fri, 22 Nov 2024 16:21:25 +0100 Subject: [PATCH 023/333] GH-3561: Update SECURITY.md with current contact --- SECURITY.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index 5b9338fe9d..7c578c0b3f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,8 +1,4 @@ We acknowledge that every line of code that we write may potentially contain security issues. We are trying to deal with it responsibly and provide patches as quickly as possible. -We host our bug bounty program on HackerOne, it is currently private, therefore if you would like to report a vulnerability and get rewarded for it, please ask to join our program by filling this form: - -https://corporate.zalando.com/en/services-and-contact#security-form - -You can also send you report via this form if you do not want to join our bug bounty program and just want to report a vulnerability or security issue. +Please report any issues to [Alan Akbik](http://alanakbik.github.io/). From f8696b31e3b89432c251bb021b3128c820d0b7f9 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 22 Nov 2024 16:21:43 +0100 Subject: [PATCH 024/333] remove clustering support --- flair/models/__init__.py | 2 - flair/models/clustering.py | 120 --------------- resources/docs/TUTORIAL_12_CLUSTERING.md | 180 ----------------------- 3 files changed, 302 deletions(-) delete mode 100644 flair/models/clustering.py delete mode 100644 resources/docs/TUTORIAL_12_CLUSTERING.md diff --git a/flair/models/__init__.py b/flair/models/__init__.py index 8357cc47ea..e75daf074b 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,3 @@ -from .clustering import ClusteringModel from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -37,6 +36,5 @@ "TARSTagger", "TextClassifier", "TextRegressor", - "ClusteringModel", "MultitaskModel", ] diff --git a/flair/models/clustering.py b/flair/models/clustering.py deleted file mode 100644 index e9902f6f67..0000000000 --- a/flair/models/clustering.py +++ /dev/null @@ -1,120 +0,0 @@ -import logging -import pickle -from collections import OrderedDict -from pathlib import Path -from typing import Optional, Union - -import joblib -from sklearn.base import BaseEstimator, ClusterMixin -from sklearn.metrics import normalized_mutual_info_score -from tqdm import tqdm - -from flair.data import Corpus, _iter_dataset -from flair.datasets import DataLoader -from flair.embeddings import DocumentEmbeddings - -log = logging.getLogger("flair") - - -class ClusteringModel: - """A wrapper class to apply sklearn clustering models on DocumentEmbeddings.""" - - def __init__(self, model: Union[ClusterMixin, BaseEstimator], embeddings: DocumentEmbeddings) -> None: - """Instantiate the ClusteringModel. - - Args: - model: the clustering algorithm from sklearn this wrapper will use. - embeddings: the flair DocumentEmbedding this wrapper uses to calculate a vector for each sentence. - """ - self.model = model - self.embeddings = embeddings - - def fit(self, corpus: Corpus, **kwargs): - """Trains the model. - - Args: - corpus: the flair corpus this wrapper will use for fitting the model. - **kwargs: parameters propagated to the models `.fit()` method. - """ - X = self._convert_dataset(corpus) - - log.info("Start clustering " + str(self.model) + " with " + str(len(X)) + " Datapoints.") - self.model.fit(X, **kwargs) - log.info("Finished clustering.") - - def predict(self, corpus: Corpus): - """Predict labels given a list of sentences and returns the respective class indices. - - Args: - corpus: the flair corpus this wrapper will use for predicting the labels. - """ - X = self._convert_dataset(corpus) - log.info("Start the prediction " + str(self.model) + " with " + str(len(X)) + " Datapoints.") - predict = self.model.predict(X) - - for idx, sentence in enumerate(_iter_dataset(corpus.get_all_sentences())): - sentence.set_label("cluster", str(predict[idx])) - - log.info("Finished prediction and labeled all sentences.") - return predict - - def save(self, model_file: Union[str, Path]): - """Saves current model. - - Args: - model_file: path where to save the model. - """ - joblib.dump(pickle.dumps(self), str(model_file)) - - log.info("Saved the model to: " + str(model_file)) - - @staticmethod - def load(model_file: Union[str, Path]): - """Loads a model from a given path. - - Args: - model_file: path to the file where the model is saved. - """ - log.info("Loading model from: " + str(model_file)) - return pickle.loads(joblib.load(str(model_file))) - - def _convert_dataset( - self, corpus, label_type: Optional[str] = None, batch_size: int = 32, return_label_dict: bool = False - ): - """Makes a flair-corpus sklearn compatible. - - Turns the corpora into X, y datasets as required for most sklearn clustering models. - Ref.: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.cluster - """ - log.info("Embed sentences...") - sentences = [] - for batch in tqdm(DataLoader(corpus.get_all_sentences(), batch_size=batch_size)): - self.embeddings.embed(batch) - sentences.extend(batch) - - X = [sentence.embedding.cpu().detach().numpy() for sentence in sentences] - - if label_type is None: - return X - - labels = [sentence.get_labels(label_type)[0].value for sentence in sentences] - label_dict = {v: k for k, v in enumerate(OrderedDict.fromkeys(labels))} - y = [label_dict.get(label) for label in labels] - - if return_label_dict: - return X, y, label_dict - - return X, y - - def evaluate(self, corpus: Corpus, label_type: str): - """This method calculates some evaluation metrics for the clustering. - - Also, the result of the evaluation is logged. - - Args: - corpus: the flair corpus this wrapper will use for evaluation. - label_type: the label from the sentence will be used for the evaluation. - """ - X, Y = self._convert_dataset(corpus, label_type=label_type) - predict = self.model.predict(X) - log.info("NMI - Score: " + str(normalized_mutual_info_score(predict, Y))) diff --git a/resources/docs/TUTORIAL_12_CLUSTERING.md b/resources/docs/TUTORIAL_12_CLUSTERING.md deleted file mode 100644 index 376e5d5639..0000000000 --- a/resources/docs/TUTORIAL_12_CLUSTERING.md +++ /dev/null @@ -1,180 +0,0 @@ -Text Clustering in flair ----------- - -In this package text clustering is implemented. This module has the following -clustering algorithms implemented: -- k-Means -- BIRCH -- Expectation Maximization - -Each of the implemented algorithm needs to have an instanced DocumentEmbedding. This embedding will -transform each text/document to a vector. With these vectors the clustering algorithm can be performed. - ---------------------------- - -k-Means ------- -k-Means is a classical and well known clustering algorithm. k-Means is a partitioning-based Clustering algorithm. -The user defines with the parameter *k* how many clusters the given data has. -So the choice of *k* is very important. -More about k-Means can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html). - - -```python -from flair.models import ClusteringModel -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from sklearn.cluster import KMeans - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = KMeans(n_clusters=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -``` - -BIRCH ---------- -BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm. -BIRCH is specialized to handle large amounts of data. BIRCH scans the data a single time and builds an internal data -structure. This data structure contains the data but in a compressed way. -More about BIRCH can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html). - -```python -from sklearn.cluster import Birch -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from flair.models import ClusteringModel - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = Birch(n_clusters=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -``` - - -Expectation Maximization --------------------------- -Expectation Maximization (EM) is a different class of clustering algorithms called soft clustering algorithms. -Here each point isn't directly assigned to a cluster by a hard decision. -Each data point has a probability to which cluster the data point belongs. The Expectation Maximization (EM) -algorithm is a soft clustering algorithm. -More about EM can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html). - - -```python -from sklearn.mixture import GaussianMixture -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from flair.models import ClusteringModel - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = GaussianMixture(n_components=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -``` - ---------------------------- - -Loading/Saving the model ------------ - -The model can be saved and loaded. The code below shows how to save a model. -```python -from flair.models import ClusteringModel -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from sklearn.cluster import KMeans - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = KMeans(n_clusters=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# save the model -clustering_model.save(model_file="clustering_model.pt") -``` - -The code for loading a model. - -````python -# load saved clustering model -model = ClusteringModel.load(model_file="clustering_model.pt") - -# load a corpus -corpus = TREC_6(memory_mode='full').downsample(0.05) - -# predict the corpus -model.predict(corpus) -```` - ---------------------- - -Evaluation ---------- -The result of the clustering can be evaluated. For this we will use the -[NMI](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html). -(Normalized Mutual Info) score. - -````python -# need to fit() the model first -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -```` - -The result of the evaluation can be seen below with the SentenceTransformerDocumentEmbeddings: - - -| Clustering Algorithm | Dataset | NMI | -|--------------------------|:-------------:|--------:| -| k Means | StackOverflow | ~0.2122 | -| BIRCH | StackOverflow | ~0,2424 | -| Expectation Maximization | 20News group | ~0,2222 | From 8f934a4c444ab53020a7b6daf9a99c5cee699228 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 22 Nov 2024 12:15:37 -0800 Subject: [PATCH 025/333] perf: optimize dictionary items by providing function to check for presence of label in constant time, rather than requiring the creation of a list to be checked. update slow uses of get_items() to use this new function. --- flair/data.py | 14 +++++++------- flair/models/lemmatizer_model.py | 2 +- flair/nn/model.py | 4 +--- tests/test_corpus_dictionary.py | 18 +++++++++--------- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/flair/data.py b/flair/data.py index 56622b249c..8a76a82c1a 100644 --- a/flair/data.py +++ b/flair/data.py @@ -122,18 +122,18 @@ def get_idx_for_items(self, items: list[str]) -> list[int]: return list(results) def get_items(self) -> list[str]: - items = [] - for item in self.idx2item: - items.append(item.decode("UTF-8")) - return items + return [item.decode("UTF-8") for item in self.idx2item] def __len__(self) -> int: return len(self.idx2item) - def get_item_for_index(self, idx): + def get_item_for_index(self, idx: int) -> str: return self.idx2item[idx].decode("UTF-8") - def set_start_stop_tags(self): + def has_item(self, item: str) -> bool: + return item.encode("utf-8") in self.item2idx + + def set_start_stop_tags(self) -> None: self.add_item("") self.add_item("") @@ -1659,7 +1659,7 @@ def make_label_dictionary( unked_count += count if len(label_dictionary.idx2item) == 0 or ( - len(label_dictionary.idx2item) == 1 and "" in label_dictionary.get_items() + len(label_dictionary.idx2item) == 1 and label_dictionary.has_item("") ): log.error(f"ERROR: You specified label_type='{label_type}' which is not in this dataset!") contained_labels = ", ".join( diff --git a/flair/models/lemmatizer_model.py b/flair/models/lemmatizer_model.py index 55fa34698c..16ec734941 100644 --- a/flair/models/lemmatizer_model.py +++ b/flair/models/lemmatizer_model.py @@ -491,7 +491,7 @@ def predict( for t_id, token in enumerate(tokens_in_batch): predicted_lemma = "".join( - self.char_dictionary.get_item_for_index(idx) if idx != self.end_index else "" + self.char_dictionary.get_item_for_index(int(idx)) if idx != self.end_index else "" for idx in predicted[t_id] ) token.set_label(typename=label_name, value=predicted_lemma) diff --git a/flair/nn/model.py b/flair/nn/model.py index f670c969a0..8f9d0aa09f 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -877,9 +877,7 @@ def predict( filtered_indices = [] has_unknown_label = False for idx, dp in enumerate(data_points): - if all( - label in self.label_dictionary.get_items() for label in self._get_label_of_datapoint(dp) - ): + if all(self.label_dictionary.has_item(label) for label in self._get_label_of_datapoint(dp)): filtered_indices.append(idx) else: has_unknown_label = True diff --git a/tests/test_corpus_dictionary.py b/tests/test_corpus_dictionary.py index cc65019b67..d107c0b62e 100644 --- a/tests/test_corpus_dictionary.py +++ b/tests/test_corpus_dictionary.py @@ -110,9 +110,9 @@ def test_tagged_corpus_make_vocab_dictionary(): vocab = corpus.make_vocab_dictionary(max_tokens=2, min_freq=-1) assert len(vocab) == 3 - assert "" in vocab.get_items() - assert "training" in vocab.get_items() - assert "." in vocab.get_items() + assert vocab.has_item("") + assert vocab.has_item("training") + assert vocab.has_item(".") vocab = corpus.make_vocab_dictionary(max_tokens=-1, min_freq=-1) @@ -121,9 +121,9 @@ def test_tagged_corpus_make_vocab_dictionary(): vocab = corpus.make_vocab_dictionary(max_tokens=-1, min_freq=2) assert len(vocab) == 3 - assert "" in vocab.get_items() - assert "training" in vocab.get_items() - assert "." in vocab.get_items() + assert vocab.has_item("") + assert vocab.has_item("training") + assert vocab.has_item(".") def test_label_set_confidence(): @@ -153,9 +153,9 @@ def test_tagged_corpus_make_label_dictionary(): label_dict = corpus.make_label_dictionary("label", add_unk=True) assert len(label_dict) == 3 - assert "" in label_dict.get_items() - assert "class_1" in label_dict.get_items() - assert "class_2" in label_dict.get_items() + assert label_dict.has_item("") + assert label_dict.has_item("class_1") + assert label_dict.has_item("class_2") with pytest.warns(DeprecationWarning): # test to make sure the warning comes, but function works corpus.make_tag_dictionary("label") From 1f831e08a03e8570af463f38c060a3ef68fac7ca Mon Sep 17 00:00:00 2001 From: MdMotahar Date: Wed, 1 May 2024 16:18:57 +0000 Subject: [PATCH 026/333] fixed _all_scores_for_token method to correctly calculate token probability --- flair/models/sequence_tagger_model.py | 2 +- tests/models/test_sequence_tagger.py | 35 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 2c0fc00ebb..68e8e95b57 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -589,7 +589,7 @@ def _all_scores_for_token(self, sentences: list[Sentence], score_tensor: torch.T previous = 0 for length in lengths: prob_tags_per_sentence.append(prob_all_tags[previous : previous + length]) - previous = length + previous += length return prob_tags_per_sentence def _get_state_dict(self): diff --git a/tests/models/test_sequence_tagger.py b/tests/models/test_sequence_tagger.py index 67bf976899..13cc15ddc1 100644 --- a/tests/models/test_sequence_tagger.py +++ b/tests/models/test_sequence_tagger.py @@ -1,4 +1,6 @@ import pytest +import torch +import torch.nn.functional as F import flair from flair.embeddings import FlairEmbeddings, WordEmbeddings @@ -121,3 +123,36 @@ def test_train_load_use_tagger_disjunct_tags( loaded_model.predict([example_sentence, self.empty_sentence]) loaded_model.predict([self.empty_sentence]) del loaded_model + + @pytest.mark.integration() + def test_all_token_prob_distribution(self, embeddings, corpus): + tag_dictionary = corpus.make_label_dictionary("ner", add_unk=False) + model = self.build_model(embeddings, tag_dictionary) + + # get features from forward propagation + sentences = [corpus.train[i] for i in range(len(corpus.train))] + + # reverse sort all sequences by their length + sentences = sorted(sentences, key=len, reverse=True) + + with torch.no_grad(): + sentence_tensor, lengths = model._prepare_tensors(sentences) + features = model.forward(sentence_tensor, lengths) + + # remove previously predicted labels of this type + for sentence in sentences: + sentence.remove_labels(model.label_type) + + softmax_batch = F.softmax(features, dim=1).cpu() + lengths = [len(sentence) for sentence in sentences] + all_tokens_prob_distrib = model._all_scores_for_token(sentences, softmax_batch, lengths) + + for i, sen_tokens_prob_distribution in enumerate(all_tokens_prob_distrib): + assert len(sen_tokens_prob_distribution) == lengths[i] + for token_prob_distrib, token in zip(sen_tokens_prob_distribution, sentences[i]): + assert len(token_prob_distrib) == len(model.label_dictionary) + score_sum = 0.0 + for token_label in token_prob_distrib: + assert token_label.data_point == token + score_sum += token_label.score + assert abs(score_sum - 1.0) < 1.0e-5 From 2c053af584ea48ff1fb6a7d9f287ebb8598ab8dc Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 10:55:18 +0200 Subject: [PATCH 027/333] changes --- docs/_static/css/header.css | 89 +++++++++++++++++++++++++++ docs/_static/css/main.css | 118 ++++++++++++++++++++++++++++++++++++ docs/_static/flair_logo.svg | 101 ++++++++++++++++++++++++++++++ docs/conf.py | 10 ++- docs/glossary/index.rst | 7 --- 5 files changed, 317 insertions(+), 8 deletions(-) create mode 100644 docs/_static/css/header.css create mode 100644 docs/_static/css/main.css create mode 100755 docs/_static/flair_logo.svg delete mode 100644 docs/glossary/index.rst diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css new file mode 100644 index 0000000000..2c086e1a40 --- /dev/null +++ b/docs/_static/css/header.css @@ -0,0 +1,89 @@ +.bd-header { + width: 100%; + height: var(--header-height); + background: var(--white-blue) !important; + .header-wrapper { + margin: 0 calc(10% + 10px); + height: 100%; + display: flex; + justify-content: space-between; + nav { + height: 100%; + ul { + display: flex; + height: 100%; + li { + height: 100%; + display: flex; + align-items: center; + a { + font-size: 1.1rem; + color: var(--orange-white); + svg { + height: 50%; + flex-grow: 1; + } + &.header-logo { + height: 100%; + display: flex; + align-items: center; + } + &.header-hover-underline { + position: relative; + &::after { + transition: width 150ms cubic-bezier(.17,.67,0,1); + content: ""; + width: 0; + height: 3px; + background-color: var(--orange-white); + position: absolute; + left: 50%; + transform: translateX(-50%); + bottom: -1rem; + } + &:hover::after { + width: 1rem; + } + } + } + margin-right: 3em; + } + } + } + .header-controls { + height: 100%; + ul { + height: 100%; + display: flex; + li { + margin-left: 3rem; + height: 100%; + display: flex; + align-items: center; + &:not(:last-of-type) { + cursor: pointer; + } + &:nth-of-type(2) { + svg { + &:first-of-type { + :root .dark-mode & { + display: none; + } + } + &:nth-of-type(2) { + display: none; + :root .dark-mode & { + display: block; + } + } + } + } + svg { + height: 35%; + width: auto; + } + } + } + } + } +} \ No newline at end of file diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css new file mode 100644 index 0000000000..0903efca63 --- /dev/null +++ b/docs/_static/css/main.css @@ -0,0 +1,118 @@ +@import url('https://fonts.googleapis.com/css2?family=Afacad:ital,wght@0,400..700;1,400..700&display=swap'); + +:root { + --flair-orange: #f79910; + --flair-orange-light: #ffc063; + --flair-orange-transparent: rgba(247, 153, 16, .5); + --white-transparent: rgba(255, 255, 255, .5); + --flair-blue: #2f2e41; + --white-blue: white; + --blue-white: var(--flair-blue); + --orange-white: var(--flair-orange); + --orange-white-transparent: var(--flair-orange-transparent); + --orange-blue: var(--flair-orange); + --blue-orange: var(--flair-blue); + --header-height: 90px; + --footer-height: 60px; + font-family: "Afacad", sans-serif; + font-optical-sizing: auto; + font-weight: 400; + font-size: 20px; + font-style: normal; +} + +:root[data-theme=dark] { + --white-blue: var(--flair-blue); + --blue-white: white; + --orange-white: white; + --orange-blue: var(--flair-blue); + --blue-orange: var(--flair-orange); + --orange-white-transparent: var(--white-transparent); +} + +.fill-white-blue { + fill: var(--white-blue); + stroke: none; +} + + +.fill-orange { + fill: var(--flair-orange); + stroke: none; +} + +.stroke-orange-white { + stroke: var(--orange-white); + fill: none; +} + +html, body, div, span, applet, object, iframe, +h1, h2, h3, h4, h5, h6, p, blockquote, pre, +a, abbr, acronym, address, big, cite, code, +del, dfn, em, img, ins, kbd, q, s, samp, +small, strike, strong, sub, sup, tt, var, +b, u, i, center, +dl, dt, dd, ol, ul, li, +fieldset, form, label, legend, +table, caption, tbody, tfoot, thead, tr, th, td, +article, aside, canvas, details, embed, +figure, figcaption, footer, header, hgroup, +menu, nav, output, ruby, section, summary, +time, mark, audio, video { + margin: 0; + padding: 0; + border: 0; + font: inherit; + vertical-align: baseline; + line-height: 1.2em; +} + +button { + font-family: inherit; +} + +article, aside, details, figcaption, figure, +footer, header, hgroup, menu, nav, section { + display: block; +} + +body { + line-height: 1; + background-color: var(--white-blue); +} + +main { + min-height: calc(100vh - var(--header-height) - var(--footer-height)); +} + +ol, ul { + list-style: none; +} + +blockquote, q { + quotes: none; +} + +table { + border-collapse: collapse; + border-spacing: 0; +} + +a { + color: var(--flair-orange); + text-decoration: none; + transition: color 200ms cubic-bezier(0,.35,.08,.89); + &:hover { + color: var(--flair-orange-light); + } +} + +::selection { + background: var(--flair-orange); + color: var(--white-blue); +} + +.inv-sel::selection { + background: var(--white-blue); + color: var(--flair-orange); +} \ No newline at end of file diff --git a/docs/_static/flair_logo.svg b/docs/_static/flair_logo.svg new file mode 100755 index 0000000000..e37b4f5a45 --- /dev/null +++ b/docs/_static/flair_logo.svg @@ -0,0 +1,101 @@ + + + + diff --git a/docs/conf.py b/docs/conf.py index 64624043e0..1664a3bb0f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,6 +83,14 @@ def linkcode_resolve(*args): # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +html_css_files = [ + 'css/main.css', + 'css/header.css', +] + +html_logo = "_static/flair_logo.svg" +html_show_sphinx = False + # Napoleon settings napoleon_include_init_with_doc = True napoleon_include_private_with_doc = True @@ -116,7 +124,7 @@ def linkcode_resolve(*args): smv_tag_whitelist = r"^v\d+\.\d+\.\d+$" # Whitelist pattern for branches (set to None to ignore all branches) -smv_branch_whitelist = r"^master$" +smv_branch_whitelist = r"^master|documentation$" # Whitelist pattern for remotes (set to None to use local branches only) smv_remote_whitelist = r"^origin$" diff --git a/docs/glossary/index.rst b/docs/glossary/index.rst deleted file mode 100644 index c732a1a121..0000000000 --- a/docs/glossary/index.rst +++ /dev/null @@ -1,7 +0,0 @@ -Glossary -======== - -.. glossary:: - - Sentence - a sentence is a text-unit consisting of tokens, labels and possibly metadata. Notice that a sentence is not limited in size, hence a Sentence itself could hold either a full document, a paragraph, a simple phrase or a linguistic \ No newline at end of file From 2cf6097bc3623cf91c3a5e9aa83a4ce56297ce81 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 10:58:03 +0200 Subject: [PATCH 028/333] changes --- docs/index.rst | 85 ++------------------------------------------------ 1 file changed, 2 insertions(+), 83 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 3cff769118..b7adde6b62 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,92 +4,11 @@ flair .. _flair_docs_mainpage: -**Version**: |version| - -**Useful links**: -`Getting started `_ | -`Source Repository `_ | -`Issue Tracker `_ | - -Flair is a very simple framework for state-of-the-art Natural Language Processing (NLP) - -.. grid:: 2 - - .. grid-item-card:: - :img-top: ./_static/tutorial.svg - - Tutorial - ^^^^^^^^ - - New to Flair? Check out the Tutorials. It contains an introduction to Flair's main concepts. - - +++ - - .. button-ref:: tutorial/index - :expand: - :color: secondary - :click-parent: - - To the tutorials - - .. grid-item-card:: - :img-top: ./_static/api.svg - - API-docs - ^^^^^^^^ - - The API-docs provides in-depth information on the classes and functions designed for public use. - - +++ - - .. button-ref:: api/index - :expand: - :color: secondary - :click-parent: - - To the API docs - - .. grid-item-card:: - :img-top: ./_static/contributing.svg - - Contributor's Guide - ^^^^^^^^^^^^^^^^^^^ - - Want to add to the codebase? Can help add to the - documentation? The contributing guidelines will guide you through the - process of improving Flair. - - +++ - - .. button-ref:: contributing/index - :expand: - :color: secondary - :click-parent: - - To the contributor's guide - - .. grid-item-card:: - :img-top: ./_static/glossary.svg - - Glossary - ^^^^^^^^ - - Not sure what the exact meaning of certain terms is? Find their definition in the Glossary. - - +++ - - .. button-ref:: glossary/index - :expand: - :color: secondary - :click-parent: - - To the glossary .. toctree:: :maxdepth: 3 :hidden: Tutorials - API reference - Contributing - Glossary \ No newline at end of file + API + Contributing \ No newline at end of file From 4354525c1d7b5991eaad685b048567a925d477f5 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:01:42 +0200 Subject: [PATCH 029/333] changes --- docs/_static/html/landing_page_header_styles.html | 5 +++++ docs/index.rst | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 docs/_static/html/landing_page_header_styles.html diff --git a/docs/_static/html/landing_page_header_styles.html b/docs/_static/html/landing_page_header_styles.html new file mode 100644 index 0000000000..d2fa50bbfd --- /dev/null +++ b/docs/_static/html/landing_page_header_styles.html @@ -0,0 +1,5 @@ + \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index b7adde6b62..269af35b37 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,8 @@ flair .. _flair_docs_mainpage: - +.. raw:: html + :file: html/landing_page_header_styles.html .. toctree:: :maxdepth: 3 From 097e9f6ddd3743e57ac3935489abd31019a2a99a Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:03:41 +0200 Subject: [PATCH 030/333] changes --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 269af35b37..be769eca7a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ flair .. _flair_docs_mainpage: .. raw:: html - :file: html/landing_page_header_styles.html + :file: _static/html/landing_page_header_styles.html .. toctree:: :maxdepth: 3 From 0322720d861cfbf09e30623f2bb3c8f17108a5cb Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:15:11 +0200 Subject: [PATCH 031/333] changes --- docs/_static/css/header.css | 2 +- docs/_static/flair_logo.svg | 4 ++-- docs/_static/html/landing_page_header_styles.html | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index 2c086e1a40..efb28a6436 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -1,7 +1,7 @@ .bd-header { width: 100%; height: var(--header-height); - background: var(--white-blue) !important; + background: var(--flair-orange) !important; .header-wrapper { margin: 0 calc(10% + 10px); height: 100%; diff --git a/docs/_static/flair_logo.svg b/docs/_static/flair_logo.svg index e37b4f5a45..aad189d1b8 100755 --- a/docs/_static/flair_logo.svg +++ b/docs/_static/flair_logo.svg @@ -89,12 +89,12 @@ d="m 224.21,3.67 c -0.28,0.04 -0.3,0.33 0.01,0.42 1.21,0.42 2.32,1.25 3.11,2.44 1.92,2.92 1.11,6.77 -1.88,8.18 -0.18,0.09 -0.38,-0.08 -0.32,-0.27 0.16,-0.45 0.25,-0.94 0.25,-1.45 0,-1.57 -0.85,-2.95 -2.11,-3.69 -0.18,-0.11 -0.39,0.08 -0.32,0.28 0.22,0.66 0.16,1.46 -0.23,2.18 -0.41,0.77 -0.75,1.55 -2.6,1.77 -1.61,0.22 -2.69,0.45 -4.08,1.46 -0.03,0.02 -0.05,0.04 -0.08,0.06 -0.26,0.19 -0.51,0.4 -0.75,0.61 -0.06,0.05 -0.12,0.11 -0.17,0.16 -0.21,0.2 -0.42,0.42 -0.62,0.64 -0.05,0.06 -0.11,0.12 -0.16,0.18 -0.21,0.26 -0.42,0.52 -0.61,0.8 -0.02,0.03 -0.04,0.05 -0.06,0.08 -0.21,0.31 -0.4,0.64 -0.58,0.98 -0.02,0.04 -0.04,0.09 -0.06,0.13 -0.14,0.28 -0.27,0.57 -0.39,0.87 -0.04,0.11 -0.08,0.22 -0.12,0.32 -0.09,0.24 -0.16,0.49 -0.23,0.74 -0.03,0.12 -0.07,0.25 -0.09,0.38 -0.06,0.26 -0.1,0.52 -0.14,0.78 -0.02,0.12 -0.04,0.23 -0.05,0.35 -0.04,0.38 -0.07,0.77 -0.07,1.16 0,3.95 2.21,7.37 5.43,9.02 0.05,0.03 0.12,0.06 0.19,0.1 0.17,0.08 0.34,0.17 0.52,0.24 1.82,0.79 3.83,1.23 5.94,1.23 3.49,0 6.7,-1.18 9.26,-3.17 3.55,-2.76 5.84,-7.07 5.84,-11.92 0,-8.23 -6.62,-14.92 -14.83,-15.06 z m 7.46,18.53 c -0.82,2.9 -3.19,4.93 -5.96,5.47 -0.97,0.27 -2.06,0.29 -3.16,-0.02 -0.08,-0.02 -0.15,-0.05 -0.23,-0.08 -0.07,-0.02 -0.13,-0.03 -0.2,-0.05 V 27.5 c -2.48,-0.9 -3.96,-3.24 -3.35,-5.4 0.65,-2.29 3.39,-3.51 6.13,-2.74 1.28,0.36 2.33,1.1 3.02,2.02 1.38,-0.59 2.48,-1.77 2.92,-3.32 0.16,-0.55 0.21,-1.1 0.19,-1.64 0.93,1.71 1.22,3.76 0.64,5.78 z" id="path21" inkscape:connector-curvature="0" - style="fill:#f79910" /> \ No newline at end of file From 7211f2c8945762fa7a030909a2950e36c84d5be9 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:22:27 +0200 Subject: [PATCH 032/333] changes --- docs/_static/{flair_logo.svg => flair_logo_white.svg} | 0 docs/_static/html/landing_page_header_styles.html | 6 ++++++ docs/conf.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) rename docs/_static/{flair_logo.svg => flair_logo_white.svg} (100%) diff --git a/docs/_static/flair_logo.svg b/docs/_static/flair_logo_white.svg similarity index 100% rename from docs/_static/flair_logo.svg rename to docs/_static/flair_logo_white.svg diff --git a/docs/_static/html/landing_page_header_styles.html b/docs/_static/html/landing_page_header_styles.html index e539264515..6b8883796d 100644 --- a/docs/_static/html/landing_page_header_styles.html +++ b/docs/_static/html/landing_page_header_styles.html @@ -1,5 +1,11 @@ \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 1664a3bb0f..929358d3af 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -88,7 +88,7 @@ def linkcode_resolve(*args): 'css/header.css', ] -html_logo = "_static/flair_logo.svg" +html_logo = "_static/flair_logo_white.svg" html_show_sphinx = False # Napoleon settings From 33dbba3f743966f2ed16aca8ef38ccc30674367b Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:26:59 +0200 Subject: [PATCH 033/333] changes --- docs/_static/flair_logo_orange.svg | 101 ++++++++++++++++++ .../html/landing_page_header_styles.html | 4 +- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100755 docs/_static/flair_logo_orange.svg diff --git a/docs/_static/flair_logo_orange.svg b/docs/_static/flair_logo_orange.svg new file mode 100755 index 0000000000..aad189d1b8 --- /dev/null +++ b/docs/_static/flair_logo_orange.svg @@ -0,0 +1,101 @@ + + + + diff --git a/docs/_static/html/landing_page_header_styles.html b/docs/_static/html/landing_page_header_styles.html index 6b8883796d..b9a064d87d 100644 --- a/docs/_static/html/landing_page_header_styles.html +++ b/docs/_static/html/landing_page_header_styles.html @@ -5,7 +5,9 @@ img { visibility: hidden; } - background: red; + background-image: url("../flair_logo_orange.svg"); + background-size: contain; + background-repeat: no-repeat; } } \ No newline at end of file From fb4ee203b5b5879ee30f21dd37fd125fde0fa4c3 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:29:04 +0200 Subject: [PATCH 034/333] changes --- docs/_static/flair_logo_orange.svg | 4 ++-- docs/_static/html/landing_page_header_styles.html | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/_static/flair_logo_orange.svg b/docs/_static/flair_logo_orange.svg index aad189d1b8..e37b4f5a45 100755 --- a/docs/_static/flair_logo_orange.svg +++ b/docs/_static/flair_logo_orange.svg @@ -89,12 +89,12 @@ d="m 224.21,3.67 c -0.28,0.04 -0.3,0.33 0.01,0.42 1.21,0.42 2.32,1.25 3.11,2.44 1.92,2.92 1.11,6.77 -1.88,8.18 -0.18,0.09 -0.38,-0.08 -0.32,-0.27 0.16,-0.45 0.25,-0.94 0.25,-1.45 0,-1.57 -0.85,-2.95 -2.11,-3.69 -0.18,-0.11 -0.39,0.08 -0.32,0.28 0.22,0.66 0.16,1.46 -0.23,2.18 -0.41,0.77 -0.75,1.55 -2.6,1.77 -1.61,0.22 -2.69,0.45 -4.08,1.46 -0.03,0.02 -0.05,0.04 -0.08,0.06 -0.26,0.19 -0.51,0.4 -0.75,0.61 -0.06,0.05 -0.12,0.11 -0.17,0.16 -0.21,0.2 -0.42,0.42 -0.62,0.64 -0.05,0.06 -0.11,0.12 -0.16,0.18 -0.21,0.26 -0.42,0.52 -0.61,0.8 -0.02,0.03 -0.04,0.05 -0.06,0.08 -0.21,0.31 -0.4,0.64 -0.58,0.98 -0.02,0.04 -0.04,0.09 -0.06,0.13 -0.14,0.28 -0.27,0.57 -0.39,0.87 -0.04,0.11 -0.08,0.22 -0.12,0.32 -0.09,0.24 -0.16,0.49 -0.23,0.74 -0.03,0.12 -0.07,0.25 -0.09,0.38 -0.06,0.26 -0.1,0.52 -0.14,0.78 -0.02,0.12 -0.04,0.23 -0.05,0.35 -0.04,0.38 -0.07,0.77 -0.07,1.16 0,3.95 2.21,7.37 5.43,9.02 0.05,0.03 0.12,0.06 0.19,0.1 0.17,0.08 0.34,0.17 0.52,0.24 1.82,0.79 3.83,1.23 5.94,1.23 3.49,0 6.7,-1.18 9.26,-3.17 3.55,-2.76 5.84,-7.07 5.84,-11.92 0,-8.23 -6.62,-14.92 -14.83,-15.06 z m 7.46,18.53 c -0.82,2.9 -3.19,4.93 -5.96,5.47 -0.97,0.27 -2.06,0.29 -3.16,-0.02 -0.08,-0.02 -0.15,-0.05 -0.23,-0.08 -0.07,-0.02 -0.13,-0.03 -0.2,-0.05 V 27.5 c -2.48,-0.9 -3.96,-3.24 -3.35,-5.4 0.65,-2.29 3.39,-3.51 6.13,-2.74 1.28,0.36 2.33,1.1 3.02,2.02 1.38,-0.59 2.48,-1.77 2.92,-3.32 0.16,-0.55 0.21,-1.1 0.19,-1.64 0.93,1.71 1.22,3.76 0.64,5.78 z" id="path21" inkscape:connector-curvature="0" - style="fill:#ffffff" /> \ No newline at end of file diff --git a/docs/_templates/landing-page-banner.html b/docs/_templates/landing-page-banner.html new file mode 100644 index 0000000000..8130a83d52 --- /dev/null +++ b/docs/_templates/landing-page-banner.html @@ -0,0 +1,353 @@ +
+ +
\ No newline at end of file diff --git a/docs/_templates/landing-page-illustrations.html b/docs/_templates/landing-page-illustrations.html new file mode 100644 index 0000000000..3b4069cb75 --- /dev/null +++ b/docs/_templates/landing-page-illustrations.html @@ -0,0 +1,30 @@ +
+
+ + + Easy to Use +

+ State-of-the-art NLP with just a few lines of code! Find entities, detect sentiment, and more. + Check out our demo! +

+
+
+ + + Huge Community +

+ With a community of ~200 contributors, Flair is used in hundreds of companies, + over 2,000 open source projects, and + 2,000+ papers! +

+
+
+ + + Open Source and Free +

+ Flair is completely free and open source, making it accessible for everyone to use + and report issues. +

+
+
\ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index be769eca7a..6a6d6b795f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,6 +6,12 @@ flair .. raw:: html :file: _static/html/landing_page_header_styles.html +.. raw:: html + :file: _templates/landing-page-banner.html + +.. raw:: html + :file: _templates/landing-page-illustrations.html + .. toctree:: :maxdepth: 3 :hidden: From c92f1b5e48cdad5759c8122c8c0a65a634a6e71a Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:36:33 +0200 Subject: [PATCH 036/333] changes --- docs/_static/css/landing-page-banner.css | 71 ----------- .../css/landing-page-illustrations.css | 38 ------ docs/_static/html/landing_page_styles.html | 111 +++++++++++++++++- 3 files changed, 109 insertions(+), 111 deletions(-) delete mode 100644 docs/_static/css/landing-page-banner.css delete mode 100644 docs/_static/css/landing-page-illustrations.css diff --git a/docs/_static/css/landing-page-banner.css b/docs/_static/css/landing-page-banner.css deleted file mode 100644 index ae3a78525c..0000000000 --- a/docs/_static/css/landing-page-banner.css +++ /dev/null @@ -1,71 +0,0 @@ -.landing-page { - .banner { - display: flex; - height: 510px; - background: var(--flair-orange); - .left-side { - align-items: flex-start; - display: flex; - flex-direction: column; - height: 100%; - justify-content: center; - width: 50%; - padding-left: 10%; - padding-right: 10%; - box-sizing: border-box; - svg { - height: 20%; - width: auto; - } - & > span { - font-size: 1.5rem; - margin: 1.3em 0; - line-height: 1.2em; - color: var(--white-blue); - } - #get-started { - background: var(--white-blue); - border: none; - padding: 0.2em 0.9em; - font-size: 1.5rem; - color: var(--blue-white); - font-weight: 600; - border-radius: .5em; - transition: transform 200ms cubic-bezier(0,.35,.08,.89); - &:hover { - transform: scale(1.1); - } - } - } - .right-side { - display: flex; - justify-content: flex-start; - align-items: center; - width: 50%; - .card { - margin-left: 3rem; - border-radius: 1rem; - width: 20%; - display: flex; - justify-content: center; - align-items: center; - flex-direction: column; - transition: transform 200ms cubic-bezier(0,.35,.08,.89); - background: var(--white-blue); - color: var(--blue-white); - min-height: 42%; - &:hover { - transform: scale(1.1); - } - svg { - height: 40%; - width: auto; - } - & > span { - color: var(--blue-white); - margin-top: 1em; - } - } - } - } -} \ No newline at end of file diff --git a/docs/_static/css/landing-page-illustrations.css b/docs/_static/css/landing-page-illustrations.css deleted file mode 100644 index 28c58dcb7a..0000000000 --- a/docs/_static/css/landing-page-illustrations.css +++ /dev/null @@ -1,38 +0,0 @@ -.landing-page-illustrations { - display: flex; - padding: 5rem; - background: var(--white-blue); - .item { - flex: 1 1 0; - display: flex; - flex-direction: column; - justify-content: flex-start; - align-items: flex-start; - padding: 0 5%; - &:not(:last-of-type) { - border-right: #ccc 2px solid; - } - svg { - height: 12vw; - width: auto; - & + span { - margin-top: 2.5em; - margin-bottom: 0.75em; - font-size: 1.5rem; - font-weight: 600; - } - &:nth-of-type(2) { - display: none; - :root .dark-mode & { - display: block; - } - } - :root .dark-mode & { - display: none; - } - } - span, p { - color: var(--blue-white); - } - } -} \ No newline at end of file diff --git a/docs/_static/html/landing_page_styles.html b/docs/_static/html/landing_page_styles.html index c9ca7a8e5c..2cfc618bdc 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_static/html/landing_page_styles.html @@ -1,6 +1,4 @@ \ No newline at end of file From 72fa60296d292e0cd22f0b3c7651e018be4fb6fe Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:38:40 +0200 Subject: [PATCH 037/333] changes --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 6a6d6b795f..47876aea55 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ flair .. _flair_docs_mainpage: .. raw:: html - :file: _static/html/landing_page_header_styles.html + :file: _static/html/landing_page_styles.html .. raw:: html :file: _templates/landing-page-banner.html From dbdf9d247ac6c4dea878c1aca2c238813283f55d Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 4 Aug 2024 11:44:42 +0200 Subject: [PATCH 038/333] changes --- docs/_static/html/landing_page_styles.html | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/_static/html/landing_page_styles.html b/docs/_static/html/landing_page_styles.html index 2cfc618bdc..fe99d482d7 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_static/html/landing_page_styles.html @@ -1,4 +1,10 @@ \ No newline at end of file From da3ee908aaa0cecc8ab64420f49558471014afc6 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 12:34:33 +0200 Subject: [PATCH 060/333] changes --- docs/_static/html/landing_page_styles.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_static/html/landing_page_styles.html b/docs/_static/html/landing_page_styles.html index 2a42b60e9d..35661ab101 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_static/html/landing_page_styles.html @@ -159,7 +159,7 @@ a.list-group-item { background: var(--white-blue); &:hover { - background: var(--white-transparent); + background: var(--orange-white-transparent); color: var(--orange-white); } span { From 4960ad748c8b1bac3aa6f647896038083617a950 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 12:35:45 +0200 Subject: [PATCH 061/333] changes --- docs/_static/html/landing_page_styles.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_static/html/landing_page_styles.html b/docs/_static/html/landing_page_styles.html index 35661ab101..9871541a87 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_static/html/landing_page_styles.html @@ -160,7 +160,7 @@ background: var(--white-blue); &:hover { background: var(--orange-white-transparent); - color: var(--orange-white); + color: var(--white-blue); } span { color: var(--orange-white); From ac06833edba1f4582882d3a574aee6e8c13115dd Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 12:38:07 +0200 Subject: [PATCH 062/333] changes --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 112a1aec01..dc117caf74 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,7 +27,7 @@ } # dummy value that sphinx-github-style won't crash when run in temp folder. html_theme_options = { - "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"] + "navbar_end": ["version-switcher", "navbar-icon-links"] } From ebf02556ab91cfb7b21dc8ee9921bbe065bc3557 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 12:40:58 +0200 Subject: [PATCH 063/333] changes --- docs/_templates/darkmode-toggle.html | 138 +++++++++++++++++++++++++++ docs/conf.py | 2 +- 2 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 docs/_templates/darkmode-toggle.html diff --git a/docs/_templates/darkmode-toggle.html b/docs/_templates/darkmode-toggle.html new file mode 100644 index 0000000000..9610954b69 --- /dev/null +++ b/docs/_templates/darkmode-toggle.html @@ -0,0 +1,138 @@ +
+ + + + + + + + + + + + + + + + + + + + + + + + +
+ diff --git a/docs/conf.py b/docs/conf.py index dc117caf74..f24ec214ac 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,7 +27,7 @@ } # dummy value that sphinx-github-style won't crash when run in temp folder. html_theme_options = { - "navbar_end": ["version-switcher", "navbar-icon-links"] + "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"] } From 665d495ebd06d18036a910d99365388547744549 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 12:47:38 +0200 Subject: [PATCH 064/333] changes --- docs/_static/css/main.css | 18 +++++++++++++++++- docs/_templates/darkmode-toggle.html | 5 ----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index 5c28405bbf..67b489b6ae 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -21,7 +21,7 @@ font-style: normal; } -:root[data-theme=dark] { +:root.dark-mode { --white-blue: var(--flair-blue); --blue-white: white; --orange-white: white; @@ -119,4 +119,20 @@ a { .inv-sel::selection { background: var(--white-blue); color: var(--flair-orange); +} + +.fill-white-blue { + fill: var(--white-blue); + stroke: none; +} + + +.fill-orange { + fill: var(--flair-orange); + stroke: none; +} + +.stroke-orange-white { + stroke: var(--orange-white); + fill: none; } \ No newline at end of file diff --git a/docs/_templates/darkmode-toggle.html b/docs/_templates/darkmode-toggle.html index 9610954b69..a5851fb072 100644 --- a/docs/_templates/darkmode-toggle.html +++ b/docs/_templates/darkmode-toggle.html @@ -117,11 +117,6 @@ id="circle1" style="stroke:#ffffff;stroke-width:0.50868;stroke-dasharray:none;stroke-opacity:1" /> - + From d9314fad3fb42cd33da8de9929ba4682bc8dd936 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:28:44 +0200 Subject: [PATCH 070/333] changes --- docs/conf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index f24ec214ac..8352d92de1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,7 +27,8 @@ } # dummy value that sphinx-github-style won't crash when run in temp folder. html_theme_options = { - "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"] + "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"], + "show_prev_next": False } From 711baafc0f1fc0ab3bd0dfab5eba12ac1e5560ed Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:33:43 +0200 Subject: [PATCH 071/333] changes --- docs/_static/css/header.css | 112 ++++++++++++++++++++++++++ docs/_templates/version-switcher.html | 2 +- 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index 55fc362886..16157e3360 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -3,6 +3,19 @@ height: var(--header-height); background: var(--flair-orange) !important; box-shadow: none; + .bd-header__inner { + margin: 0 10%; + padding: 0; + height: 100%; + display: flex; + justify-content: space-between; + .navbar-header-items__start { + padding: 0; + height: 55%;.bd-header { + width: 100%; + height: var(--header-height); + background: var(--flair-orange) !important; + box-shadow: none; .bd-header__inner { margin: 0 10%; padding: 0; @@ -69,6 +82,105 @@ } } } + .navbar-header-items__end { + + .navbar-item:first-of-type { + button.search-button { + background-size: contain; + background-repeat: no-repeat; + background-position: center; + border-radius: 0; + height: 1.75rem; + width: 1.75rem; + i { + visibility: hidden; + &:focus, &, &:active { + border: none; + outline: none; + } + } + } + } + .navbar-item:nth-of-type(2) { + svg { + cursor: pointer; + &:first-of-type { + :root .dark-mode & { + display: none; + } + } + &:nth-of-type(2) { + display: none; + :root .dark-mode & { + display: block; + } + } + } + } + svg { + height: 35%; + width: auto; + } + } + } +} + .navbar-item { + height: 100%; + .navbar-brand.logo { + height: 100%; + padding: 0; + display: block; + .logo__image { + height: 100%; + width: auto; + } + } + } + } + nav { + height: 100%; + ul.bd-navbar-elements { + display: flex; + height: 100%; + li.nav-item { + height: auto; + display: flex; + align-items: center; + margin-right: 0; + margin-left: 3rem; + + &.active a { + font-weight: inherit; + &::after { + width: 100%; + } + } + a { + position: relative; + font-size: 1.1rem; + padding: 0; + color: white; + &:hover { + color: white; + &::after { + width: 100%; + } + } + &::after { + transition: width 200ms cubic-bezier(0,.35,.08,.89); + content: ""; + width: 0; + height: 3px; + background-color: white; + position: absolute; + bottom: -.5rem; + left: 50%; + transform: translateX(-50%); + } + } + } + } + } .navbar-header-items__end { .navbar-item:first-of-type { diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html index 6b598954e5..a90f974b89 100644 --- a/docs/_templates/version-switcher.html +++ b/docs/_templates/version-switcher.html @@ -43,7 +43,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { - background-image: url("_static/magnifying_glass.svg"); + background-image: url("{{ pathto(_static/magnifying_glass.svg) }}"); } } } From cc658b1874264ea22830b314b175b6ffd9d691dd Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:37:03 +0200 Subject: [PATCH 072/333] changes --- docs/_templates/version-switcher.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html index a90f974b89..85f7f7efc1 100644 --- a/docs/_templates/version-switcher.html +++ b/docs/_templates/version-switcher.html @@ -43,7 +43,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { - background-image: url("{{ pathto(_static/magnifying_glass.svg) }}"); + background-image: url("{{ pathto(magnifying_glass.svg) }}"); } } } From cc1f12c52a72ded58603d8c663afd28be3e8972e Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:38:17 +0200 Subject: [PATCH 073/333] changes --- docs/_templates/version-switcher.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html index 85f7f7efc1..b033e82285 100644 --- a/docs/_templates/version-switcher.html +++ b/docs/_templates/version-switcher.html @@ -43,7 +43,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { - background-image: url("{{ pathto(magnifying_glass.svg) }}"); + background-image: url("{{ pathto(../_static/magnifying_glass.svg) }}"); } } } From c7ec7f113ac683d2261f3459926e91ed9ad498af Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:52:03 +0200 Subject: [PATCH 074/333] changes --- docs/_static/css/header.css | 113 --------------------- docs/_static/html/landing_page_styles.html | 15 ++- docs/_static/magnifying_glass.svg | 10 +- docs/_static/magnifying_glass_dark.svg | 56 ++++++++++ docs/_templates/version-switcher.html | 2 +- 5 files changed, 76 insertions(+), 120 deletions(-) create mode 100644 docs/_static/magnifying_glass_dark.svg diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index 16157e3360..b927575b82 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -3,19 +3,6 @@ height: var(--header-height); background: var(--flair-orange) !important; box-shadow: none; - .bd-header__inner { - margin: 0 10%; - padding: 0; - height: 100%; - display: flex; - justify-content: space-between; - .navbar-header-items__start { - padding: 0; - height: 55%;.bd-header { - width: 100%; - height: var(--header-height); - background: var(--flair-orange) !important; - box-shadow: none; .bd-header__inner { margin: 0 10%; padding: 0; @@ -123,104 +110,4 @@ } } } -} - .navbar-item { - height: 100%; - .navbar-brand.logo { - height: 100%; - padding: 0; - display: block; - .logo__image { - height: 100%; - width: auto; - } - } - } - } - nav { - height: 100%; - ul.bd-navbar-elements { - display: flex; - height: 100%; - li.nav-item { - height: auto; - display: flex; - align-items: center; - margin-right: 0; - margin-left: 3rem; - - &.active a { - font-weight: inherit; - &::after { - width: 100%; - } - } - a { - position: relative; - font-size: 1.1rem; - padding: 0; - color: white; - &:hover { - color: white; - &::after { - width: 100%; - } - } - &::after { - transition: width 200ms cubic-bezier(0,.35,.08,.89); - content: ""; - width: 0; - height: 3px; - background-color: white; - position: absolute; - bottom: -.5rem; - left: 50%; - transform: translateX(-50%); - } - } - } - } - } - .navbar-header-items__end { - - .navbar-item:first-of-type { - button.search-button { - background-size: contain; - background-repeat: no-repeat; - background-position: center; - border-radius: 0; - background-color: red; - height: 55%; - width: auto; - i { - visibility: hidden; - &:focus, &, &:active { - border: none; - outline: none; - } - } - } - } - .navbar-item:nth-of-type(2) { - svg { - cursor: pointer; - &:first-of-type { - :root .dark-mode & { - display: none; - } - } - &:nth-of-type(2) { - display: none; - :root .dark-mode & { - display: block; - } - } - } - } - svg { - height: 35%; - width: auto; - } - } - } } \ No newline at end of file diff --git a/docs/_static/html/landing_page_styles.html b/docs/_static/html/landing_page_styles.html index 014a148462..2815b9573a 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_static/html/landing_page_styles.html @@ -172,4 +172,17 @@ } } } - \ No newline at end of file + .bd-header { + .bd-header__inner { + .navbar-header-items__end { + .navbar-item:first-of-type { + button.search-button { + :root .dark-mode & { + background-image: url("{{ pathto('_static/magnifying_glass_dark.svg', 1) }}"); + } + } + } + } + } + } + diff --git a/docs/_static/magnifying_glass.svg b/docs/_static/magnifying_glass.svg index 5350d0e540..3048ce30c6 100644 --- a/docs/_static/magnifying_glass.svg +++ b/docs/_static/magnifying_glass.svg @@ -14,7 +14,7 @@ xmlns:svg="http://www.w3.org/2000/svg"> + style="fill:none;stroke:#f79910;stroke-opacity:1"> + style="fill:none;stroke:#f79910;stroke-opacity:1"> + style="fill:none;stroke:#f79910;stroke-opacity:1" /> + style="fill:none;stroke:#f79910;stroke-opacity:1" /> diff --git a/docs/_static/magnifying_glass_dark.svg b/docs/_static/magnifying_glass_dark.svg new file mode 100644 index 0000000000..5350d0e540 --- /dev/null +++ b/docs/_static/magnifying_glass_dark.svg @@ -0,0 +1,56 @@ + + + + + + + + + diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html index b033e82285..a908b0044e 100644 --- a/docs/_templates/version-switcher.html +++ b/docs/_templates/version-switcher.html @@ -43,7 +43,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { - background-image: url("{{ pathto(../_static/magnifying_glass.svg) }}"); + background-image: url("{{ pathto('_static/magnifying_glass.svg', 1) }}"); } } } From 413317d318dfaf41530da337a0900a3e494d5c56 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:54:32 +0200 Subject: [PATCH 075/333] changes --- docs/_static/html/landing_page_styles.html | 13 ------------- docs/index.rst | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/docs/_static/html/landing_page_styles.html b/docs/_static/html/landing_page_styles.html index 2815b9573a..0b43d7bb72 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_static/html/landing_page_styles.html @@ -172,17 +172,4 @@ } } } - .bd-header { - .bd-header__inner { - .navbar-header-items__end { - .navbar-item:first-of-type { - button.search-button { - :root .dark-mode & { - background-image: url("{{ pathto('_static/magnifying_glass_dark.svg', 1) }}"); - } - } - } - } - } - } diff --git a/docs/index.rst b/docs/index.rst index c8ddac1c9f..08241759ef 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,24 @@ .. raw:: html :file: _templates/landing-page-illustrations.html +.. raw:: html + + + .. toctree:: :maxdepth: 3 :hidden: From 7335d3dfd1ffb4d48ef9701a603838e659a1afdc Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 13:56:48 +0200 Subject: [PATCH 076/333] changes --- .../landing_page_styles.html | 13 ++++++++++++ docs/index.rst | 20 +------------------ 2 files changed, 14 insertions(+), 19 deletions(-) rename docs/{_static/html => _templates}/landing_page_styles.html (92%) diff --git a/docs/_static/html/landing_page_styles.html b/docs/_templates/landing_page_styles.html similarity index 92% rename from docs/_static/html/landing_page_styles.html rename to docs/_templates/landing_page_styles.html index 0b43d7bb72..3c9efe3ce1 100644 --- a/docs/_static/html/landing_page_styles.html +++ b/docs/_templates/landing_page_styles.html @@ -172,4 +172,17 @@ } } } + .bd-header { + .bd-header__inner { + .navbar-header-items__end { + .navbar-item:first-of-type { + button.search-button { + :root .dark-mode & { + background-image: url("{{ pathto('_static/magnifying_glass_dark.svg', 1) }}"); + } + } + } + } + } + } diff --git a/docs/index.rst b/docs/index.rst index 08241759ef..c39010a2c7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,7 @@ .. title:: Home .. raw:: html - :file: _static/html/landing_page_styles.html + :file: _templates/landing_page_styles.html .. raw:: html :file: _templates/landing-page-banner.html @@ -11,24 +11,6 @@ .. raw:: html :file: _templates/landing-page-illustrations.html -.. raw:: html - - - .. toctree:: :maxdepth: 3 :hidden: From 82c73beab26207a89b14a976f0992d61e8a19cde Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 14:01:48 +0200 Subject: [PATCH 077/333] changes --- docs/_templates/landing_page_styles.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/landing_page_styles.html b/docs/_templates/landing_page_styles.html index 3c9efe3ce1..d9bf936818 100644 --- a/docs/_templates/landing_page_styles.html +++ b/docs/_templates/landing_page_styles.html @@ -178,7 +178,7 @@ .navbar-item:first-of-type { button.search-button { :root .dark-mode & { - background-image: url("{{ pathto('_static/magnifying_glass_dark.svg', 1) }}"); + background-image: url("_static/magnifying_glass_dark.svg"); } } } From 081db63cb8a5166f4d19a1c87bfa086cfee7b92d Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 14:04:50 +0200 Subject: [PATCH 078/333] changes --- docs/_templates/version-switcher.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html index a908b0044e..9fb4bc511e 100644 --- a/docs/_templates/version-switcher.html +++ b/docs/_templates/version-switcher.html @@ -43,7 +43,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { - background-image: url("{{ pathto('_static/magnifying_glass.svg', 1) }}"); + background-image: url("{{ pathto('_static/magnifying_glass_dark.svg', 1) }}"); } } } From 0de42e8804d18ef30e32f368f39890e912a2a70a Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 14:07:23 +0200 Subject: [PATCH 079/333] changes --- docs/_templates/landing_page_styles.html | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/_templates/landing_page_styles.html b/docs/_templates/landing_page_styles.html index d9bf936818..ca4c834d8b 100644 --- a/docs/_templates/landing_page_styles.html +++ b/docs/_templates/landing_page_styles.html @@ -177,6 +177,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { + background-image: url("{{ pathto('_static/magnifying_glass.svg', 1) }}"); :root .dark-mode & { background-image: url("_static/magnifying_glass_dark.svg"); } From 056bdfffa04b01b7819e67b696c55ca193b737f4 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 14:07:39 +0200 Subject: [PATCH 080/333] changes --- docs/_templates/landing_page_styles.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/landing_page_styles.html b/docs/_templates/landing_page_styles.html index ca4c834d8b..f88a033b3f 100644 --- a/docs/_templates/landing_page_styles.html +++ b/docs/_templates/landing_page_styles.html @@ -177,7 +177,7 @@ .navbar-header-items__end { .navbar-item:first-of-type { button.search-button { - background-image: url("{{ pathto('_static/magnifying_glass.svg', 1) }}"); + background-image: url("_static/magnifying_glass.svg"); :root .dark-mode & { background-image: url("_static/magnifying_glass_dark.svg"); } From 1d5f3c73bb738965b8e5274e5963e7d28ee610a1 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 14:13:07 +0200 Subject: [PATCH 081/333] changes --- docs/_static/css/header.css | 15 ++++++++++++ docs/_static/css/main.css | 16 ------------- docs/_templates/landing_page_styles.html | 30 ++++++++++++++---------- 3 files changed, 32 insertions(+), 29 deletions(-) diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index b927575b82..6ef02d0faf 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -110,4 +110,19 @@ } } } +} + +.fill-white-blue { + fill: var(--white-blue); + stroke: none; +} + +.fill-orange { + fill: var(--flair-orange); + stroke: none; +} + +.stroke-moon { + stroke: white; + fill: none; } \ No newline at end of file diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index 1b5007b26e..1c24bfc662 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -119,20 +119,4 @@ a { .inv-sel::selection { background: var(--white-blue); color: var(--flair-orange); -} - -.fill-white-blue { - fill: var(--white-blue); - stroke: none; -} - - -.fill-orange { - fill: var(--flair-orange); - stroke: none; -} - -.stroke-orange-white { - stroke: var(--orange-white); - fill: none; } \ No newline at end of file diff --git a/docs/_templates/landing_page_styles.html b/docs/_templates/landing_page_styles.html index f88a033b3f..b9fcebe573 100644 --- a/docs/_templates/landing_page_styles.html +++ b/docs/_templates/landing_page_styles.html @@ -173,17 +173,21 @@ } } .bd-header { - .bd-header__inner { - .navbar-header-items__end { - .navbar-item:first-of-type { - button.search-button { - background-image: url("_static/magnifying_glass.svg"); - :root .dark-mode & { - background-image: url("_static/magnifying_glass_dark.svg"); - } - } - } - } - } - } + .bd-header__inner { + .navbar-header-items__end { + .navbar-item:first-of-type { + button.search-button { + background-image: url("_static/magnifying_glass.svg"); + :root .dark-mode & { + background-image: url("_static/magnifying_glass_dark.svg"); + } + } + } + } + } + } + .stroke-moon { + stroke: var(--orange-white); + fill: none; + } From 9c25250143c9a258c4da6568bbac1bd24a80cfc3 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 14:13:30 +0200 Subject: [PATCH 082/333] changes --- docs/_templates/darkmode-toggle.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/darkmode-toggle.html b/docs/_templates/darkmode-toggle.html index a5851fb072..da11101150 100644 --- a/docs/_templates/darkmode-toggle.html +++ b/docs/_templates/darkmode-toggle.html @@ -1,7 +1,7 @@
- + Date: Wed, 7 Aug 2024 14:15:47 +0200 Subject: [PATCH 083/333] changes --- docs/_static/css/header.css | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index 6ef02d0faf..3011bab05c 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -70,7 +70,9 @@ } } .navbar-header-items__end { - + .navbar-item:not(:last-of-type) { + margin-right: 2rem; + } .navbar-item:first-of-type { button.search-button { background-size: contain; From 7a74d52f5abefe8cd036d41502efc6559d72a6b2 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 15:00:37 +0200 Subject: [PATCH 084/333] changes --- docs/_static/css/footer.css | 12 ++++++++++++ docs/conf.py | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 docs/_static/css/footer.css diff --git a/docs/_static/css/footer.css b/docs/_static/css/footer.css new file mode 100644 index 0000000000..a28e39b799 --- /dev/null +++ b/docs/_static/css/footer.css @@ -0,0 +1,12 @@ +.bd-footer { + border: none; + background: var(--blue-orange); + .bd-footer__inner { + padding: 1rem 10%; + margin: 0; + box-sizing: border-box; + .footer-item * { + color: white; + } + } +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 8352d92de1..c9b02310b6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,7 +28,8 @@ html_theme_options = { "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"], - "show_prev_next": False + "show_prev_next": False, + "footer_end": [] } @@ -80,6 +81,7 @@ def linkcode_resolve(*args): html_css_files = [ 'css/main.css', 'css/header.css', + 'css/footer.css', 'css/version-switcher.css', ] From e067a437351996fd463fc242ea50f963c6fd79fa Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 15:17:31 +0200 Subject: [PATCH 085/333] changes --- docs/_templates/footer-links.html | 5 +++++ docs/_templates/landing-page-illustrations.html | 6 +++--- docs/_templates/version-switcher.html | 1 + docs/conf.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 docs/_templates/footer-links.html diff --git a/docs/_templates/footer-links.html b/docs/_templates/footer-links.html new file mode 100644 index 0000000000..074731dfb5 --- /dev/null +++ b/docs/_templates/footer-links.html @@ -0,0 +1,5 @@ + \ No newline at end of file diff --git a/docs/_templates/landing-page-illustrations.html b/docs/_templates/landing-page-illustrations.html index 3b4069cb75..7a2185a4e4 100644 --- a/docs/_templates/landing-page-illustrations.html +++ b/docs/_templates/landing-page-illustrations.html @@ -5,7 +5,7 @@ Easy to Use

State-of-the-art NLP with just a few lines of code! Find entities, detect sentiment, and more. - Check out our demo! + Check out our demo!

@@ -14,7 +14,7 @@ Huge Community

With a community of ~200 contributors, Flair is used in hundreds of companies, - over 2,000 open source projects, and + over 2,000 open source projects, and 2,000+ papers!

@@ -24,7 +24,7 @@ Open Source and Free

Flair is completely free and open source, making it accessible for everyone to use - and report issues. + and report issues.

\ No newline at end of file diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html index 9fb4bc511e..642039f919 100644 --- a/docs/_templates/version-switcher.html +++ b/docs/_templates/version-switcher.html @@ -36,6 +36,7 @@ `); + document.querySelector('footer.bd-footer') Open Source and Free

From 627a0d69bf22c30b450f245c342d30ea931aa589 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 16:09:30 +0200 Subject: [PATCH 094/333] changes --- docs/_templates/landing-page-illustrations.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/landing-page-illustrations.html b/docs/_templates/landing-page-illustrations.html index fa6dc03869..7a2185a4e4 100644 --- a/docs/_templates/landing-page-illustrations.html +++ b/docs/_templates/landing-page-illustrations.html @@ -19,7 +19,7 @@

- + Open Source and Free

From a801fc205f3847af4dcd060ccbd05cf5bba471b0 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 16:10:33 +0200 Subject: [PATCH 095/333] changes --- docs/_templates/landing-page-illustrations.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/landing-page-illustrations.html b/docs/_templates/landing-page-illustrations.html index 7a2185a4e4..b6b55230a5 100644 --- a/docs/_templates/landing-page-illustrations.html +++ b/docs/_templates/landing-page-illustrations.html @@ -20,7 +20,7 @@

- + Open Source and Free

Flair is completely free and open source, making it accessible for everyone to use From e6128a47acb886081842fd881430611b5416a38b Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 20:22:52 +0200 Subject: [PATCH 096/333] changes --- docs/_static/css/header.css | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index 7ddca94db2..8cc42304f4 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -116,6 +116,28 @@ } } +.search-button__wrapper.show { + .search-button__overlay { + background: var(--flair-blue); + } + form.bd-search { + i { + display: none; + } + input { + border: 2px var(--flair-orange) solid; + color: var(--flair-blue); + box-shadow: none; + padding: .25em 5.5em .25em .75em; + border-radius: 1rem; + font-size: 1.2rem; + } + .search-button__kbd-shortcut { + align-items: center; + } + } +} + .fill-white-blue { fill: var(--white-blue); stroke: none; From 78c22a349f9f6ee21e0c6378db3af4d63381acef Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 20:32:21 +0200 Subject: [PATCH 097/333] changes --- docs/_static/css/main.css | 4 ++++ docs/conf.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index 1c24bfc662..cb2312234c 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -119,4 +119,8 @@ a { .inv-sel::selection { background: var(--white-blue); color: var(--flair-orange); +} + +.bd-container__inner { + max-width: 100%; } \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index ee93d563c9..b607ec474c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,7 +29,8 @@ html_theme_options = { "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"], "show_prev_next": False, - "footer_end": ["footer-links/legal-notice.html", "footer-links/x.html", "footer-links/linkedin.html"] + "footer_end": ["footer-links/legal-notice.html", "footer-links/x.html", "footer-links/linkedin.html"], + "secondary_sidebar_items": [] } From 37da6d5b4ba486d0292a284e87fd7beb7da91b96 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 20:41:24 +0200 Subject: [PATCH 098/333] changes --- docs/_static/css/sidebar.css | 5 +++++ docs/conf.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 docs/_static/css/sidebar.css diff --git a/docs/_static/css/sidebar.css b/docs/_static/css/sidebar.css new file mode 100644 index 0000000000..b72a05b072 --- /dev/null +++ b/docs/_static/css/sidebar.css @@ -0,0 +1,5 @@ +.bd-sidebar-primary.bd-sidebar { + h3 { + display: none; + } +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index b607ec474c..b421b261e4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -84,6 +84,7 @@ def linkcode_resolve(*args): 'css/header.css', 'css/footer.css', 'css/version-switcher.css', + 'css/sidebar.css', ] html_logo = "_static/flair_logo_white.svg" @@ -111,7 +112,6 @@ def linkcode_resolve(*args): "**": [ "globaltoc.html", "searchbox.html", - "versioning.html", ], "index": [], } From 8e6e0d7becac2bb6e5144e5d602880c59e0ee75c Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 20:51:26 +0200 Subject: [PATCH 099/333] changes --- docs/_static/css/sidebar.css | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/_static/css/sidebar.css b/docs/_static/css/sidebar.css index b72a05b072..38eba96a5c 100644 --- a/docs/_static/css/sidebar.css +++ b/docs/_static/css/sidebar.css @@ -1,5 +1,29 @@ .bd-sidebar-primary.bd-sidebar { + padding: 2rem; + width: auto; h3 { display: none; } + * { + color: var(--flair-blue); + } + .toctree-l1 { + font-size: 1.5rem; + font-weight: 600; + margin-top: .8rem; + } + .toctree-l2 { + font-size: 1rem; + font-weight: 400; + margin-top: .8rem; + ul { + margin-left: 1rem; + } + } + .toctree-l3 { + margin-top: .8rem; + } + .current { + color: var(--flair-orange); + } } \ No newline at end of file From 86cb934433ee374d7ee63ca5c2d960520a05ab7b Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 21:34:23 +0200 Subject: [PATCH 100/333] changes --- docs/_static/css/main.css | 20 +++++++++++++ docs/_static/css/sidebar.css | 36 ++++++++++++++++++++++-- docs/_templates/landing_page_styles.html | 3 ++ docs/conf.py | 3 +- 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index cb2312234c..ee52630e51 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -123,4 +123,24 @@ a { .bd-container__inner { max-width: 100%; +} + + /* width */ +::-webkit-scrollbar { + width: 10px; +} + +/* Track */ +::-webkit-scrollbar-track { + background: #2f2e41; +} + +/* Handle */ +::-webkit-scrollbar-thumb { + background: #888; +} + +/* Handle on hover */ +::-webkit-scrollbar-thumb:hover { + background: #555; } \ No newline at end of file diff --git a/docs/_static/css/sidebar.css b/docs/_static/css/sidebar.css index 38eba96a5c..4a219be4f9 100644 --- a/docs/_static/css/sidebar.css +++ b/docs/_static/css/sidebar.css @@ -1,11 +1,35 @@ .bd-sidebar-primary.bd-sidebar { - padding: 2rem; - width: auto; + top: 90px; + width: 350px; + padding: 0; + border: none; + overflow: initial; + background: var(--white-blue); + #rtd-footer-container, .sidebar-primary-items__end { + display: none; + } + &::after { + content: ""; + background: linear-gradient(90deg, rgba(0,0,0,.2) 0%, rgba(0,0,0,0) 100%); + height: 100%; + width: 7px; + position: absolute; + right: 0; + transform: translateX(100%); + top: 0; + } h3 { display: none; } * { - color: var(--flair-blue); + color: var(--blue-white); + } + .sidebar-primary-items__start { + overflow-y: auto; + padding: 2rem; + } + .sidebar-primary-item { + padding: 0; } .toctree-l1 { font-size: 1.5rem; @@ -26,4 +50,10 @@ .current { color: var(--flair-orange); } + code { + padding: 0; + background: transparent; + border: none; + font-weight: 400; + } } \ No newline at end of file diff --git a/docs/_templates/landing_page_styles.html b/docs/_templates/landing_page_styles.html index b9fcebe573..042cca785c 100644 --- a/docs/_templates/landing_page_styles.html +++ b/docs/_templates/landing_page_styles.html @@ -190,4 +190,7 @@ stroke: var(--orange-white); fill: none; } + #searchbox { + display: none; + } diff --git a/docs/conf.py b/docs/conf.py index b421b261e4..9aa061a52b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,8 +110,7 @@ def linkcode_resolve(*args): html_sidebars = { "**": [ - "globaltoc.html", - "searchbox.html", + "globaltoc.html" ], "index": [], } From 1834e12d5bd65d06158318910aeeb748d2a920c3 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 21:38:59 +0200 Subject: [PATCH 101/333] changes --- docs/_static/css/header.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_static/css/header.css b/docs/_static/css/header.css index 8cc42304f4..4f075f5fc7 100644 --- a/docs/_static/css/header.css +++ b/docs/_static/css/header.css @@ -60,7 +60,7 @@ transition: width 200ms cubic-bezier(0,.35,.08,.89); content: ""; width: 0; - height: 3px; + height: 2px; background-color: white; position: absolute; bottom: -.5rem; From e788605c40c9a32d2cd0e0bcf2f77ea1adbaa7cd Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 21:49:49 +0200 Subject: [PATCH 102/333] changes --- docs/_static/css/main.css | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index ee52630e51..5b68d02269 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -14,6 +14,9 @@ --blue-orange: var(--flair-blue); --header-height: 90px; --footer-height: 60px; + --track-color: #dddddd; + --track-thumb: #bbbbbb; + --track-thumb-hover: #999999; font-family: "Afacad", sans-serif; font-optical-sizing: auto; font-weight: 400; @@ -28,6 +31,9 @@ --orange-blue: var(--flair-blue); --blue-orange: var(--flair-orange); --orange-white-transparent: var(--white-transparent); + --track-color: #38374a; + --track-thumb: #4c4b5c; + --track-thumb-hover: #5c5b69; } .fill-white-blue { @@ -125,22 +131,18 @@ a { max-width: 100%; } - /* width */ ::-webkit-scrollbar { width: 10px; } -/* Track */ ::-webkit-scrollbar-track { - background: #2f2e41; + background: var(--track-color); } -/* Handle */ ::-webkit-scrollbar-thumb { - background: #888; + background: var(--track-thumb); } -/* Handle on hover */ ::-webkit-scrollbar-thumb:hover { - background: #555; + background: var(--track-thumb-hover); } \ No newline at end of file From 509f6847190c8afb6b72abe4345e69656cb7ed59 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 21:56:10 +0200 Subject: [PATCH 103/333] changes --- docs/_static/css/main.css | 17 +++++++++-------- docs/_static/css/sidebar.css | 3 +++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index 5b68d02269..8968068092 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -131,18 +131,19 @@ a { max-width: 100%; } -::-webkit-scrollbar { - width: 10px; +body::-webkit-scrollbar, ::-webkit-scrollbar { + width: 10px; } -::-webkit-scrollbar-track { - background: var(--track-color); +body::-webkit-scrollbar-track, ::-webkit-scrollbar-track { + background: var(--track-color); } -::-webkit-scrollbar-thumb { - background: var(--track-thumb); +body::-webkit-scrollbar-thumb, ::-webkit-scrollbar-thumb { + background: var(--track-thumb); + border-radius: 0; } -::-webkit-scrollbar-thumb:hover { - background: var(--track-thumb-hover); +body::-webkit-scrollbar-thumb:hover, ::-webkit-scrollbar-thumb:hover { + background: var(--track-thumb-hover); } \ No newline at end of file diff --git a/docs/_static/css/sidebar.css b/docs/_static/css/sidebar.css index 4a219be4f9..749122db8c 100644 --- a/docs/_static/css/sidebar.css +++ b/docs/_static/css/sidebar.css @@ -24,6 +24,9 @@ * { color: var(--blue-white); } + a:hover { + color: var(--flair-orange); + } .sidebar-primary-items__start { overflow-y: auto; padding: 2rem; From 694cbba3a6d21affc800a8c7a5ace58bd9b113ae Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 7 Aug 2024 21:58:47 +0200 Subject: [PATCH 104/333] changes --- docs/_static/css/sidebar.css | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/_static/css/sidebar.css b/docs/_static/css/sidebar.css index 749122db8c..598be603cd 100644 --- a/docs/_static/css/sidebar.css +++ b/docs/_static/css/sidebar.css @@ -26,6 +26,12 @@ } a:hover { color: var(--flair-orange); + * { + color: var(--flair-orange); + } + } + a * { + transition: color 200ms cubic-bezier(0,.35,.08,.89); } .sidebar-primary-items__start { overflow-y: auto; From 6bbc2bbfc23abda65e347bc6685786b5aaa4edca Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 09:04:48 +0200 Subject: [PATCH 105/333] changes --- docs/_static/css/tutorial.css | 33 +++++++++++++++++++++++++++++++++ docs/conf.py | 1 + 2 files changed, 34 insertions(+) create mode 100644 docs/_static/css/tutorial.css diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css new file mode 100644 index 0000000000..16b47118d1 --- /dev/null +++ b/docs/_static/css/tutorial.css @@ -0,0 +1,33 @@ +ul.bd-breadcrumbs { + li.breadcrumb-home .fa-home { + &::before { + content: "Home"; + font-family: "Afacad", sans-serif; + font-size: 1rem; + font-weight: 400; + color: var(--blue-white); + } + } + li.breadcrumb-item:not(.breadcrumb-home) { + &::before { + content: ""; + width: 1px; + height: 1rem; + background: white; + display: inline-block; + padding: 0; + transform: rotate(30deg); + margin: 0 .6rem; + } + a { + color: var(--blue-white); + font-size: 1rem; + font-weight: 400; + } + &.active { + color: var(--blue-white); + font-size: 1rem; + font-weight: 400; + } + } +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 9aa061a52b..f5e6d6c6a1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,6 +85,7 @@ def linkcode_resolve(*args): 'css/footer.css', 'css/version-switcher.css', 'css/sidebar.css', + 'css/tutorial.css' ] html_logo = "_static/flair_logo_white.svg" From 60e2b72f76085b2c3a6f1d70d264603361e94bd0 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 09:13:27 +0200 Subject: [PATCH 106/333] changes --- docs/_static/css/tutorial.css | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 16b47118d1..d0a0462e48 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -1,3 +1,21 @@ +.bd-header-article { + margin-bottom: 2rem; +} + +.bd-main .bd-content .bd-article-container{ + max-width: initial; + padding: 4rem 5rem; + box-sizing: border-box; +} + +.bd-main .bd-content .bd-article-container .bd-article { + padding: 0; +} + +.header-article-items.header-article__inner { + padding: 0; +} + ul.bd-breadcrumbs { li.breadcrumb-home .fa-home { &::before { From a5b859863d6e65f78753ac496c1823ec79d1c8ec Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 09:21:51 +0200 Subject: [PATCH 107/333] changes --- docs/_static/css/main.css | 2 +- docs/_static/css/tutorial.css | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index 8968068092..1778597b8b 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -20,7 +20,7 @@ font-family: "Afacad", sans-serif; font-optical-sizing: auto; font-weight: 400; - font-size: 20px; + font-size: 21px; font-style: normal; } diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index d0a0462e48..9d700f7fd9 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -1,7 +1,3 @@ -.bd-header-article { - margin-bottom: 2rem; -} - .bd-main .bd-content .bd-article-container{ max-width: initial; padding: 4rem 5rem; @@ -10,6 +6,22 @@ .bd-main .bd-content .bd-article-container .bd-article { padding: 0; + h1 { + font-size: 3rem; + color: var(--blue-white); + font-weight: 600; + margin-top: 0.5em; + } + h2 { + font-size: 1.6rem; + color: var(--blue-white); + font-weight: 600; + margin-top: 2em; + } + p { + color: var(--blue-white); + margin: 0.5em 0 1em 0; + } } .header-article-items.header-article__inner { From bd558be123d38316c2686738034ed720f095c49d Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 09:31:04 +0200 Subject: [PATCH 108/333] changes --- docs/_static/css/main.css | 4 ++++ docs/_static/css/tutorial.css | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index 1778597b8b..bf73dbb939 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -6,12 +6,15 @@ --flair-orange-transparent: rgba(247, 153, 16, .5); --white-transparent: rgba(255, 255, 255, .5); --flair-blue: #2f2e41; + --flair-dark-blue: #262635; + --light-gray: #e6e6e6; --white-blue: white; --blue-white: var(--flair-blue); --orange-white: var(--flair-orange); --orange-white-transparent: var(--flair-orange-transparent); --orange-blue: var(--flair-orange); --blue-orange: var(--flair-blue); + --gray-dark-blue: var(--light-gray); --header-height: 90px; --footer-height: 60px; --track-color: #dddddd; @@ -31,6 +34,7 @@ --orange-blue: var(--flair-blue); --blue-orange: var(--flair-orange); --orange-white-transparent: var(--white-transparent); + --gray-dark-blue: var(--flair-dark-blue); --track-color: #38374a; --track-thumb: #4c4b5c; --track-thumb-hover: #5c5b69; diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 9d700f7fd9..a09ba9974a 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -1,6 +1,6 @@ .bd-main .bd-content .bd-article-container{ max-width: initial; - padding: 4rem 5rem; + padding: 4rem 30% 4rem 5rem; box-sizing: border-box; } @@ -20,7 +20,18 @@ } p { color: var(--blue-white); - margin: 0.5em 0 1em 0; + margin: 0.4em 0 0.9em 0; + } + .highlight-python { + .highlight { + background: transparent; + pre { + background: var(--gray-dark-blue); + .nn, .n { + color: var(--blue-white); + } + } + } } } From 6d49182704225eea5185f2cda4ee55fa50d9a412 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 09:42:52 +0200 Subject: [PATCH 109/333] changes --- docs/_static/css/tutorial.css | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index a09ba9974a..fde1236f2a 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -26,10 +26,18 @@ .highlight { background: transparent; pre { + padding: 1.1em 1.5em; + border-radius: 1rem; background: var(--gray-dark-blue); .nn, .n { color: var(--blue-white); } + .c1 { + color: #999; + :root .dark-mode & { + color: #696991; + } + } } } } From 7373488224b51980c85729a987c5feadfb615009 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 10:07:59 +0200 Subject: [PATCH 110/333] changes --- docs/_static/css/main.css | 2 +- docs/_static/css/tutorial.css | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index bf73dbb939..e2b94342c0 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -1,4 +1,4 @@ -@import url('https://fonts.googleapis.com/css2?family=Afacad:ital,wght@0,400..700;1,400..700&display=swap'); +@import url('https://fonts.googleapis.com/css2?family=Afacad:ital,wght@0,400..700;1,400..700&family=Source+Code+Pro:ital,wght@0,200..900;1,200..900&display=swap'); :root { --flair-orange: #f79910; diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index fde1236f2a..56241ee3f5 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -26,17 +26,30 @@ .highlight { background: transparent; pre { + font-family: "Source Code Pro", monospace; padding: 1.1em 1.5em; border-radius: 1rem; background: var(--gray-dark-blue); - .nn, .n { + .nn, .n, .o, .p { color: var(--blue-white); } + .kn { + color: #CF8E6D; + } + .kp { + color: #57AAF7; + } + .nb { + color: #8888C6; + } .c1 { - color: #999; - :root .dark-mode & { + color: #5F826B; + /*:root .dark-mode & { color: #696991; - } + }*/ + } + .s1 { + color : #6AAB73; } } } From ac2ec964c71eb53d976648b23be289352a5d6394 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 10:37:24 +0200 Subject: [PATCH 111/333] changes --- docs/_static/css/tutorial.css | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 56241ee3f5..c03fbf9758 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -54,6 +54,52 @@ } } } + .highlight-console { + .highlight { + background: transparent; + } + pre { + font-family: "Source Code Pro", monospace; + border-radius: 1rem; + padding: 1.1em 1.5em; + background: var(--gray-dark-blue); + color: var(--blue-white); + .go { + font-size: inherit; + padding-right: 1.1em; + } + } + } + .admonition, div.admonition, .admonition.note, div.admonition.note { + box-shadow: none !important; + background: linear-gradient(180deg, var(--flair-orange) 0, var(--flair-orange) 2rem, var(--white-blue) 2rem, var(--white-blue) 100%); + border: 2px var(--flair-orange) solid; + border-radius: 1rem; + padding: 0; + box-sizing: border-box; + p.admonition-title { + background: var(--flair-orange); + border-radius: 0 0 1rem 1rem; + margin: 0; + font-size: 1.5rem; + font-weight: 400; + padding: 0.5em 2.2em; + position: relative; + line-height: 1em; + &::before { + content: none; + } + &::after { + color: white; + margin-left: 0.4em; + font-size: 0.9em; + } + } + p:not(.admonition-title) { + padding: 0.5em 2.2em 0.25em 3.3em; + line-height: 1em; + } + } } .header-article-items.header-article__inner { From 730cb18104394355c6e8bf2cafd24bca530af2fe Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 10:55:27 +0200 Subject: [PATCH 112/333] changes --- docs/_static/css/tutorial.css | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index c03fbf9758..3c4f5d53a8 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -138,4 +138,34 @@ ul.bd-breadcrumbs { font-weight: 400; } } +} + +table.table { + tr { + &:not(:last-of-type), &:has(th) { + border-bottom: 2px var(--blue-white) solid; + } + th { + font-weight: 600; + } + th, td { + padding: .5em .2em .15em .2em; + &:first-of-type { + padding-left: 0; + } + &:last-of-type { + padding-right: 0; + } + } + } +} + +.reference .docutils { + border: 2px var(--flair-orange) solid; + background: transparent; + border-radius: 0.5em; + font-weight: 400; + font-family: "Source Code Pro", monospace; + font-size: 0.8em; + color: var(--blue-white); } \ No newline at end of file From 4010213ada6c24f5d529d7a184f9b63daa583b9a Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 11:10:54 +0200 Subject: [PATCH 113/333] changes --- docs/_static/css/tutorial.css | 46 ++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 3c4f5d53a8..dd6de62e0b 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -18,6 +18,12 @@ font-weight: 600; margin-top: 2em; } + h3 { + font-size: 1.2rem; + color: var(--blue-white); + font-weight: 600; + margin-top: 2em; + } p { color: var(--blue-white); margin: 0.4em 0 0.9em 0; @@ -30,31 +36,31 @@ padding: 1.1em 1.5em; border-radius: 1rem; background: var(--gray-dark-blue); - .nn, .n, .o, .p { + .nn, .n, .o, .p, .kp { color: var(--blue-white); } - .kn { + .kn, .k, .ow { color: #CF8E6D; } - .kp { - color: #57AAF7; - } .nb { color: #8888C6; } .c1 { - color: #5F826B; - /*:root .dark-mode & { - color: #696991; - }*/ + color: #7A7E85; } - .s1 { + .s1, .sa, .si { color : #6AAB73; } + .mi { + color: #2AACB8; + } } } } - .highlight-console { + .highlight-console, .highlight-default { + * { + color: var(--blue-white); + } .highlight { background: transparent; } @@ -97,7 +103,9 @@ } p:not(.admonition-title) { padding: 0.5em 2.2em 0.25em 3.3em; - line-height: 1em; + } + .highlight-python, .highlight-default, .highlight-console, .highlight { + margin-bottom: 0.7rem; } } } @@ -160,12 +168,22 @@ table.table { } } -.reference .docutils { - border: 2px var(--flair-orange) solid; +.docutils, .docutils.literal { + border: 2px var(--blue-white) solid; background: transparent; border-radius: 0.5em; font-weight: 400; font-family: "Source Code Pro", monospace; font-size: 0.8em; color: var(--blue-white); +} +.reference { + .docutils, .docutils.literal { + border: 2px var(--flair-orange) solid; + transition: color, background-color 200ms ease-in-out; + &:hover { + background: var(--flair-orange); + color: white; + } + } } \ No newline at end of file From b0014ebb394c0dea507d2a71ea96a44c8c7b3d65 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 11:21:40 +0200 Subject: [PATCH 114/333] changes --- docs/_static/css/tutorial.css | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index dd6de62e0b..bc07a9001b 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -148,7 +148,9 @@ ul.bd-breadcrumbs { } } -table.table { +table.table {border-collapse: collapse; + border-radius: 1rem; + box-shadow: 0 0 0 2px var(--blue-white); tr { &:not(:last-of-type), &:has(th) { border-bottom: 2px var(--blue-white) solid; @@ -157,12 +159,12 @@ table.table { font-weight: 600; } th, td { - padding: .5em .2em .15em .2em; + padding: .4em .2em 0 .2em; &:first-of-type { - padding-left: 0; + padding-left: .5em; } &:last-of-type { - padding-right: 0; + padding-right: .5em; } } } From 4a76a8be036b442714f2f57a256fa30dfe770cb7 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 11:30:18 +0200 Subject: [PATCH 115/333] changes --- docs/_static/css/tutorial.css | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index bc07a9001b..728e2b7363 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -48,7 +48,7 @@ .c1 { color: #7A7E85; } - .s1, .sa, .si { + .s1, .s2, .sa, .si { color : #6AAB73; } .mi { @@ -171,21 +171,40 @@ table.table {border-collapse: collapse; } .docutils, .docutils.literal { - border: 2px var(--blue-white) solid; + border: none; background: transparent; border-radius: 0.5em; font-weight: 400; font-family: "Source Code Pro", monospace; font-size: 0.8em; color: var(--blue-white); + position: relative; + z-index: 1; + &::before { + content: ""; + width: 100%; + height: 90%; + top: 50%; + left: 0; + transform: translateY(-50%); + border-radius: 0.5em; + border: 2px var(--blue-white) solid; + position: absolute; + } } .reference { .docutils, .docutils.literal { - border: 2px var(--flair-orange) solid; - transition: color, background-color 200ms ease-in-out; + transition: color 200ms ease-in-out; &:hover { - background: var(--flair-orange); + &::before { + background: var(--flair-orange); + } color: white; } + &::before { + z-index: -1; + border-color: var(--flair-orange); + transition: background-color 200ms ease-in-out; + } } } \ No newline at end of file From 9fe360a8780cdd2eb8712e0306fbba1b83b840ec Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 11:43:14 +0200 Subject: [PATCH 116/333] changes --- docs/_static/css/tutorial.css | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 728e2b7363..26c9ad3d2d 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -6,6 +6,9 @@ .bd-main .bd-content .bd-article-container .bd-article { padding: 0; + .headerlink { + display: none; + } h1 { font-size: 3rem; color: var(--blue-white); @@ -28,6 +31,13 @@ color: var(--blue-white); margin: 0.4em 0 0.9em 0; } + ul { + list-style: initial; + margin-left: 1rem; + ::marker { + color: var(--blue-white); + } + } .highlight-python { .highlight { background: transparent; @@ -51,7 +61,7 @@ .s1, .s2, .sa, .si { color : #6AAB73; } - .mi { + .mi, .mf { color: #2AACB8; } } @@ -122,14 +132,18 @@ ul.bd-breadcrumbs { font-size: 1rem; font-weight: 400; color: var(--blue-white); + transition: color 200ms ease-in-out; } } + li.breadcrumb-home a:hover .fa-home::before { + color: var(--flair-orange); + } li.breadcrumb-item:not(.breadcrumb-home) { &::before { content: ""; width: 1px; height: 1rem; - background: white; + background: var(--blue-white); display: inline-block; padding: 0; transform: rotate(30deg); @@ -139,6 +153,9 @@ ul.bd-breadcrumbs { color: var(--blue-white); font-size: 1rem; font-weight: 400; + &:hover { + color: var(--flair-orange); + } } &.active { color: var(--blue-white); From 2a25cadddd45c5fb8f99e4eee84c5325f366f2de Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 11:49:30 +0200 Subject: [PATCH 117/333] changes --- docs/_static/css/tutorial.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 26c9ad3d2d..cf66b9f027 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -67,7 +67,7 @@ } } } - .highlight-console, .highlight-default { + [class^="highlight-"]:not(.highlight-python) { * { color: var(--blue-white); } From f81ef2ce0489fbc0725a7d8e22f32e066c4ef81e Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 12:05:52 +0200 Subject: [PATCH 118/333] changes --- docs/_static/css/tutorial.css | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index cf66b9f027..cd2dc7e9ba 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -49,12 +49,15 @@ .nn, .n, .o, .p, .kp { color: var(--blue-white); } - .kn, .k, .ow { + .kn, .k, .ow, .kc { color: #CF8E6D; } .nb { color: #8888C6; } + .nf { + color: #56A8F5; + } .c1 { color: #7A7E85; } @@ -224,4 +227,14 @@ table.table {border-collapse: collapse; transition: background-color 200ms ease-in-out; } } +} + +#flair-tutorials, [id^="tutorial"] { + li { + list-style-type: "»"; + padding-left: 0.25em; + ::marker { + color: var(--blue-white); + } + } } \ No newline at end of file From cedca6d981f11540fc62de1099d09484d9d17dcb Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 12:22:26 +0200 Subject: [PATCH 119/333] changes --- docs/_static/css/api.css | 0 docs/conf.py | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 docs/_static/css/api.css diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/conf.py b/docs/conf.py index f5e6d6c6a1..b1238878bc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,7 +85,8 @@ def linkcode_resolve(*args): 'css/footer.css', 'css/version-switcher.css', 'css/sidebar.css', - 'css/tutorial.css' + 'css/tutorial.css', + 'css/api.css' ] html_logo = "_static/flair_logo_white.svg" From 3feaa21e12f1e29ef8224cde604170423d7f418f Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 12:49:54 +0200 Subject: [PATCH 120/333] changes --- docs/_static/css/api.css | 12 ++++++++++++ docs/_static/css/tutorial.css | 27 ++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index e69de29bb2..b25b39cd2f 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -0,0 +1,12 @@ +.sidebar-primary-item { + .docutils { + &::before { + content: none; + } + .pre { + font-family: "Afacad", sans-serif; + font-size: 1rem; + text-overflow: ellipsis; + } + } +} \ No newline at end of file diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index cd2dc7e9ba..0c48a3a79f 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -229,12 +229,37 @@ table.table {border-collapse: collapse; } } -#flair-tutorials, [id^="tutorial"] { +#flair-tutorials, #api-docs, [id^="tutorial"], [id^="flair"] { li { list-style-type: "»"; padding-left: 0.25em; + margin-top: .5em; ::marker { color: var(--blue-white); + font-size: 1.2rem; + } + } + .toctree-wrapper li[class^="toctree-l"] { + margin-top: .5em; + &::marker { + font-size: 1.2rem; + } + & > a { + font-size: 1.2rem; + } + } + .docutils { + color: var(--flair-orange); + &:hover { + color: var(--flair-orange-light); + } + &::before { + content: none; + } + .pre { + font-family: "Afacad", sans-serif; + text-overflow: ellipsis; + font-size: 1.2rem; } } } \ No newline at end of file From 59024414f238a661865be8c0452191ce00ac07f6 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 12:54:29 +0200 Subject: [PATCH 121/333] changes --- docs/_static/css/tutorial.css | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 0c48a3a79f..fa960786e2 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -32,10 +32,14 @@ margin: 0.4em 0 0.9em 0; } ul { - list-style: initial; - margin-left: 1rem; + list-style-type: "»"; + margin-left: .6em; + li { + padding-left: .3em; + } ::marker { color: var(--blue-white); + font-size: inherit; } } .highlight-python { From ec80d59c49202750d8cf9c6bc802e6fba5f37f41 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 13:26:50 +0200 Subject: [PATCH 122/333] changes --- docs/_static/css/api.css | 13 +++++++++++++ docs/_static/octocat.svg | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 docs/_static/octocat.svg diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index b25b39cd2f..39d05de307 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -9,4 +9,17 @@ text-overflow: ellipsis; } } +} + +.sig { + color: var(--blue-white); + font-family: "Source Code Pro", monospace; + .linkcode-link { + .pre { + display: none; + } + &::after { + content: url("_static/octocat.svg"); + } + } } \ No newline at end of file diff --git a/docs/_static/octocat.svg b/docs/_static/octocat.svg new file mode 100644 index 0000000000..ce2cf1d99d --- /dev/null +++ b/docs/_static/octocat.svg @@ -0,0 +1,41 @@ + + + + + + From 5491a45ba2050a383385f402badb33bc03f8c1fc Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 13:28:59 +0200 Subject: [PATCH 123/333] changes --- docs/_static/css/api.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index 39d05de307..378063ee37 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -19,7 +19,7 @@ display: none; } &::after { - content: url("_static/octocat.svg"); + content: url("../../_static/octocat.svg"); } } } \ No newline at end of file From 00f4e1bf1fca567d0875fba09a124a27cd2cc2a0 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 13:58:02 +0200 Subject: [PATCH 124/333] changes --- docs/_static/css/api.css | 46 ++++++++++++++++++++++++++++++++++++++- docs/_static/css/main.css | 2 ++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index 378063ee37..dc3f3a5435 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -15,11 +15,55 @@ color: var(--blue-white); font-family: "Source Code Pro", monospace; .linkcode-link { + padding: 0; + display: inline-block; + width: 1.5rem; + height: 1.5rem; + transform: translateY(25%); .pre { display: none; } &::after { - content: url("../../_static/octocat.svg"); + content:""; + background-image: url('../../_static/octocat.svg'); + background-size: cover; + display: inline-block; + width: 100%; + height: 100%; + } + } +} + +.sig-name.descname { + color: var(--flair-orange); +} + +.class, .method { + .sig { + margin-top: 2rem; + } + dd { + margin-left: 0 !important; + padding-left: 3rem !important; + border-left: 2px var(--gray-white) solid; + dl.field-list { + background-color: transparent; + border: 2px var(--flair-orange) solid; + border-radius: 1rem; + padding: 1rem 1.5rem; + dt { + background-color: transparent !important; + color: var(--blue-white); + font-weight: 600; + font-size: 1.5rem; + padding: 0 !important; + margin: 0 !important; + } + dd { + margin: 0 !important; + padding: 0 !important; + border-left: none; + } } } } \ No newline at end of file diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index e2b94342c0..d3c6242c41 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -15,6 +15,7 @@ --orange-blue: var(--flair-orange); --blue-orange: var(--flair-blue); --gray-dark-blue: var(--light-gray); + --gray-white: var(--light-gray); --header-height: 90px; --footer-height: 60px; --track-color: #dddddd; @@ -35,6 +36,7 @@ --blue-orange: var(--flair-orange); --orange-white-transparent: var(--white-transparent); --gray-dark-blue: var(--flair-dark-blue); + --gray-white: white; --track-color: #38374a; --track-thumb: #4c4b5c; --track-thumb-hover: #5c5b69; From f4805474da2dbb741e768405913a0fead318f316 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 14:06:00 +0200 Subject: [PATCH 125/333] changes --- docs/_static/css/api.css | 3 ++- docs/_static/css/tutorial.css | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index dc3f3a5435..0a9bba2b67 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -38,9 +38,10 @@ color: var(--flair-orange); } -.class, .method { +.class, .method, .function { .sig { margin-top: 2rem; + margin-bottom: 1rem; } dd { margin-left: 0 !important; diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index fa960786e2..c9ec7d4f64 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -252,7 +252,7 @@ table.table {border-collapse: collapse; font-size: 1.2rem; } } - .docutils { + .docutils:not(.py) { color: var(--flair-orange); &:hover { color: var(--flair-orange-light); From 55fcb22a311dc7c5282c986bf51710eada5f3340 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 14:13:47 +0200 Subject: [PATCH 126/333] changes --- docs/_static/css/api.css | 35 +++++++++++++++++++++++++++++++++++ docs/_static/css/main.css | 1 + docs/_static/css/tutorial.css | 2 +- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index 0a9bba2b67..e0640a06ef 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -67,4 +67,39 @@ } } } +} + +div.deprecated { + box-shadow: none !important; + background: linear-gradient(180deg, var(--error-red) 0, var(--error-red) 2rem, var(--white-blue) 2rem, var(--white-blue) 100%); + border: 2px var(--error-red) solid; + border-radius: 1rem; + padding: 0; + box-sizing: border-box; + p { + &::before { + content: none; + } + } + /* p { + background: var(--flair-orange); + border-radius: 0 0 1rem 1rem; + margin: 0; + font-size: 1.5rem; + font-weight: 400; + padding: 0.5em 2.2em; + position: relative; + line-height: 1em; + &::before { + content: none; + } + &::after { + color: white; + margin-left: 0.4em; + font-size: 0.9em; + } + } + p:not(.admonition-title) { + padding: 0.5em 2.2em 0.25em 3.3em; + }*/ } \ No newline at end of file diff --git a/docs/_static/css/main.css b/docs/_static/css/main.css index d3c6242c41..12a8a7abcb 100644 --- a/docs/_static/css/main.css +++ b/docs/_static/css/main.css @@ -16,6 +16,7 @@ --blue-orange: var(--flair-blue); --gray-dark-blue: var(--light-gray); --gray-white: var(--light-gray); + --error-red: #e32239; --header-height: 90px; --footer-height: 60px; --track-color: #dddddd; diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index c9ec7d4f64..58880393a8 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -121,7 +121,7 @@ p:not(.admonition-title) { padding: 0.5em 2.2em 0.25em 3.3em; } - .highlight-python, .highlight-default, .highlight-console, .highlight { + [class^="highlight-"] { margin-bottom: 0.7rem; } } From 719c740c33b58476b628ca41a4f888c31891f1c1 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 14:25:40 +0200 Subject: [PATCH 127/333] changes --- docs/_static/css/api.css | 44 +++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index e0640a06ef..16c6f2777b 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -41,7 +41,12 @@ .class, .method, .function { .sig { margin-top: 2rem; - margin-bottom: 1rem; + margin-bottom: .5em; + padding-bottom: .5em; + } + dt:target { + background: var(--flair-orange-transparent); + border-radius: 10px; } dd { margin-left: 0 !important; @@ -71,35 +76,28 @@ div.deprecated { box-shadow: none !important; - background: linear-gradient(180deg, var(--error-red) 0, var(--error-red) 2rem, var(--white-blue) 2rem, var(--white-blue) 100%); + background: var(--error-red); border: 2px var(--error-red) solid; border-radius: 1rem; - padding: 0; + padding: 0.5em 2.2em 0 2.2em; box-sizing: border-box; p { - &::before { - content: none; + span.versionmodified.deprecated::before { + color: white; } - } - /* p { - background: var(--flair-orange); - border-radius: 0 0 1rem 1rem; - margin: 0; - font-size: 1.5rem; - font-weight: 400; - padding: 0.5em 2.2em; - position: relative; - line-height: 1em; &::before { content: none; } - &::after { - color: white; - margin-left: 0.4em; - font-size: 0.9em; - } } - p:not(.admonition-title) { - padding: 0.5em 2.2em 0.25em 3.3em; - }*/ +} + +a.github::before { + content: ""; + height: 1em; + width: 1em; + display: inline-block; + background-image: url("../../_static/octocat.svg"); + background-size: contain; + background-position: center; + transform: translateY(25%); } \ No newline at end of file From 34d76c86492461749f49faa572da8b9a7b6570fc Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 14:42:42 +0200 Subject: [PATCH 128/333] changes --- docs/_static/css/api.css | 24 +++++++++++++----------- docs/_static/css/tutorial.css | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index 16c6f2777b..fad83624d7 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -17,9 +17,9 @@ .linkcode-link { padding: 0; display: inline-block; - width: 1.5rem; - height: 1.5rem; - transform: translateY(25%); + width: 1rem; + height: 1rem; + transform: translateY(10%) scale(1.5); .pre { display: none; } @@ -38,16 +38,18 @@ color: var(--flair-orange); } -.class, .method, .function { - .sig { - margin-top: 2rem; - margin-bottom: .5em; - padding-bottom: .5em; - } - dt:target { +.sig { + margin-top: 2rem; + margin-bottom: .5em; + padding-bottom: .5em; + padding-top: .5em; + &:target { background: var(--flair-orange-transparent); - border-radius: 10px; + border-radius: .5rem; } +} + +.class, .method, .function, .data { dd { margin-left: 0 !important; padding-left: 3rem !important; diff --git a/docs/_static/css/tutorial.css b/docs/_static/css/tutorial.css index 58880393a8..8b1e0d72e9 100644 --- a/docs/_static/css/tutorial.css +++ b/docs/_static/css/tutorial.css @@ -233,7 +233,7 @@ table.table {border-collapse: collapse; } } -#flair-tutorials, #api-docs, [id^="tutorial"], [id^="flair"] { +#flair-tutorials, #api-docs, #contributing, [id^="tutorial"], [id^="flair"] { li { list-style-type: "»"; padding-left: 0.25em; From fabfec107979e6f37f8f033eac9116519ccf1b1b Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 15:08:49 +0200 Subject: [PATCH 129/333] changes --- docs/_static/css/api.css | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index fad83624d7..1dd843d54c 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -14,6 +14,7 @@ .sig { color: var(--blue-white); font-family: "Source Code Pro", monospace; + line-height: 2em; .linkcode-link { padding: 0; display: inline-block; @@ -59,6 +60,9 @@ border: 2px var(--flair-orange) solid; border-radius: 1rem; padding: 1rem 1.5rem; + ul { + padding-left: 0 !important; + } dt { background-color: transparent !important; color: var(--blue-white); @@ -74,6 +78,12 @@ } } } + .n { + color: #2AACB8; + } + .default_value { + color: #8888C6; + } } div.deprecated { @@ -102,4 +112,18 @@ a.github::before { background-size: contain; background-position: center; transform: translateY(25%); +} + +blockquote { + font-family: "Source Code Pro", monospace; + border-radius: 1rem; + padding: 1.1em 1.5em; + background: var(--gray-dark-blue); + color: var(--blue-white); + &::before { + content: none; + } + p { + margin: 0 !important; + } } \ No newline at end of file From 4fbdcd272c3e19f9fdf37c7519efa67aeb2f34b6 Mon Sep 17 00:00:00 2001 From: konstantin Date: Wed, 14 Aug 2024 15:12:43 +0200 Subject: [PATCH 130/333] changes --- docs/_static/css/api.css | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/_static/css/api.css b/docs/_static/css/api.css index 1dd843d54c..5fd8ef8c89 100644 --- a/docs/_static/css/api.css +++ b/docs/_static/css/api.css @@ -45,9 +45,12 @@ padding-bottom: .5em; padding-top: .5em; &:target { - background: var(--flair-orange-transparent); + background: rgba(47, 46, 65, 0.15); border-radius: .5rem; } + :root .dark-mode &:target { + background: rgba(255, 255, 255, 0.15); + } } .class, .method, .function, .data { From 0422de10e73bf797e93417d4c32f710196755087 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:16:32 +0200 Subject: [PATCH 131/333] changes --- docs/legal-notice/index.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 docs/legal-notice/index.rst diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst new file mode 100644 index 0000000000..08b272d915 --- /dev/null +++ b/docs/legal-notice/index.rst @@ -0,0 +1,12 @@ +.. _flair_legal-notice: + +.. title:: Legal notice + +.. raw:: html + :file: _templates/landing_page_styles.html + +.. raw:: html + :file: _templates/landing-page-banner.html + +.. raw:: html + :file: _templates/landing-page-illustrations.html \ No newline at end of file From 79d664fafc9d01b96d8e3b83f06bbe8ae6f393cd Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:21:24 +0200 Subject: [PATCH 132/333] changes --- docs/_templates/footer-links.html | 5 ----- docs/_templates/footer-links/legal-notice.html | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 docs/_templates/footer-links.html diff --git a/docs/_templates/footer-links.html b/docs/_templates/footer-links.html deleted file mode 100644 index 074731dfb5..0000000000 --- a/docs/_templates/footer-links.html +++ /dev/null @@ -1,5 +0,0 @@ -

\ No newline at end of file diff --git a/docs/_templates/footer-links/legal-notice.html b/docs/_templates/footer-links/legal-notice.html index 44658a6eba..ff1eed5ca9 100644 --- a/docs/_templates/footer-links/legal-notice.html +++ b/docs/_templates/footer-links/legal-notice.html @@ -1 +1 @@ -Legal notice \ No newline at end of file +Legal notice \ No newline at end of file From 3b3379da6f564aea40174572af9029e72390752c Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:24:51 +0200 Subject: [PATCH 133/333] changes --- docs/_templates/footer-links/legal-notice.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/footer-links/legal-notice.html b/docs/_templates/footer-links/legal-notice.html index ff1eed5ca9..30553d5aed 100644 --- a/docs/_templates/footer-links/legal-notice.html +++ b/docs/_templates/footer-links/legal-notice.html @@ -1 +1 @@ -Legal notice \ No newline at end of file +Legal notice \ No newline at end of file From 1cadc5b898ea4ba66fd688dc7b6754598b375a6b Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:30:07 +0200 Subject: [PATCH 134/333] changes --- docs/_templates/footer-links/legal-notice.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/footer-links/legal-notice.html b/docs/_templates/footer-links/legal-notice.html index 30553d5aed..ab3b0dd4ac 100644 --- a/docs/_templates/footer-links/legal-notice.html +++ b/docs/_templates/footer-links/legal-notice.html @@ -1 +1 @@ -Legal notice \ No newline at end of file +Legal notice \ No newline at end of file From c95f8598d8b5cb479f38bf74d6135d59e5718f28 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:32:07 +0200 Subject: [PATCH 135/333] changes --- docs/_templates/legal-notice-content.html | 3 +++ docs/legal-notice/index.rst | 8 +------- 2 files changed, 4 insertions(+), 7 deletions(-) create mode 100644 docs/_templates/legal-notice-content.html diff --git a/docs/_templates/legal-notice-content.html b/docs/_templates/legal-notice-content.html new file mode 100644 index 0000000000..7b25d9a135 --- /dev/null +++ b/docs/_templates/legal-notice-content.html @@ -0,0 +1,3 @@ +
+ HELLO; WORLD; +
\ No newline at end of file diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst index 08b272d915..812268cebb 100644 --- a/docs/legal-notice/index.rst +++ b/docs/legal-notice/index.rst @@ -3,10 +3,4 @@ .. title:: Legal notice .. raw:: html - :file: _templates/landing_page_styles.html - -.. raw:: html - :file: _templates/landing-page-banner.html - -.. raw:: html - :file: _templates/landing-page-illustrations.html \ No newline at end of file + :file: _templates/legal-notice-content-html \ No newline at end of file From 5e63d976b2b0f47647760875d74aa3ef6899b655 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:33:48 +0200 Subject: [PATCH 136/333] changes --- docs/legal-notice/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst index 812268cebb..7eafaeeb3f 100644 --- a/docs/legal-notice/index.rst +++ b/docs/legal-notice/index.rst @@ -3,4 +3,4 @@ .. title:: Legal notice .. raw:: html - :file: _templates/legal-notice-content-html \ No newline at end of file + :file: _templates/legal-notice-content.html \ No newline at end of file From fda3778b10b24a144715575ff051bc1bfb42a8f6 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:35:32 +0200 Subject: [PATCH 137/333] changes --- docs/legal-notice/index.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst index 7eafaeeb3f..70779f7ce0 100644 --- a/docs/legal-notice/index.rst +++ b/docs/legal-notice/index.rst @@ -3,4 +3,12 @@ .. title:: Legal notice .. raw:: html - :file: _templates/legal-notice-content.html \ No newline at end of file + :file: _templates/legal-notice-content.html + +.. toctree:: + :maxdepth: 3 + :hidden: + + Tutorials + API + Contributing \ No newline at end of file From 197e58fa5412dc0ebed39aa4336e4971beeaea66 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:39:12 +0200 Subject: [PATCH 138/333] changes --- docs/legal-notice/index.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst index 70779f7ce0..54b24b860c 100644 --- a/docs/legal-notice/index.rst +++ b/docs/legal-notice/index.rst @@ -1,6 +1,7 @@ -.. _flair_legal-notice: +Legal Notice +============ -.. title:: Legal notice +.. title:: Legal Notice .. raw:: html :file: _templates/legal-notice-content.html From 2a61780dcc8e0cda682683cb0cc5bf3427d536f7 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:42:50 +0200 Subject: [PATCH 139/333] changes --- docs/legal-notice/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst index 54b24b860c..5da0252fd4 100644 --- a/docs/legal-notice/index.rst +++ b/docs/legal-notice/index.rst @@ -4,7 +4,7 @@ Legal Notice .. title:: Legal Notice .. raw:: html - :file: _templates/legal-notice-content.html + :file: ../_templates/legal-notice-content.html .. toctree:: :maxdepth: 3 From da9a507659c19841bd8b1096962dd3f98736a979 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:49:29 +0200 Subject: [PATCH 140/333] changes --- docs/_static/css/legal-notice.css | 0 docs/_templates/legal-notice-content.html | 35 ++++++++++++++++++++++- docs/conf.py | 3 +- 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 docs/_static/css/legal-notice.css diff --git a/docs/_static/css/legal-notice.css b/docs/_static/css/legal-notice.css new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/_templates/legal-notice-content.html b/docs/_templates/legal-notice-content.html index 7b25d9a135..7e4951ef01 100644 --- a/docs/_templates/legal-notice-content.html +++ b/docs/_templates/legal-notice-content.html @@ -1,3 +1,36 @@
- HELLO; WORLD; +
+
+
+

Contact

+

+ In case of questions, feel free to open a + Github Issue + or write me an email: + alan.akbik@hu-berlin.de. +

+

+ + Flair NLP is maintained by: + +
+ Alan Akbik
+ Humboldt-Universität zu Berlin
+ Institut für Informatik / Lehrstuhl Maschinelles Lernen
+ Unter den Linden 6
+ 10099 Berlin
+ Germany +

+

+ Website +

+

+ Privacy Policy
+ The webserver / web hosting company might collect certain log files to prevent abuse of services. + These log files can include: IP address, URL, date and time.
+ We do not use any tracking services or cookies to track or re-identify visitors. +

+
+
+
\ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index b1238878bc..f5f0541ab1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -86,7 +86,8 @@ def linkcode_resolve(*args): 'css/version-switcher.css', 'css/sidebar.css', 'css/tutorial.css', - 'css/api.css' + 'css/api.css', + 'css/legal-notice.css', ] html_logo = "_static/flair_logo_white.svg" From 033edd96eb8fe4c7278fdff5c48f3e6ae83b7e36 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sat, 17 Aug 2024 11:56:57 +0200 Subject: [PATCH 141/333] changes --- docs/_static/css/legal-notice.css | 6 ++++++ docs/_templates/legal-notice-content.html | 7 +++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/_static/css/legal-notice.css b/docs/_static/css/legal-notice.css index e69de29bb2..e9794bb7b4 100644 --- a/docs/_static/css/legal-notice.css +++ b/docs/_static/css/legal-notice.css @@ -0,0 +1,6 @@ +.legal-notice { + h2 { + font-size: 1.75rem; + font-weight: 600; + } +} \ No newline at end of file diff --git a/docs/_templates/legal-notice-content.html b/docs/_templates/legal-notice-content.html index 7e4951ef01..15a5a0ac8a 100644 --- a/docs/_templates/legal-notice-content.html +++ b/docs/_templates/legal-notice-content.html @@ -1,8 +1,7 @@ -
-
-
+