+
diff --git a/docs/_templates/footer-links/legal-notice.html b/docs/_templates/footer-links/legal-notice.html
new file mode 100644
index 0000000000..ab3b0dd4ac
--- /dev/null
+++ b/docs/_templates/footer-links/legal-notice.html
@@ -0,0 +1 @@
+Legal notice
\ No newline at end of file
diff --git a/docs/_templates/footer-links/linkedin.html b/docs/_templates/footer-links/linkedin.html
new file mode 100644
index 0000000000..f6f83d745d
--- /dev/null
+++ b/docs/_templates/footer-links/linkedin.html
@@ -0,0 +1 @@
+LinkedIn
\ No newline at end of file
diff --git a/docs/_templates/footer-links/x.html b/docs/_templates/footer-links/x.html
new file mode 100644
index 0000000000..c9976cdddf
--- /dev/null
+++ b/docs/_templates/footer-links/x.html
@@ -0,0 +1 @@
+Twitter/X
\ No newline at end of file
diff --git a/docs/_templates/landing-page-banner.html b/docs/_templates/landing-page-banner.html
new file mode 100644
index 0000000000..8130a83d52
--- /dev/null
+++ b/docs/_templates/landing-page-banner.html
@@ -0,0 +1,353 @@
+
+
+
+
+ A very simple framework for state-of-the-art natural language processing.
+ Get started
+
\ No newline at end of file
diff --git a/docs/_templates/landing-page-illustrations.html b/docs/_templates/landing-page-illustrations.html
new file mode 100644
index 0000000000..b6b55230a5
--- /dev/null
+++ b/docs/_templates/landing-page-illustrations.html
@@ -0,0 +1,30 @@
+
+
+
+
+ Easy to Use
+
+ State-of-the-art NLP with just a few lines of code! Find entities, detect sentiment, and more.
+ Check out our demo!
+
+
+
+
+
+ Huge Community
+
+ With a community of ~200 contributors, Flair is used in hundreds of companies,
+ over 2,000 open source projects, and
+ 2,000+ papers!
+
+
+
+
+
+ Open Source and Free
+
+ Flair is completely free and open source, making it accessible for everyone to use
+ and report issues.
+
+
+
\ No newline at end of file
diff --git a/docs/_templates/landing_page_styles.html b/docs/_templates/landing_page_styles.html
new file mode 100644
index 0000000000..d98569463a
--- /dev/null
+++ b/docs/_templates/landing_page_styles.html
@@ -0,0 +1,339 @@
+
diff --git a/docs/_templates/legal-notice-content.html b/docs/_templates/legal-notice-content.html
new file mode 100644
index 0000000000..15a5a0ac8a
--- /dev/null
+++ b/docs/_templates/legal-notice-content.html
@@ -0,0 +1,35 @@
+
+
+ Flair NLP is maintained by:
+
+
+ Alan Akbik
+ Humboldt-Universität zu Berlin
+ Institut für Informatik / Lehrstuhl Maschinelles Lernen
+ Unter den Linden 6
+ 10099 Berlin
+ Germany
+
+ Privacy Policy
+ The webserver / web hosting company might collect certain log files to prevent abuse of services.
+ These log files can include: IP address, URL, date and time.
+ We do not use any tracking services or cookies to track or re-identify visitors.
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/_templates/version-switcher.html b/docs/_templates/version-switcher.html
index 1d21c6c65f..ad1676b4d8 100644
--- a/docs/_templates/version-switcher.html
+++ b/docs/_templates/version-switcher.html
@@ -1,30 +1,45 @@
{# As the version switcher will only work when JavaScript is enabled, we add it through JavaScript.
#}
+
diff --git a/docs/api/datasets/base.rst b/docs/api/datasets/base.rst
index e42784deb0..80c375eec8 100644
--- a/docs/api/datasets/base.rst
+++ b/docs/api/datasets/base.rst
@@ -1,4 +1,8 @@
flair.datasets.base
===================
-.. automodule:: flair.datasets.base
\ No newline at end of file
+.. currentmodule:: flair.datasets.base
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/biomedical.rst b/docs/api/datasets/biomedical.rst
index d59bd8c589..c1d2525ece 100644
--- a/docs/api/datasets/biomedical.rst
+++ b/docs/api/datasets/biomedical.rst
@@ -1,4 +1,8 @@
flair.datasets.biomedical
=========================
-.. automodule:: flair.datasets.biomedical
\ No newline at end of file
+.. currentmodule:: flair.datasets.biomedical
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/document_classification.rst b/docs/api/datasets/document_classification.rst
index d8303f3aeb..0de14b6cdd 100644
--- a/docs/api/datasets/document_classification.rst
+++ b/docs/api/datasets/document_classification.rst
@@ -1,4 +1,8 @@
flair.datasets.document_classification
======================================
-.. automodule:: flair.datasets.document_classification
\ No newline at end of file
+.. currentmodule:: flair.datasets.document_classification
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/entity_linking.rst b/docs/api/datasets/entity_linking.rst
index cdb2b32356..a88308e97d 100644
--- a/docs/api/datasets/entity_linking.rst
+++ b/docs/api/datasets/entity_linking.rst
@@ -1,4 +1,8 @@
flair.datasets.entity_linking
=============================
-.. automodule:: flair.datasets.entity_linking
\ No newline at end of file
+.. currentmodule:: flair.datasets.entity_linking
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/ocr.rst b/docs/api/datasets/ocr.rst
index 3f85340440..f5629a5631 100644
--- a/docs/api/datasets/ocr.rst
+++ b/docs/api/datasets/ocr.rst
@@ -1,4 +1,8 @@
flair.datasets.ocr
==================
-.. automodule:: flair.datasets.ocr
\ No newline at end of file
+.. currentmodule:: flair.datasets.ocr
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/relation_extraction.rst b/docs/api/datasets/relation_extraction.rst
index 62dcdd55d1..fdbb690cfe 100644
--- a/docs/api/datasets/relation_extraction.rst
+++ b/docs/api/datasets/relation_extraction.rst
@@ -1,4 +1,8 @@
flair.datasets.relation_extraction
==================================
-.. automodule:: flair.datasets.relation_extraction
\ No newline at end of file
+.. currentmodule:: flair.datasets.relation_extraction
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/sequence_labeling.rst b/docs/api/datasets/sequence_labeling.rst
index 875d4831b1..0c0abc520d 100644
--- a/docs/api/datasets/sequence_labeling.rst
+++ b/docs/api/datasets/sequence_labeling.rst
@@ -1,4 +1,8 @@
flair.datasets.sequence_labeling
================================
-.. automodule:: flair.datasets.sequence_labeling
\ No newline at end of file
+.. currentmodule:: flair.datasets.sequence_labeling
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/text_image.rst b/docs/api/datasets/text_image.rst
index f14e564916..173928dfc1 100644
--- a/docs/api/datasets/text_image.rst
+++ b/docs/api/datasets/text_image.rst
@@ -1,4 +1,8 @@
flair.datasets.text_image
=========================
-.. automodule:: flair.datasets.text_image
\ No newline at end of file
+.. currentmodule:: flair.datasets.text_image
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/datasets/text_text.rst b/docs/api/datasets/text_text.rst
index f88dfd1aed..79a60ae609 100644
--- a/docs/api/datasets/text_text.rst
+++ b/docs/api/datasets/text_text.rst
@@ -1,4 +1,10 @@
flair.datasets.text_text
-=========================
+========================
+
+.. currentmodule:: flair.datasets.text_text
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
-.. automodule:: flair.datasets.text_text
\ No newline at end of file
diff --git a/docs/api/datasets/treebanks.rst b/docs/api/datasets/treebanks.rst
index 0d6c14a281..82cba954f7 100644
--- a/docs/api/datasets/treebanks.rst
+++ b/docs/api/datasets/treebanks.rst
@@ -1,4 +1,8 @@
flair.datasets.treebanks
========================
-.. automodule:: flair.datasets.treebanks
\ No newline at end of file
+.. currentmodule:: flair.datasets.treebanks
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/embeddings/base.rst b/docs/api/embeddings/base.rst
index 1bf51ffa7a..02a65a20b9 100644
--- a/docs/api/embeddings/base.rst
+++ b/docs/api/embeddings/base.rst
@@ -1,4 +1,8 @@
flair.embeddings.base
=====================
-.. automodule:: flair.embeddings.base
\ No newline at end of file
+.. currentmodule:: flair.embeddings.base
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/embeddings/document.rst b/docs/api/embeddings/document.rst
index ca870fc8ea..8c5aea548d 100644
--- a/docs/api/embeddings/document.rst
+++ b/docs/api/embeddings/document.rst
@@ -1,4 +1,8 @@
flair.embeddings.document
=========================
-.. automodule:: flair.embeddings.document
\ No newline at end of file
+.. currentmodule:: flair.embeddings.document
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/embeddings/image.rst b/docs/api/embeddings/image.rst
index 2a701b9e0b..6a115e705c 100644
--- a/docs/api/embeddings/image.rst
+++ b/docs/api/embeddings/image.rst
@@ -1,4 +1,8 @@
flair.embeddings.image
======================
-.. automodule:: flair.embeddings.image
\ No newline at end of file
+.. currentmodule:: flair.embeddings.image
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/embeddings/legacy.rst b/docs/api/embeddings/legacy.rst
index 974a777eb9..a177a1ffc2 100644
--- a/docs/api/embeddings/legacy.rst
+++ b/docs/api/embeddings/legacy.rst
@@ -1,8 +1,8 @@
flair.embeddings.legacy
-============================
+=======================
-.. warning::
- All embeddings in `flair.embeddings.legacy` are considered deprecated.
- there is no guarantee that they are still working and we recommend using different embeddings instead.
+.. currentmodule:: flair.embeddings.legacy
-.. automodule:: flair.embeddings.legacy
\ No newline at end of file
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/embeddings/token.rst b/docs/api/embeddings/token.rst
index 3705fedb1d..7fc6305bdb 100644
--- a/docs/api/embeddings/token.rst
+++ b/docs/api/embeddings/token.rst
@@ -1,4 +1,8 @@
flair.embeddings.token
======================
-.. automodule:: flair.embeddings.token
\ No newline at end of file
+.. currentmodule:: flair.embeddings.token
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/embeddings/transformer.rst b/docs/api/embeddings/transformer.rst
index 2bda02f771..51e04a43c9 100644
--- a/docs/api/embeddings/transformer.rst
+++ b/docs/api/embeddings/transformer.rst
@@ -1,4 +1,8 @@
flair.embeddings.transformer
============================
-.. automodule:: flair.embeddings.transformer
\ No newline at end of file
+.. currentmodule:: flair.embeddings.transformer
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.data.rst b/docs/api/flair.data.rst
index 00dd67a521..dae99e093b 100644
--- a/docs/api/flair.data.rst
+++ b/docs/api/flair.data.rst
@@ -1,4 +1,8 @@
flair.data
==========
-.. automodule:: flair.data
\ No newline at end of file
+.. currentmodule:: flair.data
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.datasets.rst b/docs/api/flair.datasets.rst
index 9a883c3e61..d822186e67 100644
--- a/docs/api/flair.datasets.rst
+++ b/docs/api/flair.datasets.rst
@@ -1,6 +1,8 @@
flair.datasets
==============
+.. currentmodule:: flair.datasets
+
.. toctree::
:glob:
:maxdepth: 2
diff --git a/docs/api/flair.embeddings.rst b/docs/api/flair.embeddings.rst
index 3f70e62bef..82905def03 100644
--- a/docs/api/flair.embeddings.rst
+++ b/docs/api/flair.embeddings.rst
@@ -1,6 +1,8 @@
flair.embeddings
================
+.. currentmodule:: flair.embeddings
+
.. toctree::
:glob:
:maxdepth: 2
diff --git a/docs/api/flair.models.rst b/docs/api/flair.models.rst
index 8679b3fb7d..0451fb30bd 100644
--- a/docs/api/flair.models.rst
+++ b/docs/api/flair.models.rst
@@ -1,4 +1,8 @@
flair.models
============
-.. automodule:: flair.models
\ No newline at end of file
+.. currentmodule:: flair.models
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.nn.rst b/docs/api/flair.nn.rst
index 4eb066d3ea..6a7247256a 100644
--- a/docs/api/flair.nn.rst
+++ b/docs/api/flair.nn.rst
@@ -1,4 +1,8 @@
flair.nn
========
-.. automodule:: flair.nn
\ No newline at end of file
+.. currentmodule:: flair.nn
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.rst b/docs/api/flair.rst
index 4e12a03829..946c68a6a4 100644
--- a/docs/api/flair.rst
+++ b/docs/api/flair.rst
@@ -1,4 +1,8 @@
flair
=====
-.. automodule:: flair
\ No newline at end of file
+.. currentmodule:: flair
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.splitter.rst b/docs/api/flair.splitter.rst
index 5863df5788..f9b71de316 100644
--- a/docs/api/flair.splitter.rst
+++ b/docs/api/flair.splitter.rst
@@ -1,4 +1,9 @@
flair.splitter
==============
-.. automodule:: flair.splitter
\ No newline at end of file
+.. currentmodule:: flair.splitter
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
diff --git a/docs/api/flair.tokenization.rst b/docs/api/flair.tokenization.rst
index 00f2bc4bfd..ec5ca557d1 100644
--- a/docs/api/flair.tokenization.rst
+++ b/docs/api/flair.tokenization.rst
@@ -1,4 +1,8 @@
flair.tokenization
==================
-.. automodule:: flair.tokenization
\ No newline at end of file
+.. currentmodule:: flair.tokenization
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.trainers.plugins.rst b/docs/api/flair.trainers.plugins.rst
index 4bb766876b..acd8109cf4 100644
--- a/docs/api/flair.trainers.plugins.rst
+++ b/docs/api/flair.trainers.plugins.rst
@@ -1,4 +1,8 @@
flair.trainers.plugins
======================
-.. automodule:: flair.trainers.plugins
\ No newline at end of file
+.. currentmodule:: flair.trainers.plugins
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/api/flair.trainers.rst b/docs/api/flair.trainers.rst
index db11b5029c..ac59a85039 100644
--- a/docs/api/flair.trainers.rst
+++ b/docs/api/flair.trainers.rst
@@ -1,4 +1,8 @@
flair.trainers
==============
-.. automodule:: flair.trainers
\ No newline at end of file
+.. currentmodule:: flair.trainers
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
diff --git a/docs/conf.py b/docs/conf.py
index 64624043e0..3f5d95cb04 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,9 +1,11 @@
# noqa: INP001
+import inspect
import importlib_metadata
# -- Project information -----------------------------------------------------
from sphinx_github_style import get_linkcode_resolve
+from torch.nn import Module
version = "0.14.0"
release = "0.14.0"
@@ -27,21 +29,18 @@
} # dummy value that sphinx-github-style won't crash when run in temp folder.
html_theme_options = {
- "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"],
- "github_url": linkcode_url,
- "icon_links": [
- {
- "name": "PyPI",
- "url": "https://pypi.org/project/flair",
- "icon": "fas fa-box",
- },
- ],
+ "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"],
+ "show_prev_next": False,
+ "footer_end": ["footer-links/legal-notice.html", "footer-links/x.html", "footer-links/linkedin.html"],
+ "secondary_sidebar_items": [],
}
def linkcode_resolve(*args):
+ app = inspect.currentframe().f_back.f_locals.get("app")
+ current_version = app.config.smv_current_version
# use smv_current_version as the git url
- real_linkcode_url = linkcode_url + f"/blob/{smv_current_version}/" + "{filepath}#L{linestart}-L{linestop}"
+ real_linkcode_url = linkcode_url + f"/blob/{current_version}/" + "{filepath}#L{linestart}-L{linestop}"
return get_linkcode_resolve(real_linkcode_url)(*args)
@@ -56,13 +55,15 @@ def linkcode_resolve(*args):
"sphinx.ext.ifconfig",
"sphinx.ext.napoleon", # to render Google format docstrings
"sphinx.ext.githubpages",
+ "sphinx_autosummary_autocollect",
"myst_parser",
"sphinx_github_style",
"sphinx_autodoc_typehints",
"sphinx_multiversion",
"sphinx_design",
]
-
+autosummary_generate = True
+autosummary_ignore_module_all = False
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@@ -82,10 +83,26 @@ def linkcode_resolve(*args):
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
+html_title = "Flair Documentation"
+
+html_css_files = [
+ "css/main.css",
+ "css/header.css",
+ "css/footer.css",
+ "css/version-switcher.css",
+ "css/sidebar.css",
+ "css/tutorial.css",
+ "css/api.css",
+ "css/legal-notice.css",
+ "css/search.css",
+]
+
+html_logo = "_static/flair_logo_white.svg"
+html_show_sphinx = False
# Napoleon settings
-napoleon_include_init_with_doc = True
-napoleon_include_private_with_doc = True
+napoleon_include_init_with_doc = False
+napoleon_include_private_with_doc = False
autodoc_default_options = {
"member-order": "bysource",
@@ -102,11 +119,7 @@ def linkcode_resolve(*args):
}
html_sidebars = {
- "**": [
- "globaltoc.html",
- "searchbox.html",
- "versioning.html",
- ],
+ "**": ["globaltoc.html"],
"index": [],
}
diff --git a/docs/contributing/local_development.md b/docs/contributing/local_development.md
index 87439439f2..9c7413703e 100644
--- a/docs/contributing/local_development.md
+++ b/docs/contributing/local_development.md
@@ -6,8 +6,8 @@ the code should hopefully be easy.
## Setup
-Flair requires python-3.8 or higher. To make sure our code also runs on the oldest supported
-python version, it is recommended to use python-3.8.x for flair development.
+Flair requires python-3.9 or higher. To make sure our code also runs on the oldest supported
+python version, it is recommended to use python-3.9.x for flair development.
Create a python environment of your preference and run:
```bash
diff --git a/docs/glossary/index.rst b/docs/glossary/index.rst
deleted file mode 100644
index c732a1a121..0000000000
--- a/docs/glossary/index.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-Glossary
-========
-
-.. glossary::
-
- Sentence
- a sentence is a text-unit consisting of tokens, labels and possibly metadata. Notice that a sentence is not limited in size, hence a Sentence itself could hold either a full document, a paragraph, a simple phrase or a linguistic
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 3cff769118..c39010a2c7 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,95 +1,20 @@
-flair
-=====
-
.. _flair_docs_mainpage:
+.. title:: Home
-**Version**: |version|
-
-**Useful links**:
-`Getting started `_ |
-`Source Repository `_ |
-`Issue Tracker `_ |
-
-Flair is a very simple framework for state-of-the-art Natural Language Processing (NLP)
-
-.. grid:: 2
-
- .. grid-item-card::
- :img-top: ./_static/tutorial.svg
-
- Tutorial
- ^^^^^^^^
-
- New to Flair? Check out the Tutorials. It contains an introduction to Flair's main concepts.
-
- +++
-
- .. button-ref:: tutorial/index
- :expand:
- :color: secondary
- :click-parent:
-
- To the tutorials
-
- .. grid-item-card::
- :img-top: ./_static/api.svg
-
- API-docs
- ^^^^^^^^
-
- The API-docs provides in-depth information on the classes and functions designed for public use.
-
- +++
-
- .. button-ref:: api/index
- :expand:
- :color: secondary
- :click-parent:
-
- To the API docs
-
- .. grid-item-card::
- :img-top: ./_static/contributing.svg
-
- Contributor's Guide
- ^^^^^^^^^^^^^^^^^^^
-
- Want to add to the codebase? Can help add to the
- documentation? The contributing guidelines will guide you through the
- process of improving Flair.
-
- +++
-
- .. button-ref:: contributing/index
- :expand:
- :color: secondary
- :click-parent:
-
- To the contributor's guide
-
- .. grid-item-card::
- :img-top: ./_static/glossary.svg
-
- Glossary
- ^^^^^^^^
-
- Not sure what the exact meaning of certain terms is? Find their definition in the Glossary.
-
- +++
+.. raw:: html
+ :file: _templates/landing_page_styles.html
- .. button-ref:: glossary/index
- :expand:
- :color: secondary
- :click-parent:
+.. raw:: html
+ :file: _templates/landing-page-banner.html
- To the glossary
+.. raw:: html
+ :file: _templates/landing-page-illustrations.html
.. toctree::
:maxdepth: 3
:hidden:
Tutorials
- API reference
- Contributing
- Glossary
\ No newline at end of file
+ API
+ Contributing
\ No newline at end of file
diff --git a/docs/legal-notice/index.rst b/docs/legal-notice/index.rst
new file mode 100644
index 0000000000..585047fc11
--- /dev/null
+++ b/docs/legal-notice/index.rst
@@ -0,0 +1,15 @@
+Legal Notice
+============
+
+.. title:: Legal Notice
+
+.. raw:: html
+ :file: ../_templates/legal-notice-content.html
+
+.. toctree::
+ :maxdepth: 3
+ :hidden:
+
+ Tutorials <../tutorial/index>
+ API <../api/index>
+ Contributing <../contributing/index>
\ No newline at end of file
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 0e8c4f6141..ff23c9f129 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,11 +1,12 @@
sphinx-github-style<=1.0.2 # 1.0.3 changes logic that breaks with sphinx-multiversion
sphinx-autodoc-typehints
myst-parser
-sphinx
+sphinx<8.0.0
importlib-metadata
sphinx-multiversion
pydata-sphinx-theme<0.14
sphinx_design
+sphinx-autosummary-autocollect
# previous dependencies that are required to build docs for later versions too.
semver
diff --git a/docs/tutorial/intro.md b/docs/tutorial/intro.md
index e652583f76..b8af9b5667 100644
--- a/docs/tutorial/intro.md
+++ b/docs/tutorial/intro.md
@@ -16,7 +16,7 @@ In your favorite virtual environment, simply do:
pip install flair
```
-Flair requires Python 3.8+.
+Flair requires Python 3.9+.
## Example 1: Tag Entities in Text
diff --git a/docs/tutorial/tutorial-basics/basic-types.md b/docs/tutorial/tutorial-basics/basic-types.md
index 703a5d7cd5..5ddf247166 100644
--- a/docs/tutorial/tutorial-basics/basic-types.md
+++ b/docs/tutorial/tutorial-basics/basic-types.md
@@ -242,7 +242,7 @@ for label in sentence.get_labels('ner'):
### Information for each label
-Each label is of class `Label` which next to the value has a score indicating confidence. It also has a pointer back to the data point to which it attaches.
+Each label is of class [`Label`](#flair.data.Label) which next to the value has a score indicating confidence. It also has a pointer back to the data point to which it attaches.
This means that you can print the value, the confidence and the labeled text of each label:
@@ -267,3 +267,9 @@ This should print:
Our color tag has a score of 1.0 since we manually added it. If a tag is predicted by our sequence labeler, the score value will indicate classifier confidence.
+
+### Next
+
+Congrats, you now understand Flair's basic types.
+
+Next, learn how to use [Flair models to make predictions](how-predictions-work.md).
\ No newline at end of file
diff --git a/docs/tutorial/tutorial-basics/entity-linking.md b/docs/tutorial/tutorial-basics/entity-linking.md
index 8137c2dc5f..808d5c91ad 100644
--- a/docs/tutorial/tutorial-basics/entity-linking.md
+++ b/docs/tutorial/tutorial-basics/entity-linking.md
@@ -83,3 +83,15 @@ As we can see, the linker can resolve that:
- the first mention of "Barcelona" refers to the soccer club "[FC Barcelona](https://en.wikipedia.org/wiki/FC_Barcelona)"
- the second mention of "Barcelona" refers to the city of "[Barcelona](https://en.wikipedia.org/wiki/Barcelona)"
+
+### Linking biomedical entities
+
+If you are working with biomedical data, we have a special entity linker capable of linking
+biomedical entities to specific knowledge bases. In this case, check out this [advanced tutorial on
+linking biomedical entities](entity-mention-linking.md).
+
+### Next
+
+Congrats, you learned how to link entities with Flair!
+
+Next, let's discuss how to [predict part-of-speech tags with Flair](part-of-speech-tagging.md).
\ No newline at end of file
diff --git a/docs/tutorial/tutorial-basics/entity-mention-linking.md b/docs/tutorial/tutorial-basics/entity-mention-linking.md
index e9d442c9e1..56803f1b02 100644
--- a/docs/tutorial/tutorial-basics/entity-mention-linking.md
+++ b/docs/tutorial/tutorial-basics/entity-mention-linking.md
@@ -1,6 +1,6 @@
# Using and creating entity mention linker
-As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN approach](https://huggingface.co/hunflair)].
+As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN approach](https://huggingface.co/hunflair).
You can read more at the [Hunflair2 tutorials](project:../tutorial-hunflair2/overview.md)
## Example 1: Printing Entity linking outputs to console
@@ -124,5 +124,11 @@ print(result_mentions)
```{note}
If you need more than the extracted ids, you can use `nen_tagger.dictionary[span_data["nen_id"]]`
- to look up the [`flair.data.EntityCandidate`](#flair.data.EntityCandidate) which contains further information.
-```
\ No newline at end of file
+ to look up the [`EntityCandidate`](#flair.data.EntityCandidate) which contains further information.
+```
+
+### Next
+
+Congrats, you learned how to link biomedical entities with Flair!
+
+Next, let's discuss how to [predict part-of-speech tags with Flair](part-of-speech-tagging.md).
\ No newline at end of file
diff --git a/docs/tutorial/tutorial-basics/how-predictions-work.md b/docs/tutorial/tutorial-basics/how-predictions-work.md
index 9911f6efa5..a1dffa8913 100644
--- a/docs/tutorial/tutorial-basics/how-predictions-work.md
+++ b/docs/tutorial/tutorial-basics/how-predictions-work.md
@@ -76,3 +76,8 @@ the text of label.data_point is: "Washington"
```
+### Next
+
+Congrats, you've made your first predictions with Flair and accessed value and confidence scores of each prediction.
+
+Next, let's discuss specifically how to [predict named entities with Flair](tagging-entities.md).
diff --git a/docs/tutorial/tutorial-basics/how-to-tag-corpus.md b/docs/tutorial/tutorial-basics/how-to-tag-corpus.md
index 8aa75a4027..1dee865e34 100644
--- a/docs/tutorial/tutorial-basics/how-to-tag-corpus.md
+++ b/docs/tutorial/tutorial-basics/how-to-tag-corpus.md
@@ -30,3 +30,10 @@ for sentence in sentences:
Using the `mini_batch_size` parameter of the [`Classifier.predict()`](#flair.nn.Classifier.predict) method, you can set the size of mini batches passed to the
tagger. Depending on your resources, you might want to play around with this parameter to optimize speed.
+### Next
+
+That's it - you completed tutorial 1! Congrats!
+
+You've learned how basic classes work and how to use Flair to make various predictions.
+
+Next, you can check out our tutorial on how to [train your own model](../tutorial-training/how-model-training-works.md).
diff --git a/docs/tutorial/tutorial-basics/other-models.md b/docs/tutorial/tutorial-basics/other-models.md
index dbab4f40d1..9bd02bda7e 100644
--- a/docs/tutorial/tutorial-basics/other-models.md
+++ b/docs/tutorial/tutorial-basics/other-models.md
@@ -150,3 +150,10 @@ We end this section with a list of all other models we currently ship with Flair
| 'de-historic-reported' | historical reported speech | German | @redewiedergabe project | **87.94** (F1) | [redewiedergabe](https://github.com/redewiedergabe/tagger) | |
| 'de-historic-free-indirect' | historical free-indirect speech | German | @redewiedergabe project | **87.94** (F1) | [redewiedergabe](https://github.com/redewiedergabe/tagger) | |
+
+### Next
+
+Congrats, you learned about some other models we have in Flair!
+
+So far, we only focused on predicting for single sentences. Next, let's discuss how
+to create [predictions for a whole corpus of documents](how-to-tag-corpus.md).
\ No newline at end of file
diff --git a/docs/tutorial/tutorial-basics/part-of-speech-tagging.md b/docs/tutorial/tutorial-basics/part-of-speech-tagging.md
index 3da1774bf1..9a9dc54f55 100644
--- a/docs/tutorial/tutorial-basics/part-of-speech-tagging.md
+++ b/docs/tutorial/tutorial-basics/part-of-speech-tagging.md
@@ -167,4 +167,9 @@ You choose which pre-trained model you load by passing the appropriate string to
A full list of our current and community-contributed models can be browsed on the [__model hub__](https://huggingface.co/models?library=flair&sort=downloads).
+### Next
+
+Congrats, you learned how to predict part-of-speech tags with Flair!
+
+Next, we'll present some [other models in Flair](other-models.md) you might find useful.
diff --git a/docs/tutorial/tutorial-basics/tagging-entities.md b/docs/tutorial/tutorial-basics/tagging-entities.md
index a3b41ff80c..39ccdd8e1a 100644
--- a/docs/tutorial/tutorial-basics/tagging-entities.md
+++ b/docs/tutorial/tutorial-basics/tagging-entities.md
@@ -200,3 +200,10 @@ You choose which pre-trained model you load by passing the appropriate string to
A full list of our current and community-contributed models can be browsed on the [__model hub__](https://huggingface.co/models?library=flair&sort=downloads).
+
+### Next
+
+Congrats, you learned how to predict entities with Flair and got an overview of different models!
+
+Next, let's discuss how to [predict sentiment with Flair](tagging-sentiment.md).
+
diff --git a/docs/tutorial/tutorial-basics/tagging-sentiment.md b/docs/tutorial/tutorial-basics/tagging-sentiment.md
index 0c6c9f5789..6bbebfb178 100644
--- a/docs/tutorial/tutorial-basics/tagging-sentiment.md
+++ b/docs/tutorial/tutorial-basics/tagging-sentiment.md
@@ -75,5 +75,9 @@ We end this section with a list of all models we currently ship with Flair:
| 'de-offensive-language' | German | detecting offensive language | [GermEval 2018 Task 1](https://projects.fzai.h-da.de/iggsa/projekt/) | **75.71** (Macro F1) |
+### Next
+Congrats, you learned how to predict sentiment with Flair!
+
+Next, let's discuss how to [link entities to Wikipedia with Flair](entity-linking.md).
diff --git a/docs/tutorial/tutorial-training/how-model-training-works.md b/docs/tutorial/tutorial-training/how-model-training-works.md
index 9241213c98..a4b2392a7c 100644
--- a/docs/tutorial/tutorial-training/how-model-training-works.md
+++ b/docs/tutorial/tutorial-training/how-model-training-works.md
@@ -279,16 +279,10 @@ print(sentence.to_tagged_string())
If the model works well, it will correctly tag 'love' as a verb in this example.
-## Summary
+## Next
-This tutorial gave you a general overview of the main steps to train a model:
+Congrats, you now have a general overview of the main steps to train a model in Flair!
-- load a corpus
-- choose a label type
-- create a label dictionary
-- choose embeddings
-- initialize model
-- initialize trainer
-- train
+Next, learn about the [two main training approaches in Flair](train-vs-fine-tune.md).
diff --git a/docs/tutorial/tutorial-training/how-to-load-custom-dataset.md b/docs/tutorial/tutorial-training/how-to-load-custom-dataset.md
index 1e7fadb0f1..4bd71d4340 100644
--- a/docs/tutorial/tutorial-training/how-to-load-custom-dataset.md
+++ b/docs/tutorial/tutorial-training/how-to-load-custom-dataset.md
@@ -159,3 +159,6 @@ example we chose `label_type='topic'` to denote that we are loading a corpus wit
+## Next
+
+Next, learn [how to train a sequence tagger](how-to-train-sequence-tagger.md).
diff --git a/docs/tutorial/tutorial-training/how-to-load-prepared-dataset.md b/docs/tutorial/tutorial-training/how-to-load-prepared-dataset.md
index ed29bea502..b53aeef917 100644
--- a/docs/tutorial/tutorial-training/how-to-load-prepared-dataset.md
+++ b/docs/tutorial/tutorial-training/how-to-load-prepared-dataset.md
@@ -168,7 +168,7 @@ from flair.data import MultiCorpus
multi_corpus = MultiCorpus([english_corpus, german_corpus, dutch_corpus])
```
-The [`MultiCorpus`](#flair.data.MultiCorpus) inherits from `[`Corpus`](#flair.data.Corpus), so you can use it like any other corpus to train your models.
+The [`MultiCorpus`](#flair.data.MultiCorpus) inherits from [`Corpus`](#flair.data.Corpus), so you can use it like any other corpus to train your models.
## Datasets included in Flair
@@ -193,3 +193,7 @@ The following datasets are supported:
| Universal Dependency Treebanks | [flair.datasets.treebanks](#flair.datasets.treebanks) |
| OCR-Layout-NER | [flair.datasets.ocr](#flair.datasets.ocr) |
+
+## Next
+
+Next, learn how to load a [custom dataset](how-to-load-custom-dataset.md).
\ No newline at end of file
diff --git a/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md b/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md
index fc9bc492b1..0f79022b70 100644
--- a/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md
+++ b/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md
@@ -223,3 +223,6 @@ trainer.train('resources/taggers/example-universal-pos',
This gives you a multilingual model. Try experimenting with more languages!
+## Next
+
+Next, learn [how to train a text classifier](how-to-train-text-classifier.md).
diff --git a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md
index e5d32cb426..21c5ee7de7 100644
--- a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md
+++ b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md
@@ -1,8 +1,7 @@
# Train a span classifier
Span Classification models are used to model problems such as entity linking, where you already have extracted some
-relevant spans
-within the {term}`Sentence` and want to predict some more fine-grained labels.
+relevant spans within the `Sentence` and want to predict some more fine-grained labels.
This tutorial section show you how to train models using the [Span Classifier](#flair.models.SpanClassifier) in Flair.
diff --git a/docs/tutorial/tutorial-training/how-to-train-text-classifier.md b/docs/tutorial/tutorial-training/how-to-train-text-classifier.md
index 265689c21f..b88084ff00 100644
--- a/docs/tutorial/tutorial-training/how-to-train-text-classifier.md
+++ b/docs/tutorial/tutorial-training/how-to-train-text-classifier.md
@@ -58,3 +58,7 @@ classifier.predict(sentence)
print(sentence.labels)
```
+
+## Next
+
+Next, learn [how to train an entity linker](how-to-train-span-classifier.md).
\ No newline at end of file
diff --git a/docs/tutorial/tutorial-training/train-vs-fine-tune.md b/docs/tutorial/tutorial-training/train-vs-fine-tune.md
index fd45e90ea0..657a3a0aa8 100644
--- a/docs/tutorial/tutorial-training/train-vs-fine-tune.md
+++ b/docs/tutorial/tutorial-training/train-vs-fine-tune.md
@@ -1,11 +1,50 @@
# Training vs fine-tuning
There are two broad ways you train a model: The "classic" approach and the fine-tuning approach. This section
-explains the differences, and the things you need to do.
+explains the differences.
## Fine-Tuning
+Fine-tuning is the current state-of-the-art approach. The main idea is that you take a pre-trained language model that
+consists of (hundreds of) millions of trained parameters. To this language model you add a simple prediction head with
+randomly initialized weights.
+
+Since in this case, the vast majority of parameters in the model is already trained, you only need to "fine-tune" this
+model. This means: Very small learning rate (LR) and just a few epochs. You are essentially just minimally modifying
+the model to adapt it to the task you want to solve.
+
+Use this method by calling [`ModelTrainer.fine_tune()`](#flair.trainers.ModelTrainer.fine_tune).
+Since most models in Flair were trained this way, this is likely the approach you'll want to use.
+
## Training
+On the other hand, you should use the classic training approach if the majority of the trainable parameters in your
+model is randomly initialized. This can happen for instance if you freeze the model weights of the pre-trained language
+model, leaving only the randomly initialited prediction head as trainable parameters. This training approach is also
+referred to as "feature-based" or "probing" in some papers.
+
+Since the majority of parameters is randomly initialized, you need to fully train the model. This means: high learning
+rate and many epochs.
+
+Use this method by calling [`ModelTrainer.train()`](#flair.trainers.ModelTrainer.train) .
+
+```{note}
+Another application of classic training is for linear probing of pre-trained language models. In this scenario, you
+"freeze" the weights of the language model (meaning that they cannot be changed) and add a prediction head that is
+trained from scratch. So, even though a language model is involved, its parameters are not trainable. This means that
+all trainable parameters in this scenario are randomly initialized, therefore necessitating the use of the classic
+training approach.
+```
+
+
+## Paper
+
+If you are interested in an experimental comparison of the two above-mentioned approach, check out [our paper](https://arxiv.org/pdf/2011.06993)
+that compares fine-tuning to the feature-based approach.
+
+
+## Next
+
+Next, learn how to load a [training dataset](how-to-load-prepared-dataset.md).
\ No newline at end of file
diff --git a/examples/README.md b/examples/README.md
index 53e11ffe21..1221fbf49f 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -4,6 +4,7 @@ This folder contains actively maintained examples of use of Flair, organized alo
## Table of Tasks
-| Task | Documentation
-| ----------------------------- | -------------
+| Task | Documentation
+| ------------------------------ | -------------
| Named Entity Recognition (NER) | [Here](ner/)
+| Multi GPU | [Here](multi_gpu/)
diff --git a/examples/multi_gpu/README.md b/examples/multi_gpu/README.md
new file mode 100644
index 0000000000..4c12b9d0bd
--- /dev/null
+++ b/examples/multi_gpu/README.md
@@ -0,0 +1,32 @@
+# Multi GPU
+
+Training can be distributed across multiple GPUs on a local machine when using
+[`ModelTrainer`](#flair.trainers.trainer.ModelTrainer).
+
+## Example
+
+See the script `run_multi_gpu.py` and its comments.
+
+## Tutorial
+
+There are 2 changes that are always required, as well as a few things to consider
+
+Always Required:
+1) Pass the argument `multi_gpu=True` to your [`.train()`](#flair.trainers.trainer.ModelTrainer.train) or `.fine_tune()`
+2) Wrap your code in [`launch_distributed`](#flair.distributed_utils.launch_distributed), e.g.
+ `launch_distributed(main, *args)`. This spawns multiple processes, each driving a GPU
+
+Other considerations:
+- The corpus and other preprocessing must be the same on all processes. For example, if corpus initialization involves
+ anything random, you should either
+ - Set the random seed before initializing the corpus (e.g. [`flair.set_seed(42)`) OR
+ - Initialize the corpus before calling `launch_distributed` and pass the corpus as an argument so it's serialized to
+ all processes
+- The effective batch size will be larger by a factor of num_gpus
+ - Each GPU will now process `mini_batch_size` examples before the optimizer steps, resulting in fewer total steps
+ taken relative to training with a single device. To obtain comparable results between single/multi gpu,
+ both mathematically, and in terms of wall time, consider the method in the example script.
+- Large batch sizes may be necessary to see faster runs, otherwise the communication overhead may dominate
+
+Only the parameter updates in the training process will be distributed across multiple GPUs. Evaluation and prediction
+are still done on a single device.
diff --git a/examples/multi_gpu/__init__.py b/examples/multi_gpu/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/multi_gpu/run_multi_gpu.py b/examples/multi_gpu/run_multi_gpu.py
new file mode 100644
index 0000000000..f7d059111b
--- /dev/null
+++ b/examples/multi_gpu/run_multi_gpu.py
@@ -0,0 +1,54 @@
+import torch
+
+import flair
+from flair.datasets import IMDB
+from flair.distributed_utils import launch_distributed
+from flair.embeddings import TransformerDocumentEmbeddings
+from flair.models import TextClassifier
+from flair.trainers import ModelTrainer
+
+
+def main(multi_gpu):
+ # Note: Multi-GPU can affect corpus loading
+ # This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to
+ # ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using
+ # the same seed on all processes.
+ flair.set_seed(42)
+
+ corpus = IMDB()
+ corpus.downsample(0.1)
+ label_type = "sentiment"
+ label_dictionary = corpus.make_label_dictionary(label_type)
+
+ embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
+ model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
+
+ # Note: Multi-GPU can affect choice of batch size.
+ # In order to compare batch updates fairly between single and multi-GPU training, we should:
+ # 1) Step the optimizer after the same number of examples to achieve com
+ # 2) Process the same number of examples in each forward pass
+ mini_batch_chunk_size = 32 # Make this as large as possible without running out of GPU-memory to pack device
+ num_devices_when_distributing = max(torch.cuda.device_count(), 1)
+ mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_devices_when_distributing
+ # e.g. Suppose your machine has 2 GPUs. If multi_gpu=False, the first gpu will process 32 examples, then the
+ # first gpu will process another 32 examples, then the optimizer will step. If multi_gpu=True, each gpu will
+ # process 32 examples at the same time, then the optimizer will step.
+
+ trainer = ModelTrainer(model, corpus)
+ trainer.fine_tune(
+ "resources/taggers/multi-gpu",
+ multi_gpu=multi_gpu, # Required for multi-gpu
+ max_epochs=2,
+ mini_batch_chunk_size=mini_batch_chunk_size,
+ mini_batch_size=mini_batch_size,
+ )
+
+
+if __name__ == "__main__":
+ """Minimal example demonstrating how to train a model on multiple GPUs."""
+ multi_gpu = True
+
+ if multi_gpu:
+ launch_distributed(main, multi_gpu) # Required for multi-gpu
+ else:
+ main(multi_gpu)
diff --git a/flair/class_utils.py b/flair/class_utils.py
index 9aa95cd1ee..7e01f4ff42 100644
--- a/flair/class_utils.py
+++ b/flair/class_utils.py
@@ -1,12 +1,13 @@
import importlib
import inspect
+from collections.abc import Iterable
from types import ModuleType
-from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload
+from typing import Any, Optional, TypeVar, Union, overload
T = TypeVar("T")
-def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
+def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]:
for subclass in cls.__subclasses__():
yield from get_non_abstract_subclasses(subclass)
if inspect.isabstract(subclass):
@@ -14,7 +15,7 @@ def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
yield subclass
-def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]:
+def get_state_subclass_by_name(cls: type[T], cls_name: Optional[str]) -> type[T]:
for sub_cls in get_non_abstract_subclasses(cls):
if sub_cls.__name__ == cls_name:
return sub_cls
@@ -26,12 +27,12 @@ def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ...
@overload
-def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ...
+def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> list[Any]: ...
def lazy_import(
group: str, module: str, first_symbol: Optional[str] = None, *symbols: str
-) -> Union[List[Any], ModuleType]:
+) -> Union[list[Any], ModuleType]:
try:
imported_module = importlib.import_module(module)
except ImportError:
diff --git a/flair/data.py b/flair/data.py
index 56aad301c7..6fd41e759c 100644
--- a/flair/data.py
+++ b/flair/data.py
@@ -4,10 +4,11 @@
import typing
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
+from collections.abc import Iterable
from operator import itemgetter
from os import PathLike
from pathlib import Path
-from typing import Any, DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast
+from typing import Any, NamedTuple, Optional, Union, cast
import torch
from deprecated.sphinx import deprecated
@@ -52,8 +53,8 @@ class Dictionary:
def __init__(self, add_unk: bool = True) -> None:
# init dictionaries
- self.item2idx: Dict[bytes, int] = {}
- self.idx2item: List[bytes] = []
+ self.item2idx: dict[bytes, int] = {}
+ self.idx2item: list[bytes] = []
self.add_unk = add_unk
self.multi_label = False
self.span_labels = False
@@ -73,7 +74,8 @@ def add_item(self, item: str) -> int:
Args:
item: a string for which to assign an id.
- Returns: ID of string
+ Returns:
+ ID of string
"""
bytes_item = item.encode("utf-8")
if bytes_item not in self.item2idx:
@@ -87,7 +89,8 @@ def get_idx_for_item(self, item: str) -> int:
Args:
item: string for which ID is requested
- Returns: ID of string, otherwise 0
+ Returns:
+ ID of string, otherwise 0
"""
item_encoded = item.encode("utf-8")
if item_encoded in self.item2idx:
@@ -101,13 +104,14 @@ def get_idx_for_item(self, item: str) -> int:
)
raise IndexError
- def get_idx_for_items(self, items: List[str]) -> List[int]:
+ def get_idx_for_items(self, items: list[str]) -> list[int]:
"""Returns the IDs for each item of the list of string, otherwise 0 if not found.
Args:
items: List of string for which IDs are requested
- Returns: List of ID of strings
+ Returns:
+ List of ID of strings
"""
if not hasattr(self, "item2idx_not_encoded"):
d = {key.decode("UTF-8"): value for key, value in self.item2idx.items()}
@@ -120,19 +124,19 @@ def get_idx_for_items(self, items: List[str]) -> List[int]:
return [results]
return list(results)
- def get_items(self) -> List[str]:
- items = []
- for item in self.idx2item:
- items.append(item.decode("UTF-8"))
- return items
+ def get_items(self) -> list[str]:
+ return [item.decode("UTF-8") for item in self.idx2item]
def __len__(self) -> int:
return len(self.idx2item)
- def get_item_for_index(self, idx):
+ def get_item_for_index(self, idx: int) -> str:
return self.idx2item[idx].decode("UTF-8")
- def set_start_stop_tags(self):
+ def has_item(self, item: str) -> bool:
+ return item.encode("utf-8") in self.item2idx
+
+ def set_start_stop_tags(self) -> None:
self.add_item("")
self.add_item("")
@@ -151,7 +155,7 @@ def save(self, savefile: PathLike):
mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx}
pickle.dump(mappings, f)
- def __setstate__(self, d: Dict) -> None:
+ def __setstate__(self, d: dict) -> None:
self.__dict__ = d
# set 'add_unk' if the dictionary was created with a version of Flair older than 0.9
if "add_unk" not in self.__dict__:
@@ -281,9 +285,9 @@ class DataPoint:
"""
def __init__(self) -> None:
- self.annotation_layers: Dict[str, List[Label]] = {}
- self._embeddings: Dict[str, torch.Tensor] = {}
- self._metadata: Dict[str, Any] = {}
+ self.annotation_layers: dict[str, list[Label]] = {}
+ self._embeddings: dict[str, torch.Tensor] = {}
+ self._metadata: dict[str, Any] = {}
@property
@abstractmethod
@@ -293,7 +297,7 @@ def embedding(self) -> torch.Tensor:
def set_embedding(self, name: str, vector: torch.Tensor):
self._embeddings[name] = vector
- def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor:
+ def get_embedding(self, names: Optional[list[str]] = None) -> torch.Tensor:
# if one embedding name, directly return it
if names and len(names) == 1:
if names[0] in self._embeddings:
@@ -308,7 +312,7 @@ def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor:
else:
return torch.tensor([], device=flair.device)
- def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> List[torch.Tensor]:
+ def get_each_embedding(self, embedding_names: Optional[list[str]] = None) -> list[torch.Tensor]:
embeddings = []
for embed_name in sorted(self._embeddings.keys()):
if embedding_names and embed_name not in embedding_names:
@@ -325,7 +329,7 @@ def to(self, device: str, pin_memory: bool = False) -> None:
else:
self._embeddings[name] = vector.to(device, non_blocking=True)
- def clear_embeddings(self, embedding_names: Optional[List[str]] = None) -> None:
+ def clear_embeddings(self, embedding_names: Optional[list[str]] = None) -> None:
if embedding_names is None:
self._embeddings = {}
else:
@@ -346,6 +350,17 @@ def has_metadata(self, key: str) -> bool:
return key in self._metadata
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata) -> "DataPoint":
+ """Adds a label to the :class:`DataPoint` by internally creating a :class:`Label` object.
+
+ Args:
+ typename: A string that identifies the layer of annotation, such as "ner" for named entity labels or "sentiment" for sentiment labels
+ value: A string that sets the value of the label.
+ score: Optional value setting the confidence level of the label (between 0 and 1). If not set, a default confidence of 1 is used.
+ **metadata: Additional metadata information.
+
+ Returns:
+ A pointer to itself (DataPoint object, now with an added label).
+ """
label = Label(self, value, score, **metadata)
if typename not in self.annotation_layers:
@@ -368,14 +383,25 @@ def get_label(self, label_type: Optional[str] = None, zero_tag_value: str = "O")
return Label(self, zero_tag_value)
return self.get_labels(label_type)[0]
- def get_labels(self, typename: Optional[str] = None) -> List[Label]:
+ def get_labels(self, typename: Optional[str] = None) -> list[Label]:
+ """Returns all labels of this datapoint belonging to a specific annotation layer.
+
+ For instance, if a data point has been labeled with `"sentiment"`-labels, you can call this function as
+ `get_labels("sentiment")` to return a list of all sentiment labels.
+
+ Args:
+ typename: The string identifier of the annotation layer, like "sentiment" or "ner".
+
+ Returns:
+ A list of :class:`Label` objects belonging to this annotation layer for this data point.
+ """
if typename is None:
return self.labels
return self.annotation_layers.get(typename, [])
@property
- def labels(self) -> List[Label]:
+ def labels(self) -> list[Label]:
all_labels = []
for key in self.annotation_layers:
all_labels.extend(self.annotation_layers[key])
@@ -447,8 +473,8 @@ def __init__(
concept_id: str,
concept_name: str,
database_name: str,
- additional_ids: Optional[List[str]] = None,
- synonyms: Optional[List[str]] = None,
+ additional_ids: Optional[list[str]] = None,
+ synonyms: Optional[list[str]] = None,
description: Optional[str] = None,
):
"""A Concept as part of a knowledgebase or ontology.
@@ -483,7 +509,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return str(self)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return {
"concept_id": self.concept_id,
"concept_name": self.concept_name,
@@ -550,8 +576,8 @@ def __init__(
self._start_position = start_position
- self._embeddings: Dict = {}
- self.tags_proba_dist: Dict[str, List[Label]] = {}
+ self._embeddings: dict[str, torch.Tensor] = {}
+ self.tags_proba_dist: dict[str, list[Label]] = {}
@property
def idx(self) -> int:
@@ -568,10 +594,10 @@ def text(self) -> str:
def unlabeled_identifier(self) -> str:
return f'Token[{self.idx - 1}]: "{self.text}"'
- def add_tags_proba_dist(self, tag_type: str, tags: List[Label]) -> None:
+ def add_tags_proba_dist(self, tag_type: str, tags: list[Label]) -> None:
self.tags_proba_dist[tag_type] = tags
- def get_tags_proba_dist(self, tag_type: str) -> List[Label]:
+ def get_tags_proba_dist(self, tag_type: str) -> list[Label]:
if tag_type in self.tags_proba_dist:
return self.tags_proba_dist[tag_type]
return []
@@ -617,7 +643,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
else:
DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata)
- def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
+ def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]:
return {
"text": self.text,
"start_pos": self.start_position,
@@ -629,7 +655,7 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
class Span(_PartOfSentence):
"""This class represents one textual span consisting of Tokens."""
- def __new__(self, tokens: List[Token]):
+ def __new__(self, tokens: list[Token]):
# check if the span already exists. If so, return it
unlabeled_identifier = self._make_unlabeled_identifier(tokens)
if unlabeled_identifier in tokens[0].sentence._known_spans:
@@ -643,7 +669,7 @@ def __new__(self, tokens: List[Token]):
tokens[0].sentence._known_spans[unlabeled_identifier] = span
return span
- def __init__(self, tokens: List[Token]) -> None:
+ def __init__(self, tokens: list[Token]) -> None:
if not self.initialized:
super().__init__(tokens[0].sentence)
self.tokens = tokens
@@ -662,7 +688,7 @@ def text(self) -> str:
return "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip()
@staticmethod
- def _make_unlabeled_identifier(tokens: List[Token]):
+ def _make_unlabeled_identifier(tokens: list[Token]):
text = "".join([t.text + t.whitespace_after * " " for t in tokens]).strip()
return f'Span[{tokens[0].idx - 1}:{tokens[-1].idx}]: "{text}"'
@@ -765,23 +791,25 @@ def to_dict(self, tag_type: Optional[str] = None):
class Sentence(DataPoint):
- """A Sentence is a list of tokens and is used to represent a sentence or text fragment."""
+ """A Sentence is a central object in Flair that represents either a single sentence or a whole text.
+
+ Internally, it consists of a list of Token objects that represent each word in the text. Additionally,
+ this object stores all metadata related to a text such as labels, language code, etc.
+ """
def __init__(
self,
- text: Union[str, List[str], List[Token]],
+ text: Union[str, list[str], list[Token]],
use_tokenizer: Union[bool, Tokenizer] = True,
language_code: Optional[str] = None,
start_position: int = 0,
) -> None:
- """Class to hold all metadata related to a text.
-
- Metadata can be tokens, labels, predictions, language code, etc.
+ """Create a sentence object by passing either a text or a list of tokens.
Args:
- text: original string (sentence), or a pre tokenized list of tokens.
- use_tokenizer: Specify a custom tokenizer to split the text into tokens. The Default is
- :class:`flair.tokenization.SegTokTokenizer`. If `use_tokenizer` is set to False,
+ text: Either pass the text as a string, or provide an already tokenized text as either a list of strings or a list of :class:`Token` objects.
+ use_tokenizer: You can optionally specify a custom tokenizer to split the text into tokens. By default we use
+ :class:`flair.tokenization.SegtokTokenizer`. If `use_tokenizer` is set to False,
:class:`flair.tokenization.SpaceTokenizer` will be used instead. The tokenizer will be ignored,
if `text` refers to pretokenized tokens.
language_code: Language of the sentence. If not provided, `langdetect `_
@@ -790,10 +818,10 @@ def __init__(
"""
super().__init__()
- self.tokens: List[Token] = []
+ self.tokens: list[Token] = []
# private field for all known spans
- self._known_spans: Dict[str, _PartOfSentence] = {}
+ self._known_spans: dict[str, _PartOfSentence] = {}
self.language_code: Optional[str] = language_code
@@ -818,7 +846,7 @@ def __init__(
self._previous_sentence: Optional[Sentence] = None
self._has_context: bool = False
self._next_sentence: Optional[Sentence] = None
- self._position_in_dataset: Optional[typing.Tuple[Dataset, int]] = None
+ self._position_in_dataset: Optional[tuple[Dataset, int]] = None
# if text is passed, instantiate sentence with tokens (words)
if isinstance(text, str):
@@ -830,7 +858,7 @@ def __init__(
self.tokens[-1].whitespace_after = 0
return
else:
- words = cast(List[str], text)
+ words = cast(list[str], text)
text = " ".join(words)
# determine token positions and whitespace_after flag
@@ -861,15 +889,15 @@ def __init__(
def unlabeled_identifier(self):
return f'Sentence[{len(self)}]: "{self.text}"'
- def get_relations(self, label_type: Optional[str] = None) -> List[Relation]:
- relations: List[Relation] = []
+ def get_relations(self, label_type: Optional[str] = None) -> list[Relation]:
+ relations: list[Relation] = []
for label in self.get_labels(label_type):
if isinstance(label.data_point, Relation):
relations.append(label.data_point)
return relations
- def get_spans(self, label_type: Optional[str] = None) -> List[Span]:
- spans: List[Span] = []
+ def get_spans(self, label_type: Optional[str] = None) -> list[Span]:
+ spans: list[Span] = []
for potential_span in self._known_spans.values():
if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)):
spans.append(potential_span)
@@ -922,16 +950,16 @@ def to(self, device: str, pin_memory: bool = False):
for token in self:
token.to(device, pin_memory)
- def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
+ def clear_embeddings(self, embedding_names: Optional[list[str]] = None):
super().clear_embeddings(embedding_names)
# clear token embeddings
for token in self:
token.clear_embeddings(embedding_names)
- def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> List[Token]:
+ def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]:
sentence = self
- left_context: List[Token] = []
+ left_context: list[Token] = []
while len(left_context) < context_length:
sentence = sentence.previous_sentence()
if sentence is None:
@@ -943,9 +971,9 @@ def left_context(self, context_length: int, respect_document_boundaries: bool =
left_context = sentence.tokens + left_context
return left_context[-context_length:]
- def right_context(self, context_length: int, respect_document_boundaries: bool = True) -> List[Token]:
+ def right_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]:
sentence = self
- right_context: List[Token] = []
+ right_context: list[Token] = []
while len(right_context) < context_length:
sentence = sentence.next_sentence()
if sentence is None:
@@ -1037,7 +1065,7 @@ def to_original_text(self) -> str:
[t.text + t.whitespace_after * " " for t in self.tokens]
).strip()
- def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
+ def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]:
return {
"text": self.to_original_text(),
"labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self],
@@ -1180,7 +1208,7 @@ def copy_context_from_sentence(self, sentence: "Sentence") -> None:
self._position_in_dataset = sentence._position_in_dataset
@classmethod
- def set_context_for_sentences(cls, sentences: List["Sentence"]) -> None:
+ def set_context_for_sentences(cls, sentences: list["Sentence"]) -> None:
previous_sentence = None
for sentence in sentences:
if sentence.is_context_set():
@@ -1261,7 +1289,7 @@ def to(self, device: str, pin_memory: bool = False):
self.first.to(device, pin_memory)
self.second.to(device, pin_memory)
- def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
+ def clear_embeddings(self, embedding_names: Optional[list[str]] = None):
self.first.clear_embeddings(embedding_names)
self.second.clear_embeddings(embedding_names)
if self.concatenated_data is not None:
@@ -1306,7 +1334,7 @@ def to(self, device: str, pin_memory: bool = False):
self.second.to(device, pin_memory)
self.third.to(device, pin_memory)
- def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
+ def clear_embeddings(self, embedding_names: Optional[list[str]] = None):
self.first.clear_embeddings(embedding_names)
self.second.clear_embeddings(embedding_names)
self.third.clear_embeddings(embedding_names)
@@ -1343,7 +1371,7 @@ def __init__(self, data=None, imageURL=None):
super().__init__()
self.data = data
- self._embeddings: Dict = {}
+ self._embeddings: dict[str, torch.Tensor] = {}
self.imageURL = imageURL
@property
@@ -1374,6 +1402,14 @@ def unlabeled_identifier(self) -> str:
class Corpus(typing.Generic[T_co]):
+ """The main object in Flair for holding a dataset used for training and testing.
+
+ A corpus consists of three splits: A `train` split used for training, a `dev` split used for model selection
+ and/or early stopping and a `test` split used for testing. All three splits are optional, so it is possible
+ to create a corpus only using one or two splits. If the option `sample_missing_splits` is set to True,
+ missing splits will be randomly sampled from the training split.
+ """
+
def __init__(
self,
train: Optional[Dataset[T_co]] = None,
@@ -1383,6 +1419,26 @@ def __init__(
sample_missing_splits: Union[bool, str] = True,
random_seed: Optional[int] = None,
) -> None:
+ """
+ Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split
+ by passing the corresponding Dataset object to the constructor. At least one split should be defined.
+ If the option `sample_missing_splits` is set to True, missing splits will be randomly sampled from the
+ train split.
+
+ In most cases, you will not use the constructor yourself. Rather, you will create a corpus using one of our
+ helper methods that read common NLP filetypes. For instance, you can use
+ :class:`flair.datasets.sequence_labeling.ColumnCorpus` to read CoNLL-formatted files directly into
+ a :class:`Corpus`.
+
+ Args:
+ train: The split you use for model training.
+ dev: A holdout split typically used for model selection or early stopping.
+ test: The final test data to compute the score of the model.
+ name: A name that identifies the corpus.
+ sample_missing_splits: If set to True, missing splits are sampled from train. If set to False,
+ missing splits are not sampled and left empty. Default: True.
+ random_seed: Set a random seed to control the sampling of missing splits.
+ """
# set name
self.name: str = name
@@ -1421,14 +1477,17 @@ def __init__(
@property
def train(self) -> Optional[Dataset[T_co]]:
+ """The training split as a :class:`torch.utils.data.Dataset` object."""
return self._train
@property
def dev(self) -> Optional[Dataset[T_co]]:
+ """The dev split as a :class:`torch.utils.data.Dataset` object."""
return self._dev
@property
def test(self) -> Optional[Dataset[T_co]]:
+ """The test split as a :class:`torch.utils.data.Dataset` object."""
return self._test
def downsample(
@@ -1439,7 +1498,23 @@ def downsample(
downsample_test: bool = True,
random_seed: Optional[int] = None,
) -> "Corpus":
- """Reduce all datasets in corpus proportionally to the given percentage."""
+ """Randomly downsample the corpus to the given percentage (by removing data points).
+
+ This method is an in-place operation, meaning that the Corpus object itself is modified by removing
+ data points. It additionally returns a pointer to itself for use in method chaining.
+
+ Args:
+ percentage: A float value between 0. and 1. that indicates to which percentage the corpus
+ should be downsampled. Default value is 0.1, meaning it gets downsampled to 10%.
+ downsample_train: Whether or not to include the training split in downsampling. Default is True.
+ downsample_dev: Whether or not to include the dev split in downsampling. Default is True.
+ downsample_test: Whether or not to include the test split in downsampling. Default is True.
+ random_seed: An optional random seed to make downsampling reproducible.
+
+ Returns:
+ A pointer to itself for optional use in method chaining.
+ """
+
if downsample_train and self._train is not None:
self._train = self._downsample_to_proportion(self._train, percentage, random_seed)
@@ -1452,6 +1527,10 @@ def downsample(
return self
def filter_empty_sentences(self):
+ """A method that filters all sentences consisting of 0 tokens.
+
+ This is an in-place operation that directly modifies the Corpus object itself by removing these sentences.
+ """
log.info("Filtering empty sentences")
if self._train is not None:
self._train = Corpus._filter_empty_sentences(self._train)
@@ -1462,6 +1541,15 @@ def filter_empty_sentences(self):
log.info(self)
def filter_long_sentences(self, max_charlength: int):
+ """
+ A method that filters all sentences for which the plain text is longer than a specified number of characters.
+
+ This is an in-place operation that directly modifies the Corpus object itself by removing these sentences.
+
+ Args:
+ max_charlength: The maximum permissible character length of a sentence.
+
+ """
log.info("Filtering long sentences")
if self._train is not None:
self._train = Corpus._filter_long_sentences(self._train, max_charlength)
@@ -1506,7 +1594,7 @@ def _filter_empty_sentences(dataset) -> Dataset:
return subset
def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dictionary:
- """Creates a dictionary of all tokens contained in the corpus.
+ """Creates a :class:`Dictionary` of all tokens contained in the corpus.
By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary.
If there are more than `max_tokens` tokens in the corpus, the most frequent tokens are added first.
@@ -1514,10 +1602,13 @@ def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dict
to be added to the dictionary.
Args:
- max_tokens: the maximum number of tokens that should be added to the dictionary (-1 = take all tokens)
- min_freq: a token needs to occur at least `min_freq` times to be added to the dictionary (-1 = there is no limitation)
+ max_tokens: The maximum number of tokens that should be added to the dictionary (providing a value of "-1"
+ means that there is no maximum in this regard).
+ min_freq: A token needs to occur at least `min_freq` times to be added to the dictionary (providing a value
+ of "-1" means that there is no limitation in this regard).
- Returns: dictionary of tokens
+ Returns:
+ A :class:`Dictionary` of all unique tokens in the corpus.
"""
tokens = self._get_most_common_tokens(max_tokens, min_freq)
@@ -1527,17 +1618,17 @@ def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dict
return vocab_dictionary
- def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> List[str]:
+ def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> list[str]:
tokens_and_frequencies = Counter(self._get_all_tokens())
- tokens: List[str] = []
+ tokens: list[str] = []
for token, freq in tokens_and_frequencies.most_common():
if (min_freq != -1 and freq < min_freq) or (max_tokens != -1 and len(tokens) == max_tokens):
break
tokens.append(token)
return tokens
- def _get_all_tokens(self) -> List[str]:
+ def _get_all_tokens(self) -> list[str]:
assert self.train
tokens = [s.tokens for s in _iter_dataset(self.train)]
tokens = [token for sublist in tokens for token in sublist]
@@ -1550,9 +1641,17 @@ def _downsample_to_proportion(dataset: Dataset, proportion: float, random_seed:
return splits[0]
def obtain_statistics(self, label_type: Optional[str] = None, pretty_print: bool = True) -> Union[dict, str]:
- """Print statistics about the class distribution and sentence sizes.
+ """Print statistics about the corpus, including the length of the sentences and the labels in the corpus.
- only labels of sentences are taken into account
+ Args:
+ label_type: Optionally set this value to obtain statistics only for one specific type of label (such
+ as "ner" or "pos"). If not set, statistics for all labels will be returned.
+ pretty_print: If set to True, returns pretty json (indented for readabilty). If not, the json is
+ returned as a single line. Default: True.
+
+ Returns:
+ If pretty_print is True, returns a pretty print formatted string in json format. Otherwise, returns a
+ dictionary holding a json.
"""
json_data = {
"TRAIN": self._obtain_statistics_for(self.train, "TRAIN", label_type),
@@ -1574,13 +1673,8 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict:
tags_to_count = Corpus._count_token_labels(sentences, tag_type)
tokens_per_sentence = Corpus._get_tokens_per_sentence(sentences)
- label_size_dict = {}
- for label, c in classes_to_count.items():
- label_size_dict[label] = c
-
- tag_size_dict = {}
- for tag, c in tags_to_count.items():
- tag_size_dict[tag] = c
+ label_size_dict = dict(classes_to_count)
+ tag_size_dict = dict(tags_to_count)
return {
"dataset": name,
@@ -1596,20 +1690,20 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict:
}
@staticmethod
- def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> List[int]:
+ def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> list[int]:
return [len(x.tokens) for x in sentences]
@staticmethod
- def _count_sentence_labels(sentences: Iterable[Sentence]) -> DefaultDict[str, int]:
- label_count: DefaultDict[str, int] = defaultdict(lambda: 0)
+ def _count_sentence_labels(sentences: Iterable[Sentence]) -> defaultdict[str, int]:
+ label_count: defaultdict[str, int] = defaultdict(lambda: 0)
for sent in sentences:
for label in sent.labels:
label_count[label.value] += 1
return label_count
@staticmethod
- def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> DefaultDict[str, int]:
- label_count: DefaultDict[str, int] = defaultdict(lambda: 0)
+ def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> defaultdict[str, int]:
+ label_count: defaultdict[str, int] = defaultdict(lambda: 0)
for sent in sentences:
for token in sent.tokens:
if label_type in token.annotation_layers:
@@ -1629,7 +1723,21 @@ def make_label_dictionary(
) -> Dictionary:
"""Creates a dictionary of all labels assigned to the sentences in the corpus.
- :return: dictionary of labels
+ Args:
+ label_type: The name of the label type for which the dictionary should be created. Some corpora have
+ multiple layers of annotation, such as "pos" and "ner". In this case, you should choose the label type
+ you are interested in.
+ min_count: Optionally set this to exclude rare labels from the dictionary (i.e., labels seen fewer
+ than the provided integer value).
+ add_unk: Optionally set this to True to include a "UNK" value in the dictionary. In most cases, this
+ is not needed since the label dictionary is well-defined, but some use cases might have open classes
+ and require this.
+ add_dev_test: Optionally set this to True to construct the label dictionary not only from the train
+ split, but also from dev and test. This is only necessary if some labels never appear in train but do
+ appear in one of the other splits.
+
+ Returns:
+ A Dictionary of all unique labels in the corpus.
"""
if min_count > 0 and not add_unk:
add_unk = True
@@ -1653,7 +1761,7 @@ def make_label_dictionary(
sentence_label_type_counter: typing.Counter[str] = Counter()
label_value_counter: typing.Counter[str] = Counter()
- all_sentence_labels: List[str] = []
+ all_sentence_labels: list[str] = []
# first, determine the datapoint type by going through dataset until first label is found
datapoint_type = None
@@ -1693,7 +1801,7 @@ def make_label_dictionary(
unked_count += count
if len(label_dictionary.idx2item) == 0 or (
- len(label_dictionary.idx2item) == 1 and "" in label_dictionary.get_items()
+ len(label_dictionary.idx2item) == 1 and label_dictionary.has_item("")
):
log.error(f"ERROR: You specified label_type='{label_type}' which is not in this dataset!")
contained_labels = ", ".join(
@@ -1717,10 +1825,10 @@ def make_label_dictionary(
def add_label_noise(
self,
label_type: str,
- labels: List[str],
+ labels: list[str],
noise_share: float = 0.2,
split: str = "train",
- noise_transition_matrix: Optional[Dict[str, List[float]]] = None,
+ noise_transition_matrix: Optional[dict[str, list[float]]] = None,
):
"""Generates uniform label noise distribution in the chosen dataset split.
@@ -1808,6 +1916,13 @@ def add_label_noise(
)
def get_label_distribution(self):
+ """Counts occurrences of each label in the corpus and returns them as a dictionary object.
+
+ This allows you to get an idea of which label appears how often in the Corpus.
+
+ Returns:
+ Dictionary with labels as keys and their occurrences as values.
+ """
class_to_count = defaultdict(lambda: 0)
for sent in self.train:
for label in sent.labels:
@@ -1815,6 +1930,11 @@ def get_label_distribution(self):
return class_to_count
def get_all_sentences(self) -> ConcatDataset:
+ """Returns all sentences (spanning all three splits) in the :class:`Corpus`.
+
+ Returns:
+ A :class:`torch.utils.data.Dataset` object that includes all sentences of this corpus.
+ """
parts = []
if self.train:
parts.append(self.train)
@@ -1831,7 +1951,8 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary:
Args:
tag_type: the label type to gather the tag labels
- Returns: A Dictionary containing the labeled tags, including "O" and "" and ""
+ Returns:
+ A Dictionary containing the labeled tags, including "O" and "" and ""
"""
tag_dictionary: Dictionary = Dictionary(add_unk=False)
@@ -1847,12 +1968,12 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary:
class MultiCorpus(Corpus):
def __init__(
self,
- corpora: List[Corpus],
- task_ids: Optional[List[str]] = None,
+ corpora: list[Corpus],
+ task_ids: Optional[list[str]] = None,
name: str = "multicorpus",
**corpusargs,
) -> None:
- self.corpora: List[Corpus] = corpora
+ self.corpora: list[Corpus] = corpora
ids = task_ids if task_ids else [f"Task_{i}" for i in range(len(corpora))]
@@ -1901,8 +2022,8 @@ class ConcatFlairDataset(Dataset):
datasets (sequence): List of datasets to be concatenated
"""
- datasets: List[Dataset]
- cumulative_sizes: List[int]
+ datasets: list[Dataset]
+ cumulative_sizes: list[int]
@staticmethod
def cumsum(sequence):
@@ -1937,36 +2058,13 @@ def __getitem__(self, idx: int) -> Sentence:
return sentence
@property
- def cummulative_sizes(self) -> List[int]:
+ def cummulative_sizes(self) -> list[int]:
return self.cumulative_sizes
-def iob2(tags: List) -> bool:
- """Converts the tags to the IOB2 format.
-
- Check that tags have a valid IOB format.
- Tags in IOB1 format are converted to IOB2.
- """
- for i, tag in enumerate(tags):
- if tag.value == "O":
- continue
- split = tag.value.split("-")
- if len(split) != 2 or split[0] not in ["I", "B"]:
- return False
- if split[0] == "B":
- continue
- elif i == 0 or tags[i - 1].value == "O": # conversion IOB1 to IOB2
- tags[i].value = "B" + tag.value[1:]
- elif tags[i - 1].value[1:] == tag.value[1:]:
- continue
- else: # conversion IOB1 to IOB2
- tags[i].value = "B" + tag.value[1:]
- return True
-
-
def randomly_split_into_two_datasets(
dataset: Dataset, length_of_first: int, random_seed: Optional[int] = None
-) -> Tuple[Subset, Subset]:
+) -> tuple[Subset, Subset]:
"""Shuffles a dataset and splits into two subsets.
The length of the first is specified and the remaining samples go into the second subset.
@@ -1989,17 +2087,17 @@ def randomly_split_into_two_datasets(
def get_spans_from_bio(
- bioes_tags: List[str], bioes_scores: Optional[List[float]] = None
-) -> List[typing.Tuple[List[int], float, str]]:
+ bioes_tags: list[str], bioes_scores: Optional[list[float]] = None
+) -> list[tuple[list[int], float, str]]:
# add a dummy "O" to close final prediction
bioes_tags.append("O")
# return complex list
found_spans = []
# internal variables
- current_tag_weights: Dict[str, float] = {}
+ current_tag_weights: dict[str, float] = {}
previous_tag = "O-"
- current_span: List[int] = []
- current_span_scores: List[float] = []
+ current_span: list[int] = []
+ current_span_scores: list[float] = []
for idx, bioes_tag in enumerate(bioes_tags):
# non-set tags are OUT tags
if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_":
diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py
index 2837e017c0..d54ff35e01 100644
--- a/flair/datasets/__init__.py
+++ b/flair/datasets/__init__.py
@@ -171,6 +171,7 @@
# Expose all sequence labeling datasets
from .sequence_labeling import (
BIOSCOPE,
+ CLEANCONLL,
CONLL_03,
CONLL_03_DUTCH,
CONLL_03_GERMAN,
@@ -216,6 +217,7 @@
NER_MULTI_WIKINER,
NER_MULTI_XTREME,
NER_NERMUD,
+ NER_NOISEBENCH,
NER_SWEDISH,
NER_TURKU,
NER_UKRAINIAN,
@@ -465,6 +467,7 @@
"CONLL_03_DUTCH",
"CONLL_03_GERMAN",
"CONLL_03_SPANISH",
+ "CLEANCONLL",
"CONLL_2000",
"FEWNERD",
"KEYPHRASE_INSPEC",
@@ -494,6 +497,7 @@
"NER_GERMAN_MOBIE",
"NER_GERMAN_POLITICS",
"NER_HIPE_2022",
+ "NER_NOISEBENCH",
"NER_HUNGARIAN",
"NER_ICDAR_EUROPEANA",
"NER_ICELANDIC",
diff --git a/flair/datasets/base.py b/flair/datasets/base.py
index a38d0b1321..0737b4660d 100644
--- a/flair/datasets/base.py
+++ b/flair/datasets/base.py
@@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from pathlib import Path
-from typing import Generic, List, Optional, Union
+from typing import Generic, Optional, Union
import torch.utils.data.dataloader
from deprecated.sphinx import deprecated
@@ -41,7 +41,7 @@ def __init__(
class FlairDatapointDataset(FlairDataset, Generic[DT]):
"""A simple Dataset object to wrap a List of Datapoints, for example Sentences."""
- def __init__(self, datapoints: Union[DT, List[DT]]) -> None:
+ def __init__(self, datapoints: Union[DT, list[DT]]) -> None:
"""Instantiate FlairDatapointDataset.
Args:
@@ -64,7 +64,7 @@ def __getitem__(self, index: int = 0) -> DT:
class SentenceDataset(FlairDatapointDataset):
@deprecated(version="0.11", reason="The 'SentenceDataset' class was renamed to 'FlairDatapointDataset'")
- def __init__(self, sentences: Union[Sentence, List[Sentence]]) -> None:
+ def __init__(self, sentences: Union[Sentence, list[Sentence]]) -> None:
super().__init__(sentences)
@@ -73,7 +73,7 @@ class StringDataset(FlairDataset):
def __init__(
self,
- texts: Union[str, List[str]],
+ texts: Union[str, list[str]],
use_tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
) -> None:
"""Instantiate StringDataset.
@@ -111,7 +111,7 @@ def __init__(
database: str,
collection: str,
text_field: str,
- categories_field: Optional[List[str]] = None,
+ categories_field: Optional[list[str]] = None,
max_tokens_per_doc: int = -1,
max_chars_per_doc: int = -1,
tokenizer: Tokenizer = SegtokTokenizer(),
@@ -195,7 +195,7 @@ def __init__(
def _parse_document_to_sentence(
self,
text: str,
- labels: List[str],
+ labels: list[str],
tokenizer: Union[bool, Tokenizer],
):
if self.max_chars_per_doc > 0:
diff --git a/flair/datasets/biomedical.py b/flair/datasets/biomedical.py
index 28f4aca98b..e99a71ccf7 100644
--- a/flair/datasets/biomedical.py
+++ b/flair/datasets/biomedical.py
@@ -7,6 +7,7 @@
import sys
from abc import ABC, abstractmethod
from collections import defaultdict, deque
+from collections.abc import Iterable
from copy import copy
from operator import attrgetter
from pathlib import Path
@@ -18,7 +19,7 @@
StreamError,
TarError,
)
-from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
+from typing import NamedTuple, Optional, Union
from zipfile import BadZipFile, LargeZipFile
import ftfy
@@ -56,7 +57,7 @@ class Entity:
text as well as the type of entity (e.g. Chemical, Gene, and so on).
"""
- def __init__(self, char_span: Tuple[int, int], entity_type: str) -> None:
+ def __init__(self, char_span: tuple[int, int], entity_type: str) -> None:
assert char_span[0] < char_span[1]
self.char_span = range(*char_span)
self.type = entity_type
@@ -98,9 +99,9 @@ class InternalBioNerDataset:
def __init__(
self,
- documents: Dict[str, str],
- entities_per_document: Dict[str, List[Entity]],
- entity_types: List[str] = [],
+ documents: dict[str, str],
+ entities_per_document: dict[str, list[Entity]],
+ entity_types: list[str] = [],
):
self.documents = documents
self.entities_per_document = entities_per_document
@@ -134,7 +135,7 @@ def merge_datasets(data_sets: Iterable[InternalBioNerDataset]):
def filter_and_map_entities(
- dataset: InternalBioNerDataset, entity_type_to_canonical: Dict[str, str]
+ dataset: InternalBioNerDataset, entity_type_to_canonical: dict[str, str]
) -> InternalBioNerDataset:
mapped_entities_per_document = {}
entity_types = list(entity_type_to_canonical.values())
@@ -223,7 +224,7 @@ def bioc_to_internal(bioc_file: Path):
for document in Tqdm.tqdm(documents, desc="Converting to internal"):
document_id = document.xpath("./id")[0].text
- texts: List[str] = []
+ texts: list[str] = []
entities = []
for passage in document.xpath("passage"):
@@ -358,7 +359,7 @@ def __init__(
"""
self.sentence_splitter = sentence_splitter
- def process_dataset(self, datasets: Dict[str, InternalBioNerDataset], out_dir: Path):
+ def process_dataset(self, datasets: dict[str, InternalBioNerDataset], out_dir: Path):
if "train" in datasets:
self.write_to_conll(datasets["train"], out_dir / (self.sentence_splitter.name + "_train.conll"))
if "dev" in datasets:
@@ -450,7 +451,7 @@ def to_internal(self, data_folder: Path) -> InternalBioNerDataset:
@staticmethod
@abstractmethod
- def split_url() -> Union[str, List[str]]:
+ def split_url() -> Union[str, list[str]]:
raise NotImplementedError
def get_corpus_sentence_splitter(self) -> Optional[SentenceSplitter]:
@@ -596,8 +597,8 @@ def download_dataset(cls, data_dir: Path) -> Path:
@classmethod
def parse_dataset(cls, original_file: Path):
- documents: Dict[str, str] = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ documents: dict[str, str] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
tree = etree.parse(str(original_file))
sentence_elems = tree.xpath("//sentence")
@@ -647,7 +648,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return merge_datasets([train_data, test_data])
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -726,14 +727,14 @@ def download_and_prepare_test(cls, data_folder: Path, sentence_tag: str) -> Inte
@classmethod
def read_file(cls, input_iob_file: Path, sentence_tag: str) -> InternalBioNerDataset:
- documents: Dict[str, str] = {}
- entities_per_document: Dict[str, List[Entity]] = defaultdict(list)
+ documents: dict[str, str] = {}
+ entities_per_document: dict[str, list[Entity]] = defaultdict(list)
with open(str(input_iob_file), encoding="utf8") as file_reader:
document_id: Optional[str] = None
document_text: Optional[str] = None
- entities: List[Entity] = []
+ entities: list[Entity] = []
entity_type: Optional[str] = None
entity_start = 0
@@ -818,7 +819,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return merge_datasets([train_data, test_data])
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -994,7 +995,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
- def split_url() -> List[str]:
+ def split_url() -> list[str]:
split_urls = [
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/cellfinder_cellline",
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/cellfinder_species",
@@ -1009,7 +1010,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -1176,7 +1177,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return merge_datasets([train_data, test_data])
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -1566,7 +1567,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(dataset, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -1747,8 +1748,8 @@ def download_dataset(data_dir: Path):
@classmethod
def parse_dataset(cls, original_file: Path):
- documents: Dict[str, str] = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ documents: dict[str, str] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
tree = etree.parse(str(original_file))
document_elems = tree.xpath("//document")
@@ -1905,7 +1906,7 @@ def split_url() -> str:
def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return LINNEAUS.download_and_parse_dataset(data_dir)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -1995,7 +1996,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2021,7 +2022,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2033,7 +2034,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
- def split_url() -> List[str]:
+ def split_url() -> list[str]:
split_urls = [
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/CDRDisease",
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/CDRChem",
@@ -2052,7 +2053,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2167,7 +2168,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2190,7 +2191,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2213,7 +2214,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2230,7 +2231,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
- def split_url() -> List[str]:
+ def split_url() -> list[str]:
split_urls = [
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/variome_gene",
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/variome_disease",
@@ -2247,7 +2248,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return all_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2343,7 +2344,7 @@ def parse_input_file(input_file: Path):
with open(str(input_file), encoding="utf8") as file:
document_id = ""
document_text = ""
- entities: List[Entity] = []
+ entities: list[Entity] = []
c = 1
for line in file:
@@ -2406,7 +2407,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return merge_datasets([train_data, dev_data, test_data])
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2455,13 +2456,13 @@ def download_corpus(self, data_folder: Path) -> Path:
@staticmethod
def parse_input_file(input_file: Path):
- documents: Dict[str, str] = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ documents: dict[str, str] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
with open(str(input_file), encoding="iso-8859-1") as file:
document_id = None
document_text = ""
- entities: List[Entity] = []
+ entities: list[Entity] = []
entity_type = None
entity_start = 0
@@ -2584,7 +2585,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2605,7 +2606,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2628,7 +2629,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
- def split_url() -> List[str]:
+ def split_url() -> list[str]:
split_urls = [
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/scai_chemicals",
"https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/scai_disease",
@@ -2641,7 +2642,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2763,7 +2764,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2863,7 +2864,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -2945,7 +2946,7 @@ def download_dev_corpus(cls, data_dir) -> Path:
@staticmethod
def parse_input_file(text_file: Path, ann_file: Path) -> InternalBioNerDataset:
documents = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
document_title_length = {}
@@ -3010,7 +3011,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return merge_datasets([train_data, dev_data])
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -3071,8 +3072,8 @@ def download_corpus(cls, data_dir: Path) -> Path:
@staticmethod
def parse_corpus(text_dir: Path, gold_file: Path) -> InternalBioNerDataset:
- documents: Dict[str, str] = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ documents: dict[str, str] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
text_files = [file for file in os.listdir(str(text_dir)) if not file.startswith(".")]
@@ -3122,7 +3123,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return DECA.parse_corpus(text_dir, gold_file)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -3221,7 +3222,7 @@ def parse_corpus(corpus_dir: Path, sentence_separator: str) -> InternalBioNerDat
akt_pos += len(words[i]) + 1
sentences += [tmp_sentence]
- pre_entities: List[List[Tuple[int, int, str]]] = [[] for _ in sentences]
+ pre_entities: list[list[tuple[int, int, str]]] = [[] for _ in sentences]
for protein in protein_tree:
for span in protein.get("span").split(","):
start = word_to_id[span.split("..")[0]]
@@ -3287,7 +3288,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -3450,8 +3451,8 @@ def parse_dataset(data_dir: Path) -> InternalBioNerDataset:
]
text_files = sorted(text_files)
- documents: Dict[str, str] = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ documents: dict[str, str] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
for text_file in sorted(text_files):
document_id = os.path.basename(text_file).split("_")[0]
@@ -3590,7 +3591,7 @@ def parse_test_dataset(cls, data_folder: Path) -> InternalBioNerDataset:
@staticmethod
def parse_dataset(text_file: Path, ann_file: Path) -> InternalBioNerDataset:
documents = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
with open(str(text_file), encoding="utf8") as text_file_reader:
for line in text_file_reader:
@@ -3733,7 +3734,7 @@ def download_dev_corpus(cls, data_dir) -> Path:
@staticmethod
def parse_input_file(text_file: Path, ann_file: Path) -> InternalBioNerDataset:
documents = {}
- entities_per_document: Dict[str, List[Entity]] = {}
+ entities_per_document: dict[str, list[Entity]] = {}
document_abstract_length = {}
with open(str(text_file), encoding="utf8") as text_reader:
@@ -3806,7 +3807,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
dataset = merge_datasets([train_data, dev_data])
return filter_and_map_entities(dataset, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -3945,7 +3946,7 @@ def to_internal(self, data_dir: Path, annotator: int = 0) -> InternalBioNerDatas
dataset = CHEBI.parse_dataset(corpus_dir, annotator=annotator)
return filter_and_map_entities(dataset, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -4038,7 +4039,7 @@ def __init__(
@staticmethod
@abstractmethod
- def download_corpus(data_folder: Path) -> Tuple[Path, Path, Path]:
+ def download_corpus(data_folder: Path) -> tuple[Path, Path, Path]:
pass
@staticmethod
@@ -4083,7 +4084,7 @@ class BIONLP2013_PC(BioNLPCorpus):
"""
@staticmethod
- def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]:
+ def download_corpus(download_folder: Path) -> tuple[Path, Path, Path]:
train_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_training_data.tar.gz"
dev_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_development_data.tar.gz"
test_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_test_data.tar.gz"
@@ -4125,7 +4126,7 @@ class BIONLP2013_CG(BioNLPCorpus):
"""
@staticmethod
- def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]:
+ def download_corpus(download_folder: Path) -> tuple[Path, Path, Path]:
url = "https://github.com/openbiocorpora/bionlp-st-2013-cg/archive/refs/heads/master.zip"
cached_path(url, download_folder)
@@ -4292,9 +4293,10 @@ def download_corpora(download_dir: Path):
@staticmethod
def convert_and_write(download_folder, data_folder, tag_type):
data_folder.mkdir(parents=True, exist_ok=True)
- with (download_folder / "train.tsv").open(encoding="utf8") as f_in, (data_folder / "train.conll").open(
- "w", encoding="utf8"
- ) as f_out:
+ with (
+ (download_folder / "train.tsv").open(encoding="utf8") as f_in,
+ (data_folder / "train.conll").open("w", encoding="utf8") as f_out,
+ ):
for line in f_in:
if not line.strip():
f_out.write("\n")
@@ -4305,9 +4307,10 @@ def convert_and_write(download_folder, data_folder, tag_type):
tag = tag + "-" + tag_type
f_out.write(f"{token} {tag}\n")
- with (download_folder / "devel.tsv").open(encoding="utf8") as f_in, (data_folder / "dev.conll").open(
- "w", encoding="utf8"
- ) as f_out:
+ with (
+ (download_folder / "devel.tsv").open(encoding="utf8") as f_in,
+ (data_folder / "dev.conll").open("w", encoding="utf8") as f_out,
+ ):
for line in f_in:
if not line.strip():
f_out.write("\n")
@@ -4317,9 +4320,10 @@ def convert_and_write(download_folder, data_folder, tag_type):
tag = tag + "-" + tag_type
f_out.write(f"{token} {tag}\n")
- with (download_folder / "test.tsv").open(encoding="utf8") as f_in, (data_folder / "test.conll").open(
- "w", encoding="utf8"
- ) as f_out:
+ with (
+ (download_folder / "test.tsv").open(encoding="utf8") as f_in,
+ (data_folder / "test.conll").open("w", encoding="utf8") as f_out,
+ ):
for line in f_in:
if not line.strip():
f_out.write("\n")
@@ -4638,7 +4642,7 @@ def download_corpus(cls, data_dir: Path) -> Path:
@staticmethod
def prepare_splits(
data_dir: Path, corpus: InternalBioNerDataset
- ) -> Tuple[InternalBioNerDataset, InternalBioNerDataset, InternalBioNerDataset]:
+ ) -> tuple[InternalBioNerDataset, InternalBioNerDataset, InternalBioNerDataset]:
splits_dir = data_dir / "splits"
os.makedirs(str(splits_dir), exist_ok=True)
@@ -4734,7 +4738,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -4792,7 +4796,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return filter_and_map_entities(corpus, self.entity_type_mapping)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -4896,7 +4900,7 @@ def parse_corpus(input_file: Path) -> InternalBioNerDataset:
prev_sentence_id: Optional[str] = None
document_text: Optional[str] = None
- entities: List[Entity] = []
+ entities: list[Entity] = []
offset: Optional[int] = None
for line in azdz_reader:
@@ -5014,7 +5018,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset:
return corpus_data
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return self.entity_type_mapping
@@ -5221,7 +5225,7 @@ def __init__(
sample_missing_splits=True,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
"""Return the mapping of entity type given in the dataset to canonical types.
Note, if a entity type is not present in the map it is discarded.
@@ -5279,8 +5283,8 @@ def build_corpus_directory_name(self, dataset_name: str) -> str:
def to_internal_dataset(self, dataset, split: str) -> InternalBioNerDataset:
"""Converts a dataset given in hugging datasets format to our internal corpus representation."""
- id_to_text: Dict[str, str] = {}
- id_to_entities: Dict[str, list] = {}
+ id_to_text: dict[str, str] = {}
+ id_to_entities: dict[str, list] = {}
entity_type_set = set()
for document in dataset[split]:
document_id = document["document_id"]
@@ -5331,10 +5335,10 @@ def to_internal_dataset(self, dataset, split: str) -> InternalBioNerDataset:
def bin_search_passage(
self,
- passages: List[Tuple[str, List[Tuple[int, int]]]],
+ passages: list[tuple[str, list[tuple[int, int]]]],
low: int,
high: int,
- entity: Dict,
+ entity: dict,
):
"""Helper methods to find the passage to a given entity mention (incl. offset).
@@ -5381,7 +5385,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {
"Gene": GENE_TAG,
"GENERIF": GENE_TAG,
@@ -5414,7 +5418,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"GENE-N": GENE_TAG, "GENE-Y": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5441,7 +5445,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"CHEMICAL": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5452,7 +5456,7 @@ class HUNER_ALL_DRUGPROT(BIGBIO_NER_CORPUS):
def __init__(self, *args, **kwargs):
super().__init__(*args, dataset_name="drugprot", **kwargs)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"GENE-N": GENE_TAG, "GENE-Y": GENE_TAG, "CHEMICAL": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5479,7 +5483,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"GeneOrGeneProduct": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5506,7 +5510,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"ChemicalEntity": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5533,7 +5537,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"DiseaseOrPhenotypicFeature": DISEASE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5560,7 +5564,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"OrganismTaxon": SPECIES_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5587,7 +5591,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"CellLine": CELL_LINE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5598,7 +5602,7 @@ class HUNER_ALL_BIORED(BIGBIO_NER_CORPUS):
def __init__(self, *args, **kwargs):
super().__init__(*args, dataset_name="biored", **kwargs)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {
"GeneOrGeneProduct": GENE_TAG,
"ChemicalEntity": CHEMICAL_TAG,
@@ -5631,7 +5635,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5658,7 +5662,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"compound": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5669,7 +5673,7 @@ class HUNER_ALL_CPI(BIGBIO_NER_CORPUS):
def __init__(self, *args, **kwargs):
super().__init__(*args, dataset_name="cpi", **kwargs)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"protein": GENE_TAG, "compound": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5696,7 +5700,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Gene_or_gene_product": GENE_TAG, "Complex": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5723,7 +5727,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Simple_chemical": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5734,7 +5738,7 @@ class HUNER_ALL_BIONLP_ST_2013_PC(BIGBIO_NER_CORPUS):
def __init__(self, *args, **kwargs):
super().__init__(*args, dataset_name="bionlp_st_2013_pc", **kwargs)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {
"Gene_or_gene_product": GENE_TAG,
"Complex": GENE_TAG,
@@ -5765,7 +5769,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5792,7 +5796,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5819,7 +5823,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5846,7 +5850,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Chemical": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5873,7 +5877,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Organism": SPECIES_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5884,7 +5888,7 @@ class HUNER_ALL_BIONLP_ST_2011_ID(BIGBIO_NER_CORPUS):
def __init__(self, *args, **kwargs):
super().__init__(*args, dataset_name="bionlp_st_2011_id", **kwargs)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {
"Protein": GENE_TAG,
"Chemical": CHEMICAL_TAG,
@@ -5915,7 +5919,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5942,7 +5946,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5969,7 +5973,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Microorganism": SPECIES_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -5996,7 +6000,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"gene": GENE_TAG, "protein": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6023,7 +6027,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"chemical": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6050,7 +6054,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"species": SPECIES_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6077,7 +6081,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
# TODO whether cell or cell line is the correct tag
return {"cellline": CELL_LINE_TAG}
@@ -6089,7 +6093,7 @@ class HUNER_ALL_BIOID(BIGBIO_NER_CORPUS):
def __init__(self, *args, **kwargs):
super().__init__(*args, dataset_name="bioid", **kwargs)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
# TODO whether cell or cell line is the correct tag
return {
"gene": GENE_TAG,
@@ -6123,7 +6127,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Gene": GENE_TAG, "FamilyName": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6155,7 +6159,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"progene_text": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6182,7 +6186,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Chemical": CHEMICAL_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6209,7 +6213,7 @@ def __init__(
test_split_name=test_split_name,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Gene": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
@@ -6224,7 +6228,7 @@ def __init__(self, *args, **kwargs):
**kwargs,
)
- def get_entity_type_mapping(self) -> Optional[Dict]:
+ def get_entity_type_mapping(self) -> Optional[dict]:
return {"Gene": GENE_TAG}
def build_corpus_directory_name(self, dataset_name: str) -> str:
diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py
index 0bbc471818..363c84e561 100644
--- a/flair/datasets/document_classification.py
+++ b/flair/datasets/document_classification.py
@@ -2,8 +2,9 @@
import json
import logging
import os
+import tarfile
from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import Optional, Union
import flair
from flair.data import (
@@ -36,8 +37,8 @@ def __init__(
filter_if_longer_than: int = -1,
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
- label_name_map: Optional[Dict[str, str]] = None,
- skip_labels: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ skip_labels: Optional[list[str]] = None,
allow_examples_without_labels=False,
sample_missing_splits: bool = True,
encoding: str = "utf-8",
@@ -131,8 +132,8 @@ def __init__(
filter_if_longer_than: int = -1,
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
- label_name_map: Optional[Dict[str, str]] = None,
- skip_labels: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ skip_labels: Optional[list[str]] = None,
allow_examples_without_labels=False,
encoding: str = "utf-8",
) -> None:
@@ -277,11 +278,7 @@ def _parse_line_to_sentence(self, line: str, label_prefix: str, tokenizer: Union
return None
def is_in_memory(self) -> bool:
- if self.memory_mode == "disk":
- return False
- if self.memory_mode == "partial":
- return False
- return True
+ return self.memory_mode not in ["disk", "partial"]
def __len__(self) -> int:
return self.total_sentence_count
@@ -309,7 +306,7 @@ class CSVClassificationCorpus(Corpus):
def __init__(
self,
data_folder: Union[str, Path],
- column_name_map: Dict[int, str],
+ column_name_map: dict[int, str],
label_type: str,
name: str = "csv_corpus",
train_file=None,
@@ -404,7 +401,7 @@ class CSVClassificationDataset(FlairDataset):
def __init__(
self,
path_to_file: Union[str, Path],
- column_name_map: Dict[int, str],
+ column_name_map: dict[int, str],
label_type: str,
max_tokens_per_doc: int = -1,
max_chars_per_doc: int = -1,
@@ -453,8 +450,8 @@ def __init__(
self.total_sentence_count: int = 0
# most data sets have the token text in the first column, if not, pass 'text' as column
- self.text_columns: List[int] = []
- self.pair_columns: List[int] = []
+ self.text_columns: list[int] = []
+ self.pair_columns: list[int] = []
for column in column_name_map:
if column_name_map[column] == "text":
self.text_columns.append(column)
@@ -567,7 +564,7 @@ class AMAZON_REVIEWS(ClassificationCorpus):
def __init__(
self,
split_max: int = 30000,
- label_name_map: Dict[str, str] = {
+ label_name_map: dict[str, str] = {
"1.0": "NEGATIVE",
"2.0": "NEGATIVE",
"3.0": "NEGATIVE",
@@ -955,9 +952,10 @@ def __init__(
original_filenames = original_filenames[:-1]
if not data_file.is_file():
for original_filename, new_filename in zip(original_filenames, new_filenames):
- with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open(
- data_folder / new_filename, "w", encoding="utf-8"
- ) as write_fp:
+ with (
+ open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp,
+ open(data_folder / new_filename, "w", encoding="utf-8") as write_fp,
+ ):
csv_reader = csv.reader(
open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True
)
@@ -1048,9 +1046,10 @@ def __init__(
label_list.append(labels[int(line) - 1])
# handle data file
- with (data_path / "original" / "title_StackOverflow.txt").open(encoding="latin1") as open_fp, (
- data_folder / "train.txt"
- ).open("w", encoding="utf-8") as write_fp:
+ with (
+ (data_path / "original" / "title_StackOverflow.txt").open(encoding="latin1") as open_fp,
+ (data_folder / "train.txt").open("w", encoding="utf-8") as write_fp,
+ ):
for idx, line in enumerate(open_fp):
line = line.rstrip()
@@ -1104,9 +1103,10 @@ def __init__(
os.makedirs(data_folder)
# create train.txt file from CSV
- with open(data_folder / "train.txt", "w") as train_file, open(
- senteval_folder / "training.1600000.processed.noemoticon.csv", encoding="latin-1"
- ) as csv_train:
+ with (
+ open(data_folder / "train.txt", "w") as train_file,
+ open(senteval_folder / "training.1600000.processed.noemoticon.csv", encoding="latin-1") as csv_train,
+ ):
csv_reader = csv.reader(csv_train)
for row in csv_reader:
@@ -1115,9 +1115,10 @@ def __init__(
train_file.write(f"__label__{label} {text}\n")
# create test.txt file from CSV
- with (data_folder / "test.txt").open("w", encoding="utf-8") as train_file, (
- senteval_folder / "testdata.manual.2009.06.14.csv"
- ).open(encoding="latin-1") as csv_train:
+ with (
+ (data_folder / "test.txt").open("w", encoding="utf-8") as train_file,
+ (senteval_folder / "testdata.manual.2009.06.14.csv").open(encoding="latin-1") as csv_train,
+ ):
csv_reader = csv.reader(csv_train)
for row in csv_reader:
@@ -1384,9 +1385,10 @@ def __init__(
# create train dev and test files in fasttext format
for new_filename, original_filename in zip(new_filenames, original_filenames):
- with open(data_folder / new_filename, "a") as out_file, open(
- data_folder / "raw" / original_filename
- ) as in_file:
+ with (
+ open(data_folder / new_filename, "a") as out_file,
+ open(data_folder / "raw" / original_filename) as in_file,
+ ):
for line in in_file:
fields = line.split("\t")
label = "POSITIVE" if fields[1].rstrip() == "1" else "NEGATIVE"
@@ -1437,9 +1439,10 @@ def __init__(
# convert to FastText format
for split in ["train", "dev", "test"]:
- with (data_folder / f"{split}.txt").open("w", encoding="utf-8") as train_file, (
- data_folder / "raw" / f"stsa.fine.{split}"
- ).open(encoding="latin1") as file:
+ with (
+ (data_folder / f"{split}.txt").open("w", encoding="utf-8") as train_file,
+ (data_folder / "raw" / f"stsa.fine.{split}").open(encoding="latin1") as file,
+ ):
for line in file:
train_file.write(f"__label__{line[0]} {line[2:]}")
@@ -1496,9 +1499,10 @@ def __init__(
# create train and dev splits in fasttext format
for split in ["train", "dev"]:
- with open(data_folder / "CoLA" / (split + ".txt"), "a") as out_file, open(
- data_folder / "CoLA" / "original" / (split + ".tsv")
- ) as in_file:
+ with (
+ open(data_folder / "CoLA" / (split + ".txt"), "a") as out_file,
+ open(data_folder / "CoLA" / "original" / (split + ".tsv")) as in_file,
+ ):
for line in in_file:
fields = line.rstrip().split("\t")
label = int(fields[1])
@@ -1506,9 +1510,10 @@ def __init__(
out_file.write(f"__label__{label_map[label]} {sentence}\n")
# create eval_dataset file with no labels
- with open(data_folder / "CoLA" / "eval_dataset.txt", "a") as out_file, open(
- data_folder / "CoLA" / "original" / "test.tsv"
- ) as in_file:
+ with (
+ open(data_folder / "CoLA" / "eval_dataset.txt", "a") as out_file,
+ open(data_folder / "CoLA" / "original" / "test.tsv") as in_file,
+ ):
for line in in_file:
fields = line.rstrip().split("\t")
sentence = fields[1]
@@ -1702,9 +1707,10 @@ def __init__(
data_path = flair.cache_root / "datasets" / dataset_name / "raw"
# create correctly formated txt files
for name in ["train", "test", "dev"]:
- with (data_folder / (name + ".txt")).open("w", encoding="utf-8") as txt_file, (
- data_path / (name + ".tsv")
- ).open(encoding="utf-8") as tsv_file:
+ with (
+ (data_folder / (name + ".txt")).open("w", encoding="utf-8") as txt_file,
+ (data_path / (name + ".tsv")).open(encoding="utf-8") as tsv_file,
+ ):
lines = tsv_file.readlines()
for line in lines:
row = line.split("\t")
@@ -1764,9 +1770,10 @@ def __init__(
if not data_file.is_file():
for original_filename, new_filename in zip(original_filenames, new_filenames):
- with (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, (
- data_folder / new_filename
- ).open("w", encoding="utf-8") as write_fp:
+ with (
+ (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp,
+ (data_folder / new_filename).open("w", encoding="utf-8") as write_fp,
+ ):
for line in open_fp:
line = line.rstrip()
fields = line.split()
@@ -1820,9 +1827,10 @@ def __init__(
if not data_file.is_file():
for original_filename, new_filename in zip(original_filenames, new_filenames):
- with (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, (
- data_folder / new_filename
- ).open("w", encoding="utf-8") as write_fp:
+ with (
+ (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp,
+ (data_folder / new_filename).open("w", encoding="utf-8") as write_fp,
+ ):
for line in open_fp:
line = line.rstrip()
fields = line.split()
@@ -1887,21 +1895,20 @@ def __init__(
if not (data_folder / "train.txt").is_file():
cached_path(url, original)
- import tarfile
-
- tar = tarfile.open(original / "yahoo_answers_csv.tgz", "r:gz")
- members = []
+ with tarfile.open(original / "yahoo_answers_csv.tgz", "r:gz") as tar:
+ members = []
- for member in tar.getmembers():
- if "test.csv" in member.name or "train.csv" in member.name:
- members.append(member)
+ for member in tar.getmembers():
+ if "test.csv" in member.name or "train.csv" in member.name:
+ members.append(member)
- tar.extractall(original, members=members)
+ tar.extractall(original, members=members)
for name in ["train", "test"]:
- with (original / "yahoo_answers_csv" / (name + ".csv")).open(encoding="utf-8") as file, (
- data_folder / (name + ".txt")
- ).open("w", encoding="utf-8") as writer:
+ with (
+ (original / "yahoo_answers_csv" / (name + ".csv")).open(encoding="utf-8") as file,
+ (data_folder / (name + ".txt")).open("w", encoding="utf-8") as writer,
+ ):
reader = csv.reader(file)
for row in reader:
writer.write("__label__" + label_map[row[0]] + " " + row[1] + "\n")
@@ -1963,9 +1970,10 @@ def __init__(
if not data_file.is_file():
for original_filename, new_filename in zip(original_filenames, new_filenames):
- with (data_folder / "original" / original_filename).open(encoding="utf-8") as open_fp, (
- data_folder / task_setting / new_filename
- ).open("w", encoding="utf-8") as write_fp:
+ with (
+ (data_folder / "original" / original_filename).open(encoding="utf-8") as open_fp,
+ (data_folder / task_setting / new_filename).open("w", encoding="utf-8") as write_fp,
+ ):
for line in open_fp:
line = line.rstrip()
fields = line.split("\t")
diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py
index 20f2caefdd..f74bad7092 100644
--- a/flair/datasets/entity_linking.py
+++ b/flair/datasets/entity_linking.py
@@ -4,8 +4,9 @@
import logging
import os
import re
+from collections.abc import Iterable, Iterator
from pathlib import Path
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
+from typing import Any, Optional, Union
import requests
from bioc import biocxml, pubtator
@@ -47,7 +48,7 @@ def __init__(
self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates}
# one name can map to multiple concepts
- self._text_to_index: Dict[str, List[str]] = {}
+ self._text_to_index: dict[str, list[str]] = {}
for candidate in candidates:
for text in [candidate.concept_name, *candidate.synonyms]:
if text not in self._text_to_index:
@@ -60,11 +61,11 @@ def database_name(self) -> str:
return self._dataset_name
@property
- def text_to_index(self) -> Dict[str, List[str]]:
+ def text_to_index(self) -> dict[str, list[str]]:
return self._text_to_index
@property
- def candidates(self) -> List[EntityCandidate]:
+ def candidates(self) -> list[EntityCandidate]:
return list(self._idx_to_candidates.values())
def __getitem__(self, item: str) -> EntityCandidate:
@@ -80,18 +81,18 @@ def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary":
# NOTE: EntityLinkingDictionary are lazy-loaded from a preprocessed file.
# Use this class to load into memory all candidates
class InMemoryEntityLinkingDictionary(EntityLinkingDictionary):
- def __init__(self, candidates: List[EntityCandidate], dataset_name: str):
+ def __init__(self, candidates: list[EntityCandidate], dataset_name: str):
self._dataset_name = dataset_name
super().__init__(candidates, dataset_name=dataset_name)
- def to_state(self) -> Dict[str, Any]:
+ def to_state(self) -> dict[str, Any]:
return {
"dataset_name": self._dataset_name,
"candidates": [candidate.to_dict() for candidate in self._idx_to_candidates.values()],
}
@classmethod
- def from_state(cls, state: Dict[str, Any]) -> "InMemoryEntityLinkingDictionary":
+ def from_state(cls, state: dict[str, Any]) -> "InMemoryEntityLinkingDictionary":
return cls(
dataset_name=state["dataset_name"],
candidates=[EntityCandidate(**candidate) for candidate in state["candidates"]],
@@ -488,7 +489,7 @@ def __init__(
to point to a different folder but typically this should not be necessary.
in_memory: bool
If True, keeps dataset in memory giving speedups in training.
- column_format: Dict[int, str]
+ column_format: dict[int, str]
The column-format to specify which columns correspond to the text or label types.
"""
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
@@ -776,9 +777,10 @@ def __init__(
wiki_language + "_dev.tsv",
],
):
- with open(doc_path, encoding="utf-8") as read, open(
- data_folder / file_name, "w", encoding="utf-8"
- ) as write:
+ with (
+ open(doc_path, encoding="utf-8") as read,
+ open(data_folder / file_name, "w", encoding="utf-8") as write,
+ ):
# ignore first line
read.readline()
line = read.readline()
@@ -1208,9 +1210,10 @@ def __init__(
if not parsed_dataset.exists():
original_file_path = cached_path(f"{tweeki_gold_el_path}", Path("datasets") / dataset_name)
- with open(original_file_path, encoding="utf-8") as read, open(
- parsed_dataset, "w", encoding="utf-8"
- ) as write:
+ with (
+ open(original_file_path, encoding="utf-8") as read,
+ open(parsed_dataset, "w", encoding="utf-8") as write,
+ ):
line = read.readline()
while line:
if line.startswith("#"):
@@ -1274,9 +1277,10 @@ def __init__(
with open(data_folder / corpus_file_name, "w", encoding="utf-8") as txtout:
# First parse the post titles
- with open(data_folder / "posts.tsv", encoding="utf-8") as tsvin1, open(
- data_folder / "gold_post_annotations.tsv", encoding="utf-8"
- ) as tsvin2:
+ with (
+ open(data_folder / "posts.tsv", encoding="utf-8") as tsvin1,
+ open(data_folder / "gold_post_annotations.tsv", encoding="utf-8") as tsvin2,
+ ):
posts = csv.reader(tsvin1, delimiter="\t")
self.post_annotations = csv.reader(tsvin2, delimiter="\t")
self.curr_annot = next(self.post_annotations)
@@ -1312,13 +1316,14 @@ def __init__(
)
# Then parse the comments
- with open(data_folder / "comments.tsv", encoding="utf-8") as tsvin3, open(
- data_folder / "gold_comment_annotations.tsv", encoding="utf-8"
- ) as tsvin4:
+ with (
+ open(data_folder / "comments.tsv", encoding="utf-8") as tsvin3,
+ open(data_folder / "gold_comment_annotations.tsv", encoding="utf-8") as tsvin4,
+ ):
self.comments = csv.reader(tsvin3, delimiter="\t")
self.comment_annotations = csv.reader(tsvin4, delimiter="\t")
self.curr_annot = next(self.comment_annotations)
- self.curr_row: Optional[List[str]] = next(self.comments)
+ self.curr_row: Optional[list[str]] = next(self.comments)
self.stop_iter = False
# Iterate over the comments.tsv file, until the end is reached
@@ -1545,7 +1550,7 @@ def make_line(word, begin_or_inside, attributes):
return line
- def split_span(word_fields: List[str], datasetname: str):
+ def split_span(word_fields: list[str], datasetname: str):
"""Function that splits a word if necessary, i.e. if it is a multiple-word-span.
Parameters
@@ -1646,12 +1651,12 @@ def determine_tsv_file(filename: str, data_folder: Path, cut_multisense: bool =
class WSD_UFSAC(MultiCorpus):
def __init__(
self,
- filenames: Union[str, List[str]] = ["masc", "semcor"],
+ filenames: Union[str, list[str]] = ["masc", "semcor"],
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
cut_multisense: bool = True,
columns={0: "text", 3: "sense"},
- banned_sentences: Optional[List[str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits_in_multicorpus: Union[bool, str] = True,
sample_missing_splits_in_each_corpus: Union[bool, str] = True,
use_raganato_ALL_as_test_data: bool = False,
@@ -1713,7 +1718,7 @@ def __init__(
if isinstance(filenames, str):
filenames = [filenames]
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
log.info("Transforming data into column format and creating corpora...")
@@ -1784,8 +1789,8 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
columns={0: "text", 3: "sense"},
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits: bool = True,
cut_multisense: bool = True,
) -> None:
@@ -1847,8 +1852,8 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
columns={0: "text", 3: "sense"},
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits: Union[bool, str] = True,
cut_multisense: bool = True,
use_raganato_ALL_as_test_data: bool = False,
@@ -1922,8 +1927,8 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
columns={0: "text", 3: "sense"},
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits: Union[bool, str] = True,
use_raganato_ALL_as_test_data: bool = False,
) -> None:
@@ -1994,8 +1999,8 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
columns={0: "text", 3: "sense"},
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits: Union[bool, str] = True,
cut_multisense: bool = True,
use_raganato_ALL_as_test_data: bool = False,
@@ -2070,8 +2075,8 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
columns={0: "text", 3: "sense"},
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits: Union[bool, str] = True,
cut_multisense: bool = True,
use_raganato_ALL_as_test_data: bool = False,
@@ -2147,8 +2152,8 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
columns={0: "text", 3: "sense"},
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
sample_missing_splits: Union[bool, str] = True,
use_raganato_ALL_as_test_data: bool = False,
) -> None:
@@ -2230,7 +2235,7 @@ def __init__(
self,
base_path: Optional[Union[str, Path]] = None,
label_type: str = "el",
- norm_keys: List[str] = ["db_name", "db_id"],
+ norm_keys: list[str] = ["db_name", "db_id"],
**kwargs,
) -> None:
self.label_type = label_type
@@ -2250,14 +2255,14 @@ def __init__(
)
@abc.abstractmethod
- def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]:
+ def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]:
pass
@abc.abstractmethod
- def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]:
+ def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]:
pass
- def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]:
+ def _dict_to_sentences(self, entry: dict[str, Any]) -> list[Sentence]:
entities = [entity for entity in entry["entities"] if entity["normalized"]]
tokenized_passages = [
@@ -2326,7 +2331,7 @@ def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]:
sent_s[start_token_idx : end_token_idx + 1].add_label(self.label_type, mention_id)
return passage_sentences
- def _files_to_dataset(self, paths: Union[Path, List[Path]]) -> FlairDatapointDataset:
+ def _files_to_dataset(self, paths: Union[Path, list[Path]]) -> FlairDatapointDataset:
if isinstance(paths, Path):
paths = [paths]
all_sentences = []
@@ -2347,7 +2352,7 @@ class BIGBIO_EL_NCBI_DISEASE(BigBioEntityLinkingCorpus):
def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-diseases", **kwargs) -> None:
super().__init__(base_path, label_type, **kwargs)
- def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]:
+ def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]:
download_urls = {
"train": (
"NCBItrainset_corpus.txt",
@@ -2362,7 +2367,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat
"https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBItestset_corpus.zip",
),
}
- results_files: Dict[str, Union[Path, List[Path]]] = {}
+ results_files: dict[str, Union[Path, list[Path]]] = {}
for split, (filename, url) in download_urls.items():
result_path = data_folder / filename
@@ -2376,7 +2381,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat
return results_files
- def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]:
+ def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]:
with open(filepath) as f:
for doc in pubtator.iterparse(f):
unified_example = {
@@ -2449,7 +2454,7 @@ class BIGBIO_EL_BC5CDR_CHEMICAL(BigBioEntityLinkingCorpus):
def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-chemical", **kwargs) -> None:
super().__init__(base_path, label_type, **kwargs)
- def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]:
+ def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]:
url = "https://huggingface.co/datasets/bigbio/bc5cdr/resolve/main/CDR_Data.zip"
path = cached_path(url, data_folder)
@@ -2458,7 +2463,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat
unpack_file(path, data_folder)
assert data_folder.exists()
- results_files: Dict[str, Union[Path, List[Path]]] = {
+ results_files: dict[str, Union[Path, list[Path]]] = {
"train": data_path / "CDR_TrainingSet.BioC.xml",
"dev": data_path / "CDR_DevelopmentSet.BioC.xml",
"test": data_path / "CDR_TestSet.BioC.xml",
@@ -2497,7 +2502,7 @@ def _get_bioc_entity(self, span, db_id_key="MESH"):
"normalized": normalized,
}
- def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]:
+ def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]:
reader = biocxml.BioCXMLDocumentReader(str(filepath))
for i, xdoc in enumerate(reader):
@@ -2542,7 +2547,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str
self._re_tax_id = re.compile(r"(?P\d+)\([tT]ax:(?P\d+)\)")
super().__init__(base_path, label_type, norm_keys=["db_id"], **kwargs)
- def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]:
+ def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]:
url = "https://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/tmTools/download/GNormPlus/GNormPlusCorpus.zip"
path = cached_path(url, data_folder)
@@ -2551,7 +2556,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat
unpack_file(path, data_folder)
assert data_folder.exists()
- results_files: Dict[str, Union[Path, List[Path]]] = {
+ results_files: dict[str, Union[Path, list[Path]]] = {
"train": [data_path / "BC2GNtrain.BioC.xml", data_path / "NLMIAT.BioC.xml"],
"test": data_path / "BC2GNtest.BioC.xml",
}
@@ -2595,7 +2600,7 @@ def _parse_bioc_entity(self, span, db_id_key="NCBIGene", insert_tax_id=False):
"normalized": normalized,
}
- def _adjust_entity_offsets(self, text: str, entities: List[Dict]):
+ def _adjust_entity_offsets(self, text: str, entities: list[dict]):
for entity in entities:
start, end = entity["offsets"][0]
entity_mention = entity["text"][0]
@@ -2605,7 +2610,7 @@ def _adjust_entity_offsets(self, text: str, entities: List[Dict]):
elif text[start : end - 1] == entity_mention:
entity["offsets"] = [(start, end - 1)]
- def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]:
+ def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]:
with filepath.open("r") as f:
collection = biocxml.load(f)
diff --git a/flair/datasets/ocr.py b/flair/datasets/ocr.py
index bf60b2b0d6..4a58e4e7d3 100644
--- a/flair/datasets/ocr.py
+++ b/flair/datasets/ocr.py
@@ -1,6 +1,6 @@
import json
from pathlib import Path
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import gdown.download_folder
import PIL
@@ -20,7 +20,7 @@ def __init__(
encoding: str = "utf-8",
load_images: bool = False,
normalize_coords_to_thousands: bool = True,
- label_name_map: Optional[Dict[str, str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
) -> None:
"""Instantiates a Dataset from a OCR-Json format.
@@ -132,7 +132,7 @@ def __init__(
in_memory: bool = True,
load_images: bool = False,
normalize_coords_to_thousands: bool = True,
- label_name_map: Optional[Dict[str, str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
**corpusargs,
) -> None:
"""Instantiates a Corpus from a OCR-Json format.
@@ -205,7 +205,7 @@ def __init__(
in_memory: bool = True,
load_images: bool = False,
normalize_coords_to_thousands: bool = True,
- label_name_map: Optional[Dict[str, str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
**corpusargs,
) -> None:
"""Instantiates the SROIE corpus with perfect ocr boxes.
diff --git a/flair/datasets/relation_extraction.py b/flair/datasets/relation_extraction.py
index 30709a14c4..871811abc2 100644
--- a/flair/datasets/relation_extraction.py
+++ b/flair/datasets/relation_extraction.py
@@ -5,8 +5,9 @@
import os
import re
from collections import defaultdict
+from collections.abc import Iterable
from pathlib import Path
-from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import Any, Optional, Union
import conllu
import gdown
@@ -279,7 +280,7 @@ def extract_and_convert_to_conllu(self, data_file, data_folder):
token_list = self._tacred_example_to_token_list(example)
target_file.write(token_list.serialize())
- def _tacred_example_to_token_list(self, example: Dict[str, Any]) -> conllu.TokenList:
+ def _tacred_example_to_token_list(self, example: dict[str, Any]) -> conllu.TokenList:
id_ = example["id"]
tokens = example["token"]
ner = example["stanford_ner"]
@@ -379,7 +380,7 @@ def _parse_incr(self, source_file) -> Iterable[conllu.TokenList]:
}
metadata_parsers = {"__fallback__": lambda k, v: tuple(k.split())}
- lines: List[str] = []
+ lines: list[str] = []
for index, line in enumerate(source_file):
if index > 0 and line.startswith("#"):
source_str = "".join(lines)
@@ -416,9 +417,10 @@ def convert_to_conllu(self, source_data_folder: Path, data_folder):
]
for source_filename, target_filename in zip(source_filenames, target_filenames):
- with (source_data_folder / source_filename).open(encoding="utf-8") as source_file, (
- data_folder / target_filename
- ).open("w", encoding="utf-8") as target_file:
+ with (
+ (source_data_folder / source_filename).open(encoding="utf-8") as source_file,
+ (data_folder / target_filename).open("w", encoding="utf-8") as target_file,
+ ):
# write CoNLL-U Plus header
target_file.write("# global.columns = id form ner\n")
@@ -426,7 +428,7 @@ def convert_to_conllu(self, source_data_folder: Path, data_folder):
token_list = self._src_token_list_to_token_list(src_token_list)
target_file.write(token_list.serialize())
- def _bio_tags_to_spans(self, tags: List[str]) -> List[Tuple[int, int]]:
+ def _bio_tags_to_spans(self, tags: list[str]) -> list[tuple[int, int]]:
spans = []
span_start = 0
span_end = 0
@@ -590,7 +592,7 @@ def extract_and_convert_to_conllu(self, data_file, data_folder):
ent2 = arg2.split(":")[1]
pmid_to_relations[pmid].add((rel_type, ent1, ent2))
- tokenlists: List[conllu.TokenList] = []
+ tokenlists: list[conllu.TokenList] = []
with zip_file.open(
f"drugprot-gs-training-development/{split}/drugprot_{split}_abstracs.tsv"
) as abstracts_file:
@@ -652,13 +654,13 @@ def has_overlap(self, a, b):
def drugprot_document_to_tokenlists(
self,
pmid: str,
- title_sentences: List[Sentence],
- abstract_sentences: List[Sentence],
+ title_sentences: list[Sentence],
+ abstract_sentences: list[Sentence],
abstract_offset: int,
- entities: Dict[str, Tuple[str, int, int, str]],
- relations: Set[Tuple[str, str, str]],
- ) -> List[conllu.TokenList]:
- tokenlists: List[conllu.TokenList] = []
+ entities: dict[str, tuple[str, int, int, str]],
+ relations: set[tuple[str, str, str]],
+ ) -> list[conllu.TokenList]:
+ tokenlists: list[conllu.TokenList] = []
sentence_id = 1
for offset, sents in [
(0, title_sentences),
diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py
index 613ee639d0..bf5aa4e9e9 100644
--- a/flair/datasets/sequence_labeling.py
+++ b/flair/datasets/sequence_labeling.py
@@ -1,15 +1,21 @@
import copy
+import gzip
import json
import logging
import os
import re
import shutil
+import tarfile
+import tempfile
+import zipfile
from collections import defaultdict
+from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union, cast
-
from torch.utils.data import ConcatDataset, Dataset, Subset
+import requests
+
import flair
from flair.data import (
Corpus,
@@ -215,7 +221,7 @@ def __init__(
self.label_type = label_type
self.path_to_json_file = path_to_json_file
- self.sentences: List[Sentence] = []
+ self.sentences: list[Sentence] = []
with path_to_json_file.open(encoding=encoding) as jsonl_fp:
for line in jsonl_fp:
current_line = json.loads(line)
@@ -229,7 +235,7 @@ def __init__(
self.sentences.append(current_sentence)
- def _add_labels_to_sentence(self, raw_text: str, sentence: Sentence, labels: List[List[Any]]):
+ def _add_labels_to_sentence(self, raw_text: str, sentence: Sentence, labels: list[list[Any]]):
# Add tags for each annotated span
for label in labels:
self._add_label_to_sentence(raw_text, sentence, label[0], label[1], label[2])
@@ -279,7 +285,7 @@ def _add_label_to_sentence(self, text: str, sentence: Sentence, start: int, end:
sentence[start_idx : end_idx + 1].add_label(self.label_type, label)
- def _add_metadatas_to_sentence(self, sentence: Sentence, metadatas: List[Tuple[str, str]]):
+ def _add_metadatas_to_sentence(self, sentence: Sentence, metadatas: list[tuple[str, str]]):
# Add metadatas for sentence
for metadata in metadatas:
self._add_metadata_to_sentence(sentence, metadata[0], metadata[1])
@@ -304,7 +310,7 @@ def __getitem__(self, index: int) -> Sentence:
class MultiFileColumnCorpus(Corpus):
def __init__(
self,
- column_format: Dict[int, str],
+ column_format: dict[int, str],
train_files=None,
test_files=None,
dev_files=None,
@@ -314,8 +320,8 @@ def __init__(
document_separator_token: Optional[str] = None,
skip_first_line: bool = False,
in_memory: bool = True,
- label_name_map: Optional[Dict[str, str]] = None,
- banned_sentences: Optional[List[str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
+ banned_sentences: Optional[list[str]] = None,
default_whitespace_after: int = 1,
**corpusargs,
) -> None:
@@ -415,7 +421,7 @@ class ColumnCorpus(MultiFileColumnCorpus):
def __init__(
self,
data_folder: Union[str, Path],
- column_format: Dict[int, str],
+ column_format: dict[int, str],
train_file=None,
test_file=None,
dev_file=None,
@@ -609,15 +615,15 @@ class ColumnDataset(FlairDataset):
def __init__(
self,
path_to_column_file: Union[str, Path],
- column_name_map: Dict[int, str],
+ column_name_map: dict[int, str],
column_delimiter: str = r"\s+",
comment_symbol: Optional[str] = None,
- banned_sentences: Optional[List[str]] = None,
+ banned_sentences: Optional[list[str]] = None,
in_memory: bool = True,
document_separator_token: Optional[str] = None,
encoding: str = "utf-8",
skip_first_line: bool = False,
- label_name_map: Optional[Dict[str, str]] = None,
+ label_name_map: Optional[dict[str, str]] = None,
default_whitespace_after: int = 1,
) -> None:
r"""Instantiates a column dataset.
@@ -671,7 +677,7 @@ def __init__(
# option 1: keep Sentence objects in memory
if self.in_memory:
- self.sentences: List[Sentence] = []
+ self.sentences: list[Sentence] = []
# pointer to previous
previous_sentence = None
@@ -713,7 +719,7 @@ def __init__(
# option 2: keep source data in memory
if not self.in_memory:
- self.sentences_raw: List[List[str]] = []
+ self.sentences_raw: list[list[str]] = []
while True:
# read lines for next sentence, but don't parse
@@ -813,10 +819,10 @@ def _read_next_sentence(self, file):
return lines
def _convert_lines_to_sentence(
- self, lines, word_level_tag_columns: Dict[int, str], span_level_tag_columns: Optional[Dict[int, str]] = None
+ self, lines, word_level_tag_columns: dict[int, str], span_level_tag_columns: Optional[dict[int, str]] = None
):
token: Optional[Token] = None
- tokens: List[Token] = []
+ tokens: list[Token] = []
filtered_lines = []
comments = []
for line in lines:
@@ -883,9 +889,9 @@ def _convert_lines_to_sentence(
return sentence
return None
- def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: Optional[Token] = None) -> Token:
+ def _parse_token(self, line: str, column_name_map: dict[int, str], last_token: Optional[Token] = None) -> Token:
# get fields from line
- fields: List[str] = self.column_delimiter.split(line.rstrip())
+ fields: list[str] = self.column_delimiter.split(line.rstrip())
field_count = len(fields)
# get head_id if exists (only in dependency parses)
head_id = int(fields[self.head_id_column]) if self.head_id_column else None
@@ -1009,7 +1015,7 @@ def __init__(
base_path: Optional[Union[str, Path]] = None,
version: str = "v4",
language: str = "english",
- domain: Union[None, str, List[str], Dict[str, Union[None, str, List[str]]]] = None,
+ domain: Union[None, str, list[str], dict[str, Union[None, str, list[str]]]] = None,
in_memory: bool = True,
**corpusargs,
) -> None:
@@ -1047,7 +1053,7 @@ def get_available_domains(
version: str = "v4",
language: str = "english",
split: str = "train",
- ) -> List[str]:
+ ) -> list[str]:
processed_data_path = cls._ensure_data_processed(base_path=base_path, language=language, version=version)
processed_split_path = processed_data_path / "splits" / version / language / split
@@ -1061,7 +1067,7 @@ def _get_processed_file_paths(
split: str = "train",
version: str = "v4",
language: str = "english",
- domain: Optional[Union[str, List[str], Dict[str, Union[None, str, List[str]]]]] = None,
+ domain: Optional[Union[str, list[str], dict[str, Union[None, str, list[str]]]]] = None,
) -> Iterable[Path]:
processed_split_path = processed_data_path / "splits" / version / language / split
@@ -1163,8 +1169,8 @@ def _process_coref_span_annotations_for_word(
cls,
label: str,
word_index: int,
- clusters: DefaultDict[int, List[Tuple[int, int]]],
- coref_stacks: DefaultDict[int, List[int]],
+ clusters: defaultdict[int, list[tuple[int, int]]],
+ coref_stacks: defaultdict[int, list[int]],
) -> None:
"""For a given coref label, add it to a currently open span(s), complete a span(s) or ignore it, if it is outside of all spans.
@@ -1202,9 +1208,9 @@ def _process_coref_span_annotations_for_word(
@classmethod
def _process_span_annotations_for_word(
cls,
- annotations: List[str],
- span_labels: List[List[str]],
- current_span_labels: List[Optional[str]],
+ annotations: list[str],
+ span_labels: list[list[str]],
+ current_span_labels: list[Optional[str]],
) -> None:
for annotation_index, annotation in enumerate(annotations):
# strip all bracketing information to
@@ -1230,33 +1236,33 @@ def _process_span_annotations_for_word(
current_span_labels[annotation_index] = None
@classmethod
- def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict:
+ def _conll_rows_to_sentence(cls, conll_rows: list[str]) -> dict:
document_id: str
sentence_id: int
# The words in the sentence.
- sentence: List[str] = []
+ sentence: list[str] = []
# The pos tags of the words in the sentence.
- pos_tags: List[str] = []
+ pos_tags: list[str] = []
# the pieces of the parse tree.
- parse_pieces: List[Optional[str]] = []
+ parse_pieces: list[Optional[str]] = []
# The lemmatised form of the words in the sentence which
# have SRL or word sense information.
- predicate_lemmas: List[Optional[str]] = []
+ predicate_lemmas: list[Optional[str]] = []
# The FrameNet ID of the predicate.
- predicate_framenet_ids: List[Optional[str]] = []
+ predicate_framenet_ids: list[Optional[str]] = []
# The sense of the word, if available.
- word_senses: List[Optional[float]] = []
+ word_senses: list[Optional[float]] = []
# The current speaker, if available.
- speakers: List[Optional[str]] = []
+ speakers: list[Optional[str]] = []
- verbal_predicates: List[str] = []
- span_labels: List[List[str]] = []
- current_span_labels: List[Optional[str]] = []
+ verbal_predicates: list[str] = []
+ span_labels: list[list[str]] = []
+ current_span_labels: list[Optional[str]] = []
# Cluster id -> List of (start_index, end_index) spans.
- clusters: DefaultDict[int, List[Tuple[int, int]]] = defaultdict(list)
+ clusters: defaultdict[int, list[tuple[int, int]]] = defaultdict(list)
# Cluster id -> List of start_indices which are open for this id.
- coref_stacks: DefaultDict[int, List[int]] = defaultdict(list)
+ coref_stacks: defaultdict[int, list[int]] = defaultdict(list)
for index, row in enumerate(conll_rows):
conll_components = row.split()
@@ -1332,7 +1338,7 @@ def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict:
srl_frames = list(zip(verbal_predicates, span_labels[1:]))
# this would not be reached if parse_pieces contained None, hence the cast
- parse_tree = "".join(cast(List[str], parse_pieces)) if all(parse_pieces) else None
+ parse_tree = "".join(cast(list[str], parse_pieces)) if all(parse_pieces) else None
coref_span_tuples = {(cluster_id, span) for cluster_id, span_list in clusters.items() for span in span_list}
return {
@@ -1351,7 +1357,7 @@ def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict:
}
@classmethod
- def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List]:
+ def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[list[dict]]:
"""An iterator over CONLL formatted files which yields documents, regardless of the number of document annotations in a particular file.
This is useful for conll data which has been preprocessed, such
@@ -1360,7 +1366,7 @@ def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List
"""
with open(file_path, encoding="utf8") as open_file:
conll_rows = []
- document: List = []
+ document: list[dict] = []
for line in open_file:
line = line.strip()
if line != "" and not line.startswith("#"):
@@ -1579,6 +1585,240 @@ def __init__(
)
+class CLEANCONLL(ColumnCorpus):
+ def __init__(
+ self,
+ base_path: Optional[Union[str, Path]] = None,
+ in_memory: bool = True,
+ **corpusargs,
+ ) -> None:
+ """Initialize the CleanCoNLL corpus.
+
+ Args:
+ base_path: Base directory for the dataset. If None, defaults to flair.cache_root / "datasets".
+ in_memory: If True, keeps dataset in memory for faster training.
+ """
+ # Set the base path for the dataset
+ base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
+
+ # Define column format
+ columns = {0: "text", 1: "pos", 2: "nel", 3: "ner*", 4: "ner"}
+
+ # Define dataset name
+ dataset_name = self.__class__.__name__.lower()
+
+ # Define data folder path
+ data_folder = base_path / dataset_name
+
+ # Check if the train data file exists, otherwise download and prepare the dataset
+ train_set = data_folder / "cleanconll.train"
+
+ if not train_set.exists():
+ print("CleanCoNLL files not found, so downloading and creating them.")
+
+ # Download and prepare the dataset
+ self.download_and_prepare_data(data_folder)
+
+ else:
+ print("Found files for CleanCoNLL in:", data_folder)
+
+ # Initialize the parent class with the specified parameters
+ super().__init__(
+ data_folder,
+ columns,
+ encoding="utf-8",
+ in_memory=in_memory,
+ document_separator_token="-DOCSTART-",
+ **corpusargs,
+ )
+
+ @staticmethod
+ def download_and_prepare_data(data_folder: Path):
+ def parse_patch(patch_file_path):
+ """Parses a patch file and returns a structured representation of the changes."""
+ changes = []
+ current_change = None
+
+ with open(patch_file_path, encoding="utf-8") as patch_file:
+ for line in patch_file:
+ # Check if the line is a change, delete or add command (like 17721c17703,17705 or 5728d5727)
+ if line and (line[0].isdigit() and ("c" in line or "d" in line or "a" in line)):
+ if current_change:
+ # Append the previous change block to the changes list
+ changes.append(current_change)
+
+ # Start a new change block
+ current_change = {"command": line, "original": [], "new": []}
+
+ # Capture original lines (those marked with "<")
+ elif line.startswith("<"):
+ if current_change:
+ current_change["original"].append(line[2:]) # Remove the "< " part
+
+ # Capture new lines (those marked with ">")
+ elif line.startswith(">"):
+ if current_change:
+ current_change["new"].append(line[2:]) # Remove the "> " part
+
+ # Append the last change block to the changes list
+ if current_change:
+ changes.append(current_change)
+
+ return changes
+
+ def parse_line_range(line_range_str):
+ """Utility function to parse a line range string like '17703,17705' or '5727' and returns a tuple (start, end)."""
+ parts = line_range_str.split(",")
+ if len(parts) == 1:
+ start = int(parts[0]) - 1
+ return (start, start + 1)
+ else:
+ start = int(parts[0]) - 1
+ end = int(parts[1])
+ return (start, end)
+
+ def apply_patch_to_file(original_file, changes, output_file_path):
+ """Applies the patch instructions to the content of the original file."""
+ with open(original_file, encoding="utf-8") as f:
+ original_lines = f.readlines()
+
+ modified_lines = original_lines[:] # Make a copy of original lines
+
+ # Apply each change in reverse order (important to avoid index shift issues)
+ for change in reversed(changes):
+ command = change["command"]
+
+ # Determine the type of the change: `c` for change, `d` for delete, `a` for add
+ if "c" in command:
+ # Example command: 17721c17703,17705
+ original_line_range, new_line_range = command.split("c")
+ original_line_range = parse_line_range(original_line_range)
+ modified_lines[original_line_range[0] : original_line_range[1]] = change["new"]
+
+ elif "d" in command:
+ # Example command: 5728d5727
+ original_line_number = int(command.split("d")[0]) - 1
+ del modified_lines[original_line_number]
+
+ elif "a" in command:
+ # Example command: 1000a1001,1002
+ original_line_number = int(command.split("a")[0]) - 1
+ insertion_point = original_line_number + 1
+ for new_token in reversed(change["new"]):
+ modified_lines.insert(insertion_point, new_token)
+
+ # Write the modified content to the output file
+ with open(output_file_path, "w", encoding="utf-8") as output_file:
+ output_file.writelines(modified_lines)
+
+ def apply_patch(file_path, patch_path, output_path):
+ patch_instructions = parse_patch(patch_path)
+ apply_patch_to_file(file_path, patch_instructions, output_path)
+
+ def extract_tokens(file_path: Path, output_path: Path):
+ with open(file_path, encoding="utf-8") as f_in, open(output_path, "w", encoding="utf-8") as f_out:
+ for line in f_in:
+ # Strip whitespace to check if the line is empty
+ stripped_line = line.strip()
+ if stripped_line:
+ # Write the first token followed by a newline if the line is not empty
+ f_out.write(stripped_line.split()[0] + "\n")
+ else:
+ # Write an empty line if the line is empty
+ f_out.write("\n")
+
+ def merge_annotations(tokens_file, annotations_file, output_file):
+ with (
+ open(tokens_file, encoding="utf-8") as tokens_file,
+ open(annotations_file, encoding="utf-8") as annotations_file,
+ open(output_file, "w", encoding="utf-8") as output_file,
+ ):
+ tokens = tokens_file.readlines()
+ annotations = annotations_file.readlines()
+
+ for token, annotation in zip(tokens, annotations):
+ # Strip the leading '[TOKEN]\t' from the annotation
+ stripped_annotation = "\t".join(annotation.strip().split("\t")[1:])
+ output_file.write(token.strip() + "\t" + stripped_annotation + "\n")
+
+ # Create a temporary directory
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ tmpdir = Path(tmpdirname)
+
+ github_url = "https://github.com/flairNLP/CleanCoNLL/archive/main.zip"
+ zip_path = cached_path(github_url, tmpdir)
+ unpack_file(zip_path, tmpdir, "zip", False)
+ cleanconll_data_root = tmpdir / "CleanCoNLL-main"
+
+ # Check the contents of the temporary directory
+ print(f"Contents of the temporary directory: {list(tmpdir.iterdir())}")
+
+ conll03_dir = data_folder / "original_conll-03"
+ if conll03_dir.exists() and conll03_dir.is_dir() and "train.txt" in [f.name for f in conll03_dir.iterdir()]:
+ print(f"Original CoNLL03 files detected here: {conll03_dir}")
+
+ else:
+ conll_url = "https://data.deepai.org/conll2003.zip"
+
+ conll03_dir.mkdir(parents=True, exist_ok=True)
+ print(f"Downloading the original CoNLL03 from {conll_url} into {conll03_dir} ...")
+
+ zip_path = conll03_dir / "conll2003.zip"
+ response = requests.get(conll_url)
+ zip_path.write_bytes(response.content)
+
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ zip_ref.extractall(conll03_dir)
+
+ conll03_train = conll03_dir / "train.txt"
+ conll03_dev = conll03_dir / "valid.txt"
+ conll03_test = conll03_dir / "test.txt"
+
+ patch_dir = cleanconll_data_root / "data" / "patch_files"
+ tokens_dir = cleanconll_data_root / "data" / "tokens_updated"
+ tokens_dir.mkdir(parents=True, exist_ok=True)
+
+ # Extract only the tokens from the original CoNLL03 files
+ extract_tokens(conll03_train, tokens_dir / "train_tokens.txt")
+ extract_tokens(conll03_dev, tokens_dir / "valid_tokens.txt")
+ extract_tokens(conll03_test, tokens_dir / "test_tokens.txt")
+
+ # Apply the downloaded patch files to apply our token modifications (e.g. line breaks)
+ apply_patch(
+ tokens_dir / "train_tokens.txt",
+ patch_dir / "train_tokens.patch",
+ tokens_dir / "train_tokens_updated.txt",
+ )
+ apply_patch(
+ tokens_dir / "valid_tokens.txt", patch_dir / "dev_tokens.patch", tokens_dir / "dev_tokens_updated.txt"
+ )
+ apply_patch(
+ tokens_dir / "test_tokens.txt", patch_dir / "test_tokens.patch", tokens_dir / "test_tokens_updated.txt"
+ )
+
+ # Merge the updated token files with the CleanCoNLL annotations
+ cleanconll_annotations_dir = cleanconll_data_root / "data" / "cleanconll_annotations"
+ data_folder.mkdir(parents=True, exist_ok=True)
+
+ merge_annotations(
+ tokens_dir / "train_tokens_updated.txt",
+ cleanconll_annotations_dir / "cleanconll_annotations.train",
+ data_folder / "cleanconll.train",
+ )
+ merge_annotations(
+ tokens_dir / "dev_tokens_updated.txt",
+ cleanconll_annotations_dir / "cleanconll_annotations.dev",
+ data_folder / "cleanconll.dev",
+ )
+ merge_annotations(
+ tokens_dir / "test_tokens_updated.txt",
+ cleanconll_annotations_dir / "cleanconll_annotations.test",
+ data_folder / "cleanconll.test",
+ )
+
+ print("Done with creating. CleanCoNLL files are placed here:", data_folder)
+
+
class CONLL_2000(ColumnCorpus):
def __init__(
self,
@@ -1609,18 +1849,22 @@ def __init__(
if not data_file.is_file():
cached_path(f"{conll_2000_path}train.txt.gz", Path("datasets") / dataset_name)
cached_path(f"{conll_2000_path}test.txt.gz", Path("datasets") / dataset_name)
- import gzip
- import shutil
- with gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in, open(
- flair.cache_root / "datasets" / dataset_name / "train.txt",
- "wb",
- ) as f_out:
+ with (
+ gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in,
+ open(
+ flair.cache_root / "datasets" / dataset_name / "train.txt",
+ "wb",
+ ) as f_out,
+ ):
shutil.copyfileobj(f_in, f_out)
- with gzip.open(flair.cache_root / "datasets" / dataset_name / "test.txt.gz", "rb") as f_in, open(
- flair.cache_root / "datasets" / dataset_name / "test.txt",
- "wb",
- ) as f_out:
+ with (
+ gzip.open(flair.cache_root / "datasets" / dataset_name / "test.txt.gz", "rb") as f_in,
+ open(
+ flair.cache_root / "datasets" / dataset_name / "test.txt",
+ "wb",
+ ) as f_out,
+ ):
shutil.copyfileobj(f_in, f_out)
super().__init__(
@@ -1889,8 +2133,6 @@ def __init__(
data_file = data_path / "named_ent_eu.train"
if not data_file.is_file():
cached_path(f"{ner_basque_path}/eiec_v1.0.tgz", Path("datasets") / dataset_name)
- import shutil
- import tarfile
with tarfile.open(
flair.cache_root / "datasets" / dataset_name / "eiec_v1.0.tgz",
@@ -2401,15 +2643,13 @@ def __init__(
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
- import tarfile
if not os.path.isfile(data_folder / "webpages_ner.txt"):
# # download zip
tar_file = "https://cogcomp.seas.upenn.edu/Data/NERWebpagesColumns.tgz"
webpages_ner_path = cached_path(tar_file, Path("datasets") / dataset_name)
- tf = tarfile.open(webpages_ner_path)
- tf.extractall(data_folder)
- tf.close()
+ with tarfile.open(webpages_ner_path) as tf:
+ tf.extractall(data_folder)
outputfile = os.path.abspath(data_folder)
# merge the files in one as the zip is containing multiples files
@@ -2692,7 +2932,7 @@ def _add_IOB_tags(self, data_file: Union[str, Path], encoding: str = "utf8", ner
Specifies the ner-tagged column. The default is 1 (the second column).
"""
- def add_I_prefix(current_line: List[str], ner: int, tag: str):
+ def add_I_prefix(current_line: list[str], ner: int, tag: str):
for i in range(len(current_line)):
if i == 0:
f.write(line_list[i])
@@ -2933,9 +3173,11 @@ def _create_datasets(self, data_file: Union[str, Path], data_folder: Path):
train_len = round(num_lines * 0.8)
test_len = round(num_lines * 0.1)
- with (data_folder / "train.txt").open("w", encoding="utf-8") as train, (data_folder / "test.txt").open(
- "w", encoding="utf-8"
- ) as test, (data_folder / "dev.txt").open("w", encoding="utf-8") as dev:
+ with (
+ (data_folder / "train.txt").open("w", encoding="utf-8") as train,
+ (data_folder / "test.txt").open("w", encoding="utf-8") as test,
+ (data_folder / "dev.txt").open("w", encoding="utf-8") as dev,
+ ):
for k, line in enumerate(file.readlines(), start=1):
if k <= train_len:
train.write(line)
@@ -3126,7 +3368,7 @@ def __prepare_jap_wikinews_corpus(file_in: Union[str, Path], file_out: Union[str
class NER_MASAKHANE(MultiCorpus):
def __init__(
self,
- languages: Union[str, List[str]] = "luo",
+ languages: Union[str, list[str]] = "luo",
version: str = "v2",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
@@ -3210,7 +3452,7 @@ def __init__(
if languages == ["all"]:
languages = list(language_to_code.values())
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
for language in languages:
if language in language_to_code:
language = language_to_code[language]
@@ -3393,7 +3635,7 @@ def __init__(
class NER_MULTI_WIKIANN(MultiCorpus):
def __init__(
self,
- languages: Union[str, List[str]] = "en",
+ languages: Union[str, list[str]] = "en",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = False,
**corpusargs,
@@ -3405,7 +3647,7 @@ def __init__(
Parameters
----------
- languages : Union[str, List[str]]
+ languages : Union[str, list[str]]
Should be an abbreviation of a language ("en", "de",..) or a list of abbreviations.
The datasets of all passed languages will be saved in one MultiCorpus.
(Note that, even though listed on https://elisa-ie.github.io/wikiann/ some datasets are empty.
@@ -3436,7 +3678,7 @@ def __init__(
# this list is handed to the multicorpus
# list that contains the columncopora
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
google_drive_path = "https://drive.google.com/uc?id="
# download data if necessary
@@ -3448,8 +3690,6 @@ def __init__(
# if language not downloaded yet, download it
if not language_folder.exists():
if first:
- import tarfile
-
import gdown
first = False
@@ -3464,10 +3704,8 @@ def __init__(
# unzip
log.info("Extracting data...")
- tar = tarfile.open(str(language_folder / language) + ".tar.gz", "r:gz")
- # tar.extractall(language_folder,members=[tar.getmember(file_name)])
- tar.extract(file_name, str(language_folder))
- tar.close()
+ with tarfile.open(str(language_folder / language) + ".tar.gz", "r:gz") as tar:
+ tar.extract(file_name, str(language_folder))
log.info("...done.")
# transform data into required format
@@ -3496,9 +3734,10 @@ def __init__(
)
def _silver_standard_to_simple_ner_annotation(self, data_file: Union[str, Path]):
- with open(data_file, encoding="utf-8") as f_read, open(
- str(data_file) + "_new", "w+", encoding="utf-8"
- ) as f_write:
+ with (
+ open(data_file, encoding="utf-8") as f_read,
+ open(str(data_file) + "_new", "w+", encoding="utf-8") as f_write,
+ ):
while True:
line = f_read.readline()
if line:
@@ -3814,7 +4053,7 @@ def _google_drive_id_from_language_name(self, language):
class NER_MULTI_XTREME(MultiCorpus):
def __init__(
self,
- languages: Union[str, List[str]] = "en",
+ languages: Union[str, list[str]] = "en",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = False,
**corpusargs,
@@ -3826,7 +4065,7 @@ def __init__(
Parameters
----------
- languages : Union[str, List[str]], optional
+ languages : Union[str, list[str]], optional
Specify the languages you want to load. Provide an empty list or string to select all languages.
base_path : Union[str, Path], optional
Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this to point to a different folder but typically this should not be necessary.
@@ -3897,7 +4136,7 @@ def __init__(
# This list is handed to the multicorpus
# list that contains the columncopora
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
hu_path = "https://nlp.informatik.hu-berlin.de/resources/datasets/panx_dataset"
@@ -3919,12 +4158,10 @@ def __init__(
# unzip
log.info("Extracting data...")
- import tarfile
- tar = tarfile.open(str(temp_file), "r:gz")
- for part in ["train", "test", "dev"]:
- tar.extract(part, str(language_folder))
- tar.close()
+ with tarfile.open(str(temp_file), "r:gz") as tar:
+ for part in ["train", "test", "dev"]:
+ tar.extract(part, str(language_folder))
log.info("...done.")
# transform data into required format
@@ -3963,7 +4200,7 @@ def _xtreme_to_simple_ner_annotation(self, data_file: Union[str, Path]):
class NER_MULTI_WIKINER(MultiCorpus):
def __init__(
self,
- languages: Union[str, List[str]] = "en",
+ languages: Union[str, list[str]] = "en",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = False,
**corpusargs,
@@ -3982,7 +4219,7 @@ def __init__(
data_folder = base_path / dataset_name
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
for language in languages:
language_folder = data_folder / language
@@ -4022,11 +4259,14 @@ def _download_wikiner(self, language_code: str, dataset_name: str):
flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.bz2",
"rb",
)
- with bz_file as f, open(
- flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.train",
- "w",
- encoding="utf-8",
- ) as out:
+ with (
+ bz_file as f,
+ open(
+ flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.train",
+ "w",
+ encoding="utf-8",
+ ) as out,
+ ):
for lineb in f:
line = lineb.decode("utf-8")
words = line.split(" ")
@@ -4894,7 +5134,7 @@ def __init__(
class NER_NERMUD(MultiCorpus):
def __init__(
self,
- domains: Union[str, List[str]] = "all",
+ domains: Union[str, list[str]] = "all",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = False,
**corpusargs,
@@ -4933,7 +5173,7 @@ def __init__(
data_folder = base_path / dataset_name
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
github_path = "https://raw.githubusercontent.com/dhfbk/KIND/main/evalita-2023"
@@ -5077,7 +5317,7 @@ def _set_path(cls, base_path) -> Path:
return base_path
@classmethod
- def _load_features(cls, base_path) -> List[List[str]]:
+ def _load_features(cls, base_path) -> list[list[str]]:
print(base_path)
unpack_file(cached_path(cls.data_url, base_path), base_path, "zip", False)
with open(f"{base_path}/estner.cnll", encoding="utf-8") as in_file:
@@ -5086,17 +5326,17 @@ def _load_features(cls, base_path) -> List[List[str]]:
return features
@classmethod
- def _process_clean_labels(cls, features) -> List[List[str]]:
+ def _process_clean_labels(cls, features) -> list[list[str]]:
preinstances = [[instance[0], instance[len(instance) - 1]] for instance in features]
return preinstances
@classmethod
- def _rmv_clean_labels(cls, features) -> List[str]:
+ def _rmv_clean_labels(cls, features) -> list[str]:
rdcd_features = [feature[:-1] for feature in features]
return rdcd_features
@classmethod
- def _load_noisy_labels(cls, version, base_path) -> List[str]:
+ def _load_noisy_labels(cls, version, base_path) -> list[str]:
file_name = f"NoisyNER_labelset{version}.labels"
cached_path(f"{cls.label_url}/{file_name}", base_path)
with open(f"{base_path}/{file_name}", encoding="utf-8") as in_file:
@@ -5104,7 +5344,7 @@ def _load_noisy_labels(cls, version, base_path) -> List[str]:
return labels
@classmethod
- def _process_noisy_labels(cls, rdcd_features, labels) -> List[List[str]]:
+ def _process_noisy_labels(cls, rdcd_features, labels) -> list[list[str]]:
instances = []
label_idx = 0
for feature in rdcd_features:
@@ -5119,7 +5359,7 @@ def _process_noisy_labels(cls, rdcd_features, labels) -> List[List[str]]:
return instances
@classmethod
- def _delete_empty_labels(cls, version, preinstances) -> List[str]:
+ def _delete_empty_labels(cls, version, preinstances) -> list[str]:
instances = []
if version == 0:
for instance in preinstances:
@@ -5132,7 +5372,7 @@ def _delete_empty_labels(cls, version, preinstances) -> List[str]:
return instances
@classmethod
- def _split_data(cls, instances) -> Tuple[List[str], List[str], List[str]]:
+ def _split_data(cls, instances) -> tuple[list[str], list[str], list[str]]:
train = instances[:185708]
dev = instances[185708:208922]
test = instances[208922:]
@@ -5147,10 +5387,183 @@ def _write_instances(cls, version, base_path, split, data):
out_file.write("\n")
+class NER_NOISEBENCH(ColumnCorpus):
+ label_url = "https://raw.githubusercontent.com/elenamer/NoiseBench/main/data/annotations/"
+ SAVE_TRAINDEV_FILE = False
+
+ def __init__(
+ self,
+ noise: str = "clean",
+ base_path: Optional[Union[str, Path]] = None,
+ in_memory: bool = True,
+ **corpusargs,
+ ) -> None:
+ """Initialize the NoiseBench corpus.
+
+ Args:
+ noise (string): Chooses the labelset for the data.
+ clean (default): Clean labels
+ crowd,crowdbest,expert,distant,weak,llm : Different kinds of noisy labelsets (details: ...)
+ base_path (Optional[Union[str, Path]]): Path to the data.
+ Default is None, meaning the corpus gets automatically downloaded and saved.
+ You can override this by passing a path to a directory containing the unprocessed files but typically this
+ should not be necessary.
+ in_memory (bool): If True the dataset is kept in memory achieving speedups in training.
+ **corpusargs: The arguments propagated to :meth:'flair.datasets.ColumnCorpus.__init__'.
+ """
+ VALUE_NOISE_VALUES = ["clean", "crowd", "crowdbest", "expert", "distant", "weak", "llm"]
+
+ if noise not in VALUE_NOISE_VALUES:
+ raise ValueError(
+ f"Unsupported value for noise type argument. Got {noise}, expected one of {VALUE_NOISE_VALUES}!"
+ )
+
+ self.base_path = flair.cache_root / "datasets" / "noisebench" if not base_path else Path(base_path)
+
+ filename = "clean" if noise == "clean" else f"noise_{noise}"
+ file_paths = [
+ self.base_path / f"{filename}.train",
+ self.base_path / f"{filename}.dev",
+ self.base_path / "clean.test",
+ ]
+ files_exist = [path.exists() for path in file_paths]
+
+ if not all(files_exist):
+ cached_path(f"{self.label_url}/{filename}.traindev", self.base_path / "annotations_only")
+ cached_path(f"{self.label_url}/index.txt", self.base_path / "annotations_only")
+
+ cleanconll_corpus = CLEANCONLL()
+
+ self.cleanconll_base_path = flair.cache_root / "datasets" / cleanconll_corpus.__class__.__name__.lower()
+
+ # create dataset files from index and train/test splits
+ self._generate_data_files(filename, cleanconll_corpus.__class__.__name__.lower())
+
+ super().__init__(
+ data_folder=self.base_path,
+ train_file=f"{filename}.train",
+ dev_file=f"{filename}.dev",
+ test_file="clean.test", # test set is always clean (without noise)
+ column_format={0: "text", 1: "ner"},
+ in_memory=in_memory,
+ column_delimiter="\t",
+ document_separator_token="-DOCSTART-",
+ **corpusargs,
+ )
+
+ @staticmethod
+ def _read_column_file(filename: Union[str, Path]) -> list[list[list[str]]]:
+ with open(filename, errors="replace", encoding="utf-8") as file:
+ lines = file.readlines()
+ all_sentences = []
+ sentence = []
+ for line in lines:
+ stripped_line = line.strip().split("\t") if "\t" in line.strip() else line.strip().split(" ")
+
+ sentence.append(stripped_line)
+ if line.strip() == "":
+ if len(sentence[:-1]) > 0:
+ all_sentences.append(sentence[:-1])
+ sentence = []
+
+ if len(sentence) > 0:
+ all_sentences.append(sentence)
+
+ all_sentences = all_sentences
+ return all_sentences
+
+ @staticmethod
+ def _save_to_column_file(filename: Union[str, Path], sentences: list[list[list[str]]]) -> None:
+ with open(filename, "w", encoding="utf-8") as f:
+ for sentence in sentences:
+ for token in sentence:
+ f.write("\t".join(token))
+ f.write("\n")
+ f.write("\n")
+
+ def _create_train_dev_splits(
+ self, filename: Path, all_sentences: Optional[list] = None, datestring: str = "1996-08-24"
+ ) -> None:
+ if not all_sentences:
+ all_sentences = self._read_column_file(filename)
+
+ train_sentences = []
+ dev_sentences = []
+ for i, s in enumerate(all_sentences):
+ if "DOCSTART" in s[0][0]:
+ assert i + 3 < len(all_sentences) # last document is too short
+
+ # news date is usually in 3rd or 4th sentence of each article
+ if datestring in all_sentences[i + 2][-1][0] or datestring in all_sentences[i + 3][-1][0]:
+ save_to_dev = True
+ else:
+ save_to_dev = False
+
+ if save_to_dev:
+ dev_sentences.append(s)
+ else:
+ train_sentences.append(s)
+
+ self._save_to_column_file(
+ filename.parent / f"{filename.stem}.dev",
+ dev_sentences,
+ )
+ self._save_to_column_file(
+ filename.parent / f"{filename.stem}.train",
+ train_sentences,
+ )
+
+ def _merge_tokens_labels(
+ self, corpus: str, all_clean_sentences: list, token_indices: list
+ ) -> list[list[list[str]]]:
+ # generate NoiseBench dataset variants, given CleanCoNLL, noisy label files and index file
+
+ noisy_labels = self._read_column_file(self.base_path / "annotations_only" / f"{corpus}.traindev")
+ for index, sentence in zip(token_indices, noisy_labels):
+
+ if index.strip() == "docstart":
+ assert len(sentence) == 1
+ sentence[0][0] = "-DOCSTART-"
+ continue
+ clean_sentence = all_clean_sentences[int(index.strip())]
+
+ assert len(clean_sentence) == len(sentence) # this means indexing is wrong
+
+ for token, label in zip(clean_sentence, sentence):
+ label[0] = token[0] # token[0] -> text, token[1] -> BIO label
+ if self.SAVE_TRAINDEV_FILE:
+ self._save_to_column_file(self.base_path / f"{corpus}.traindev", noisy_labels)
+ return noisy_labels
+
+ def _generate_data_files(self, filename: str, origin_dataset_name: str) -> None:
+
+ with open(self.base_path / "annotations_only" / "index.txt", encoding="utf-8") as index_file:
+ token_indices = index_file.readlines()
+ all_clean_sentences = self._read_column_file(self.cleanconll_base_path / f"{origin_dataset_name}.train")
+
+ # os.makedirs(os.path.join('data','noisebench'), exist_ok=True)
+
+ noisy_sentences = self._merge_tokens_labels(filename, all_clean_sentences, token_indices)
+ self._create_train_dev_splits(
+ all_sentences=noisy_sentences, filename=self.base_path / f"{filename}.traindev"
+ )
+
+ # copy test set
+ all_clean_test_sentences = self._read_column_file(self.cleanconll_base_path / f"{origin_dataset_name}.test")
+
+ test_sentences = []
+
+ for sentence in all_clean_test_sentences:
+ new_sentence = [[tokens[0], tokens[4]] for tokens in sentence]
+ test_sentences.append(new_sentence)
+
+ self._save_to_column_file(self.base_path / "clean.test", test_sentences)
+
+
class MASAKHA_POS(MultiCorpus):
def __init__(
self,
- languages: Union[str, List[str]] = "bam",
+ languages: Union[str, list[str]] = "bam",
version: str = "v1",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
@@ -5217,7 +5630,7 @@ def __init__(
if languages == ["all"]:
languages = supported_languages
- corpora: List[Corpus] = []
+ corpora: list[Corpus] = []
for language in languages:
if language not in supported_languages:
log.error(f"Language '{language}' is not in list of supported languages!")
diff --git a/flair/datasets/text_image.py b/flair/datasets/text_image.py
index f7baf72be9..676b078d7f 100644
--- a/flair/datasets/text_image.py
+++ b/flair/datasets/text_image.py
@@ -3,7 +3,6 @@
import os
import urllib
from pathlib import Path
-from typing import List
import numpy as np
import torch.utils.data.dataloader
@@ -40,13 +39,13 @@ def __init__(self, **kwargs) -> None:
feidegger_dataset: Dataset = FeideggerDataset(dataset_info, **kwargs)
- train_indices = list(np.where(np.in1d(feidegger_dataset.split, list(range(8))))[0]) # type: ignore[attr-defined]
+ train_indices = list(np.where(np.isin(feidegger_dataset.split, list(range(8))))[0]) # type: ignore[attr-defined]
train = torch.utils.data.dataset.Subset(feidegger_dataset, train_indices)
- dev_indices = list(np.where(np.in1d(feidegger_dataset.split, [8]))[0]) # type: ignore[attr-defined]
+ dev_indices = list(np.where(np.isin(feidegger_dataset.split, [8]))[0]) # type: ignore[attr-defined]
dev = torch.utils.data.dataset.Subset(feidegger_dataset, dev_indices)
- test_indices = list(np.where(np.in1d(feidegger_dataset.split, [9]))[0]) # type: ignore[attr-defined]
+ test_indices = list(np.where(np.isin(feidegger_dataset.split, [9]))[0]) # type: ignore[attr-defined]
test = torch.utils.data.dataset.Subset(feidegger_dataset, test_indices)
super().__init__(train, dev, test, name="feidegger")
@@ -56,8 +55,8 @@ class FeideggerDataset(FlairDataset):
def __init__(self, dataset_info, **kwargs) -> None:
super().__init__()
- self.data_points: List[DataPair] = []
- self.split: List[int] = []
+ self.data_points: list[DataPair] = []
+ self.split: list[int] = []
def identity(x):
return x
diff --git a/flair/datasets/text_text.py b/flair/datasets/text_text.py
index 0bf0e91020..58a40d62c9 100644
--- a/flair/datasets/text_text.py
+++ b/flair/datasets/text_text.py
@@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
-from typing import List, Optional, Union
+from typing import Optional, Union
import flair
from flair.data import (
@@ -144,14 +144,15 @@ def __init__(
self.total_sentence_count: int = 0
if self.in_memory:
- self.bi_sentences: List[DataPair] = []
+ self.bi_sentences: list[DataPair] = []
else:
- self.source_lines: List[str] = []
- self.target_lines: List[str] = []
+ self.source_lines: list[str] = []
+ self.target_lines: list[str] = []
- with open(str(path_to_source), encoding="utf-8") as source_file, open(
- str(path_to_target), encoding="utf-8"
- ) as target_file:
+ with (
+ open(str(path_to_source), encoding="utf-8") as source_file,
+ open(str(path_to_target), encoding="utf-8") as target_file,
+ ):
source_line = source_file.readline()
target_line = target_file.readline()
@@ -204,7 +205,7 @@ class DataPairCorpus(Corpus):
def __init__(
self,
data_folder: Union[str, Path],
- columns: List[int] = [0, 1, 2],
+ columns: list[int] = [0, 1, 2],
train_file=None,
test_file=None,
dev_file=None,
@@ -318,7 +319,7 @@ class DataPairDataset(FlairDataset):
def __init__(
self,
path_to_data: Union[str, Path],
- columns: List[int] = [0, 1, 2],
+ columns: list[int] = [0, 1, 2],
max_tokens_per_doc=-1,
max_chars_per_doc=-1,
use_tokenizer=True,
@@ -368,11 +369,11 @@ def __init__(
self.total_data_count: int = 0
if self.in_memory:
- self.data_pairs: List[DataPair] = []
+ self.data_pairs: list[DataPair] = []
else:
- self.first_elements: List[str] = []
- self.second_elements: List[str] = []
- self.labels: List[Optional[str]] = []
+ self.first_elements: list[str] = []
+ self.second_elements: list[str] = []
+ self.labels: list[Optional[str]] = []
with open(str(path_to_data), encoding=encoding) as source_file:
source_line = source_file.readline()
@@ -448,7 +449,7 @@ class DataTripleCorpus(Corpus):
def __init__(
self,
data_folder: Union[str, Path],
- columns: List[int] = [0, 1, 2, 3],
+ columns: list[int] = [0, 1, 2, 3],
train_file=None,
test_file=None,
dev_file=None,
@@ -563,7 +564,7 @@ class DataTripleDataset(FlairDataset):
def __init__(
self,
path_to_data: Union[str, Path],
- columns: List[int] = [0, 1, 2, 3],
+ columns: list[int] = [0, 1, 2, 3],
max_tokens_per_doc=-1,
max_chars_per_doc=-1,
use_tokenizer=True,
@@ -614,12 +615,12 @@ def __init__(
self.total_data_count: int = 0
if self.in_memory:
- self.data_triples: List[DataTriple] = []
+ self.data_triples: list[DataTriple] = []
else:
- self.first_elements: List[str] = []
- self.second_elements: List[str] = []
- self.third_elements: List[str] = []
- self.labels: List[Optional[str]] = []
+ self.first_elements: list[str] = []
+ self.second_elements: list[str] = []
+ self.third_elements: list[str] = []
+ self.labels: list[Optional[str]] = []
with open(str(path_to_data), encoding=encoding) as source_file:
source_line = source_file.readline()
@@ -828,9 +829,10 @@ def __init__(
str(data_folder / "MNLI" / temp_file),
)
- with open(data_folder / "MNLI" / dev_filename, "a", encoding="utf-8") as out_file, open(
- data_folder / "MNLI" / temp_file, encoding="utf-8"
- ) as in_file:
+ with (
+ open(data_folder / "MNLI" / dev_filename, "a", encoding="utf-8") as out_file,
+ open(data_folder / "MNLI" / temp_file, encoding="utf-8") as in_file,
+ ):
for line in in_file:
fields = line.split("\t")
reordered_columns = "\t".join(fields[column_id] for column_id in range(11))
diff --git a/flair/datasets/treebanks.py b/flair/datasets/treebanks.py
index ed0f0135cd..21ae327691 100644
--- a/flair/datasets/treebanks.py
+++ b/flair/datasets/treebanks.py
@@ -1,7 +1,7 @@
import logging
import re
from pathlib import Path
-from typing import List, Optional, Union
+from typing import Optional, Union
import flair
from flair.data import Corpus, FlairDataset, Sentence, Token
@@ -82,7 +82,7 @@ def __init__(
with open(str(self.path_to_conll_file), encoding="utf-8") as file:
# option 1: read only sentence boundaries as offset positions
if not self.in_memory:
- self.indices: List[int] = []
+ self.indices: list[int] = []
line = file.readline()
position = 0
@@ -97,7 +97,7 @@ def __init__(
# option 2: keep everything in memory
if self.in_memory:
- self.sentences: List[Sentence] = []
+ self.sentences: list[Sentence] = []
while True:
sentence = self._read_next_sentence(file)
@@ -129,7 +129,7 @@ def __getitem__(self, index: int = 0) -> Sentence:
def _read_next_sentence(self, file) -> Optional[Sentence]:
line = file.readline()
- tokens: List[Token] = []
+ tokens: list[Token] = []
# current token ID
token_idx = 0
@@ -143,7 +143,7 @@ def _read_next_sentence(self, file) -> Optional[Sentence]:
newline_reached = False
while line:
line = line.strip()
- fields: List[str] = re.split("\t+", line)
+ fields: list[str] = re.split("\t+", line)
# end of sentence
if line == "":
diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py
new file mode 100644
index 0000000000..e774084009
--- /dev/null
+++ b/flair/distributed_utils.py
@@ -0,0 +1,88 @@
+import logging
+import os
+import random
+from multiprocessing.connection import Connection
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.multiprocessing as mp
+from torch.distributed import destroy_process_group, init_process_group
+from torch.utils.data import Dataset
+
+import flair
+from flair.data import Corpus, _len_dataset
+
+log = logging.getLogger("flair")
+
+
+def launch_distributed(fn, *args, **kwargs):
+ """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU).
+
+ If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune.
+
+ Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process
+ """
+ world_size = torch.cuda.device_count()
+ log.info(f"Launching {world_size} processes")
+ parent_conn, child_conn = mp.Pipe()
+ mp.spawn(_process_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size)
+ return_value = parent_conn.recv()
+ return return_value
+
+
+def _process_entrypoint(
+ rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict
+) -> None:
+ """Lifecycle of a distributed process -- setup, run, cleanup."""
+ log.info(f"Started process on rank={rank}")
+ try:
+ _ddp_setup(rank, world_size)
+ return_value = fn(*args, **kwargs)
+ if is_main_process():
+ child_conn.send(return_value)
+ finally:
+ destroy_process_group()
+
+
+def _ddp_setup(rank: int, world_size: int) -> None:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+ flair.device = torch.device(rank)
+ torch.cuda.set_device(flair.device)
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
+
+
+def is_main_process() -> bool:
+ """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu."""
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_rank() == 0
+ else:
+ return True
+
+
+def aggregate(value, aggregation_fn=np.mean):
+ """Gather `value` from all processes and send to `aggregation_fn` to get a single return value."""
+ if torch.distributed.is_initialized():
+ gathered_values = [None for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather_object(gathered_values, value)
+ else:
+ gathered_values = [value]
+ return aggregation_fn(gathered_values)
+
+
+def validate_corpus_same_each_process(corpus: Corpus) -> None:
+ """Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
+ reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable"""
+ for dataset in [corpus.train, corpus.dev, corpus.test]:
+ if dataset is not None:
+ _validate_dataset_same_each_process(dataset)
+
+
+def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None:
+ random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset)))
+ for i in random_indices:
+ example = str(dataset[i])
+ examples = aggregate(example, list)
+ if not all(example == examples[0] for example in examples):
+ raise ValueError("Dataset must be the same on each process")
diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py
index 154f2600be..294b41fac8 100644
--- a/flair/embeddings/base.py
+++ b/flair/embeddings/base.py
@@ -1,7 +1,8 @@
import inspect
import logging
from abc import abstractmethod
-from typing import Any, Dict, Generic, List, Sequence, Type, Union
+from collections.abc import Sequence
+from typing import Any, Generic, Union
import torch
from torch.nn import Parameter, ParameterList
@@ -37,7 +38,7 @@ def embedding_length(self) -> int:
def embedding_type(self) -> str:
raise NotImplementedError
- def embed(self, data_points: Union[DT, List[DT]]) -> List[DT]:
+ def embed(self, data_points: Union[DT, list[DT]]) -> list[DT]:
"""Add embeddings to all words in a list of sentences.
If embeddings are already added, updates only if embeddings are non-static.
@@ -55,10 +56,10 @@ def _everything_embedded(self, data_points: Sequence[DT]) -> bool:
return all(self.name in data_point._embeddings for data_point in data_points)
@abstractmethod
- def _add_embeddings_internal(self, sentences: List[DT]):
+ def _add_embeddings_internal(self, sentences: list[DT]):
"""Private method for adding embeddings to all words in a list of sentences."""
- def get_names(self) -> List[str]:
+ def get_names(self) -> list[str]:
"""Returns a list of embedding names.
In most cases, it is just a list with one item, namely the name of
@@ -67,9 +68,6 @@ def get_names(self) -> List[str]:
"""
return [self.name]
- def get_named_embeddings_dict(self) -> Dict:
- return {self.name: self}
-
@staticmethod
def get_instance_parameters(locals: dict) -> dict:
class_definition = locals.get("__class__")
@@ -84,14 +82,14 @@ def get_instance_parameters(locals: dict) -> dict:
return instance_parameters
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "Embeddings":
+ def from_params(cls, params: dict[str, Any]) -> "Embeddings":
raise NotImplementedError
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
raise NotImplementedError
@classmethod
- def load_embedding(cls, params: Dict[str, Any]):
+ def load_embedding(cls, params: dict[str, Any]):
state_dict = params.pop("state_dict", None)
embedding = cls.from_params(params)
@@ -155,7 +153,7 @@ def __init__(self, mixture_size: int, trainable: bool = False) -> None:
requires_grad=trainable,
)
- def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
+ def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor:
"""Forward pass of scalar mix.
Computes a weighted average of the ``tensors``. The input tensors an be any shape
@@ -203,7 +201,7 @@ def _everything_embedded(self, data_points: Sequence[Sentence]) -> bool:
return True
-EMBEDDING_CLASSES: Dict[str, Type[Embeddings]] = {}
+EMBEDDING_CLASSES: dict[str, type[Embeddings]] = {}
def register_embeddings(*args):
@@ -225,7 +223,7 @@ def _register(cls):
return _register
-def load_embeddings(params: Dict[str, Any]) -> Embeddings:
+def load_embeddings(params: dict[str, Any]) -> Embeddings:
cls_name = params.pop("__cls__")
cls = EMBEDDING_CLASSES[cls_name]
return cls.load_embedding(params)
diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py
index c1e73442e6..8f66a198ed 100644
--- a/flair/embeddings/document.py
+++ b/flair/embeddings/document.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any, Dict, List, Optional, Union, cast
+from typing import Any, Optional, Union, cast
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
@@ -67,7 +67,7 @@ def create_from_state(cls, **state):
class DocumentPoolEmbeddings(DocumentEmbeddings):
def __init__(
self,
- embeddings: Union[TokenEmbeddings, List[TokenEmbeddings]],
+ embeddings: Union[TokenEmbeddings, list[TokenEmbeddings]],
fine_tune_mode: str = "none",
pooling: str = "mean",
) -> None:
@@ -114,7 +114,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def embed(self, sentences: Union[List[Sentence], Sentence]):
+ def embed(self, sentences: Union[list[Sentence], Sentence]):
"""Add embeddings to every sentence in the given list of sentences.
If embeddings are already added, updates only if embeddings are non-static.
@@ -146,18 +146,18 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
sentence.set_embedding(self.name, pooled_embedding)
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
pass
def extra_repr(self):
return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}"
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "DocumentPoolEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "DocumentPoolEmbeddings":
embeddings = cast(StackedEmbeddings, load_embeddings(params.pop("embeddings"))).embeddings
return cls(embeddings=embeddings, **params)
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {
"pooling": self.pooling,
"fine_tune_mode": self.fine_tune_mode,
@@ -169,7 +169,7 @@ def to_params(self) -> Dict[str, Any]:
class DocumentTFIDFEmbeddings(DocumentEmbeddings):
def __init__(
self,
- train_dataset: List[Sentence],
+ train_dataset: list[Sentence],
vectorizer: Optional[TfidfVectorizer] = None,
**vectorizer_params,
) -> None:
@@ -203,7 +203,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def embed(self, sentences: Union[List[Sentence], Sentence]):
+ def embed(self, sentences: Union[list[Sentence], Sentence]):
"""Add embeddings to every sentence in the given list of sentences."""
# if only one sentence is passed, convert to list of sentence
if isinstance(sentences, Sentence):
@@ -215,14 +215,14 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
for sentence_id, sentence in enumerate(sentences):
sentence.set_embedding(self.name, tfidf_vectors[sentence_id])
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
pass
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "DocumentTFIDFEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "DocumentTFIDFEmbeddings":
return cls(train_dataset=[], vectorizer=params["vectorizer"])
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {
"vectorizer": self.vectorizer,
}
@@ -232,7 +232,7 @@ def to_params(self) -> Dict[str, Any]:
class DocumentRNNEmbeddings(DocumentEmbeddings):
def __init__(
self,
- embeddings: List[TokenEmbeddings],
+ embeddings: list[TokenEmbeddings],
hidden_size=128,
rnn_layers=1,
reproject_words: bool = True,
@@ -317,7 +317,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
"""Add embeddings to all sentences in the given list of sentences.
If embeddings are already added, update only if embeddings are non-static.
@@ -332,7 +332,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
# embed words in the sentence
self.embeddings.embed(sentences)
- lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
+ lengths: list[int] = [len(sentence.tokens) for sentence in sentences]
longest_token_sequence_in_batch: int = max(lengths)
pre_allocated_zero_tensor = torch.zeros(
@@ -341,7 +341,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
device=flair.device,
)
- all_embs: List[torch.Tensor] = []
+ all_embs: list[torch.Tensor] = []
for sentence in sentences:
all_embs += [emb for token in sentence for emb in token.get_each_embedding()]
nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
@@ -371,7 +371,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
sentence_tensor = self.word_reprojection_map(sentence_tensor)
# push through RNN
- packed = pack_padded_sequence(sentence_tensor, lengths, enforce_sorted=False, batch_first=True) # type: ignore[arg-type]
+ packed = pack_padded_sequence(sentence_tensor, lengths, enforce_sorted=False, batch_first=True)
rnn_out, hidden = self.rnn(packed)
outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True)
@@ -436,7 +436,7 @@ def to_params(self):
return model_state
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "DocumentRNNEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "DocumentRNNEmbeddings":
stacked_embeddings = load_embeddings(params["embeddings"])
assert isinstance(stacked_embeddings, StackedEmbeddings)
return cls(
@@ -484,7 +484,7 @@ def __setstate__(self, d):
@register_embeddings
class DocumentLMEmbeddings(DocumentEmbeddings):
- def __init__(self, flair_embeddings: List[FlairEmbeddings]) -> None:
+ def __init__(self, flair_embeddings: list[FlairEmbeddings]) -> None:
super().__init__()
self.embeddings = flair_embeddings
@@ -503,7 +503,7 @@ def __init__(self, flair_embeddings: List[FlairEmbeddings]) -> None:
def embedding_length(self) -> int:
return self._embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
for embedding in self.embeddings:
embedding.embed(sentences)
@@ -520,17 +520,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
return sentences
- def get_names(self) -> List[str]:
+ def get_names(self) -> list[str]:
if "__names" not in self.__dict__:
self.__names = [name for embedding in self.embeddings for name in embedding.get_names()]
return self.__names
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {"flair_embeddings": [embedding.save_embeddings(False) for embedding in self.embeddings]}
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "DocumentLMEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "DocumentLMEmbeddings":
return cls([cast(FlairEmbeddings, load_embeddings(embedding)) for embedding in params["flair_embeddings"]])
@@ -566,7 +566,7 @@ def __init__(
self.static_embeddings = True
self.eval()
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
sentence_batches = [
sentences[i * self.batch_size : (i + 1) * self.batch_size]
for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)
@@ -577,7 +577,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences
- def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
+ def _add_embeddings_to_sentences(self, sentences: list[Sentence]):
# convert to plain strings, embedded in a list for the encode function
sentences_plain_text = [sentence.to_plain_string() for sentence in sentences]
@@ -591,10 +591,10 @@ def embedding_length(self) -> int:
return self.model.get_sentence_embedding_dimension()
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "SentenceTransformerDocumentEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "SentenceTransformerDocumentEmbeddings":
return cls(**params)
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {
"model": self.model_name,
"batch_size": self.batch_size,
@@ -605,7 +605,7 @@ def to_params(self) -> Dict[str, Any]:
class DocumentCNNEmbeddings(DocumentEmbeddings):
def __init__(
self,
- embeddings: List[TokenEmbeddings],
+ embeddings: list[TokenEmbeddings],
kernels=((100, 3), (100, 4), (100, 5)),
reproject_words: bool = True,
reproject_words_dimension: Optional[int] = None,
@@ -673,7 +673,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
"""Add embeddings to all sentences in the given list of sentences.
If embeddings are already added, update only if embeddings are non-static.
@@ -689,7 +689,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
# embed words in the sentence
self.embeddings.embed(sentences)
- lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
+ lengths: list[int] = [len(sentence.tokens) for sentence in sentences]
padding_length: int = max(max(lengths), self.min_sequence_length)
pre_allocated_zero_tensor = torch.zeros(
@@ -698,7 +698,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
device=flair.device,
)
- all_embs: List[torch.Tensor] = []
+ all_embs: list[torch.Tensor] = []
for sentence in sentences:
all_embs += [emb for token in sentence for emb in token.get_each_embedding()]
nb_padding_tokens = padding_length - len(sentence)
@@ -757,11 +757,11 @@ def _apply(self, fn):
child_module._apply(fn)
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "DocumentCNNEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "DocumentCNNEmbeddings":
embeddings = cast(StackedEmbeddings, load_embeddings(params.pop("embeddings"))).embeddings
return cls(embeddings=embeddings, **params)
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {
"embeddings": self.embeddings.save_embeddings(False),
"kernels": self.kernels,
diff --git a/flair/embeddings/image.py b/flair/embeddings/image.py
index df6d1fadd9..5d79a04390 100644
--- a/flair/embeddings/image.py
+++ b/flair/embeddings/image.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
import torch
import torch.nn.functional as F
@@ -29,12 +29,12 @@ class ImageEmbeddings(Embeddings[Image]):
def embedding_type(self) -> str:
return "image-level"
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
# legacy pickle-like saving for image embeddings, as implementation details are not obvious
return self.__getstate__()
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "Embeddings":
+ def from_params(cls, params: dict[str, Any]) -> "Embeddings":
# legacy pickle-like loading for image embeddings, as implementation details are not obvious
embedding = cls.__new__(cls)
embedding.__setstate__(params)
@@ -53,7 +53,7 @@ def __init__(self, transforms) -> None:
self.static_embeddings = True
super().__init__()
- def _add_embeddings_internal(self, images: List[Image]):
+ def _add_embeddings_internal(self, images: list[Image]):
for image in images:
image_data = self.PIL.Image.open(image.imageURL)
image_data.load()
@@ -77,7 +77,7 @@ def __init__(self, url2tensor_dict, name) -> None:
self.static_embeddings = True
super().__init__()
- def _add_embeddings_internal(self, images: List[Image]):
+ def _add_embeddings_internal(self, images: list[Image]):
for image in images:
if image.imageURL in self.url2tensor_dict:
image.set_embedding(self.name, self.url2tensor_dict[image.imageURL])
@@ -137,7 +137,7 @@ def __init__(self, name, pretrained=True, transforms=None) -> None:
else:
raise Exception(f"Image embeddings {name} not available.")
- def _add_embeddings_internal(self, images: List[Image]):
+ def _add_embeddings_internal(self, images: list[Image]):
image_tensor = torch.stack([self.transforms(image.data) for image in images])
image_embeddings = self.features(image_tensor)
image_embeddings = (
@@ -163,7 +163,7 @@ def __init__(self, feats_in, convnet_parms, posnet_parms, transformer_parms) ->
adaptive_pool_func_map = {"max": AdaptiveMaxPool2d, "avg": AdaptiveAvgPool2d}
- convnet_arch: List[Any] = [] if convnet_parms["dropout"][0] <= 0 else [Dropout2d(convnet_parms["dropout"][0])]
+ convnet_arch: list[Any] = [] if convnet_parms["dropout"][0] <= 0 else [Dropout2d(convnet_parms["dropout"][0])]
convnet_arch.extend(
[
Conv2d(
@@ -266,7 +266,7 @@ def forward(self, x):
return x
- def _add_embeddings_internal(self, images: List[Image]):
+ def _add_embeddings_internal(self, images: list[Image]):
image_tensor = torch.stack([image.data for image in images])
image_embeddings = self.forward(image_tensor)
for image_id, image in enumerate(images):
diff --git a/flair/embeddings/legacy.py b/flair/embeddings/legacy.py
index b2658e2d2f..4b3d2a9517 100644
--- a/flair/embeddings/legacy.py
+++ b/flair/embeddings/legacy.py
@@ -1,7 +1,7 @@
import logging
import re
from pathlib import Path
-from typing import List, Optional, Union
+from typing import Optional, Union
import torch
from deprecated.sphinx import deprecated
@@ -110,12 +110,12 @@ def use_layers_top(self, x):
def use_layers_average(self, x):
return torch.mean(torch.stack(x), 0)
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
# ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn
if not getattr(self, "embedding_mode_fn", None):
self.embedding_mode_fn = self.use_layers_all
- sentence_words: List[List[str]] = []
+ sentence_words: list[list[str]] = []
for sentence in sentences:
sentence_words.append([token.text for token in sentence])
@@ -394,7 +394,7 @@ def __getstate__(self):
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
# if cache is used, try setting embeddings from cache first
if "cache" in self.__dict__ and self.cache is not None:
# try populating embeddings from cache
@@ -463,7 +463,7 @@ class DocumentMeanEmbeddings(DocumentEmbeddings):
version="0.3.1",
reason="The functionality of this class is moved to 'DocumentPoolEmbeddings'",
)
- def __init__(self, token_embeddings: List[TokenEmbeddings]) -> None:
+ def __init__(self, token_embeddings: list[TokenEmbeddings]) -> None:
"""The constructor takes a list of embeddings to be combined."""
super().__init__()
@@ -478,7 +478,7 @@ def __init__(self, token_embeddings: List[TokenEmbeddings]) -> None:
def embedding_length(self) -> int:
return self.__embedding_length
- def embed(self, sentences: Union[List[Sentence], Sentence]):
+ def embed(self, sentences: Union[list[Sentence], Sentence]):
"""Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates
only if embeddings are non-static.
"""
@@ -506,7 +506,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
sentence.set_embedding(self.name, mean_embedding)
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
pass
@@ -517,7 +517,7 @@ class DocumentLSTMEmbeddings(DocumentEmbeddings):
)
def __init__(
self,
- embeddings: List[TokenEmbeddings],
+ embeddings: list[TokenEmbeddings],
hidden_size=128,
rnn_layers=1,
reproject_words: bool = True,
@@ -587,7 +587,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def embed(self, sentences: Union[List[Sentence], Sentence]):
+ def embed(self, sentences: Union[list[Sentence], Sentence]):
"""Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
only if embeddings are non-static.
"""
@@ -604,7 +604,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
longest_token_sequence_in_batch: int = len(sentences[0])
all_sentence_tensors = []
- lengths: List[int] = []
+ lengths: list[int] = []
# go through each sentence in batch
for _i, sentence in enumerate(sentences):
@@ -669,5 +669,5 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
sentence = sentences[sentence_no]
sentence.set_embedding(self.name, embedding)
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
pass
diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py
index b068305800..3d95c8ee0b 100644
--- a/flair/embeddings/token.py
+++ b/flair/embeddings/token.py
@@ -4,7 +4,7 @@
import tempfile
from collections import Counter
from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional, Union
import numpy as np
import torch
@@ -64,7 +64,7 @@ def create_from_state(cls, **state):
class StackedEmbeddings(TokenEmbeddings):
"""A stack of embeddings, used if you need to combine several different embedding types."""
- def __init__(self, embeddings: List[TokenEmbeddings], overwrite_names: bool = True) -> None:
+ def __init__(self, embeddings: list[TokenEmbeddings], overwrite_names: bool = True) -> None:
"""The constructor takes a list of embeddings to be combined."""
super().__init__()
@@ -88,7 +88,7 @@ def __init__(self, embeddings: List[TokenEmbeddings], overwrite_names: bool = Tr
self.__embedding_length += embedding.embedding_length
self.eval()
- def embed(self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True):
+ def embed(self, sentences: Union[Sentence, list[Sentence]], static_embeddings: bool = True):
# if only one sentence is passed, convert to list of sentence
if type(sentences) is Sentence:
sentences = [sentences]
@@ -104,7 +104,7 @@ def embedding_type(self) -> str:
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
for embedding in self.embeddings:
embedding._add_embeddings_internal(sentences)
@@ -113,7 +113,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
def __str__(self) -> str:
return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]'
- def get_names(self) -> List[str]:
+ def get_names(self) -> list[str]:
"""Returns a list of embedding names.
In most cases, it is just a list with one item, namely the name of this embedding. But in some cases, the
@@ -126,13 +126,6 @@ def get_names(self) -> List[str]:
return self.__names
- def get_named_embeddings_dict(self) -> Dict:
- named_embeddings_dict = {}
- for embedding in self.embeddings:
- named_embeddings_dict.update(embedding.get_named_embeddings_dict())
-
- return named_embeddings_dict
-
@classmethod
def from_params(cls, params):
embeddings = [load_embeddings(p) for p in params["embeddings"]]
@@ -154,7 +147,7 @@ def __init__(
force_cpu: bool = True,
stable: bool = False,
no_header: bool = False,
- vocab: Optional[Dict[str, int]] = None,
+ vocab: Optional[dict[str, int]] = None,
embedding_length: Optional[int] = None,
name: Optional[str] = None,
) -> None:
@@ -334,10 +327,10 @@ def get_cached_token_index(self, word: str) -> int:
else:
return len(self.vocab) # token
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
tokens = [token for sentence in sentences for token in sentence.tokens]
- word_indices: List[int] = []
+ word_indices: list[int] = []
for token in tokens:
word = token.text if self.field is None else token.get_label(self.field).value
word_indices.append(self.get_cached_token_index(word))
@@ -386,7 +379,7 @@ def __getattribute__(self, item):
return None
return super().__getattribute__(item)
- def __setstate__(self, state: Dict[str, Any]):
+ def __setstate__(self, state: dict[str, Any]):
state.pop("get_cached_vec", None)
state.setdefault("embeddings", state["name"])
state.setdefault("force_cpu", True)
@@ -416,10 +409,10 @@ def __setstate__(self, state: Dict[str, Any]):
super().__setstate__(state)
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "WordEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "WordEmbeddings":
return cls(embeddings=None, **params)
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {
"vocab": self.vocab,
"stable": self.stable,
@@ -487,7 +480,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
for sentence in sentences:
tokens_char_indices = []
@@ -520,7 +513,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
character_embeddings = self.char_embedding(chars).transpose(0, 1)
- packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length) # type: ignore[arg-type]
+ packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length)
lstm_out, self.hidden = self.char_rnn(packed)
@@ -544,10 +537,10 @@ def __str__(self) -> str:
return self.name
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "CharacterEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "CharacterEmbeddings":
return cls(**params)
- def to_params(self) -> Dict[str, Any]:
+ def to_params(self) -> dict[str, Any]:
return {
"path_to_char_dict": self.char_dictionary,
"char_embedding_dim": self.char_embedding_dim,
@@ -793,7 +786,7 @@ def train(self, mode=True):
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
# gradients are enable if fine-tuning is enabled
gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad()
@@ -885,7 +878,7 @@ def from_params(cls, params):
lm = LanguageModel(**model_params)
return cls(lm, **params)
- def __setstate__(self, d: Dict[str, Any]):
+ def __setstate__(self, d: dict[str, Any]):
# make compatible with old models
d.setdefault("fine_tune", False)
d.setdefault("chars_per_chunk", 512)
@@ -920,8 +913,8 @@ def __init__(
self.name = self.context_embeddings.name + "-context"
# these fields are for the embedding memory
- self.word_embeddings: Dict[str, torch.Tensor] = {}
- self.word_count: Dict[str, int] = {}
+ self.word_embeddings: dict[str, torch.Tensor] = {}
+ self.word_count: dict[str, int] = {}
# whether to add only capitalized words to memory (faster runtime and lower memory consumption)
self.only_capitalized = only_capitalized
@@ -940,7 +933,7 @@ def train(self, mode=True):
self.word_embeddings = {}
self.word_count = {}
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
self.context_embeddings.embed(sentences)
# if we keep a pooling, it needs to be updated continuously
@@ -989,10 +982,10 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
def embedding_length(self) -> int:
return self.__embedding_length
- def get_names(self) -> List[str]:
+ def get_names(self) -> list[str]:
return [self.name, self.context_embeddings.name]
- def __setstate__(self, d: Dict[str, Any]):
+ def __setstate__(self, d: dict[str, Any]):
super().__setstate__(d)
if flair.device.type != "cpu":
@@ -1073,7 +1066,7 @@ def get_cached_vec(self, word: str) -> torch.Tensor:
word_embedding = torch.tensor(word_embedding.tolist(), device=flair.device, dtype=torch.float)
return word_embedding
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
for sentence in sentences:
for token in sentence.tokens:
word = token.text if self.field is None else token.get_label(self.field).value
@@ -1152,7 +1145,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
tokens = [t for sentence in sentences for t in sentence.tokens]
if self.field == "text":
@@ -1240,7 +1233,7 @@ def num_embeddings(self) -> int:
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
def get_idx_for_item(text):
hash_function = hashlib.new(self.__hash_method)
hash_function.update(bytes(str(text), "utf-8"))
@@ -1282,7 +1275,7 @@ def __init__(
self.name: str = "muse-crosslingual"
self.static_embeddings = True
self.__embedding_length: int = 300
- self.language_embeddings: Dict[str, Any] = {}
+ self.language_embeddings: dict[str, Any] = {}
(KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors")
self.kv = KeyedVectors
super().__init__()
@@ -1304,7 +1297,7 @@ def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor:
word_embedding = torch.tensor(word_embedding, device=flair.device, dtype=torch.float)
return word_embedding
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
for _i, sentence in enumerate(sentences):
language_code = sentence.get_language_code()
supported = [
@@ -1465,10 +1458,10 @@ def _preprocess(self, text: str) -> str:
def embedding_length(self) -> int:
return self.__embedding_length
- def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
+ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
tokens = [token for sentence in sentences for token in sentence.tokens]
- word_indices: List[List[int]] = []
+ word_indices: list[list[int]] = []
for token in tokens:
word = token.text if self.field is None else token.get_label(self.field).value
@@ -1601,13 +1594,13 @@ def __init__(self, embeddings: str, model: str = "skip", size: int = 100) -> Non
else:
embeddings_path = embeddings
- log.info("Reading embeddings from %s" % embeddings_path)
+ log.info("Reading embeddings from %s", embeddings_path)
super().__init__(
embeddings=str(extract_single_zip_file(embeddings_path, cache_dir=cache_dir)), name="NILC-" + embeddings
)
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "WordEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "WordEmbeddings":
# no need to recreate as NILCEmbeddings
return WordEmbeddings(embeddings=None, **params)
diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py
index f3492178f9..fdb16eea28 100644
--- a/flair/embeddings/transformer.py
+++ b/flair/embeddings/transformer.py
@@ -8,7 +8,7 @@
from abc import abstractmethod
from io import BytesIO
from pathlib import Path
-from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
+from typing import Any, Literal, Optional, Union, cast
import torch
import transformers
@@ -26,24 +26,20 @@
LayoutLMv2FeatureExtractor,
PretrainedConfig,
PreTrainedTokenizer,
+ T5TokenizerFast,
)
from transformers.tokenization_utils_base import LARGE_INTEGER
from transformers.utils import PaddingStrategy
import flair
from flair.data import Sentence, Token, log
-from flair.embeddings.base import (
- DocumentEmbeddings,
- Embeddings,
- TokenEmbeddings,
- register_embeddings,
-)
+from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings
SENTENCE_BOUNDARY_TAG: str = "[FLERT]"
@torch.jit.script_if_tracing
-def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tensor:
+def pad_sequence_embeddings(all_hidden_states: list[torch.Tensor]) -> torch.Tensor:
embedding_length = all_hidden_states[0].shape[1]
longest_token_sequence_in_batch = 0
for hidden_states in all_hidden_states:
@@ -197,7 +193,12 @@ def fill_mean_token_embeddings(
@torch.jit.script_if_tracing
-def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
+def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
+ return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1]
+
+
+@torch.jit.script_if_tracing
+def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)
@@ -205,9 +206,11 @@ def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths
for i in torch.arange(sentence_hidden_states.shape[0]):
result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0)
+ return result
+
@torch.jit.script_if_tracing
-def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
+def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)
@@ -215,15 +218,17 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths:
for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0)
+ return result
+
def _legacy_reconstruct_word_ids(
- embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]]
-) -> List[List[Optional[int]]]:
+ embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]]
+) -> list[list[Optional[int]]]:
word_ids_list = []
max_len = 0
for tokens in flair_tokens:
token_texts = embedding.tokenizer.tokenize(" ".join(tokens), is_split_into_words=True)
- token_ids = cast(List[int], embedding.tokenizer.convert_tokens_to_ids(token_texts))
+ token_ids = cast(list[int], embedding.tokenizer.convert_tokens_to_ids(token_texts))
expanded_token_ids = embedding.tokenizer.build_inputs_with_special_tokens(token_ids)
j = 0
for _i, token_id in enumerate(token_ids):
@@ -263,21 +268,21 @@ def _get_processed_token_text(tokenizer, token: str) -> str:
return token_text.strip()
-def _reconstruct_word_ids_from_subtokens(embedding, tokens: List[str], subtokens: List[str]):
+def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens: list[str]):
word_iterator = iter(enumerate(_get_processed_token_text(embedding.tokenizer, token) for token in tokens))
token_id, token_text = next(word_iterator)
- word_ids: List[Optional[int]] = []
+ word_ids: list[Optional[int]] = []
reconstructed_token = ""
subtoken_count = 0
processed_first_token = False
special_tokens = []
# check if special tokens exist to circumvent error message
- if embedding.tokenizer._bos_token:
+ if embedding.tokenizer.bos_token is not None:
special_tokens.append(embedding.tokenizer.bos_token)
- if embedding.tokenizer._cls_token:
+ if embedding.tokenizer.cls_token is not None:
special_tokens.append(embedding.tokenizer.cls_token)
- if embedding.tokenizer._sep_token:
+ if embedding.tokenizer.sep_token is not None:
special_tokens.append(embedding.tokenizer.sep_token)
# iterate over subtokens and reconstruct tokens
@@ -444,7 +449,7 @@ def _tokenizer_from_bytes(cls, zip_data: BytesIO) -> PreTrainedTokenizer:
zip_obj = zipfile.ZipFile(zip_data)
with tempfile.TemporaryDirectory() as temp_dir:
zip_obj.extractall(temp_dir)
- return AutoTokenizer.from_pretrained(temp_dir, add_prefix_space=True)
+ return AutoTokenizer.from_pretrained(temp_dir)
@classmethod
def _feature_extractor_from_bytes(cls, zip_data: Optional[BytesIO]) -> Optional[FeatureExtractionMixin]:
@@ -458,7 +463,13 @@ def _feature_extractor_from_bytes(cls, zip_data: Optional[BytesIO]) -> Optional[
def __tokenizer_bytes(self):
with tempfile.TemporaryDirectory() as temp_dir:
files = list(self.tokenizer.save_pretrained(temp_dir))
- if self.tokenizer.is_fast and self.tokenizer.slow_tokenizer_class:
+ if (
+ self.tokenizer.is_fast
+ and self.tokenizer.slow_tokenizer_class
+ and not isinstance(
+ self.tokenizer, T5TokenizerFast
+ ) # do not remove slow files for T5, as it can only be created from slow tokenizer with prefix space
+ ):
vocab_files = self.tokenizer.slow_tokenizer_class.vocab_files_names.values()
files = [f for f in files if all(v not in f for v in vocab_files)]
zip_data = BytesIO()
@@ -497,10 +508,10 @@ def embedding_type(self) -> str:
return "word-level" if self.token_embedding else "sentence-level"
@abstractmethod
- def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
+ def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]:
return self(**tensors)
- def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.device] = None):
+ def prepare_tensors(self, sentences: list[Sentence], device: Optional[torch.device] = None):
if device is None:
device = flair.device
flair_tokens, offsets, lengths = self.__gather_flair_tokens(sentences)
@@ -528,13 +539,13 @@ def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.devi
def __build_transformer_model_inputs(
self,
- sentences: List[Sentence],
- offsets: List[int],
- sentence_lengths: List[int],
- flair_tokens: List[List[Token]],
+ sentences: list[Sentence],
+ offsets: list[int],
+ sentence_lengths: list[int],
+ flair_tokens: list[list[Token]],
device: torch.device,
):
- tokenizer_kwargs: Dict[str, Any] = {}
+ tokenizer_kwargs: dict[str, Any] = {}
if self.tokenizer_needs_ocr_boxes:
tokenizer_kwargs["boxes"] = [[t.get_metadata("bbox") for t in tokens] for tokens in flair_tokens]
else:
@@ -655,7 +666,7 @@ def __build_transformer_model_inputs(
return model_kwargs
- def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[Token]], List[int], List[int]]:
+ def __gather_flair_tokens(self, sentences: list[Sentence]) -> tuple[list[list[Token]], list[int], list[int]]:
offsets = []
lengths = []
if self.context_length > 0:
@@ -679,7 +690,7 @@ def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[To
lengths.append(len(sentence))
return sentence_tokens, offsets, lengths
- def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]:
+ def _expand_sentence_with_context(self, sentence) -> tuple[list[Token], int]:
# fields to store left and right context
left_context = []
right_context = []
@@ -715,7 +726,7 @@ def __extract_token_embeddings(self, sentence_embeddings, sentences):
for token_embedding, token in zip(token_embeddings, sentence):
token.set_embedding(self.name, token_embedding)
- def _add_embeddings_internal(self, sentences: List[Sentence]):
+ def _add_embeddings_internal(self, sentences: list[Sentence]):
tensors = self.prepare_tensors(sentences, device=self.force_device)
gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
with gradient_context:
@@ -732,7 +743,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
@register_embeddings
class TransformerOnnxEmbeddings(TransformerBaseEmbeddings):
- def __init__(self, onnx_model: str, providers: List = [], session_options: Optional[Dict] = None, **kwargs) -> None:
+ def __init__(self, onnx_model: str, providers: list = [], session_options: Optional[dict] = None, **kwargs) -> None:
# onnx prepares numpy arrays, no mather if it runs on gpu or cpu, the input is on cpu first.
super().__init__(**kwargs, force_device=torch.device("cpu"))
self.onnx_model = onnx_model
@@ -749,7 +760,7 @@ def to_params(self):
return params
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "TransformerOnnxEmbeddings":
+ def from_params(cls, params: dict[str, Any]) -> "TransformerOnnxEmbeddings":
params["tokenizer"] = cls._tokenizer_from_bytes(params.pop("tokenizer_data"))
params["feature_extractor"] = cls._feature_extractor_from_bytes(params.pop("feature_extractor_data", None))
return cls(**params)
@@ -805,7 +816,7 @@ def quantize_model(self, quantize_model_path, use_external_data_format: bool = F
self.onnx_model = quantize_model_path
self.create_session()
- def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
+ def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]:
input_array = {k: v.numpy() for k, v in tensors.items()}
embeddings = self.session.run([], input_array)
@@ -847,9 +858,9 @@ def export_from_embedding(
cls,
path: Union[str, Path],
embedding: "TransformerEmbeddings",
- example_sentences: List[Sentence],
+ example_sentences: list[Sentence],
opset_version: int = 14,
- providers: Optional[List] = None,
+ providers: Optional[list] = None,
session_options: Optional[dict] = None,
):
path = str(path)
@@ -896,7 +907,7 @@ def export_from_embedding(
@register_embeddings
class TransformerJitEmbeddings(TransformerBaseEmbeddings):
- def __init__(self, jit_model: Union[bytes, ScriptModule], param_names: List[str], **kwargs) -> None:
+ def __init__(self, jit_model: Union[bytes, ScriptModule], param_names: list[str], **kwargs) -> None:
super().__init__(**kwargs)
if isinstance(jit_model, bytes):
buffer = BytesIO(jit_model)
@@ -918,12 +929,12 @@ def to_params(self):
return state
@classmethod
- def from_params(cls, params: Dict[str, Any]) -> "Embeddings":
+ def from_params(cls, params: dict[str, Any]) -> "Embeddings":
params["tokenizer"] = cls._tokenizer_from_bytes(params.pop("tokenizer_data"))
params["feature_extractor"] = cls._feature_extractor_from_bytes(params.pop("feature_extractor_data", None))
return cls(**params)
- def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
+ def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]:
parameters = []
for param in self.param_names:
parameters.append(tensors[param])
@@ -938,13 +949,13 @@ def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
raise ValueError("either 'token_embedding' or 'document_embedding' needs to be set.")
@classmethod
- def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbeddings", param_names: List[str]):
+ def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbeddings", param_names: list[str]):
return cls(jit_model=module, param_names=param_names, **embedding.to_args())
@classmethod
def parameter_to_list(
- cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence]
- ) -> Tuple[List[str], List[torch.Tensor]]:
+ cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: list[Sentence]
+ ) -> tuple[list[str], list[torch.Tensor]]:
tensors = embedding.prepare_tensors(sentences)
param_names = list(inspect.signature(wrapper.forward).parameters.keys())
params = []
@@ -991,7 +1002,7 @@ def __init__(
@register_embeddings
class TransformerEmbeddings(TransformerBaseEmbeddings):
- onnx_cls: Type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings
+ onnx_cls: type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings
def __init__(
self,
@@ -1014,11 +1025,11 @@ def __init__(
force_max_length: bool = False,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
- transformers_tokenizer_kwargs: Dict[str, Any] = {},
- transformers_config_kwargs: Dict[str, Any] = {},
- transformers_model_kwargs: Dict[str, Any] = {},
+ transformers_tokenizer_kwargs: dict[str, Any] = {},
+ transformers_config_kwargs: dict[str, Any] = {},
+ transformers_model_kwargs: dict[str, Any] = {},
peft_config=None,
- peft_gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = {},
+ peft_gradient_checkpointing_kwargs: Optional[dict[str, Any]] = {},
**kwargs,
) -> None:
"""Instantiate transformers embeddings.
@@ -1120,11 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if peft_config is not None:
# add adapters for finetuning
try:
- from peft import (
- TaskType,
- get_peft_model,
- prepare_model_for_kbit_training,
- )
+ from peft import TaskType, get_peft_model, prepare_model_for_kbit_training
except ImportError:
log.error("You cannot use the PEFT finetuning without peft being installed")
raise
@@ -1346,6 +1353,12 @@ def from_params(cls, params):
def to_params(self):
config_dict = self.model.config.to_dict()
+
+ if hasattr(self.model.config, "_attn_implementation"):
+ # do not switch the attention implementation upon reload.
+ config_dict["attn_implementation"] = self.model.config._attn_implementation
+ config_dict.pop("_attn_implementation_autoset", None)
+
super_params = super().to_params()
# those parameters are only from the super class and will be recreated in the constructor.
@@ -1434,9 +1447,7 @@ def forward(
else:
assert sub_token_lengths is not None
if self.cls_pooling == "cls":
- document_embeddings = sentence_hidden_states[
- torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1
- ]
+ document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "mean":
document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "max":
@@ -1496,11 +1507,11 @@ def forward(
result["token_embeddings"] = all_token_embeddings
return result
- def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
+ def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]:
return self.forward(**tensors)
def export_onnx(
- self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs
+ self, path: Union[str, Path], example_sentences: list[Sentence], **kwargs
) -> TransformerOnnxEmbeddings:
"""Export TransformerEmbeddings to OnnxFormat.
diff --git a/flair/file_utils.py b/flair/file_utils.py
index 7a0118822b..518d69e809 100644
--- a/flair/file_utils.py
+++ b/flair/file_utils.py
@@ -12,8 +12,9 @@
import typing
import warnings
import zipfile
+from collections.abc import Sequence
from pathlib import Path
-from typing import Optional, Sequence, Tuple, Union, cast
+from typing import Optional, Union, cast
from urllib.parse import urlparse
import boto3
@@ -28,10 +29,10 @@
logger = logging.getLogger("flair")
-url_proxies: Optional[typing.Dict[str, str]] = None
+url_proxies: Optional[dict[str, str]] = None
-def set_proxies(proxies: typing.Dict[str, str]) -> None:
+def set_proxies(proxies: dict[str, str]) -> None:
r"""Allows for data downloaded from urls to be forwarded to a proxy.
see https://requests.readthedocs.io/en/latest/user/advanced/#proxies
@@ -74,7 +75,7 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
return decoded
-def filename_to_url(filename: str) -> Tuple[str, Optional[str]]:
+def filename_to_url(filename: str) -> tuple[str, Optional[str]]:
"""Recovers the the url from the encoded filename.
Returns it and the ETag (which may be ``None``)
@@ -374,7 +375,7 @@ def create_cache(self, *args, **kwargs):
return decorator
-def load_torch_state(model_file: str) -> typing.Dict[str, typing.Any]:
+def load_torch_state(model_file: str) -> dict[str, typing.Any]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# load_big_file is a workaround byhttps://github.com/highway11git
diff --git a/flair/inference_utils.py b/flair/inference_utils.py
index 0310671534..c811bf39a1 100644
--- a/flair/inference_utils.py
+++ b/flair/inference_utils.py
@@ -126,7 +126,7 @@ def create_stores(model, backend="sqlite"):
Also deletes the original vectors to save memory.
"""
for embedding in WordEmbeddingsStore._word_embeddings(model):
- if type(embedding) == WordEmbeddings:
+ if isinstance(embedding, WordEmbeddings):
WordEmbeddingsStore(embedding, backend)
del embedding.precomputed_word_embeddings
@@ -135,7 +135,7 @@ def load_stores(model, backend="sqlite"):
"""Loads the db versions of all word embeddings in the model."""
embeds = WordEmbeddingsStore._word_embeddings(model)
for i, embedding in enumerate(embeds):
- if type(embedding) == WordEmbeddings:
+ if isinstance(embedding, WordEmbeddings):
embeds[i] = WordEmbeddingsStore(embedding, backend)
@staticmethod
diff --git a/flair/models/__init__.py b/flair/models/__init__.py
index 8357cc47ea..e75daf074b 100644
--- a/flair/models/__init__.py
+++ b/flair/models/__init__.py
@@ -1,4 +1,3 @@
-from .clustering import ClusteringModel
from .entity_linker_model import SpanClassifier
from .entity_mention_linking import EntityMentionLinker
from .language_model import LanguageModel
@@ -37,6 +36,5 @@
"TARSTagger",
"TextClassifier",
"TextRegressor",
- "ClusteringModel",
"MultitaskModel",
]
diff --git a/flair/models/clustering.py b/flair/models/clustering.py
deleted file mode 100644
index e9902f6f67..0000000000
--- a/flair/models/clustering.py
+++ /dev/null
@@ -1,120 +0,0 @@
-import logging
-import pickle
-from collections import OrderedDict
-from pathlib import Path
-from typing import Optional, Union
-
-import joblib
-from sklearn.base import BaseEstimator, ClusterMixin
-from sklearn.metrics import normalized_mutual_info_score
-from tqdm import tqdm
-
-from flair.data import Corpus, _iter_dataset
-from flair.datasets import DataLoader
-from flair.embeddings import DocumentEmbeddings
-
-log = logging.getLogger("flair")
-
-
-class ClusteringModel:
- """A wrapper class to apply sklearn clustering models on DocumentEmbeddings."""
-
- def __init__(self, model: Union[ClusterMixin, BaseEstimator], embeddings: DocumentEmbeddings) -> None:
- """Instantiate the ClusteringModel.
-
- Args:
- model: the clustering algorithm from sklearn this wrapper will use.
- embeddings: the flair DocumentEmbedding this wrapper uses to calculate a vector for each sentence.
- """
- self.model = model
- self.embeddings = embeddings
-
- def fit(self, corpus: Corpus, **kwargs):
- """Trains the model.
-
- Args:
- corpus: the flair corpus this wrapper will use for fitting the model.
- **kwargs: parameters propagated to the models `.fit()` method.
- """
- X = self._convert_dataset(corpus)
-
- log.info("Start clustering " + str(self.model) + " with " + str(len(X)) + " Datapoints.")
- self.model.fit(X, **kwargs)
- log.info("Finished clustering.")
-
- def predict(self, corpus: Corpus):
- """Predict labels given a list of sentences and returns the respective class indices.
-
- Args:
- corpus: the flair corpus this wrapper will use for predicting the labels.
- """
- X = self._convert_dataset(corpus)
- log.info("Start the prediction " + str(self.model) + " with " + str(len(X)) + " Datapoints.")
- predict = self.model.predict(X)
-
- for idx, sentence in enumerate(_iter_dataset(corpus.get_all_sentences())):
- sentence.set_label("cluster", str(predict[idx]))
-
- log.info("Finished prediction and labeled all sentences.")
- return predict
-
- def save(self, model_file: Union[str, Path]):
- """Saves current model.
-
- Args:
- model_file: path where to save the model.
- """
- joblib.dump(pickle.dumps(self), str(model_file))
-
- log.info("Saved the model to: " + str(model_file))
-
- @staticmethod
- def load(model_file: Union[str, Path]):
- """Loads a model from a given path.
-
- Args:
- model_file: path to the file where the model is saved.
- """
- log.info("Loading model from: " + str(model_file))
- return pickle.loads(joblib.load(str(model_file)))
-
- def _convert_dataset(
- self, corpus, label_type: Optional[str] = None, batch_size: int = 32, return_label_dict: bool = False
- ):
- """Makes a flair-corpus sklearn compatible.
-
- Turns the corpora into X, y datasets as required for most sklearn clustering models.
- Ref.: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.cluster
- """
- log.info("Embed sentences...")
- sentences = []
- for batch in tqdm(DataLoader(corpus.get_all_sentences(), batch_size=batch_size)):
- self.embeddings.embed(batch)
- sentences.extend(batch)
-
- X = [sentence.embedding.cpu().detach().numpy() for sentence in sentences]
-
- if label_type is None:
- return X
-
- labels = [sentence.get_labels(label_type)[0].value for sentence in sentences]
- label_dict = {v: k for k, v in enumerate(OrderedDict.fromkeys(labels))}
- y = [label_dict.get(label) for label in labels]
-
- if return_label_dict:
- return X, y, label_dict
-
- return X, y
-
- def evaluate(self, corpus: Corpus, label_type: str):
- """This method calculates some evaluation metrics for the clustering.
-
- Also, the result of the evaluation is logged.
-
- Args:
- corpus: the flair corpus this wrapper will use for evaluation.
- label_type: the label from the sentence will be used for the evaluation.
- """
- X, Y = self._convert_dataset(corpus, label_type=label_type)
- predict = self.model.predict(X)
- log.info("NMI - Score: " + str(normalized_mutual_info_score(predict, Y)))
diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py
index 9f516a703c..0f1c916bfb 100644
--- a/flair/models/entity_linker_model.py
+++ b/flair/models/entity_linker_model.py
@@ -2,7 +2,7 @@
import re
from functools import lru_cache
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
+from typing import Any, Callable, Optional, Union, cast
from unicodedata import category
import torch
@@ -19,9 +19,9 @@
class CandidateGenerator:
"""Given a string, the CandidateGenerator returns possible target classes as candidates."""
- def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = True) -> None:
+ def __init__(self, candidates: Union[str, dict[str, list[str]]], backoff: bool = True) -> None:
# internal candidate lists of generator
- self.mention_to_candidates_map: Dict = {}
+ self.mention_to_candidates_map: dict[str, list[str]] = {}
# load Zelda candidates if so passed
if isinstance(candidates, str) and candidates.lower() == "zelda":
@@ -39,16 +39,15 @@ def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool =
self.mention_to_candidates_map = candidate_lists
- elif isinstance(candidates, Dict):
+ elif isinstance(candidates, dict):
self.mention_to_candidates_map = candidates
else:
raise ValueError(f"'{candidates}' could not be loaded.")
- self.mention_to_candidates_map = cast(Dict[str, List[str]], self.mention_to_candidates_map)
# if lower casing is enabled, create candidate lists of lower cased versions
self.backoff = backoff
if self.backoff:
# create a new dictionary for lower cased mentions
- lowercased_mention_to_candidates_map: Dict = {}
+ lowercased_mention_to_candidates_map: dict[str, list[str]] = {}
# go through each mention and its candidates
for mention, candidates_list in self.mention_to_candidates_map.items():
@@ -56,8 +55,8 @@ def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool =
# check if backoff mention already seen. If so, add candidates. Else, create new entry.
if backoff_mention in lowercased_mention_to_candidates_map:
current_candidates = lowercased_mention_to_candidates_map[backoff_mention]
- lowercased_mention_to_candidates_map[backoff_mention] = set(current_candidates).union(
- candidates_list
+ lowercased_mention_to_candidates_map[backoff_mention] = list(
+ set(current_candidates).union(candidates_list)
)
else:
lowercased_mention_to_candidates_map[backoff_mention] = candidates_list
@@ -72,7 +71,7 @@ def _make_backoff_string(self, mention: str) -> str:
backoff_mention = re.sub(" +", " ", backoff_mention)
return backoff_mention
- def get_candidates(self, mention: str) -> Set[str]:
+ def get_candidates(self, mention: str) -> set[str]:
"""Given a mention, this method returns a set of candidate classes."""
if self.backoff:
mention = self._make_backoff_string(mention)
@@ -125,7 +124,7 @@ def __init__(
self._label_type = label_type
self._span_label_type = span_label_type
- cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = {
+ cases: dict[str, Callable[[Span, list[str]], torch.Tensor]] = {
"average": self.emb_mean,
"first": self.emb_first,
"last": self.emb_last,
@@ -155,7 +154,7 @@ def emb_firstAndLast(self, span: Span, embedding_names):
def emb_mean(self, span, embedding_names):
return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0)
- def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]:
+ def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Span]:
if self._span_label_type is not None:
spans = sentence.get_spans(self._span_label_type)
# only use span label type if there are predictions, otherwise search for output label type (training labels)
@@ -223,7 +222,7 @@ def _init_model_with_state_dict(cls, state, **kwargs):
def label_type(self):
return self._label_type
- def _mask_scores(self, scores: torch.Tensor, data_points: List[Span]):
+ def _mask_scores(self, scores: torch.Tensor, data_points: list[Span]):
if not self.candidates:
return scores
@@ -242,9 +241,7 @@ def _mask_scores(self, scores: torch.Tensor, data_points: List[Span]):
return masked_scores
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SpanClassifier":
- from typing import cast
-
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "SpanClassifier":
return cast("SpanClassifier", super().load(model_path=model_path))
diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py
index cecf2d9e57..5a1382dd60 100644
--- a/flair/models/entity_mention_linking.py
+++ b/flair/models/entity_mention_linking.py
@@ -4,9 +4,10 @@
import re
import string
from abc import ABC, abstractmethod
+from collections.abc import Sequence
from enum import Enum, auto
from pathlib import Path
-from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
+from typing import Any, Optional, Union, cast
import numpy as np
import torch
@@ -89,7 +90,7 @@
"chemical": "ctd-chemicals",
}
-BIOMEDICAL_DICTIONARIES: Dict[str, Type] = {
+BIOMEDICAL_DICTIONARIES: dict[str, type] = {
"ctd-diseases": CTD_DISEASES_DICTIONARY,
"ctd-chemicals": CTD_CHEMICALS_DICTIONARY,
"ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY,
@@ -151,7 +152,7 @@ def load_dictionary(
class EntityPreprocessor(ABC):
"""A pre-processor used to transform / clean both entity mentions and entity names."""
- def initialize(self, sentences: List[Sentence]) -> None:
+ def initialize(self, sentences: list[Sentence]) -> None:
"""Initializes the pre-processor for a batch of sentences.
This may be necessary for more sophisticated transformations.
@@ -187,14 +188,14 @@ def process_entity_name(self, entity_name: str) -> str:
"""
@classmethod
- def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor":
+ def _from_state(cls, state_dict: dict[str, Any]) -> "EntityPreprocessor":
if inspect.isabstract(cls):
cls_name = state_dict.pop("__cls__", None)
return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict)
else:
return cls(**state_dict)
- def _get_state(self) -> Dict[str, Any]:
+ def _get_state(self) -> dict[str, Any]:
return {"__cls__": self.__class__.__name__}
@@ -237,7 +238,7 @@ def process_entity_name(self, entity_name: str) -> str:
return entity_name
- def _get_state(self) -> Dict[str, Any]:
+ def _get_state(self) -> dict[str, Any]:
return {
**super()._get_state(),
"lowercase": self.lowercase,
@@ -270,9 +271,9 @@ def __init__(
self.ab3p = pyab3p.Ab3p()
self.preprocessor = preprocessor
- self.abbreviation_dict: Dict[str, Dict[str, str]] = {}
+ self.abbreviation_dict: dict[str, dict[str, str]] = {}
- def initialize(self, sentences: List[Sentence]) -> None:
+ def initialize(self, sentences: list[Sentence]) -> None:
self.abbreviation_dict = self._build_abbreviation_dict(sentences)
def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = None) -> str:
@@ -303,7 +304,7 @@ def process_entity_name(self, entity_name: str) -> str:
return entity_name
- def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]:
+ def _build_abbreviation_dict(self, sentences: list[flair.data.Sentence]) -> dict[str, dict[str, str]]:
"""Processes the given sentences with the Ab3P tool.
The function returns a (nested) dictionary containing the abbreviations found for each sentence, e.g.:
@@ -321,7 +322,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict
Returns:
abbreviation_dict: abbreviations and their resolution detected in each input sentence
"""
- abbreviation_dict: Dict[str, Dict[str, str]] = {}
+ abbreviation_dict: dict[str, dict[str, str]] = {}
for sentence in sentences:
sentence_text = sentence.to_original_text()
@@ -331,14 +332,14 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict
return abbreviation_dict
- def _get_state(self) -> Dict[str, Any]:
+ def _get_state(self) -> dict[str, Any]:
return {
**super()._get_state(),
"preprocessor": None if self.preprocessor is None else self.preprocessor._get_state(),
}
@classmethod
- def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor":
+ def _from_state(cls, state_dict: dict[str, Any]) -> "EntityPreprocessor":
return cls(
preprocessor=(
None
@@ -364,7 +365,7 @@ def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[Enti
"""
@abstractmethod
- def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]:
+ def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]:
"""Returns the top-k entity / concept identifiers for each entity mention.
Args:
@@ -376,14 +377,14 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str,
"""
@classmethod
- def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex":
+ def _from_state(cls, state_dict: dict[str, Any]) -> "CandidateSearchIndex":
if inspect.isabstract(cls):
cls_name = state_dict.pop("__cls__", None)
return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict)
else:
return cls(**state_dict)
- def _get_state(self) -> Dict[str, Any]:
+ def _get_state(self) -> dict[str, Any]:
return {"__cls__": self.__class__.__name__}
@@ -396,7 +397,7 @@ def __init__(self):
Args:
name_to_id_index: internal state, should only be set when loading an initialized index.
"""
- self.name_to_id_index: Dict[str, str] = {}
+ self.name_to_id_index: dict[str, str] = {}
def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None:
def p(text: str) -> str:
@@ -407,8 +408,8 @@ def p(text: str) -> str:
for synonym in candidate.synonyms:
self.name_to_id_index[p(synonym)] = candidate.concept_id
- def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]:
- results: List[List[Tuple[str, float]]] = []
+ def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]:
+ results: list[list[tuple[str, float]]] = []
for mention in entity_mentions:
dict_entry = self.name_to_id_index.get(mention)
if dict_entry is None:
@@ -419,12 +420,12 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str,
return results
@classmethod
- def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex":
+ def _from_state(cls, state_dict: dict[str, Any]) -> "CandidateSearchIndex":
index = cls()
index.name_to_id_index = state_dict["name_to_id_index"]
return index
- def _get_state(self) -> Dict[str, Any]:
+ def _get_state(self) -> dict[str, Any]:
return {
**super()._get_state(),
"name_to_id_index": self.name_to_id_index,
@@ -436,7 +437,7 @@ class SemanticCandidateSearchIndex(CandidateSearchIndex):
def __init__(
self,
- embeddings: Dict[str, DocumentEmbeddings],
+ embeddings: dict[str, DocumentEmbeddings],
hybrid_search: bool,
similarity_metric: SimilarityMetric = SimilarityMetric.INNER_PRODUCT,
sparse_weight: float = DEFAULT_SPARSE_WEIGHT,
@@ -460,8 +461,8 @@ def __init__(
self.show_progress = show_progress
self.batch_size = batch_size
- self.ids: List[str] = []
- self._precomputed_embeddings: Dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])}
+ self.ids: list[str] = []
+ self._precomputed_embeddings: dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])}
@classmethod
def bi_encoder(
@@ -479,7 +480,7 @@ def bi_encoder(
if model_name_or_path in PRETRAINED_MODELS:
similarity_metric = PRETRAINED_MODEL_TO_SIMILARITY_METRIC[model_name_or_path]
- embeddings: Dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)}
+ embeddings: dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)}
if hybrid_search:
if dictionary is None:
@@ -515,7 +516,7 @@ def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[Enti
def p(text: str) -> str:
return preprocessor.process_entity_name(text) if preprocessor is not None else text
- texts: List[str] = []
+ texts: list[str] = []
self.ids = []
for candidate in dictionary.candidates:
texts.append(p(candidate.concept_name))
@@ -564,8 +565,8 @@ def p(text: str) -> str:
sent.clear_embeddings()
self._precomputed_embeddings["sparse"] = np.stack(sparse_embs, axis=0)
- def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]:
- query_embeddings: Dict[str, List] = {"dense": []}
+ def embed(self, entity_mentions: list[str]) -> dict[str, np.ndarray]:
+ query_embeddings: dict[str, list[np.ndarray]] = {"dense": []}
inputs = [Sentence(name) for name in entity_mentions]
@@ -600,7 +601,7 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]:
return {k: np.stack(v, axis=0) for k, v in query_embeddings.items()}
- def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]:
+ def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]:
"""Returns the top-k entity / concept identifiers for each entity mention.
Args:
@@ -634,10 +635,10 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str,
return results
@classmethod
- def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchIndex":
+ def _from_state(cls, state_dict: dict[str, Any]) -> "SemanticCandidateSearchIndex":
index = cls(
embeddings=cast(
- Dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()}
+ dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()}
),
similarity_metric=SimilarityMetric(state_dict["similarity_metric"]),
sparse_weight=state_dict["sparse_weight"],
@@ -649,7 +650,7 @@ def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchInde
index._precomputed_embeddings = state_dict["precomputed_embeddings"]
return index
- def _get_state(self) -> Dict[str, Any]:
+ def _get_state(self) -> dict[str, Any]:
return {
**super()._get_state(),
"embeddings": {k: emb.save_embeddings() for k, emb in self.embeddings.items()},
@@ -670,7 +671,7 @@ def __init__(
self,
candidate_generator: CandidateSearchIndex,
preprocessor: EntityPreprocessor,
- entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]],
+ entity_label_types: Union[str, Sequence[str], dict[str, set[str]]],
label_type: str,
dictionary: EntityLinkingDictionary,
batch_size: int = 1024,
@@ -698,8 +699,8 @@ def __init__(
super().__init__()
def get_entity_label_types(
- self, entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]]
- ) -> Dict[str, Set[str]]:
+ self, entity_label_types: Union[str, Sequence[str], dict[str, set[str]]]
+ ) -> dict[str, set[str]]:
"""Find out what NER labels to extract from sentence.
Args:
@@ -709,9 +710,9 @@ def get_entity_label_types(
To use all labels from 'ner', pass 'ner'
"""
if isinstance(entity_label_types, str):
- entity_label_types = cast(Dict[str, Set[str]], {entity_label_types: {}})
+ entity_label_types = cast(dict[str, set[str]], {entity_label_types: {}})
elif isinstance(entity_label_types, Sequence):
- entity_label_types = cast(Dict[str, Set[str]], {label: {} for label in entity_label_types})
+ entity_label_types = cast(dict[str, set[str]], {label: {} for label in entity_label_types})
entity_label_types = {
label: {normalize_entity_type(e) for e in entity_types}
@@ -728,9 +729,9 @@ def label_type(self):
def dictionary(self) -> EntityLinkingDictionary:
return self._dictionary
- def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict[str, Set[str]]) -> List[Label]:
+ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: dict[str, set[str]]) -> list[Label]:
"""Extract tagged mentions from sentences."""
- entities_mentions: List[Label] = []
+ entities_mentions: list[Label] = []
# NOTE: This is a hacky workaround for the fact that
# the `label_type`s in `Classifier.load('hunflair)` are
@@ -762,10 +763,10 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict
def predict(
self,
- sentences: Union[List[Sentence], Sentence],
+ sentences: Union[list[Sentence], Sentence],
top_k: int = 1,
pred_label_type: Optional[str] = None,
- entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Set[str]]]] = None,
+ entity_label_types: Optional[Union[str, Sequence[str], dict[str, set[str]]]] = None,
batch_size: Optional[int] = None,
) -> None:
"""Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer.
@@ -859,7 +860,7 @@ def _fetch_model(model_name: str) -> str:
return hf_download(model_name)
@classmethod
- def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker":
+ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs) -> "EntityMentionLinker":
candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"])
preprocessor = EntityPreprocessor._from_state(state["entity_preprocessor"])
entity_label_types = state["entity_label_types"]
@@ -961,7 +962,7 @@ def __get_model_path_and_entity_type(
model_name_or_path: str,
entity_type: Optional[str] = None,
hybrid_search: bool = False,
- ) -> Tuple[str, str]:
+ ) -> tuple[str, str]:
"""Try to figure out what model the user wants."""
if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES:
raise ValueError(
@@ -1039,24 +1040,24 @@ def __get_dictionary_path(
return dictionary_name_or_path
- def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]:
raise NotImplementedError("The EntityLinker cannot be trained")
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLinker":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "EntityMentionLinker":
from typing import cast
return cast("EntityMentionLinker", super().load(model_path=model_path))
def evaluate(
self,
- data_points: Union[List[Sentence], Dataset],
+ data_points: Union[list[Sentence], Dataset],
gold_label_type: str,
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: str = "none",
mini_batch_size: int = 32,
- main_evaluation_metric: Tuple[str, str] = ("accuracy", "f1-score"),
- exclude_labels: Optional[List[str]] = None,
+ main_evaluation_metric: tuple[str, str] = ("accuracy", "f1-score"),
+ exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
k: int = 1,
diff --git a/flair/models/language_model.py b/flair/models/language_model.py
index ed417f2434..d85db2fb93 100644
--- a/flair/models/language_model.py
+++ b/flair/models/language_model.py
@@ -1,6 +1,6 @@
import math
from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Union
import torch
from torch import logsumexp, nn
@@ -111,7 +111,7 @@ def init_hidden(self, bsz):
def get_representation(
self,
- strings: List[str],
+ strings: list[str],
start_marker: str,
end_marker: str,
chars_per_chunk: int = 512,
@@ -119,7 +119,7 @@ def get_representation(
len_longest_str: int = len(max(strings, key=len))
# pad strings with whitespaces to longest sentence
- padded_strings: List[str] = []
+ padded_strings: list[str] = []
for string in strings:
if not self.is_forward_lm:
@@ -141,11 +141,11 @@ def get_representation(
padding_char_index = self.dictionary.get_idx_for_item(" ")
- batches: List[torch.Tensor] = []
+ batches: list[torch.Tensor] = []
# push each chunk through the RNN language model
for chunk in chunks:
len_longest_chunk: int = len(max(chunk, key=len))
- sequences_as_char_indices: List[List[int]] = []
+ sequences_as_char_indices: list[list[int]] = []
for string in chunk:
char_indices = self.dictionary.get_idx_for_items(list(string))
char_indices += [padding_char_index] * (len_longest_chunk - len(string))
@@ -176,7 +176,7 @@ def get_output(self, text: str):
def repackage_hidden(self, h):
"""Wraps hidden states in new Variables, to detach them from their history."""
- if type(h) == torch.Tensor:
+ if isinstance(h, torch.Tensor):
return h.clone().detach()
else:
return tuple(self.repackage_hidden(v) for v in h)
@@ -296,7 +296,7 @@ def generate_text(
number_of_characters: int = 1000,
temperature: float = 1.0,
break_on_suffix=None,
- ) -> Tuple[str, float]:
+ ) -> tuple[str, float]:
if prefix == "":
prefix = "\n"
diff --git a/flair/models/lemmatizer_model.py b/flair/models/lemmatizer_model.py
index 6f0854d4b5..16ec734941 100644
--- a/flair/models/lemmatizer_model.py
+++ b/flair/models/lemmatizer_model.py
@@ -1,6 +1,6 @@
import logging
from math import inf
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Union
import torch
from torch import nn
@@ -159,7 +159,7 @@ def label_type(self):
def words_to_char_indices(
self,
- tokens: List[str],
+ tokens: list[str],
end_symbol=True,
start_symbol=False,
padding_in_front=False,
@@ -202,7 +202,7 @@ def words_to_char_indices(
return tensor
- def forward_pass(self, sentences: Union[List[Sentence], Sentence]):
+ def forward_pass(self, sentences: Union[list[Sentence], Sentence]):
if isinstance(sentences, Sentence):
sentences = [sentences]
@@ -247,7 +247,7 @@ def decode(self, decoder_input_indices, initial_hidden_states, all_encoder_outpu
output_vectors = self.character_decoder(output)
return output_vectors, hidden
- def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[Optional[torch.Tensor], ...]:
+ def _prepare_tensors(self, sentences: list[Sentence]) -> tuple[Optional[torch.Tensor], ...]:
# get all tokens
tokens = [token for sentence in sentences for token in sentence]
@@ -290,7 +290,7 @@ def forward(
encoder_input_indices: Optional[torch.Tensor],
lengths: Optional[torch.Tensor],
token_embedding_hidden: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# variable to store initial hidden states for decoder
initial_hidden_for_decoder = []
@@ -340,7 +340,7 @@ def forward(
return initial_hidden, all_encoder_outputs
- def encode(self, sentences: List[Sentence]):
+ def encode(self, sentences: list[Sentence]):
tensors = self._prepare_tensors(sentences)
return self.forward(*tensors)
@@ -396,14 +396,14 @@ def _calculate_loss(self, scores, labels):
return self.loss(scores_in_correct_format, target), len(labels)
- def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, sentences: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]:
scores, labels = self.forward_pass(sentences)
return self._calculate_loss(scores, labels)
def predict(
self,
- sentences: Union[List[Sentence], Sentence],
+ sentences: Union[list[Sentence], Sentence],
mini_batch_size: int = 16,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -474,7 +474,7 @@ def predict(
# option 1: greedy decoding
if self.beam_size == 1:
# predictions
- predicted: List[List[Union[int, float]]] = [[] for _ in range(number_tokens)]
+ predicted: list[list[Union[int, float]]] = [[] for _ in range(number_tokens)]
for _decode_step in range(max_length):
# decode next character
@@ -491,7 +491,7 @@ def predict(
for t_id, token in enumerate(tokens_in_batch):
predicted_lemma = "".join(
- self.char_dictionary.get_item_for_index(idx) if idx != self.end_index else ""
+ self.char_dictionary.get_item_for_index(int(idx)) if idx != self.end_index else ""
for idx in predicted[t_id]
)
token.set_label(typename=label_name, value=predicted_lemma)
@@ -525,7 +525,7 @@ def predict(
# keep track of how many hypothesis were completed for each token
n_completed = [0 for _ in range(number_tokens)] # cpu
- final_candidates: List[List[Tuple[torch.Tensor, float]]] = [[] for _ in range(number_tokens)] # cpu
+ final_candidates: list[list[tuple[torch.Tensor, float]]] = [[] for _ in range(number_tokens)] # cpu
# if all_encoder_outputs returned, expand them to beam size (otherwise keep this as None)
batched_encoding_output = (
@@ -552,24 +552,24 @@ def predict(
# check if an end symbol has been predicted and, in that case, set hypothesis aside
end_symbols = (index_candidates == self.end_index).nonzero(as_tuple=False)
- for tuple in end_symbols:
+ for row in end_symbols:
# if the sequence is already ended, do not record as candidate
- if sequences[tuple[0], -1].item() == self.end_index:
+ if sequences[row[0], -1].item() == self.end_index:
continue
# index of token in in list tokens_in_batch
- token_number = torch.div(tuple[0], self.beam_size, rounding_mode="trunc")
+ token_number = torch.div(row[0], self.beam_size, rounding_mode="trunc")
# print(token_number)
- seq = sequences[tuple[0], :] # hypothesis sequence
+ seq = sequences[row[0], :] # hypothesis sequence
# hypothesis score
- score = (scores[tuple[0]] + log_probabilities[tuple[0], tuple[1]]) / (len(seq) + 1)
+ score = (scores[row[0]] + log_probabilities[row[0], row[1]]) / (len(seq) + 1)
final_candidates[token_number].append((seq, score.item()))
# TODO: remove token if number of completed hypothesis exceeds given value
n_completed[token_number] += 1
# set score of corresponding entry to -inf so it will not be expanded
- log_probabilities[tuple[0], tuple[1]] = -inf
+ log_probabilities[row[0], row[1]] = -inf
# get leading_indices for next expansion
# find highest scoring hypothesis among beam_size*beam_size possible ones for each token
@@ -594,8 +594,8 @@ def predict(
# a list of length beam_size*batch_size
# where the first three inidices belong to the first token, the next three to the second token,
# and so on
- beam_numbers: List[int] = []
- seq_numbers: List[int] = []
+ beam_numbers: list[int] = []
+ seq_numbers: list[int] = []
for i, row in enumerate(indices_per_token):
beam_numbers.extend(i * self.beam_size + index.item() // self.beam_size for index in row)
diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py
index 733751eff7..414eb46197 100644
--- a/flair/models/multitask_model.py
+++ b/flair/models/multitask_model.py
@@ -2,7 +2,7 @@
import random
import typing
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
import torch
@@ -27,9 +27,9 @@ class MultitaskModel(flair.nn.Classifier):
def __init__(
self,
- models: List[flair.nn.Classifier],
- task_ids: Optional[List[str]] = None,
- loss_factors: Optional[List[float]] = None,
+ models: list[flair.nn.Classifier],
+ task_ids: Optional[list[str]] = None,
+ loss_factors: Optional[list[float]] = None,
use_all_tasks: bool = False,
) -> None:
"""Instantiates the MultiTaskModel.
@@ -42,10 +42,10 @@ def __init__(
"""
super().__init__()
- task_ids_internal: List[str] = task_ids if task_ids else [f"Task_{i}" for i in range(len(models))]
+ task_ids_internal: list[str] = task_ids if task_ids else [f"Task_{i}" for i in range(len(models))]
- self.tasks: Dict[str, flair.nn.Classifier] = {}
- self.loss_factors: Dict[str, float] = {}
+ self.tasks: dict[str, flair.nn.Classifier] = {}
+ self.loss_factors: dict[str, float] = {}
self.use_all_tasks = use_all_tasks
if not loss_factors:
@@ -63,10 +63,10 @@ def __init__(
def forward(self, *args) -> torch.Tensor:
raise NotImplementedError("`forward` is not used for multitask learning")
- def _prepare_tensors(self, data_points: List[DT]) -> Tuple[torch.Tensor, ...]:
+ def _prepare_tensors(self, data_points: list[DT]) -> tuple[torch.Tensor, ...]:
raise NotImplementedError("`_prepare_tensors` is not used for multitask learning")
- def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, sentences: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]:
"""Calls the respective forward loss of each model and sums them weighted by their loss factors.
Args:
@@ -92,7 +92,9 @@ def predict(
task.predict(sentences, **predictargs)
@staticmethod
- def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_tasks: bool = False) -> Dict:
+ def split_batch_to_task_ids(
+ sentences: Union[list[Sentence], Sentence], all_tasks: bool = False
+ ) -> dict[str, list[int]]:
"""Splits a batch of sentences to its respective model.
If single sentence is assigned to several tasks (i.e. same corpus but different tasks), then the model
@@ -104,7 +106,7 @@ def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_task
Returns: Key-value pairs as (task_id, list of sentences ids in batch)
"""
- batch_to_task_mapping: Dict[str, List[int]] = {}
+ batch_to_task_mapping: dict[str, list[int]] = {}
for sentence_id, sentence in enumerate(sentences):
if all_tasks:
multitask_ids = sentence.get_labels("multitask_id")
@@ -122,7 +124,7 @@ def evaluate( # type: ignore[override]
data_points,
gold_label_type: str,
out_path: Optional[Union[str, Path]] = None,
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
evaluate_all: bool = True,
**evalargs,
) -> Result:
@@ -161,7 +163,7 @@ def evaluate( # type: ignore[override]
loss = torch.tensor(0.0, device=flair.device)
main_score = 0.0
all_detailed_results = ""
- all_classification_report: Dict[str, Dict[str, Any]] = {}
+ all_classification_report: dict[str, dict[str, Any]] = {}
for task_id, split in batch_split.items():
result = self.tasks[task_id].evaluate(
@@ -203,7 +205,7 @@ def evaluate( # type: ignore[override]
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
for model in self.tasks.values():
yield from model.get_used_tokens(corpus, context_length, respect_document_boundaries)
@@ -272,7 +274,7 @@ def _fetch_model(model_name) -> str:
return model_name
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "MultitaskModel":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "MultitaskModel":
from typing import cast
return cast("MultitaskModel", super().load(model_path=model_path))
diff --git a/flair/models/pairwise_classification_model.py b/flair/models/pairwise_classification_model.py
index 262fd08cb5..6308573edc 100644
--- a/flair/models/pairwise_classification_model.py
+++ b/flair/models/pairwise_classification_model.py
@@ -1,5 +1,4 @@
import typing
-from typing import List
import torch
@@ -69,7 +68,7 @@ def __init__(
def label_type(self):
return self._label_type
- def _get_data_points_from_sentence(self, sentence: TextPair) -> List[TextPair]:
+ def _get_data_points_from_sentence(self, sentence: TextPair) -> list[TextPair]:
return [sentence]
def _get_embedding_for_data_point(self, prediction_data_point: TextPair) -> torch.Tensor:
@@ -119,7 +118,7 @@ def _init_model_with_state_dict(cls, state, **kwargs):
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)]
diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py
index c3f34e0f69..9a1c2704be 100644
--- a/flair/models/pairwise_regression_model.py
+++ b/flair/models/pairwise_regression_model.py
@@ -1,5 +1,6 @@
+from collections.abc import Iterable
from pathlib import Path
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
import torch
from torch import nn
@@ -90,7 +91,7 @@ def label_type(self):
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> Iterable[List[str]]:
+ ) -> Iterable[list[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)]
@@ -99,14 +100,14 @@ def get_used_tokens(
yield [t.text for t in sentence_pair.second.left_context(context_length, respect_document_boundaries)]
yield [t.text for t in sentence_pair.second.right_context(context_length, respect_document_boundaries)]
- def forward_loss(self, pairs: List[TextPair]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, pairs: list[TextPair]) -> tuple[torch.Tensor, int]:
loss, num = self._forward_loss_and_scores(pairs=pairs, return_num=True, return_scores=False)
assert isinstance(loss, torch.Tensor)
assert isinstance(num, int)
return loss, num
- def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, return_scores=True) -> Tuple:
+ def _forward_loss_and_scores(self, pairs: list[TextPair], return_num=True, return_scores=True) -> tuple:
# make a forward pass to produce embedded data points and labels
pairs = [pair for pair in pairs if self._filter_data_point(pair)]
@@ -128,7 +129,7 @@ def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, retur
# calculate the loss
loss, num = self._calculate_loss(scores, target_tensor)
- return_value: Tuple[Any, ...] = (loss,)
+ return_value: tuple[Any, ...] = (loss,)
if return_num:
return_value += (num,)
@@ -138,10 +139,10 @@ def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, retur
return return_value
- def _calculate_loss(self, scores: torch.Tensor, target_tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ def _calculate_loss(self, scores: torch.Tensor, target_tensor: torch.Tensor) -> tuple[torch.Tensor, int]:
return self.loss_function(scores, target_tensor), target_tensor.size(0)
- def _prepare_target_tensor(self, pairs: List[TextPair]):
+ def _prepare_target_tensor(self, pairs: list[TextPair]):
target_values = [
torch.tensor([float(label.value) for label in pair.get_labels(self.label_name)], dtype=torch.float)
for pair in pairs
@@ -152,7 +153,7 @@ def _prepare_target_tensor(self, pairs: List[TextPair]):
def _filter_data_point(self, pair: TextPair) -> bool:
return len(pair) > 0
- def _encode_data_points(self, data_points: List[TextPair]) -> torch.Tensor:
+ def _encode_data_points(self, data_points: list[TextPair]) -> torch.Tensor:
# get a tensor of data points
data_point_tensor = torch.stack([self._get_embedding_for_data_point(data_point) for data_point in data_points])
@@ -203,7 +204,7 @@ def _get_state_dict(self):
return model_state
@classmethod
- def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
+ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
"""Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict).
Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary.
@@ -227,12 +228,12 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
def predict(
self,
- pairs: Union[TextPair, List[TextPair]],
+ pairs: Union[TextPair, list[TextPair]],
mini_batch_size: int = 32,
verbose: bool = False,
label_name: Optional[str] = None,
embedding_storage_mode="none",
- ) -> List[TextPair]:
+ ) -> list[TextPair]:
if label_name is None:
label_name = self.label_name if self.label_name is not None else "label"
@@ -278,13 +279,13 @@ def predict(
def evaluate(
self,
- data_points: Union[List[TextPair], Dataset],
+ data_points: Union[list[TextPair], Dataset],
gold_label_type: str,
out_path: Union[str, Path, None] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
- main_evaluation_metric: Tuple[str, str] = ("correlation", "pearson"),
- exclude_labels: Optional[List[str]] = None,
+ main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"),
+ exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
**kwargs,
diff --git a/flair/models/prefixed_tagger.py b/flair/models/prefixed_tagger.py
index 05d8fa8c34..b001653bdc 100644
--- a/flair/models/prefixed_tagger.py
+++ b/flair/models/prefixed_tagger.py
@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union, cast
+from typing import Any, Optional, Union, cast
import torch
from torch.utils.data import Dataset
@@ -26,7 +26,7 @@ class SentenceAugmentationStrategy(ABC):
@abstractmethod
def augment_sentence(
- self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None
+ self, sentence: Sentence, annotation_layers: Optional[Union[str, list[str]]] = None
) -> PrefixedSentence:
"""Augments the given sentence text with additional instructions for working / predicting the task on the given annotations.
@@ -64,7 +64,7 @@ def _init_strategy_with_state_dict(cls, state, **kwargs):
"""Initializes the strategy from the given state."""
def augment_dataset(
- self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None
+ self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, list[str]]] = None
) -> FlairDatapointDataset[PrefixedSentence]:
"""Transforms a dataset into a dataset containing augmented sentences specific to the `PrefixedSequenceTagger`.
@@ -78,14 +78,14 @@ def augment_dataset(
Returns: A dataset of augmented sentences specific to the `PrefixedSequenceTagger`
"""
data_loader: DataLoader = DataLoader(dataset, batch_size=1)
- original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)]
+ original_sentences: list[Sentence] = [batch[0] for batch in iter(data_loader)]
augmented_sentences = [self.augment_sentence(sentence, annotation_layers) for sentence in original_sentences]
return FlairDatapointDataset(augmented_sentences)
def augment_corpus(
- self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None
+ self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, list[str]]] = None
) -> Corpus[PrefixedSentence]:
"""Transforms a corpus into a corpus containing augmented sentences specific to the `PrefixedSequenceTagger`.
@@ -120,7 +120,7 @@ class EntityTypeTaskPromptAugmentationStrategy(SentenceAugmentationStrategy):
"[Tag gene and disease] Mutations in the TP53 tumour suppressor gene are found in ~50% of human cancers"
"""
- def __init__(self, entity_types: List[str]):
+ def __init__(self, entity_types: list[str]):
if len(entity_types) <= 0:
raise AssertionError
@@ -128,7 +128,7 @@ def __init__(self, entity_types: List[str]):
self.task_prompt = self._build_tag_prompt_prefix(entity_types)
def augment_sentence(
- self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None
+ self, sentence: Sentence, annotation_layers: Optional[Union[str, list[str]]] = None
) -> PrefixedSentence:
# Prepend the task description prompt to the sentence text
augmented_sentence = PrefixedSentence(
@@ -182,7 +182,7 @@ def apply_predictions(
]
orig_span.add_label(target_annotation_layer, label.value, label.score)
- def _build_tag_prompt_prefix(self, entity_types: List[str]) -> List[str]:
+ def _build_tag_prompt_prefix(self, entity_types: list[str]) -> list[str]:
if len(self.entity_types) == 1:
prompt = f"[ Tag {entity_types[0]} ]"
else:
@@ -219,29 +219,29 @@ def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(state, augmentation_strategy=strategy, **kwargs)
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "PrefixedSequenceTagger":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "PrefixedSequenceTagger":
from typing import cast
return cast("PrefixedSequenceTagger", super().load(model_path=model_path))
- def forward_loss(self, sentences: Union[List[Sentence], List[PrefixedSentence]]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, sentences: Union[list[Sentence], list[PrefixedSentence]]) -> tuple[torch.Tensor, int]:
# If all sentences are not augmented -> augment them
if all(isinstance(sentence, Sentence) for sentence in sentences):
# mypy does not infer the type of "sentences" restricted by the if statement
- sentences = cast(List[Sentence], sentences)
+ sentences = cast(list[Sentence], sentences)
sentences = self.augment_sentences(sentences=sentences, annotation_layers=self.tag_type)
elif not all(isinstance(sentence, PrefixedSentence) for sentence in sentences):
raise ValueError("All passed sentences must be either uniformly augmented or not.")
# mypy does not infer the type of "sentences" restricted by code above
- sentences = cast(List[Sentence], sentences)
+ sentences = cast(list[Sentence], sentences)
return super().forward_loss(sentences)
def predict(
self,
- sentences: Union[List[Sentence], Sentence, List[PrefixedSentence], PrefixedSentence],
+ sentences: Union[list[Sentence], Sentence, list[PrefixedSentence], PrefixedSentence],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -260,7 +260,7 @@ def predict(
# If all sentences are already augmented (i.e. compatible with this class), just forward the sentences
if all(isinstance(sentence, PrefixedSentence) for sentence in sentences):
# mypy does not infer the type of "sentences" restricted by the if statement
- sentences = cast(List[Sentence], sentences)
+ sentences = cast(list[Sentence], sentences)
return super().predict(
sentences,
@@ -280,12 +280,12 @@ def predict(
for sentence in sentences:
sentence.remove_labels(prediction_label_type)
- sentences = cast(List[Sentence], sentences)
+ sentences = cast(list[Sentence], sentences)
# Augment sentences - copy all annotation of the given tag type
augmented_sentences = self.augment_sentences(sentences, self.tag_type)
- mypy_safe_augmented_sentences = cast(List[Sentence], augmented_sentences)
+ mypy_safe_augmented_sentences = cast(list[Sentence], augmented_sentences)
# Predict on augmented sentence and store it in an internal annotation layer / label
loss_and_count = super().predict(
@@ -312,8 +312,8 @@ def predict(
return loss_and_count
def augment_sentences(
- self, sentences: Union[Sentence, List[Sentence]], annotation_layers: Optional[Union[str, List[str]]] = None
- ) -> List[PrefixedSentence]:
+ self, sentences: Union[Sentence, list[Sentence]], annotation_layers: Optional[Union[str, list[str]]] = None
+ ) -> list[PrefixedSentence]:
if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset):
sentences = [sentences]
diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py
index 35c244d960..e41981c899 100644
--- a/flair/models/regexp_tagger.py
+++ b/flair/models/regexp_tagger.py
@@ -1,7 +1,7 @@
import re
import typing
from dataclasses import dataclass, field
-from typing import Dict, List, Tuple, Union
+from typing import Union
from flair.data import Sentence, Span, Token
@@ -15,8 +15,8 @@ class TokenCollection:
"""
sentence: Sentence
- __tokens_start_pos: List[int] = field(init=False, default_factory=list)
- __tokens_end_pos: List[int] = field(init=False, default_factory=list)
+ __tokens_start_pos: list[int] = field(init=False, default_factory=list)
+ __tokens_end_pos: list[int] = field(init=False, default_factory=list)
def __post_init__(self):
for token in self.tokens:
@@ -24,10 +24,10 @@ def __post_init__(self):
self.__tokens_end_pos.append(token.end_position)
@property
- def tokens(self) -> List[Token]:
+ def tokens(self) -> list[Token]:
return list(self.sentence)
- def get_token_span(self, span: Tuple[int, int]) -> Span:
+ def get_token_span(self, span: tuple[int, int]) -> Span:
"""Find a span by the token character positions.
Given an interval specified with start and end pos as tuple, this function returns a Span object
@@ -45,7 +45,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span:
class RegexpTagger:
- def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]) -> None:
+ def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> None:
r"""This tagger is capable of tagging sentence objects with given regexp -> label mappings.
I.e: The tuple (r'(["\'])(?:(?=(\\?))\2.)*?\1', 'QUOTE') maps every match of the regexp to
@@ -58,14 +58,14 @@ def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]) -> No
Args:
mapping: A list of tuples or a single tuple representing a mapping as regexp -> label
"""
- self._regexp_mapping: Dict[str, typing.Pattern] = {}
+ self._regexp_mapping: dict[str, typing.Pattern] = {}
self.register_labels(mapping=mapping)
@property
def registered_labels(self):
return self._regexp_mapping
- def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]):
+ def register_labels(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]):
"""Register a regexp -> label mapping.
Args:
@@ -81,7 +81,7 @@ def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]
f"Couldn't compile regexp '{regexp}' for label '{label}'. Aborted with error: '{err.msg}'"
)
- def remove_labels(self, labels: Union[List[str], str]):
+ def remove_labels(self, labels: Union[list[str], str]):
"""Remove a registered regexp -> label mapping given by label.
Args:
@@ -101,7 +101,7 @@ def _listify(element: object) -> list:
else:
return element
- def predict(self, sentences: Union[List[Sentence], Sentence]) -> List[Sentence]:
+ def predict(self, sentences: Union[list[Sentence], Sentence]) -> list[Sentence]:
"""Predict the given sentences according to the registered mappings."""
if not isinstance(sentences, list):
sentences = [sentences]
@@ -122,7 +122,7 @@ def _label(self, sentence: Sentence):
for label, pattern in self._regexp_mapping.items():
for match in pattern.finditer(sentence.to_original_text()):
- span: Tuple[int, int] = match.span()
+ span: tuple[int, int] = match.span()
try:
token_span = collection.get_token_span(span)
except ValueError:
diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py
index 53ccabac36..9c6c69577f 100644
--- a/flair/models/relation_classifier_model.py
+++ b/flair/models/relation_classifier_model.py
@@ -2,17 +2,12 @@
import logging
import typing
from abc import ABC, abstractmethod
+from collections.abc import Iterator, Sequence
from pathlib import Path
from typing import (
Any,
- Dict,
- Iterator,
- List,
NamedTuple,
Optional,
- Sequence,
- Set,
- Tuple,
Union,
cast,
)
@@ -50,7 +45,7 @@ class EncodedSentence(Sentence):
class EncodingStrategy(ABC):
"""The encoding of the head and tail entities in a sentence with a relation annotation."""
- special_tokens: Set[str] = set()
+ special_tokens: set[str] = set()
def __init__(self, add_special_tokens: bool = False) -> None:
self.add_special_tokens = add_special_tokens
@@ -84,7 +79,7 @@ class EntityMask(EncodingStrategy):
- "Larry Page and [TAIL] founded [HEAD]" -> Relation(head='Google', tail='Sergey Brin').
"""
- special_tokens: Set[str] = {"[HEAD]", "[TAIL]"}
+ special_tokens: set[str] = {"[HEAD]", "[TAIL]"}
def encode_head(self, head_span: Span, label: Label) -> str:
return "[HEAD]"
@@ -126,7 +121,7 @@ class EntityMarker(EncodingStrategy):
-> Relation(head='Google', tail='Sergey Brin').
"""
- special_tokens: Set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"}
+ special_tokens: set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"}
def encode_head(self, head: Span, label: Label) -> str:
space_tokenized_text: str = " ".join(token.text for token in head)
@@ -254,8 +249,8 @@ def __init__(
embeddings: DocumentEmbeddings,
label_dictionary: Dictionary,
label_type: str,
- entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]],
- entity_pair_labels: Optional[Set[Tuple[str, str]]] = None,
+ entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]],
+ entity_pair_labels: Optional[set[tuple[str, str]]] = None,
entity_threshold: Optional[float] = None,
cross_augmentation: bool = True,
encoding_strategy: EncodingStrategy = TypedEntityMarker(),
@@ -298,7 +293,7 @@ def __init__(
)
if isinstance(entity_label_types, str):
- self.entity_label_types: Dict[str, Optional[Set[str]]] = {entity_label_types: None}
+ self.entity_label_types: dict[str, Optional[set[str]]] = {entity_label_types: None}
elif isinstance(entity_label_types, Sequence):
self.entity_label_types = {entity_label_type: None for entity_label_type in entity_label_types}
else:
@@ -316,7 +311,7 @@ def __init__(
and self.encoding_strategy.special_tokens
and isinstance(self.embeddings, TransformerDocumentEmbeddings)
):
- special_tokens: List[str] = list(self.encoding_strategy.special_tokens)
+ special_tokens: list[str] = list(self.encoding_strategy.special_tokens)
tokenizer = self.embeddings.tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
self.embeddings.model.resize_token_embeddings(len(tokenizer))
@@ -355,7 +350,7 @@ def _valid_entities(self, sentence: Sentence) -> Iterator[_Entity]:
def _entity_pair_permutations(
self,
sentence: Sentence,
- ) -> Iterator[Tuple[_Entity, _Entity, Optional[str]]]:
+ ) -> Iterator[tuple[_Entity, _Entity, Optional[str]]]:
"""Yields all valid entity pair permutations (relation candidates).
If the passed sentence contains relation annotations,
@@ -370,10 +365,10 @@ def _entity_pair_permutations(
Yields:
Tuples of (HEAD, TAIL, gold_label): The head and tail `_Entity`s` have span references to the passed sentence.
"""
- valid_entities: List[_Entity] = list(self._valid_entities(sentence))
+ valid_entities: list[_Entity] = list(self._valid_entities(sentence))
# Use a dictionary to find gold relation annotations for a given entity pair
- relation_to_gold_label: Dict[str, str] = {
+ relation_to_gold_label: dict[str, str] = {
relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value
for relation in sentence.get_relations(self.label_type)
}
@@ -420,13 +415,13 @@ def _encode_sentence(
assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence."
# Pre-compute non-leading head and tail tokens for entity masking
- non_leading_head_tokens: List[Token] = head.span.tokens[1:]
- non_leading_tail_tokens: List[Token] = tail.span.tokens[1:]
+ non_leading_head_tokens: list[Token] = head.span.tokens[1:]
+ non_leading_tail_tokens: list[Token] = tail.span.tokens[1:]
# We can not use the plaintext of the head/tail span in the sentence as the mask/marker
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
- encoded_sentence_tokens: List[str] = []
+ encoded_sentence_tokens: list[str] = []
for token in original_sentence:
if token is head.span[0]:
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label))
@@ -456,7 +451,7 @@ def _encode_sentence(
def _encode_sentence_for_inference(
self,
sentence: Sentence,
- ) -> Iterator[Tuple[EncodedSentence, Relation]]:
+ ) -> Iterator[tuple[EncodedSentence, Relation]]:
"""Create Encoded Sentences and Relation pairs for Inference.
Yields encoded sentences annotated with their gold relation and
@@ -505,7 +500,7 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
yield masked_sentence
- def transform_sentence(self, sentences: Union[Sentence, List[Sentence]]) -> List[EncodedSentence]:
+ def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]:
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
For more information on the internal sentence transformation procedure,
@@ -541,7 +536,7 @@ def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset
Returns: A dataset of encoded sentences specific to the `RelationClassifier`
"""
data_loader: DataLoader = DataLoader(dataset, batch_size=1)
- original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)]
+ original_sentences: list[Sentence] = [batch[0] for batch in iter(data_loader)]
return FlairDatapointDataset(self.transform_sentence(original_sentences))
def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]:
@@ -568,10 +563,10 @@ def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]:
)
def _get_embedding_for_data_point(self, prediction_data_point: EncodedSentence) -> torch.Tensor:
- embedding_names: List[str] = self.embeddings.get_names()
+ embedding_names: list[str] = self.embeddings.get_names()
return prediction_data_point.get_embedding(embedding_names)
- def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> List[EncodedSentence]:
+ def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> list[EncodedSentence]:
"""Returns the encoded sentences to which labels are added.
To encode sentences, use the `transform` function of the `RelationClassifier`.
@@ -597,14 +592,14 @@ def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> List[Enco
def predict(
self,
- sentences: Union[List[Sentence], List[EncodedSentence], Sentence, EncodedSentence],
+ sentences: Union[list[Sentence], list[EncodedSentence], Sentence, EncodedSentence],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
label_name: Optional[str] = None,
return_loss: bool = False,
embedding_storage_mode: EmbeddingStorageMode = "none",
- ) -> Optional[Tuple[torch.Tensor, int]]:
+ ) -> Optional[tuple[torch.Tensor, int]]:
"""Predicts the class labels for the given sentence(s).
Standard `Sentence` objects and `EncodedSentences` specific to the `RelationClassifier` are allowed as input.
@@ -626,14 +621,14 @@ def predict(
if not isinstance(sentences, list):
sentences = [sentences]
- loss: Optional[Tuple[torch.Tensor, int]]
- encoded_sentences: List[EncodedSentence]
+ loss: Optional[tuple[torch.Tensor, int]]
+ encoded_sentences: list[EncodedSentence]
if all(isinstance(sentence, EncodedSentence) for sentence in sentences):
# Deal with the case where all sentences are encoded sentences
# mypy does not infer the type of "sentences" restricted by the if statement
- encoded_sentences = cast(List[EncodedSentence], sentences)
+ encoded_sentences = cast(list[EncodedSentence], sentences)
loss = super().predict(
encoded_sentences,
mini_batch_size=mini_batch_size,
@@ -646,8 +641,8 @@ def predict(
elif all(not isinstance(sentence, EncodedSentence) for sentence in sentences):
# Deal with the case where all sentences are standard (non-encoded) sentences
- Sentence.set_context_for_sentences(cast(List[Sentence], sentences))
- sentences_with_relation_reference: List[Tuple[EncodedSentence, Relation]] = list(
+ Sentence.set_context_for_sentences(cast(list[Sentence], sentences))
+ sentences_with_relation_reference: list[tuple[EncodedSentence, Relation]] = list(
itertools.chain.from_iterable(self._encode_sentence_for_inference(sentence) for sentence in sentences)
)
@@ -672,8 +667,8 @@ def predict(
return loss if return_loss else None
- def _get_state_dict(self) -> Dict[str, Any]:
- model_state: Dict[str, Any] = {
+ def _get_state_dict(self) -> dict[str, Any]:
+ model_state: dict[str, Any] = {
**super()._get_state_dict(),
"embeddings": self.embeddings.save_embeddings(use_state_dict=False),
"label_dictionary": self.label_dictionary,
@@ -689,7 +684,7 @@ def _get_state_dict(self) -> Dict[str, Any]:
return model_state
@classmethod
- def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
+ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
return super()._init_model_with_state_dict(
state,
embeddings=state["embeddings"],
@@ -719,7 +714,7 @@ def allow_unk_tag(self) -> bool:
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
yield from super().get_used_tokens(corpus, context_length, respect_document_boundaries)
for sentence in _iter_dataset(corpus.get_all_sentences()):
for span in sentence.get_spans(self.label_type):
@@ -727,7 +722,7 @@ def get_used_tokens(
yield self.encoding_strategy.encode_tail(span, span.get_label(self.label_type)).split(" ")
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationClassifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "RelationClassifier":
from typing import cast
return cast("RelationClassifier", super().load(model_path=model_path))
diff --git a/flair/models/relation_extractor_model.py b/flair/models/relation_extractor_model.py
index 795e8a517f..0c56abf5bd 100644
--- a/flair/models/relation_extractor_model.py
+++ b/flair/models/relation_extractor_model.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Optional, Union
import torch
@@ -18,7 +18,7 @@ def __init__(
embeddings: flair.embeddings.TokenEmbeddings,
label_type: str,
entity_label_type: str,
- entity_pair_filters: Optional[List[Tuple[str, str]]] = None,
+ entity_pair_filters: Optional[list[tuple[str, str]]] = None,
pooling_operation: str = "first_last",
train_on_gold_pairs_only: bool = False,
**classifierargs,
@@ -56,13 +56,13 @@ def __init__(
# whether to use gold entity pairs, and whether to filter entity pairs by type
if entity_pair_filters is not None:
- self.entity_pair_filters: Optional[Set[Tuple[str, str]]] = set(entity_pair_filters)
+ self.entity_pair_filters: Optional[set[tuple[str, str]]] = set(entity_pair_filters)
else:
self.entity_pair_filters = None
self.to(flair.device)
- def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]:
+ def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Relation]:
entity_pairs = []
entity_spans = sentence.get_spans(self.entity_label_type)
@@ -172,7 +172,7 @@ def _fetch_model(model_name) -> str:
return model_name
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationExtractor":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "RelationExtractor":
from typing import cast
return cast("RelationExtractor", super().load(model_path=model_path))
diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py
index 16f20a0ddf..7bc9ec051b 100644
--- a/flair/models/sequence_tagger_model.py
+++ b/flair/models/sequence_tagger_model.py
@@ -1,7 +1,7 @@
import logging
import tempfile
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union, cast
+from typing import Any, Optional, Union, cast
import torch
import torch.nn
@@ -22,6 +22,17 @@
class SequenceTagger(flair.nn.Classifier[Sentence]):
+ """The SequenceTagger is one of two main architectures in Flair used for sequence tagging.
+
+ Sequence tagging means classifying words in a sentence, for instance for part-of-speech tagging or named entity
+ recognition. The SequenceTagger implements the "classic" model based on the LSTM-CRF architecture: words are first
+ embedded using one or multiple :class:`flair.embeddings.TokenEmbeddings`, these embeddings are then passed to the
+ LSTM. Its hidden states for each input word are used to make the final prediction with a softmax classifier.
+ For decoding, the SequenceTagger by default uses a CRF approach.
+
+ Alternatively, you can use the class :class:`flair.models.TokenClassifier` for sequence tagging without a LSTM-CRF.
+ """
+
def __init__(
self,
embeddings: TokenEmbeddings,
@@ -40,13 +51,11 @@ def __init__(
word_dropout: float = 0.05,
locked_dropout: float = 0.5,
train_initial_hidden_state: bool = False,
- loss_weights: Optional[Dict[str, float]] = None,
+ loss_weights: Optional[dict[str, float]] = None,
init_from_state_dict: bool = False,
allow_unk_predictions: bool = False,
) -> None:
- """Sequence Tagger class for predicting labels for single tokens. Can be parameterized by several attributes.
-
- In case of multitask learning, pass shared embeddings or shared rnn into respective attributes.
+ """Constructor for this class.
Args:
embeddings: Embeddings to use during training and prediction
@@ -204,7 +213,7 @@ def __init__(
def label_type(self):
return self.tag_type
- def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor:
+ def _init_loss_weights(self, loss_weights: dict[str, float]) -> torch.Tensor:
"""Initializes the loss weights based on given dictionary.
Args:
@@ -267,7 +276,17 @@ def RNN(
return RNN
- def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]:
+ """Conducts a forward pass through the SequenceTagger using labeled sentences and return the loss.
+
+ Args:
+ sentences: A batch of labeled sentences.
+
+ Returns:
+ A tuple consisting of the loss tensor and the number of tokens in the batch.
+
+ """
+
# if there are no sentences, there is no loss
if len(sentences) == 0:
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
@@ -281,7 +300,7 @@ def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
# calculate loss given scores and labels
return self._calculate_loss(scores, gold_labels)
- def _prepare_tensors(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, torch.LongTensor]:
+ def _prepare_tensors(self, data_points: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, torch.LongTensor]:
sentences = [data_points] if not isinstance(data_points, list) else data_points
self.embeddings.embed(sentences)
@@ -291,7 +310,7 @@ def _prepare_tensors(self, data_points: Union[List[Sentence], Sentence]) -> Tupl
return sentence_tensor, lengths
def forward(self, sentence_tensor: torch.Tensor, lengths: torch.LongTensor):
- """Forward propagation through network.
+ """Forward pass through the SequenceTagger.
Args:
sentence_tensor: A tensor representing the batch of sentences.
@@ -331,15 +350,15 @@ def forward(self, sentence_tensor: torch.Tensor, lengths: torch.LongTensor):
return scores
- def _calculate_loss(self, scores: torch.Tensor, labels: torch.LongTensor) -> Tuple[torch.Tensor, int]:
+ def _calculate_loss(self, scores: torch.Tensor, labels: torch.LongTensor) -> tuple[torch.Tensor, int]:
if labels.size(0) == 0:
return torch.tensor(0.0, requires_grad=True, device=flair.device), 1
return self.loss_function(scores, labels), len(labels)
- def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torch.LongTensor, torch.Tensor]:
+ def _make_padded_tensor_for_batch(self, sentences: list[Sentence]) -> tuple[torch.LongTensor, torch.Tensor]:
names = self.embeddings.get_names()
- lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
+ lengths: list[int] = [len(sentence.tokens) for sentence in sentences]
longest_token_sequence_in_batch: int = max(lengths)
pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * longest_token_sequence_in_batch,
@@ -382,7 +401,7 @@ def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor):
return scores
- def _get_gold_labels(self, sentences: List[Sentence]) -> List[str]:
+ def _get_gold_labels(self, sentences: list[Sentence]) -> list[str]:
"""Extracts gold labels from each sentence.
Args:
@@ -419,7 +438,7 @@ def _get_gold_labels(self, sentences: List[Sentence]) -> List[str]:
return labels
- def _prepare_label_tensor(self, sentences: List[Sentence]):
+ def _prepare_label_tensor(self, sentences: list[Sentence]):
gold_labels = self._get_gold_labels(sentences)
labels = torch.tensor(
[self.label_dictionary.get_idx_for_item(label) for label in gold_labels],
@@ -430,7 +449,7 @@ def _prepare_label_tensor(self, sentences: List[Sentence]):
def predict(
self,
- sentences: Union[List[Sentence], Sentence],
+ sentences: Union[list[Sentence], Sentence],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -439,7 +458,11 @@ def predict(
embedding_storage_mode="none",
force_token_predictions: bool = False,
):
- """Predicts labels for current batch with CRF or Softmax.
+ """Call this method to predict labels for sentences.
+
+ Predictions are directly added to the Sentence objects that are passed to this method. This means that
+ the predict() method does not return predictions. Rather, predictions are stored at each sentence and can
+ be retrieved by calling :func:`flair.data.Sentence.get_labels()` on each :class:`flair.data.Sentence`.
Args:
sentences: List of sentences in batch
@@ -462,7 +485,7 @@ def predict(
if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset):
sentences = [sentences]
- Sentence.set_context_for_sentences(cast(List[Sentence], sentences))
+ Sentence.set_context_for_sentences(cast(list[Sentence], sentences))
# filter empty sentences
sentences = [sentence for sentence in sentences if len(sentence) > 0]
@@ -542,7 +565,7 @@ def predict(
return overall_loss, label_count
return None
- def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool):
+ def _standard_inference(self, features: torch.Tensor, batch: list[Sentence], probabilities_for_all_classes: bool):
"""Softmax over emission scores from forward propagation.
Args:
@@ -573,7 +596,7 @@ def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], pro
return predictions, all_tags
- def _all_scores_for_token(self, sentences: List[Sentence], score_tensor: torch.Tensor, lengths: List[int]):
+ def _all_scores_for_token(self, sentences: list[Sentence], score_tensor: torch.Tensor, lengths: list[int]):
"""Returns all scores for each tag in tag dictionary."""
scores = score_tensor.numpy()
tokens = [token for sentence in sentences for token in sentence]
@@ -589,7 +612,7 @@ def _all_scores_for_token(self, sentences: List[Sentence], score_tensor: torch.T
previous = 0
for length in lengths:
prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
- previous = length
+ previous += length
return prob_tags_per_sentence
def _get_state_dict(self):
@@ -861,7 +884,7 @@ def push_to_hub(
return repo_url
@staticmethod
- def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
+ def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]:
filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
if len(sentences) != len(filtered_sentences):
log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.")
@@ -919,7 +942,7 @@ def _print_predictions(self, batch, gold_label_type):
return lines
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SequenceTagger":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "SequenceTagger":
from typing import cast
return cast("SequenceTagger", super().load(model_path=model_path))
diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py
index ed84d1a6b7..5c87f49f0c 100644
--- a/flair/models/sequence_tagger_utils/viterbi.py
+++ b/flair/models/sequence_tagger_utils/viterbi.py
@@ -1,5 +1,3 @@
-from typing import Tuple
-
import numpy as np
import torch
import torch.nn
@@ -7,7 +5,7 @@
from torch.nn.utils.rnn import pack_padded_sequence
import flair
-from flair.data import Dictionary, Label, List, Sentence
+from flair.data import Dictionary, Label, Sentence
START_TAG: str = ""
STOP_TAG: str = ""
@@ -141,8 +139,8 @@ def __init__(self, tag_dictionary: Dictionary) -> None:
self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)
def decode(
- self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence]
- ) -> Tuple[List, List]:
+ self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: list[Sentence]
+ ) -> tuple[list[list[tuple[str, float]]], list[list[list[Label]]]]:
"""Decoding function returning the most likely sequence of tags.
Args:
@@ -211,7 +209,7 @@ def decode(
scores = softmax(scores_upto_t, dim=2)
confidences = torch.max(scores, dim=2)
- tags = []
+ tags: list[list[tuple[str, float]]] = []
for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths):
tags.append(
[
@@ -230,7 +228,7 @@ def _all_scores_for_token(
score_tensor: torch.Tensor,
tag_sequences: torch.Tensor,
lengths: torch.IntTensor,
- sentences: List[Sentence],
+ sentences: list[Sentence],
):
"""Returns all scores for each tag in tag dictionary."""
scores = score_tensor.numpy()
diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py
index a7a41bdb5d..f4171fdb27 100644
--- a/flair/models/tars_model.py
+++ b/flair/models/tars_model.py
@@ -3,7 +3,7 @@
from abc import ABC
from collections import OrderedDict
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Optional, Union
import numpy as np
import torch
@@ -30,14 +30,14 @@
class FewshotClassifier(flair.nn.Classifier[Sentence], ABC):
def __init__(self) -> None:
self._current_task = None
- self._task_specific_attributes: Dict[str, Dict[str, Any]] = {}
+ self._task_specific_attributes: dict[str, dict[str, Any]] = {}
self.label_nearest_map = None
self.tars_model: flair.nn.Classifier[Sentence]
self.separator: str
super().__init__()
- def forward_loss(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, data_points: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]:
if not isinstance(data_points, list):
data_points = [data_points]
@@ -54,7 +54,7 @@ def tars_embeddings(self):
def _get_tars_formatted_sentence(self, label, sentence):
raise NotImplementedError
- def _get_tars_formatted_sentences(self, sentences: List[Sentence]):
+ def _get_tars_formatted_sentences(self, sentences: list[Sentence]):
label_text_pairs = []
all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
for sentence in sentences:
@@ -173,7 +173,7 @@ def is_current_task_multi_label(self):
def add_and_switch_to_new_task(
self,
task_name: str,
- label_dictionary: Union[List, Set, Dictionary, str],
+ label_dictionary: Union[list, set, Dictionary, str],
label_type: str,
multi_label: bool = True,
force_switch: bool = False,
@@ -219,7 +219,7 @@ def add_and_switch_to_new_task(
self.switch_to_task(task_name)
- def list_existing_tasks(self) -> Set[str]:
+ def list_existing_tasks(self) -> set[str]:
"""Lists existing tasks in the loaded TARS model on the console."""
return set(self._task_specific_attributes.keys())
@@ -246,7 +246,7 @@ def _drop_task(self, task_name):
log.warning("No task exists with the name `%s`.", task_name)
@staticmethod
- def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
+ def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]:
filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
if len(sentences) != len(filtered_sentences):
log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.")
@@ -258,8 +258,8 @@ def label_type(self):
def predict_zero_shot(
self,
- sentences: Union[List[Sentence], Sentence],
- candidate_label_set: Union[List[str], Set[str], str],
+ sentences: Union[list[Sentence], Sentence],
+ candidate_label_set: Union[list[str], set[str], str],
multi_label: bool = True,
):
"""Make zero shot predictions from the TARS model.
@@ -307,14 +307,14 @@ def predict_zero_shot(
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
yield from super().get_used_tokens(corpus, context_length, respect_document_boundaries)
for label in self.get_current_label_dictionary().idx2item:
yield [label.decode("utf-8")]
yield [self.separator]
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "FewshotClassifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "FewshotClassifier":
from typing import cast
return cast("FewshotClassifier", super().load(model_path=model_path))
@@ -383,7 +383,7 @@ def __init__(
# transformer separator
self.separator = str(self.tars_embeddings.tokenizer.sep_token)
- if self.tars_embeddings.tokenizer._bos_token:
+ if self.tars_embeddings.tokenizer.bos_token is not None:
self.separator += str(self.tars_embeddings.tokenizer.bos_token)
self.prefix = prefix
@@ -472,7 +472,7 @@ def tars_embeddings(self):
def predict(
self,
- sentences: Union[List[Sentence], Sentence],
+ sentences: Union[list[Sentence], Sentence],
mini_batch_size=32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -532,12 +532,12 @@ def predict(
if not batch:
continue
- tars_sentences: List[Sentence] = []
- all_labels_to_sentence: List[Dict[str, Sentence]] = []
+ tars_sentences: list[Sentence] = []
+ all_labels_to_sentence: list[dict[str, Sentence]] = []
for sentence in batch:
# always remove tags first
sentence.remove_labels(label_name)
- labels_to_sentence: Dict[str, Sentence] = {}
+ labels_to_sentence: dict[str, Sentence] = {}
for label in all_labels:
tars_sentence = self._get_tars_formatted_sentence(label, sentence)
tars_sentences.append(tars_sentence)
@@ -570,7 +570,7 @@ def predict(
if most_probable_first:
import operator
- already_set_indices: List[int] = []
+ already_set_indices: list[int] = []
sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1))
sorted_x.reverse()
@@ -648,7 +648,7 @@ def _print_predictions(self, batch, gold_label_type):
return lines
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TARSTagger":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TARSTagger":
from typing import cast
return cast("TARSTagger", super().load(model_path=model_path))
@@ -718,9 +718,11 @@ def __init__(
)
# transformer separator
- self.separator = str(self.tars_embeddings.tokenizer.sep_token)
- if self.tars_embeddings.tokenizer._bos_token:
- self.separator += str(self.tars_embeddings.tokenizer.bos_token)
+ self.separator = (
+ self.tars_embeddings.tokenizer.sep_token if self.tars_embeddings.tokenizer.sep_token is not None else ""
+ )
+ if self.tars_embeddings.tokenizer.bos_token is not None:
+ self.separator += self.tars_embeddings.tokenizer.bos_token
self.prefix = prefix
self.num_negative_labels_to_sample = num_negative_labels_to_sample
@@ -832,7 +834,7 @@ def tars_embeddings(self):
def predict(
self,
- sentences: Union[List[Sentence], Sentence],
+ sentences: Union[list[Sentence], Sentence],
mini_batch_size=32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -907,12 +909,12 @@ def predict(
if not batch:
continue
- tars_sentences: List[Sentence] = []
- all_labels_to_sentence: List[Dict[str, Sentence]] = []
+ tars_sentences: list[Sentence] = []
+ all_labels_to_sentence: list[dict[str, Sentence]] = []
for sentence in batch:
# always remove tags first
sentence.remove_labels(label_name)
- labels_to_sentence: Dict[str, Sentence] = {}
+ labels_to_sentence: dict[str, Sentence] = {}
for label in all_labels:
tars_sentence = self._get_tars_formatted_sentence(label, sentence)
tars_sentences.append(tars_sentence)
@@ -972,7 +974,7 @@ def predict(
return None
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TARSClassifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TARSClassifier":
from typing import cast
return cast("TARSClassifier", super().load(model_path=model_path))
diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py
index 1b330a0da3..7f4e00d2c4 100644
--- a/flair/models/text_classification_model.py
+++ b/flair/models/text_classification_model.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Any, Dict, List, Union
+from typing import Any, Union
import torch
@@ -56,7 +56,7 @@ def _get_embedding_for_data_point(self, prediction_data_point: Sentence) -> torc
embedding_names = self.embeddings.get_names()
return prediction_data_point.get_embedding(embedding_names)
- def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Sentence]:
+ def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Sentence]:
return [sentence]
def _get_state_dict(self):
@@ -133,7 +133,7 @@ def label_type(self):
return self._label_type
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextClassifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TextClassifier":
from typing import cast
return cast("TextClassifier", super().load(model_path=model_path))
diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py
index 894ce3087e..d1ad98d4e0 100644
--- a/flair/models/text_regression_model.py
+++ b/flair/models/text_regression_model.py
@@ -1,7 +1,7 @@
import logging
import typing
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
import torch
from torch import nn
@@ -43,7 +43,7 @@ def __init__(
def label_type(self):
return self.label_name
- def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[torch.Tensor]:
+ def _prepare_tensors(self, sentences: list[Sentence]) -> tuple[torch.Tensor]:
self.document_embeddings.embed(sentences)
embedding_names = self.document_embeddings.get_names()
text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences]
@@ -55,14 +55,14 @@ def forward(self, *args: torch.Tensor) -> torch.Tensor:
label_scores = self.decoder(text_embedding_tensor)
return label_scores
- def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]:
labels = self._labels_to_tensor(sentences)
text_embedding_tensor = self._prepare_tensors(sentences)
scores = self.forward(*text_embedding_tensor)
return self.loss_function(scores.squeeze(1), labels), len(sentences)
- def _labels_to_tensor(self, sentences: List[Sentence]):
+ def _labels_to_tensor(self, sentences: list[Sentence]):
indices = [
torch.tensor([float(label.value) for label in sentence.get_labels(self.label_name)], dtype=torch.float)
for sentence in sentences
@@ -74,12 +74,12 @@ def _labels_to_tensor(self, sentences: List[Sentence]):
def predict(
self,
- sentences: Union[Sentence, List[Sentence]],
+ sentences: Union[Sentence, list[Sentence]],
mini_batch_size: int = 32,
verbose: bool = False,
label_name: Optional[str] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
- ) -> List[Sentence]:
+ ) -> list[Sentence]:
if label_name is None:
label_name = self.label_name if self.label_name is not None else "label"
@@ -123,7 +123,7 @@ def predict(
return sentences
- def forward_labels_and_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.Tensor]:
+ def forward_labels_and_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, torch.Tensor]:
labels = self._labels_to_tensor(sentences)
text_embedding_tensor = self._prepare_tensors(sentences)
scores = self.forward(*text_embedding_tensor)
@@ -132,13 +132,13 @@ def forward_labels_and_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tens
def evaluate(
self,
- data_points: Union[List[Sentence], Dataset],
+ data_points: Union[list[Sentence], Dataset],
gold_label_type: str,
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
- exclude_labels: Optional[List[str]] = None,
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
+ exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
**kwargs,
@@ -154,7 +154,7 @@ def evaluate(
metric = MetricRegression("Evaluation")
- lines: List[str] = []
+ lines: list[str] = []
total_count = 0
for batch in data_loader:
if isinstance(batch, Sentence):
@@ -227,21 +227,21 @@ def _init_model_with_state_dict(cls, state, **kwargs):
)
@staticmethod
- def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
+ def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]:
filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
if len(sentences) != len(filtered_sentences):
log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.")
return filtered_sentences
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextRegressor":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TextRegressor":
from typing import cast
return cast("TextRegressor", super().load(model_path=model_path))
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
for sentence in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence]
yield [t.text for t in sentence.left_context(context_length, respect_document_boundaries)]
diff --git a/flair/models/triple_classification_model.py b/flair/models/triple_classification_model.py
index 1c1337f9b0..9f1a57a23e 100644
--- a/flair/models/triple_classification_model.py
+++ b/flair/models/triple_classification_model.py
@@ -1,5 +1,4 @@
import typing
-from typing import List
import torch
@@ -69,7 +68,7 @@ def __init__(
def label_type(self):
return self._label_type
- def _get_data_points_from_sentence(self, sentence: TextTriple) -> List[TextTriple]:
+ def _get_data_points_from_sentence(self, sentence: TextTriple) -> list[TextTriple]:
return [sentence]
def _get_embedding_for_data_point(self, prediction_data_point: TextTriple) -> torch.Tensor:
@@ -121,7 +120,7 @@ def _init_model_with_state_dict(cls, state, **kwargs):
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)]
diff --git a/flair/models/word_tagger_model.py b/flair/models/word_tagger_model.py
index 2d32a54b06..5040a63728 100644
--- a/flair/models/word_tagger_model.py
+++ b/flair/models/word_tagger_model.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Any, Dict, List, Union
+from typing import Any, Union
import torch
from deprecated.sphinx import deprecated
@@ -99,7 +99,7 @@ def _get_embedding_for_data_point(self, prediction_data_point: Token) -> torch.T
names = self.embeddings.get_names()
return prediction_data_point.get_embedding(names)
- def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Token]:
+ def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Token]:
# special handling during training if this is a span prediction problem
if self.training and self.span_prediction_problem:
for token in sentence.tokens:
@@ -125,7 +125,7 @@ def _post_process_batch_after_prediction(self, batch, label_name):
for sentence in batch:
# internal variables
previous_tag = "O-"
- current_span: List[Token] = []
+ current_span: list[Token] = []
for token in sentence:
bioes_tag = token.get_label(label_name).value
@@ -222,7 +222,7 @@ def _print_predictions(self, batch, gold_label_type):
return lines
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TokenClassifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TokenClassifier":
from typing import cast
return cast("TokenClassifier", super().load(model_path=model_path))
diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py
index 65f802148a..48cdbf39b0 100644
--- a/flair/nn/decoder.py
+++ b/flair/nn/decoder.py
@@ -1,5 +1,5 @@
import logging
-from typing import List, Optional
+from typing import Optional
import torch
@@ -151,11 +151,11 @@ class LabelVerbalizerDecoder(torch.nn.Module):
def __init__(self, label_embedding: Embeddings, label_dictionary: Dictionary):
super().__init__()
self.label_embedding = label_embedding
- self.verbalized_labels: List[Sentence] = self.verbalize_labels(label_dictionary)
+ self.verbalized_labels: list[Sentence] = self.verbalize_labels(label_dictionary)
self.to(flair.device)
@staticmethod
- def verbalize_labels(label_dictionary: Dictionary) -> List[Sentence]:
+ def verbalize_labels(label_dictionary: Dictionary) -> list[Sentence]:
"""Takes a label dictionary and returns a list of sentences with verbalized labels.
Args:
diff --git a/flair/nn/model.py b/flair/nn/model.py
index eeb5b7c84a..f28b5e993f 100644
--- a/flair/nn/model.py
+++ b/flair/nn/model.py
@@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from collections import Counter
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Optional, Union
import torch.nn
from torch import Tensor
@@ -17,6 +17,7 @@
from flair.class_utils import get_non_abstract_subclasses
from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
+from flair.distributed_utils import is_main_process
from flair.embeddings import Embeddings
from flair.embeddings.base import load_embeddings
from flair.file_utils import Tqdm, load_torch_state
@@ -31,7 +32,7 @@ class Model(torch.nn.Module, typing.Generic[DT], ABC):
Every new type of model must implement these methods.
"""
- model_card: Optional[Dict[str, Any]] = None
+ model_card: Optional[dict[str, Any]] = None
@property
@abstractmethod
@@ -40,7 +41,7 @@ def label_type(self) -> str:
raise NotImplementedError
@abstractmethod
- def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]:
"""Performs a forward pass and returns a loss tensor for backpropagation.
Implement this to enable training.
@@ -50,13 +51,13 @@ def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]:
@abstractmethod
def evaluate(
self,
- data_points: Union[List[DT], Dataset],
+ data_points: Union[list[DT], Dataset],
gold_label_type: str,
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
- exclude_labels: Optional[List[str]] = None,
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
+ exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
**kwargs,
@@ -68,12 +69,12 @@ def evaluate(
Args:
data_points: The labeled data_points to evaluate.
gold_label_type: The label type indicating the gold labels
- out_path: Optional output path to store predictions
+ out_path: Optional output path to store predictions.
embedding_storage_mode: One of 'none', 'cpu' or 'gpu'. 'none' means all embeddings are deleted and freshly
recomputed, 'cpu' means all embeddings are stored on CPU, or 'gpu' means all embeddings are stored on GPU
- mini_batch_size: The batch_size to use for predictions
- main_evaluation_metric: Specify which metric to highlight as main_score
- exclude_labels: Specify classes that won't be considered in evaluation
+ mini_batch_size: The batch_size to use for predictions.
+ main_evaluation_metric: Specify which metric to highlight as main_score.
+ exclude_labels: Specify classes that won't be considered in evaluation.
gold_label_dictionary: Specify which classes should be considered, all other classes will be taken as .
return_loss: Weather to additionally compute the loss on the data-points.
**kwargs: Arguments that will be ignored.
@@ -84,7 +85,7 @@ def evaluate(
exclude_labels = exclude_labels if exclude_labels is not None else []
raise NotImplementedError
- def _get_state_dict(self) -> Dict:
+ def _get_state_dict(self) -> dict:
"""Returns the state dictionary for this model."""
# Always include the name of the Model class for which the state dict holds
state_dict = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__}
@@ -92,7 +93,7 @@ def _get_state_dict(self) -> Dict:
return state_dict
@classmethod
- def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
+ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
"""Initialize the model from a state dictionary."""
if "embeddings" in kwargs:
embeddings = kwargs.pop("embeddings")
@@ -115,8 +116,8 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
"""Saves the current model to the provided file.
Args:
- model_file: the model file
- checkpoint: currently unused.
+ model_file: The model file.
+ checkpoint: This parameter is currently unused.
"""
model_state = self._get_state_dict()
@@ -128,13 +129,14 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
torch.save(model_state, str(model_file), pickle_protocol=4)
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model":
- """Loads the model from the given file.
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":
+ """Loads a Flair model from the given file or state dictionary.
Args:
- model_path: the model file or the already loaded state dict
+ model_path: Either the path to the model (as string or `Path` variable) or the already loaded state dict.
- Returns: the loaded text classifier model
+ Returns:
+ The loaded Flair model.
"""
# if this class is abstract, go through all inheriting classes and try to fetch and load the model
if inspect.isabstract(cls):
@@ -206,6 +208,14 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model":
return model
def print_model_card(self):
+ """
+ This method produces a log message that includes all recorded parameters the model was trained with.
+
+ The model card includes information such as the Flair, PyTorch and Transformers versions used during training,
+ and the training parameters.
+
+ Only available for models trained with with Flair >= 0.9.1.
+ """
if hasattr(self, "model_card"):
param_out = "\n------------------------------------\n"
param_out += "--------- Flair Model Card ---------\n"
@@ -238,7 +248,7 @@ class ReduceTransformerVocabMixin(ABC):
@abstractmethod
def get_used_tokens(
self, corpus: Corpus, context_lenth: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
pass
@@ -251,13 +261,13 @@ class Classifier(Model[DT], typing.Generic[DT], ReduceTransformerVocabMixin, ABC
def evaluate(
self,
- data_points: Union[List[DT], Dataset],
+ data_points: Union[list[DT], Dataset],
gold_label_type: str,
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
- exclude_labels: Optional[List[str]] = None,
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
+ exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
**kwargs,
@@ -281,17 +291,17 @@ def evaluate(
average_over = 0
# variables for printing
- lines: List[str] = []
+ lines: list[str] = []
# variables for computing scores
- all_spans: Set[str] = set()
+ all_spans: set[str] = set()
all_true_values = {}
all_predicted_values = {}
loader = DataLoader(data_points, batch_size=mini_batch_size)
sentence_id = 0
- for batch in Tqdm.tqdm(loader):
+ for batch in Tqdm.tqdm(loader, disable=not is_main_process()):
# remove any previously predicted labels
for datapoint in batch:
datapoint.remove_labels("predicted")
@@ -476,7 +486,7 @@ def evaluate(
)
# Create and populate score object for logging with all evaluation values, plus the loss
- scores: Dict[Union[Tuple[str, ...], str], Any] = {}
+ scores: dict[Union[tuple[str, ...], str], Any] = {}
for avg_type in ("micro avg", "macro avg"):
for metric_type in ("f1-score", "precision", "recall"):
@@ -514,7 +524,7 @@ def evaluate(
@abstractmethod
def predict(
self,
- sentences: Union[List[DT], DT],
+ sentences: Union[list[DT], DT],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -522,22 +532,26 @@ def predict(
return_loss: bool = False,
embedding_storage_mode: EmbeddingStorageMode = "none",
):
- """Predicts the class labels for the given sentences.
+ """Uses the model to predict labels for a given set of data points.
- The labels are directly added to the sentences.
+ The method does not directly return the predicted labels. Rather, labels are added as :class:`flair.data.Label` objects to
+ the respective data points. You can then access these predictions by calling :func:`flair.data.DataPoint.get_labels`
+ on each data point that you passed through this method.
Args:
- sentences: list of sentences
- mini_batch_size: mini batch size to use
- return_probabilities_for_all_classes: return probabilities for all classes instead of only best predicted
- verbose: set to True to display a progress bar
- return_loss: set to True to return loss
- label_name: set this to change the name of the label type that is predicted # noqa: E501
- embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory. # noqa: E501
+ sentences: The data points for which the model should predict labels, most commonly Sentence objects.
+ mini_batch_size: The mini batch size to use. Setting this value higher typically makes predictions faster,
+ but also costs more memory.
+ return_probabilities_for_all_classes: If set to True, the model will store probabilities for all classes
+ instead of only the predicted class.
+ verbose: If set to True, will display a progress bar while predicting. By default, this parameter is set to False.
+ return_loss: Set this to True to return loss (only possible if gold labels are set for the sentences).
+ label_name: Optional parameter that if set, changes the identifier of the label type that is predicted. # noqa: E501
+ embedding_storage_mode: Default is 'none' which is always best. Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory. # noqa: E501
"""
raise NotImplementedError
- def _print_predictions(self, batch: List[DT], gold_label_type: str) -> List[str]:
+ def _print_predictions(self, batch: list[DT], gold_label_type: str) -> list[str]:
lines = []
for datapoint in batch:
# check if there is a label mismatch
@@ -557,14 +571,14 @@ def _print_predictions(self, batch: List[DT], gold_label_type: str) -> List[str]
def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
- ) -> typing.Iterable[List[str]]:
+ ) -> typing.Iterable[list[str]]:
for sentence in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence]
yield [t.text for t in sentence.left_context(context_length, respect_document_boundaries)]
yield [t.text for t in sentence.right_context(context_length, respect_document_boundaries)]
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Classifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Classifier":
from typing import cast
return cast("Classifier", super().load(model_path=model_path))
@@ -589,7 +603,7 @@ def __init__(
word_dropout: float = 0.0,
multi_label: bool = False,
multi_label_threshold: float = 0.5,
- loss_weights: Optional[Dict[str, float]] = None,
+ loss_weights: Optional[dict[str, float]] = None,
decoder: Optional[torch.nn.Module] = None,
inverse_model: bool = False,
train_on_gold_pairs_only: bool = False,
@@ -663,21 +677,21 @@ def _get_embedding_for_data_point(self, prediction_data_point: DT2) -> torch.Ten
raise NotImplementedError
@abstractmethod
- def _get_data_points_from_sentence(self, sentence: DT) -> List[DT2]:
+ def _get_data_points_from_sentence(self, sentence: DT) -> list[DT2]:
"""Returns the data_points to which labels are added.
The results should be of any type that inherits from DataPoint (Sentence, Span, Token, ... objects).
"""
raise NotImplementedError
- def _get_data_points_for_batch(self, sentences: List[DT]) -> List[DT2]:
+ def _get_data_points_for_batch(self, sentences: list[DT]) -> list[DT2]:
"""Returns the data_points to which labels are added.
The results should be of any type that inherits from DataPoint (Sentence, Span, Token, ... objects).
"""
return [data_point for sentence in sentences for data_point in self._get_data_points_from_sentence(sentence)]
- def _get_label_of_datapoint(self, data_point: DT2) -> List[str]:
+ def _get_label_of_datapoint(self, data_point: DT2) -> list[str]:
"""Extracts the labels from the data points.
Each data point might return a list of strings, representing multiple labels.
@@ -701,7 +715,7 @@ def multi_label_threshold(self, x): # setter method
else:
self._multi_label_threshold = {"default": x}
- def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tensor:
+ def _prepare_label_tensor(self, prediction_data_points: list[DT2]) -> torch.Tensor:
labels = [self._get_label_of_datapoint(dp) for dp in prediction_data_points]
if self.multi_label:
return torch.tensor(
@@ -726,7 +740,7 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens
device=flair.device,
)
- def _encode_data_points(self, sentences: List[DT], data_points: List[DT2]) -> Tensor:
+ def _encode_data_points(self, sentences: list[DT], data_points: list[DT2]) -> Tensor:
# embed sentences
if self.should_embed_sentence:
self.embeddings.embed(sentences)
@@ -747,7 +761,7 @@ def _mask_scores(self, scores: Tensor, data_points) -> Tensor:
"""Classes that inherit from DefaultClassifier may optionally mask scores."""
return scores
- def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]:
+ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]:
# make a forward pass to produce embedded data points and labels
sentences = [sentence for sentence in sentences if self._filter_data_point(sentence)]
@@ -773,10 +787,10 @@ def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]:
# calculate the loss
return self._calculate_loss(scores, label_tensor)
- def _calculate_loss(self, scores: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ def _calculate_loss(self, scores: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, int]:
return self.loss_function(scores, labels), labels.size(0)
- def _sort_data(self, data_points: List[DT]) -> List[DT]:
+ def _sort_data(self, data_points: list[DT]) -> list[DT]:
if len(data_points) == 0:
return []
@@ -784,16 +798,16 @@ def _sort_data(self, data_points: List[DT]) -> List[DT]:
return data_points
# filter empty sentences
- sentences = [sentence for sentence in typing.cast(List[Sentence], data_points) if len(sentence) > 0]
+ sentences = [sentence for sentence in typing.cast(list[Sentence], data_points) if len(sentence) > 0]
# reverse sort all sequences by their length
reordered_sentences = sorted(sentences, key=len, reverse=True)
- return typing.cast(List[DT], reordered_sentences)
+ return typing.cast(list[DT], reordered_sentences)
def predict(
self,
- sentences: Union[List[DT], DT],
+ sentences: Union[list[DT], DT],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
@@ -824,7 +838,7 @@ def predict(
sentences = [sentences]
if isinstance(sentences[0], Sentence):
- Sentence.set_context_for_sentences(typing.cast(List[Sentence], sentences))
+ Sentence.set_context_for_sentences(typing.cast(list[Sentence], sentences))
reordered_sentences = self._sort_data(sentences)
@@ -832,7 +846,7 @@ def predict(
return sentences
if len(reordered_sentences) > mini_batch_size:
- batches: Union[DataLoader, List[List[DT]]] = DataLoader(
+ batches: Union[DataLoader, list[list[DT]]] = DataLoader(
dataset=FlairDatapointDataset(reordered_sentences),
batch_size=mini_batch_size,
)
@@ -876,9 +890,7 @@ def predict(
filtered_indices = []
has_unknown_label = False
for idx, dp in enumerate(data_points):
- if all(
- label in self.label_dictionary.get_items() for label in self._get_label_of_datapoint(dp)
- ):
+ if all(self.label_dictionary.has_item(label) for label in self._get_label_of_datapoint(dp)):
filtered_indices.append(idx)
else:
has_unknown_label = True
@@ -981,7 +993,7 @@ def _get_state_dict(self):
state["locked_dropout"] = self.locked_dropout.dropout_rate
state["multi_label"] = self.multi_label
state["multi_label_threshold"] = self.multi_label_threshold
- state["loss_weights"] = self.loss_weights
+ state["loss_weights"] = self.weight_dict
state["train_on_gold_pairs_only"] = self.train_on_gold_pairs_only
state["inverse_model"] = self.inverse_model
if self._custom_decoder:
@@ -990,7 +1002,7 @@ def _get_state_dict(self):
return state
@classmethod
- def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "DefaultClassifier":
+ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "DefaultClassifier":
from typing import cast
return cast("DefaultClassifier", super().load(model_path=model_path))
diff --git a/flair/nn/multitask.py b/flair/nn/multitask.py
index 6fa2f20c02..42c5665141 100644
--- a/flair/nn/multitask.py
+++ b/flair/nn/multitask.py
@@ -1,4 +1,5 @@
-from typing import Iterable, Tuple, Union
+from collections.abc import Iterable
+from typing import Union
from flair.data import Corpus, MultiCorpus
from flair.models import MultitaskModel
@@ -6,18 +7,18 @@
def make_multitask_model_and_corpus(
- mapping: Iterable[Union[Tuple[Classifier, Corpus], Tuple[Classifier, Corpus, float]]]
-) -> Tuple[Model, Corpus]:
+ mapping: Iterable[Union[tuple[Classifier, Corpus], tuple[Classifier, Corpus, float]]]
+) -> tuple[Model, Corpus]:
models = []
corpora = []
loss_factors = []
ids = []
- for task_id, map in enumerate(mapping):
- models.append(map[0])
- corpora.append(map[1])
- if len(map) == 3:
- loss_factors.append(map[2])
+ for task_id, _map in enumerate(mapping):
+ models.append(_map[0])
+ corpora.append(_map[1])
+ if len(_map) == 3:
+ loss_factors.append(_map[2])
else:
loss_factors.append(1.0)
diff --git a/flair/samplers.py b/flair/samplers.py
index 135dfb3310..53ad40c4c5 100644
--- a/flair/samplers.py
+++ b/flair/samplers.py
@@ -1,7 +1,6 @@
import logging
import random
from collections import defaultdict
-from typing import Dict
import torch
from torch.utils.data.sampler import Sampler
@@ -36,7 +35,7 @@ def set_dataset(self, data_source):
self.indices = list(range(len(data_source)))
# first determine the distribution of classes in the dataset
- label_count: Dict[str, int] = defaultdict(int)
+ label_count: dict[str, int] = defaultdict(int)
for sentence in data_source:
for label in sentence.labels:
label_count[label.value] += 1
diff --git a/flair/splitter.py b/flair/splitter.py
index 9f7e502c87..6246969f28 100644
--- a/flair/splitter.py
+++ b/flair/splitter.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Union
+from typing import Any, Optional, Union
from segtok.segmenter import split_multi
@@ -16,16 +16,34 @@ class SentenceSplitter(ABC):
r"""An abstract class representing a :class:`SentenceSplitter`.
Sentence splitters are used to represent algorithms and models to split plain text into
- sentences and individual tokens / words. All subclasses should overwrite :meth:`splits`,
- which splits the given plain text into a sequence of sentences (:class:`Sentence`). The
- individual sentences are in turn subdivided into tokens / words. In most cases, this can
- be controlled by passing custom implementation of :class:`Tokenizer`.
+ sentences and individual tokens / words. All subclasses should overwrite :func:`split`,
+ which splits the given plain text into a list of :class:`flair.data.Sentence` objects. The
+ individual sentences are in turn subdivided into tokens. In most cases, this can
+ be controlled by passing custom implementation of :class:`flair.tokenization.Tokenizer`.
Moreover, subclasses may overwrite :meth:`name`, returning a unique identifier representing
the sentence splitter's configuration.
+
+ The most common class in Flair that implements this base class is :class:`SegtokSentenceSplitter`.
"""
- def split(self, text: str, link_sentences: Optional[bool] = True) -> List[Sentence]:
+ def split(self, text: str, link_sentences: bool = True) -> list[Sentence]:
+ """
+ Takes as input a text as a plain string and outputs a list of :class:`flair.data.Sentence` objects.
+
+ If link_sentences is set (by default, it is). The :class:`flair.data.Sentence` objects will include pointers
+ to the preceding and following sentences in the original text. This way, the original sequence information will
+ always be preserved.
+
+ Args:
+ text (str): The plain text to split.
+ link_sentences (bool): If set to True, :class:`flair.data.Sentence` objects will include pointers
+ to the preceding and following sentences in the original text.
+
+ Returns:
+ A list of :class:`flair.data.Sentence` objects that each represent one sentence in the given text.
+
+ """
sentences = self._perform_split(text)
if not link_sentences:
return sentences
@@ -34,15 +52,17 @@ def split(self, text: str, link_sentences: Optional[bool] = True) -> List[Senten
return sentences
@abstractmethod
- def _perform_split(self, text: str) -> List[Sentence]:
+ def _perform_split(self, text: str) -> list[Sentence]:
raise NotImplementedError
@property
def name(self) -> str:
+ """A string identifier of the sentence splitter."""
return self.__class__.__name__
@property
def tokenizer(self) -> Tokenizer:
+ """The :class:`flair.tokenization.Tokenizer` class used to tokenize sentences after they are split."""
raise NotImplementedError
@tokenizer.setter
@@ -62,11 +82,11 @@ def __init__(self, tokenizer: Tokenizer = SegtokTokenizer()) -> None:
super().__init__()
self._tokenizer = tokenizer
- def _perform_split(self, text: str) -> List[Sentence]:
- plain_sentences: List[str] = split_multi(text)
+ def _perform_split(self, text: str) -> list[Sentence]:
+ plain_sentences: list[str] = split_multi(text)
sentence_offset = 0
- sentences: List[Sentence] = []
+ sentences: list[Sentence] = []
for sentence in plain_sentences:
try:
sentence_offset = text.index(sentence, sentence_offset)
@@ -133,7 +153,7 @@ def __init__(self, model: Union[Any, str], tokenizer: Optional[Tokenizer] = None
else:
self._tokenizer = tokenizer
- def _perform_split(self, text: str) -> List[Sentence]:
+ def _perform_split(self, text: str) -> list[Sentence]:
document = self.model(text)
sentences = [
@@ -192,7 +212,7 @@ def __init__(self, tag: str, tokenizer: Tokenizer = SegtokTokenizer()) -> None:
self._tokenizer = tokenizer
self.tag = tag
- def _perform_split(self, text: str) -> List[Sentence]:
+ def _perform_split(self, text: str) -> list[Sentence]:
plain_sentences = text.split(self.tag)
sentences = []
@@ -252,7 +272,7 @@ def __init__(self, tokenizer: Tokenizer = SegtokTokenizer()) -> None:
super().__init__()
self._tokenizer = tokenizer
- def _perform_split(self, text: str) -> List[Sentence]:
+ def _perform_split(self, text: str) -> list[Sentence]:
return [Sentence(text=text, use_tokenizer=self._tokenizer, start_position=0)]
@property
diff --git a/flair/tokenization.py b/flair/tokenization.py
index 185e944d3b..b377c419e6 100644
--- a/flair/tokenization.py
+++ b/flair/tokenization.py
@@ -1,7 +1,7 @@
import logging
import sys
from abc import ABC, abstractmethod
-from typing import Callable, List
+from typing import Callable
from segtok.segmenter import split_single
from segtok.tokenizer import split_contractions, word_tokenizer
@@ -20,7 +20,7 @@ class Tokenizer(ABC):
"""
@abstractmethod
- def tokenize(self, text: str) -> List[str]:
+ def tokenize(self, text: str) -> list[str]:
raise NotImplementedError
@property
@@ -57,11 +57,11 @@ def __init__(self, model) -> None:
"spacy model or the name of the model to load."
)
- def tokenize(self, text: str) -> List[str]:
+ def tokenize(self, text: str) -> list[str]:
from spacy.tokens.doc import Doc
doc: Doc = self.model.make_doc(text)
- words: List[str] = []
+ words: list[str] = []
for word in doc:
if len(word.text.strip()) == 0:
continue
@@ -82,12 +82,12 @@ class SegtokTokenizer(Tokenizer):
def __init__(self) -> None:
super().__init__()
- def tokenize(self, text: str) -> List[str]:
+ def tokenize(self, text: str) -> list[str]:
return SegtokTokenizer.run_tokenize(text)
@staticmethod
- def run_tokenize(text: str) -> List[str]:
- words: List[str] = []
+ def run_tokenize(text: str) -> list[str]:
+ words: list[str] = []
sentences = split_single(text)
for sentence in sentences:
@@ -105,12 +105,12 @@ class SpaceTokenizer(Tokenizer):
def __init__(self) -> None:
super().__init__()
- def tokenize(self, text: str) -> List[str]:
+ def tokenize(self, text: str) -> list[str]:
return SpaceTokenizer.run_tokenize(text)
@staticmethod
- def run_tokenize(text: str) -> List[str]:
- tokens: List[str] = []
+ def run_tokenize(text: str) -> list[str]:
+ tokens: list[str] = []
word = ""
index = -1
for index, char in enumerate(text):
@@ -166,8 +166,8 @@ def __init__(self, tokenizer: str, sudachi_mode: str = "A") -> None:
self.sentence_tokenizer = konoha.SentenceTokenizer()
self.word_tokenizer = konoha.WordTokenizer(tokenizer, mode=sudachi_mode)
- def tokenize(self, text: str) -> List[str]:
- words: List[str] = []
+ def tokenize(self, text: str) -> list[str]:
+ words: list[str] = []
sentences = self.sentence_tokenizer.tokenize(text)
for sentence in sentences:
@@ -184,11 +184,11 @@ def name(self) -> str:
class TokenizerWrapper(Tokenizer):
"""Helper class to wrap tokenizer functions to the class-based tokenizer interface."""
- def __init__(self, tokenizer_func: Callable[[str], List[str]]) -> None:
+ def __init__(self, tokenizer_func: Callable[[str], list[str]]) -> None:
super().__init__()
self.tokenizer_func = tokenizer_func
- def tokenize(self, text: str) -> List[str]:
+ def tokenize(self, text: str) -> list[str]:
return self.tokenizer_func(text)
@property
@@ -225,7 +225,7 @@ def __init__(self) -> None:
" Note that the scispacy version and the version of the model must match to work properly!"
)
- def combined_rule_prefixes() -> List[str]:
+ def combined_rule_prefixes() -> list[str]:
"""Helper function that returns the prefix pattern for the tokenizer.
It is a helper function to accommodate spacy tests that only test prefixes.
@@ -270,9 +270,9 @@ def combined_rule_prefixes() -> List[str]:
self.model.tokenizer.prefix_search = prefix_re.search
self.model.tokenizer.infix_finditer = infix_re.finditer
- def tokenize(self, text: str) -> List[str]:
+ def tokenize(self, text: str) -> list[str]:
sentence = self.model(text)
- words: List[str] = []
+ words: list[str] = []
for word in sentence:
words.append(word.text)
return words
diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py
index eb374ed75d..341cead776 100644
--- a/flair/trainers/language_model_trainer.py
+++ b/flair/trainers/language_model_trainer.py
@@ -3,8 +3,9 @@
import math
import random
import time
+from collections.abc import Iterable
from pathlib import Path
-from typing import Any, Dict, Iterable, Optional, Type, Union
+from typing import Any, Optional, Union
import torch
from torch import cuda
@@ -155,16 +156,16 @@ def __init__(
self,
model: LanguageModel,
corpus: TextCorpus,
- optimizer: Type[Optimizer] = SGD,
+ optimizer: type[Optimizer] = SGD,
test_mode: bool = False,
epoch: int = 0,
split: int = 0,
loss: float = 10000,
- optimizer_state: Optional[Dict[str, Any]] = None,
- scaler_state: Optional[Dict[str, Any]] = None,
+ optimizer_state: Optional[dict[str, Any]] = None,
+ scaler_state: Optional[dict[str, Any]] = None,
) -> None:
self.model: LanguageModel = model
- self.optimizer: Type[Optimizer] = optimizer
+ self.optimizer: type[Optimizer] = optimizer
self.corpus: TextCorpus = corpus
self.test_mode: bool = test_mode
@@ -362,7 +363,7 @@ def train(
)
with open(loss_txt, "a") as myfile:
- myfile.write("%s\n" % summary)
+ myfile.write(f"{summary}\n")
log.info(summary)
log.info("-" * 89)
@@ -386,7 +387,7 @@ def train(
summary = f"TEST: valid loss {test_loss:5.4f} | valid ppl {math.exp(test_loss):8.4f}"
with open(loss_txt, "a") as myfile:
- myfile.write("%s\n" % summary)
+ myfile.write(f"{summary}\n")
log.info(summary)
log.info("-" * 89)
@@ -440,7 +441,7 @@ def _repackage_hidden(h):
def load_checkpoint(
checkpoint_file: Union[str, Path],
corpus: TextCorpus,
- optimizer: Type[Optimizer] = SGD,
+ optimizer: type[Optimizer] = SGD,
):
if isinstance(checkpoint_file, str):
checkpoint_file = Path(checkpoint_file)
diff --git a/flair/trainers/plugins/base.py b/flair/trainers/plugins/base.py
index 958d57b785..33e787a063 100644
--- a/flair/trainers/plugins/base.py
+++ b/flair/trainers/plugins/base.py
@@ -1,27 +1,24 @@
import logging
from collections import defaultdict
+from collections.abc import Iterator, Sequence
from inspect import isclass, signature
from itertools import count
from queue import Queue
from typing import (
Any,
Callable,
- Dict,
- Iterator,
- List,
NewType,
Optional,
- Sequence,
- Set,
- Type,
Union,
cast,
)
+from flair.distributed_utils import is_main_process
+
log = logging.getLogger("flair")
-PluginArgument = Union["BasePlugin", Type["BasePlugin"]]
+PluginArgument = Union["BasePlugin", type["BasePlugin"]]
HookHandleId = NewType("HookHandleId", int)
EventIdenifier = str
@@ -34,7 +31,7 @@ class TrainingInterrupt(Exception):
class Pluggable:
"""Dispatches events which attached plugins can react to."""
- valid_events: Optional[Set[EventIdenifier]] = None
+ valid_events: Optional[set[EventIdenifier]] = None
def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None:
"""Initialize a `Pluggable`.
@@ -42,11 +39,11 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None:
Args:
plugins: Plugins which should be attached to this `Pluggable`.
"""
- self._hook_handles: Dict[EventIdenifier, Dict[HookHandleId, HookHandle]] = defaultdict(dict)
+ self._hook_handles: dict[EventIdenifier, dict[HookHandleId, HookHandle]] = defaultdict(dict)
self._hook_handle_id_counter = count()
- self._plugins: List[BasePlugin] = []
+ self._plugins: list[BasePlugin] = []
# This flag tracks, whether an event is currently being processed (otherwise it is added to the queue)
self._processing_events = False
@@ -62,6 +59,11 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None:
@property
def plugins(self):
+ """Returns all plugins attached to this instance as a list of :class:`BasePlugin`.
+
+ Returns:
+ List of :class:`BasePlugin` instances attached to this `Pluggable`.
+ """
return self._plugins
def append_plugin(self, plugin):
@@ -181,7 +183,7 @@ class BasePlugin:
def __init__(self) -> None:
"""Initialize the base plugin."""
- self._hook_handles: List[HookHandle] = []
+ self._hook_handles: list[HookHandle] = []
self._pluggable: Optional[Pluggable] = None
def attach_to(self, pluggable: Pluggable):
@@ -189,6 +191,8 @@ def attach_to(self, pluggable: Pluggable):
assert self._pluggable is None
assert len(self._hook_handles) == 0
+ if not is_main_process() and not self.attach_to_all_processes:
+ return
self._pluggable = pluggable
pluggable.append_plugin(self)
@@ -257,10 +261,15 @@ def decorator_func(func: Callable):
def pluggable(self) -> Optional[Pluggable]:
return self._pluggable
+ @property
+ def attach_to_all_processes(self) -> bool:
+ """If set, the plugin will be attached to all processes when distributed, not just the main process."""
+ return True
+
def __str__(self) -> str:
return self.__class__.__name__
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {"__cls__": f"{self.__module__}.{self.__class__.__name__}"}
diff --git a/flair/trainers/plugins/functional/anneal_on_plateau.py b/flair/trainers/plugins/functional/anneal_on_plateau.py
index d62b21fba1..ccd330bf0f 100644
--- a/flair/trainers/plugins/functional/anneal_on_plateau.py
+++ b/flair/trainers/plugins/functional/anneal_on_plateau.py
@@ -1,6 +1,6 @@
import logging
import os
-from typing import Any, Dict
+from typing import Any
from flair.trainers.plugins.base import TrainerPlugin, TrainingInterrupt
from flair.trainers.plugins.metric_records import MetricRecord
@@ -108,7 +108,7 @@ def __str__(self) -> str:
f"min_learning_rate: '{self.min_learning_rate}'"
)
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py
index 75ecb9bd98..a8179edbc5 100644
--- a/flair/trainers/plugins/functional/checkpoints.py
+++ b/flair/trainers/plugins/functional/checkpoints.py
@@ -1,5 +1,7 @@
import logging
-from typing import Any, Dict
+from typing import Any
+
+import torch
from flair.trainers.plugins.base import TrainerPlugin
@@ -28,8 +30,14 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
+
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py
index 1000be6dd7..fdf4752bde 100644
--- a/flair/trainers/plugins/functional/linear_scheduler.py
+++ b/flair/trainers/plugins/functional/linear_scheduler.py
@@ -1,5 +1,7 @@
import logging
-from typing import Any, Dict
+from typing import Any
+
+import torch.distributed
from flair.optim import LinearSchedulerWithWarmup
from flair.trainers.plugins.base import TrainerPlugin
@@ -34,7 +36,8 @@ def after_setup(
):
"""Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers."""
# calculate warmup steps
- steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size
+ num_processes = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
+ steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size / num_processes
num_train_steps = int(steps_per_epoch * max_epochs)
num_warmup_steps = int(num_train_steps * self.warmup_fraction)
@@ -62,7 +65,7 @@ def after_training_batch(self, optimizer_was_run: bool, **kwargs):
def __str__(self) -> str:
return f"LinearScheduler | warmup_fraction: '{self.warmup_fraction}'"
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"warmup_fraction": self.warmup_fraction,
diff --git a/flair/trainers/plugins/functional/reduce_transformer_vocab.py b/flair/trainers/plugins/functional/reduce_transformer_vocab.py
index 162c667f88..eed6d7f1b2 100644
--- a/flair/trainers/plugins/functional/reduce_transformer_vocab.py
+++ b/flair/trainers/plugins/functional/reduce_transformer_vocab.py
@@ -1,6 +1,5 @@
import logging
from pathlib import Path
-from typing import List
from transformer_smaller_training_vocab import reduce_train_vocab
@@ -56,8 +55,12 @@ def save_model_at_the_end(self, **kw):
elif (self.base_path / "final-model.pt").exists():
self.model.save(self.base_path / "final-model.pt", checkpoint=self.save_optimizer_state)
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
-def get_transformer_embeddings(model: Model) -> List[TransformerEmbeddings]:
+
+def get_transformer_embeddings(model: Model) -> list[TransformerEmbeddings]:
embeddings = model.tars_embeddings if isinstance(model, FewshotClassifier) else getattr(model, "embeddings", None)
if embeddings is None:
diff --git a/flair/trainers/plugins/functional/weight_extractor.py b/flair/trainers/plugins/functional/weight_extractor.py
index ef5afe081e..5c9bd4c4ac 100644
--- a/flair/trainers/plugins/functional/weight_extractor.py
+++ b/flair/trainers/plugins/functional/weight_extractor.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict
+from typing import Any
from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import WeightExtractor
@@ -21,7 +21,11 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw):
if (iteration + 1) % modulo == 0:
self.weight_extractor.extract_weights(self.model.state_dict(), iteration)
- def get_state(self) -> Dict[str, Any]:
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
+
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
diff --git a/flair/trainers/plugins/loggers/clearml_logger.py b/flair/trainers/plugins/loggers/clearml_logger.py
index 891b9f9244..18228d2db6 100644
--- a/flair/trainers/plugins/loggers/clearml_logger.py
+++ b/flair/trainers/plugins/loggers/clearml_logger.py
@@ -40,3 +40,7 @@ def metric_recorded(self, record: MetricRecord) -> None:
self.logger.report_text(record.value, print_console=False)
elif record.is_histogram:
self.logger.report_histogram(record_name, record_name, record.value, record.global_step)
+
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
diff --git a/flair/trainers/plugins/loggers/log_file.py b/flair/trainers/plugins/loggers/log_file.py
index a9b7453a09..9bf22a284d 100644
--- a/flair/trainers/plugins/loggers/log_file.py
+++ b/flair/trainers/plugins/loggers/log_file.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Any, Dict
+from typing import Any
from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import add_file_handler
@@ -21,5 +21,9 @@ def close_file_handler(self, **kw):
self.log_handler.close()
log.removeHandler(self.log_handler)
- def get_state(self) -> Dict[str, Any]:
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
+
+ def get_state(self) -> dict[str, Any]:
return {**super().get_state(), "base_path": str(self.base_path)}
diff --git a/flair/trainers/plugins/loggers/loss_file.py b/flair/trainers/plugins/loggers/loss_file.py
index 29c42fc930..bfe938b72d 100644
--- a/flair/trainers/plugins/loggers/loss_file.py
+++ b/flair/trainers/plugins/loggers/loss_file.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Optional, Union
from flair.trainers.plugins.base import TrainerPlugin
from flair.trainers.plugins.metric_records import MetricName
@@ -10,7 +10,7 @@ class LossFilePlugin(TrainerPlugin):
"""Plugin that manages the loss.tsv file output."""
def __init__(
- self, base_path, epoch: int, metrics_to_collect: Optional[Dict[Union[Tuple, str], str]] = None
+ self, base_path, epoch: int, metrics_to_collect: Optional[dict[Union[tuple, str], str]] = None
) -> None:
super().__init__()
@@ -56,9 +56,9 @@ def __init__(
self.headers[metric_name] = f"{prefix.upper()}_{header}"
# initialize the first log line
- self.current_row: Optional[Dict[MetricName, str]] = None
+ self.current_row: Optional[dict[MetricName, str]] = None
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
@@ -113,3 +113,7 @@ def after_evaluation(self, epoch, **kw):
f.write("\t".join([str(self.current_row[col]) for col in self.headers]) + "\n")
self.current_row = {}
+
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
diff --git a/flair/trainers/plugins/loggers/metric_history.py b/flair/trainers/plugins/loggers/metric_history.py
index 8d7c946e8d..426e055186 100644
--- a/flair/trainers/plugins/loggers/metric_history.py
+++ b/flair/trainers/plugins/loggers/metric_history.py
@@ -1,5 +1,6 @@
import logging
-from typing import Any, Dict, Mapping
+from collections.abc import Mapping
+from typing import Any
from flair.trainers.plugins.base import TrainerPlugin
@@ -17,7 +18,7 @@ class MetricHistoryPlugin(TrainerPlugin):
def __init__(self, metrics_to_collect: Mapping = default_metrics_to_collect) -> None:
super().__init__()
- self.metric_history: Dict[str, list] = {}
+ self.metric_history: dict[str, list] = {}
self.metrics_to_collect: Mapping = metrics_to_collect
for target in self.metrics_to_collect.values():
self.metric_history[target] = []
@@ -33,7 +34,11 @@ def after_training(self, **kw):
"""Returns metric history."""
self.trainer.return_values.update(self.metric_history)
- def get_state(self) -> Dict[str, Any]:
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
+
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"metrics_to_collect": dict(self.metrics_to_collect),
diff --git a/flair/trainers/plugins/loggers/tensorboard.py b/flair/trainers/plugins/loggers/tensorboard.py
index 59bba9f2e9..bf2dfcc29d 100644
--- a/flair/trainers/plugins/loggers/tensorboard.py
+++ b/flair/trainers/plugins/loggers/tensorboard.py
@@ -1,6 +1,6 @@
import logging
import os
-from typing import Any, Dict
+from typing import Any
from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import log_line
@@ -59,7 +59,11 @@ def _training_finally(self, **kw):
assert self.writer is not None
self.writer.close()
- def get_state(self) -> Dict[str, Any]:
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
+
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"log_dir": str(self.log_dir) if self.log_dir is not None else None,
diff --git a/flair/trainers/plugins/loggers/wandb.py b/flair/trainers/plugins/loggers/wandb.py
index 8608fcdbd9..410d0377ba 100644
--- a/flair/trainers/plugins/loggers/wandb.py
+++ b/flair/trainers/plugins/loggers/wandb.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any, Dict
+from typing import Any
from flair.trainers.plugins.base import TrainerPlugin
@@ -72,7 +72,11 @@ def metric_recorded(self, record):
def _training_finally(self, **kw):
self.writer.close()
- def get_state(self) -> Dict[str, Any]:
+ @property
+ def attach_to_all_processes(self) -> bool:
+ return False
+
+ def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
"emit_alerts": self.emit_alerts,
diff --git a/flair/trainers/plugins/metric_records.py b/flair/trainers/plugins/metric_records.py
index 034c021854..548b54fccd 100644
--- a/flair/trainers/plugins/metric_records.py
+++ b/flair/trainers/plugins/metric_records.py
@@ -1,14 +1,15 @@
import time
+from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Iterable, Iterator, Optional, Tuple, Union
+from typing import Any, Optional, Union
RecordType = Enum("RecordType", ["scalar", "image", "histogram", "string", "scalar_list"])
class MetricName:
def __init__(self, name) -> None:
- self.parts: Tuple[str, ...]
+ self.parts: tuple[str, ...]
if isinstance(name, str):
self.parts = tuple(name.split("/"))
diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py
index 03e6edc083..03879a2b13 100644
--- a/flair/trainers/trainer.py
+++ b/flair/trainers/trainer.py
@@ -7,16 +7,20 @@
import warnings
from inspect import signature
from pathlib import Path
-from typing import List, Optional, Tuple, Type, Union
+from typing import Optional, Union
+import numpy as np
import torch
+from torch.nn.parallel import DistributedDataParallel
from torch.optim.sgd import SGD
+from torch.utils.data import DistributedSampler
from torch.utils.data.dataset import ConcatDataset
import flair
import flair.nn
from flair.data import Corpus, Dictionary, _len_dataset
from flair.datasets import DataLoader
+from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_each_process
from flair.samplers import FlairSampler
from flair.trainers.plugins import (
AnnealingPlugin,
@@ -38,6 +42,22 @@
class ModelTrainer(Pluggable):
+ """Use this class to train a Flair model.
+
+ The ModelTrainer is initialized using a :class:`flair.nn.Model` (the architecture you want to train) and a
+ :class:`flair.data.Corpus` (the labeled data you use to train and evaluate the model). It offers two main training
+ functions for the two main modes of training a model: (1) :func:`train`, which is used to train a model from scratch or
+ to fit a classification head on a frozen transformer language model. (2) :func:`fine_tune`, which is used if you
+ do not freeze the transformer language model and rather fine-tune it for a specific task.
+
+ Additionally, there is also a `train_custom` method that allows you to fully customize the training run.
+
+ ModelTrainer inherits from :class:`flair.trainers.plugins.base.Pluggable` and thus uses a plugin system to inject
+ specific functionality into the training process. You can add any number of plugins to the above-mentioned training
+ modes. For instance, if you want to use an annealing scheduler during training, you can add the
+ :class:`flair.trainers.plugins.functional.AnnealingPlugin` plugin to the train command.
+ """
+
valid_events = {
"after_setup",
"before_training_epoch",
@@ -55,11 +75,14 @@ class ModelTrainer(Pluggable):
}
def __init__(self, model: flair.nn.Model, corpus: Corpus) -> None:
- """Initialize a model trainer.
+ """Initialize a model trainer by passing a :class:`flair.nn.Model` (the architecture you want to train) and a
+ :class:`flair.data.Corpus` (the labeled data you use to train and evaluate the model).
Args:
- model: The model that you want to train. The model should inherit from flair.nn.Model # noqa: E501
- corpus: The dataset used to train the model, should be of type Corpus
+ model: The model that you want to train. The model should inherit from :class:`flair.nn.Model`. So for
+ instance you should pass a :class:`flair.models.TextClassifier` if you want to train a text classifier,
+ or :class:`flair.models.SequenceTagger` if you want to train an RNN-based sequence labeler.
+ corpus: The dataset (of type :class:`flair.data.Corpus`) used to train the model.
"""
super().__init__()
self.model: flair.nn.Model = model
@@ -128,7 +151,7 @@ def train(
base_path,
anneal_factor: float = 0.5,
patience: int = 3,
- min_learning_rate: Union[float, List[float]] = 0.0001,
+ min_learning_rate: Union[float, list[float]] = 0.0001,
initial_extra_patience: int = 0,
anneal_with_restarts: bool = False,
learning_rate: float = 0.1,
@@ -137,17 +160,17 @@ def train(
eval_batch_size: int = 64,
mini_batch_chunk_size: Optional[int] = None,
max_epochs: int = 100,
- optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
+ optimizer: type[torch.optim.Optimizer] = torch.optim.SGD,
train_with_dev: bool = False,
train_with_test: bool = False,
reduce_transformer_vocab: bool = False,
# evaluation and monitoring
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
monitor_test: bool = False,
monitor_train_sample: float = 0.0,
use_final_model_for_eval: bool = False,
gold_label_dictionary_for_eval: Optional[Dictionary] = None,
- exclude_labels: Optional[List[str]] = None,
+ exclude_labels: Optional[list[str]] = None,
# sampling and shuffling
sampler=None,
shuffle: bool = True,
@@ -163,8 +186,10 @@ def train(
create_file_logs: bool = True,
create_loss_file: bool = True,
write_weights: bool = False,
+ # acceleration
+ multi_gpu: bool = False,
# plugins
- plugins: Optional[List[TrainerPlugin]] = None,
+ plugins: Optional[list[TrainerPlugin]] = None,
attach_default_scheduler: bool = True,
**kwargs,
):
@@ -211,17 +236,17 @@ def fine_tune(
eval_batch_size: int = 16,
mini_batch_chunk_size: Optional[int] = None,
max_epochs: int = 10,
- optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
+ optimizer: type[torch.optim.Optimizer] = torch.optim.AdamW,
train_with_dev: bool = False,
train_with_test: bool = False,
reduce_transformer_vocab: bool = False,
# evaluation and monitoring
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
monitor_test: bool = False,
monitor_train_sample: float = 0.0,
use_final_model_for_eval: bool = True,
gold_label_dictionary_for_eval: Optional[Dictionary] = None,
- exclude_labels: Optional[List[str]] = None,
+ exclude_labels: Optional[list[str]] = None,
# sampling and shuffling
sampler=None,
shuffle: bool = True,
@@ -237,10 +262,11 @@ def fine_tune(
create_file_logs: bool = True,
create_loss_file: bool = True,
write_weights: bool = False,
- # amp
+ # acceleration
use_amp: bool = False,
+ multi_gpu: bool = False,
# plugins
- plugins: Optional[List[TrainerPlugin]] = None,
+ plugins: Optional[list[TrainerPlugin]] = None,
attach_default_scheduler: bool = True,
**kwargs,
):
@@ -287,8 +313,9 @@ def fine_tune(
create_file_logs=create_file_logs,
create_loss_file=create_loss_file,
write_weights=write_weights,
- # amp
+ # acceleration
use_amp=use_amp,
+ multi_gpu=multi_gpu,
# plugins
plugins=plugins,
**kwargs,
@@ -304,20 +331,20 @@ def train_custom(
eval_batch_size: int = 64,
mini_batch_chunk_size: Optional[int] = None,
max_epochs: int = 100,
- optimizer: Type[torch.optim.Optimizer] = SGD,
+ optimizer: type[torch.optim.Optimizer] = SGD,
train_with_dev: bool = False,
train_with_test: bool = False,
max_grad_norm: Optional[float] = 5.0,
reduce_transformer_vocab: bool = False,
# evaluation and monitoring
- main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
+ main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
monitor_test: bool = False,
monitor_train_sample: float = 0.0,
use_final_model_for_eval: bool = False,
gold_label_dictionary_for_eval: Optional[Dictionary] = None,
- exclude_labels: Optional[List[str]] = None,
+ exclude_labels: Optional[list[str]] = None,
# sampling and shuffling
- sampler: Optional[FlairSampler] = None,
+ sampler: Optional[Union[FlairSampler, type[FlairSampler]]] = None,
shuffle: bool = True,
shuffle_first_epoch: bool = True,
# evaluation and monitoring
@@ -331,13 +358,14 @@ def train_custom(
create_file_logs: bool = True,
create_loss_file: bool = True,
write_weights: bool = False,
- # amp
+ # acceleration
use_amp: bool = False,
+ multi_gpu: bool = False,
# plugins
- plugins: Optional[List[TrainerPlugin]] = None,
+ plugins: Optional[list[TrainerPlugin]] = None,
**kwargs,
) -> dict:
- """Trains any class that implements the flair.nn.Model interface.
+ """Trains any class that implements the :class:`flair.nn.Model` interface.
Args:
base_path: Main path to which all output during training is logged and models are saved
@@ -375,6 +403,7 @@ def train_custom(
create_file_logs: If True, logging output is written to a file
create_loss_file: If True, a loss file logging output is created
use_amp: If True, uses the torch automatic mixed precision
+ multi_gpu: If True, distributes training across local GPUs
write_weights: If True, write weights to weights.txt on each batch logging event.
plugins: Any additional plugins you want to pass to the trainer
**kwargs: Additional arguments, for instance for the optimizer
@@ -475,12 +504,24 @@ def train_custom(
# initialize sampler if provided
if sampler is not None:
# init with default values if only class is provided
- if inspect.isclass(sampler):
+ if isinstance(sampler, type):
sampler = sampler()
# set dataset to sample from
sampler.set_dataset(train_data)
shuffle = False
+ # configure special behavior to use multiple GPUs
+ if multi_gpu:
+ if not torch.distributed.is_initialized():
+ raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()")
+ # Guard against each process initializing corpus differently due to e.g. different random seeds
+ validate_corpus_same_each_process(self.corpus)
+ self.ddp_model = DistributedDataParallel(
+ self.model, device_ids=[flair.device.index], find_unused_parameters=True
+ )
+ log.disabled = not is_main_process() # Only print logs once
+ original_forward = self.model.forward
+
# this field stores the names of all dynamic embeddings in the model (determined after first forward pass)
dynamic_embeddings = None
@@ -508,6 +549,9 @@ def train_custom(
if use_final_model_for_eval
else "model from best epoch (best-model.pt)"
)
+ computation_device_info = aggregate(
+ flair.device, lambda devices: ", ".join([str(device) for device in devices])
+ )
log_line(log)
log.info(f'Model: "{self.model}"')
@@ -534,7 +578,7 @@ def train_custom(
log.info(f' - metric: "{main_evaluation_metric}"')
log_line(log)
log.info("Computation:")
- log.info(f" - compute on device: {flair.device}")
+ log.info(f" - compute on device: {computation_device_info}")
log.info(f" - embedding storage: {embeddings_storage_mode}")
log_line(log)
log.info(f'Model training base path: "{base_path}"')
@@ -560,12 +604,24 @@ def train_custom(
if not shuffle_first_epoch and epoch == 1:
shuffle_data_this_epoch = False
- batch_loader = DataLoader(
- train_data,
- batch_size=mini_batch_size,
- shuffle=shuffle_data_this_epoch,
- sampler=sampler,
- )
+ if multi_gpu:
+ distributed_sampler: DistributedSampler = DistributedSampler(
+ train_data, shuffle=shuffle_data_this_epoch
+ )
+ distributed_sampler.set_epoch(epoch - 1)
+ batch_loader = DataLoader(
+ train_data,
+ batch_size=mini_batch_size,
+ shuffle=False,
+ sampler=distributed_sampler,
+ )
+ else:
+ batch_loader = DataLoader(
+ train_data,
+ batch_size=mini_batch_size,
+ shuffle=shuffle_data_this_epoch,
+ sampler=sampler,
+ )
self.model.train()
@@ -603,7 +659,18 @@ def train_custom(
for batch_step in batch_steps:
# forward pass
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
- loss, datapoint_count = self.model.forward_loss(batch_step)
+ if multi_gpu:
+ # We need to __call__ ddp_model() because this triggers hooks that sync gradients.
+ # But that calls forward rather than forward_loss. So we patch forward to redirect
+ # to forward_loss. Then undo the patch in case forward_loss itself calls forward.
+ def wrapped_forward_loss(*args, **kwargs2):
+ self.model.forward = original_forward
+ return self.model.forward_loss(*args, **kwargs2)
+
+ self.model.forward = wrapped_forward_loss
+ loss, datapoint_count = self.ddp_model(batch_step)
+ else:
+ loss, datapoint_count = self.model.forward_loss(batch_step)
batch_train_samples += datapoint_count
batch_train_loss += loss.item()
@@ -649,8 +716,11 @@ def train_custom(
if epoch_train_samples > 0
else epoch_train_samples / (batch_no + 1)
)
+ intermittent_loss = aggregate(intermittent_loss)
current_time = time.time()
+ samples_per_second = epoch_train_samples / (current_time - epoch_start_time)
+ samples_per_second = aggregate(samples_per_second, np.sum)
lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count)
log.info(
@@ -658,7 +728,7 @@ def train_custom(
f" - iter {batch_no + 1}/{len(batch_loader)}"
f" - loss {intermittent_loss:.8f}"
f" - time (sec): {(current_time - epoch_start_time):.2f}"
- f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}"
+ f" - samples/sec: {samples_per_second:.2f}"
f"{lr_info}{momentum_info}"
)
@@ -667,6 +737,7 @@ def train_custom(
self.dispatch("after_training_batch", **batch_kw)
train_loss = epoch_train_loss / epoch_train_samples
+ train_loss = aggregate(train_loss)
self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch))
total_train_samples += epoch_train_samples
@@ -682,7 +753,7 @@ def train_custom(
# Determine if this is the best model or if we need to anneal
current_epoch_has_best_model_so_far = False
- validation_scores: tuple
+ validation_scores: tuple = ()
for evaluation_split, evaluation_split_data in evaluation_splits.items():
eval_result = self.model.evaluate(
@@ -722,7 +793,7 @@ def train_custom(
if not determine_best_epoch_using_dev_score:
validation_scores = (train_loss,)
- if epoch_train_loss < best_epoch_score:
+ if train_loss < best_epoch_score:
current_epoch_has_best_model_so_far = True
best_epoch_score = train_loss
@@ -737,14 +808,14 @@ def train_custom(
if save_best_model and current_epoch_has_best_model_so_far:
log.info("saving best model")
- self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state)
+ self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state)
# - SWAPlugin -> restores SGD weights from SWA
self.dispatch("after_training_loop")
# if we do not use dev data for model selection, save final model
if save_final_model:
- self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
+ self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
except KeyboardInterrupt:
log_line(log)
@@ -754,7 +825,7 @@ def train_custom(
if save_final_model:
log.info("Saving model ...")
- self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
+ self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
log.info("Done.")
except TrainingInterrupt as exc:
@@ -765,7 +836,7 @@ def train_custom(
if save_final_model:
log.info("Saving model ...")
- self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
+ self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
log.info("Done.")
except Exception:
@@ -783,7 +854,7 @@ def train_custom(
if (base_path / "best-model.pt").exists():
log.info("Loading model from best epoch ...")
- self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict())
+ self._load_model(base_path / "best-model.pt")
else:
log.info("Testing using last state of model ...")
@@ -808,7 +879,7 @@ def train_custom(
else:
if (base_path / "best-model.pt").exists():
log.info("Loading model from best epoch ...")
- self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict())
+ self._load_model(base_path / "best-model.pt")
self.return_values["test_score"] = 0
log.info("Test data not provided setting final score to 0")
@@ -905,3 +976,12 @@ def _initialize_model_card(self, **training_parameters):
def _record(self, metric):
self.dispatch("metric_recorded", metric)
+
+ def _load_model(self, model_file: Union[str, Path]) -> None:
+ self.model.load_state_dict(self.model.load(model_file).state_dict())
+
+ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
+ if is_main_process():
+ self.model.save(model_file, checkpoint)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
diff --git a/flair/training_utils.py b/flair/training_utils.py
index 0b4ef91cbf..9b38ec1ddb 100644
--- a/flair/training_utils.py
+++ b/flair/training_utils.py
@@ -5,7 +5,7 @@
from functools import reduce
from math import inf
from pathlib import Path
-from typing import Dict, List, Literal, Optional, Union
+from typing import Literal, Optional, Union
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_absolute_error, mean_squared_error
@@ -25,7 +25,7 @@ def __init__(
main_score: float,
detailed_results: str,
classification_report: Optional[dict] = None,
- scores: Optional[Dict] = None,
+ scores: Optional[dict] = None,
) -> None:
classification_report = classification_report if classification_report is not None else {}
assert scores is not None and "loss" in scores, "No loss provided."
@@ -47,8 +47,8 @@ class MetricRegression:
def __init__(self, name) -> None:
self.name = name
- self.true: List[float] = []
- self.pred: List[float] = []
+ self.true: list[float] = []
+ self.pred: list[float] = []
def mean_squared_error(self):
return mean_squared_error(self.true, self.pred)
@@ -98,7 +98,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
if isinstance(directory, str):
directory = Path(directory)
self.weights_file = init_output_file(directory, "weights.txt")
- self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list))
+ self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
self.number_of_weights = number_of_weights
def extract_weights(self, state_dict, iteration):
@@ -338,7 +338,7 @@ def init_output_file(base_path: Union[str, Path], file_name: str) -> Path:
return file
-def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionary) -> List[List[int]]:
+def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionary) -> list[list[int]]:
"""Convert list of labels to a one hot list.
Args:
@@ -365,9 +365,9 @@ def add_file_handler(log, output_file):
def store_embeddings(
- data_points: Union[List[DT], Dataset],
+ data_points: Union[list[DT], Dataset],
storage_mode: EmbeddingStorageMode,
- dynamic_embeddings: Optional[List[str]] = None,
+ dynamic_embeddings: Optional[list[str]] = None,
):
if isinstance(data_points, Dataset):
data_points = list(_iter_dataset(data_points))
@@ -391,7 +391,7 @@ def store_embeddings(
data_point.to("cpu", pin_memory=pin_memory)
-def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List]:
+def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]:
dynamic_embeddings = []
all_embeddings = []
for data_point in data_points:
diff --git a/flair/visual/ner_html.py b/flair/visual/ner_html.py
index c71e108379..5b691a9e60 100644
--- a/flair/visual/ner_html.py
+++ b/flair/visual/ner_html.py
@@ -1,5 +1,5 @@
import html
-from typing import List, Union
+from typing import Union
from flair.data import Sentence
@@ -41,7 +41,7 @@ def split_to_spans(s: Sentence, label_name="ner"):
def render_ner_html(
- sentences: Union[List[Sentence], Sentence],
+ sentences: Union[list[Sentence], Sentence],
title: str = "Flair",
colors={
"PER": "#F7FF53",
diff --git a/flair/visual/training_curves.py b/flair/visual/training_curves.py
index 1fd856b669..32947c3348 100644
--- a/flair/visual/training_curves.py
+++ b/flair/visual/training_curves.py
@@ -3,7 +3,7 @@
import math
from collections import defaultdict
from pathlib import Path
-from typing import Dict, List, Union
+from typing import Union
import matplotlib.pyplot as plt
import numpy as np
@@ -27,7 +27,7 @@ class Plotter:
def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") -> dict:
file_name = Path(file_name)
- training_curves: Dict[str, Dict[str, List[float]]] = {
+ training_curves: dict[str, dict[str, list[float]]] = {
"train": {"loss": [], "score": []},
"test": {"loss": [], "score": []},
"dev": {"loss": [], "score": []},
@@ -70,7 +70,7 @@ def _extract_weight_data(file_name: Union[str, Path]) -> dict:
if isinstance(file_name, str):
file_name = Path(file_name)
- weights: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
+ weights: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list))
with open(file_name) as f:
tsvin = csv.reader(f, delimiter="\t")
@@ -151,7 +151,7 @@ def plot_weights(self, file_name: Union[str, Path]):
log.info(f"Weights plots are saved in {path}") # to let user know the path of the save plots
plt.close(fig)
- def plot_training_curves(self, file_name: Union[str, Path], plot_values: List[str] = ["loss", "F1"]):
+ def plot_training_curves(self, file_name: Union[str, Path], plot_values: list[str] = ["loss", "F1"]):
file_name = Path(file_name)
fig = plt.figure(figsize=(15, 10))
diff --git a/flair/visual/tree_printer.py b/flair/visual/tree_printer.py
index fc461d9f81..9753a37a09 100644
--- a/flair/visual/tree_printer.py
+++ b/flair/visual/tree_printer.py
@@ -1,5 +1,3 @@
-from typing import List
-
from pptree import print_tree
from flair.data import Sentence, Token
@@ -9,7 +7,7 @@ class NodeToken:
def __init__(self, token: Token, tag_type: str) -> None:
self.token: Token = token
self.tag_type: str = tag_type
- self.children: List[NodeToken] = []
+ self.children: list[NodeToken] = []
def set_haed(self, parent):
parent.children.append(self)
@@ -19,7 +17,7 @@ def __str__(self) -> str:
def tree_printer(sentence: Sentence, tag_type: str):
- tree: List[NodeToken] = [NodeToken(token, tag_type) for token in sentence]
+ tree: list[NodeToken] = [NodeToken(token, tag_type) for token in sentence]
for x in tree:
if x.token.head_id != 0:
head_token = x.token.get_head()
diff --git a/pyproject.toml b/pyproject.toml
index 78d1692a09..9711794abb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.black]
line-length = 120
-target-version = ['py37']
+target-version = ['py39']
exclude = '''
(
/(
@@ -32,6 +32,7 @@ filterwarnings = [
'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub
"ignore:`torch.cuda.amp.GradScaler", # GradScaler changes in torch 2.3.0 but we want to be backwards compatible.
"ignore:`clean_up_tokenization_spaces` was not set", # Default behavior changes in transformers v4.45, raising irrelevant FutureWarning for serialized models.
+ "ignore:1Torch was not compiled with flash attention", # You might want to install flash attention, but you don't have to.
]
markers = [
"integration",
@@ -49,7 +50,7 @@ ignore_errors = true
[tool.ruff]
line-length = 120
-target-version = "py38"
+target-version = "py39"
[tool.ruff.lint]
#select = ["ALL"] # Uncommit to autofix all the things
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 61d45acf8c..3b8fbde79c 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -6,7 +6,7 @@ pytest-black-ng==0.4.*
pytest-github-actions-annotate-failures>=0.1.8
pytest-mypy>=0.10.3
pytest-ruff==0.3.*
-ruff==0.3.*
+ruff==0.7.*
types-dataclasses>=0.6.6
types-Deprecated>=1.2.9.2
types-requests>=2.28.11.17
diff --git a/requirements.txt b/requirements.txt
index bb5ecafd45..2704114ace 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -20,6 +20,6 @@ tabulate>=0.8.10
torch>=1.5.0,!=1.8
tqdm>=4.63.0
transformer-smaller-training-vocab>=0.2.3
-transformers[sentencepiece]>=4.18.0,<5.0.0
+transformers[sentencepiece]>=4.25.0,<5.0.0
wikipedia-api>=0.5.7
bioc<3.0.0,>=2.0.0
diff --git a/resources/docs/EXPERIMENTS.md b/resources/docs/EXPERIMENTS.md
index 69f1a5fbf2..c6bbe72a1c 100644
--- a/resources/docs/EXPERIMENTS.md
+++ b/resources/docs/EXPERIMENTS.md
@@ -55,7 +55,7 @@ tag_type = 'ner'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# initialize embeddings
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
# GloVe embeddings
WordEmbeddings('glove'),
@@ -124,7 +124,7 @@ tag_type = 'ner'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# initialize embeddings
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('de'),
PooledFlairEmbeddings('german-forward'),
PooledFlairEmbeddings('german-backward'),
@@ -225,7 +225,7 @@ tag_type = 'ner'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# initialize embeddings
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('crawl'),
WordEmbeddings('twitter'),
FlairEmbeddings('news-forward'),
@@ -292,7 +292,7 @@ tag_type = 'ner'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# initialize embeddings
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('crawl'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward'),
@@ -361,7 +361,7 @@ tag_type = 'pos'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# initialize embeddings
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('extvec'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward'),
@@ -416,7 +416,7 @@ tag_type = 'np'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# initialize embeddings
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('extvec'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward'),
diff --git a/resources/docs/HUNFLAIR2.md b/resources/docs/HUNFLAIR2.md
index 6f2c1474b3..032b4fe075 100644
--- a/resources/docs/HUNFLAIR2.md
+++ b/resources/docs/HUNFLAIR2.md
@@ -14,7 +14,7 @@ NER tools on unseen corpora.
## Quick Start
#### Requirements and Installation
-*HunFlair2* is based on Flair 0.14+ and Python 3.8+. If you do not have Python 3.8, install it first.
+*HunFlair2* is based on Flair 0.14+ and Python 3.9+. If you do not have Python 3.9, install it first.
Then, in your favorite virtual environment, simply do:
```
pip install flair
diff --git a/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md b/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md
index cd839acc15..382600d5be 100644
--- a/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md
+++ b/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md
@@ -313,7 +313,7 @@ label_type = 'ner'
# 3. 말뭉치에서 레이블 사전 만들기
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)
# 4. 임베딩 초기화하기
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('glove')
]
embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)
diff --git a/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md b/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md
index e821dc2d17..ec3affe947 100644
--- a/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md
+++ b/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md
@@ -95,7 +95,7 @@ tag_type = 'ner'
tag_dictionary = corpus.make_label_dictionary(label_type=tag_type, add_unk=False)
print(tag_dictionary.idx2item)
# 4. 임베딩 초기화하기
-embedding_types: List[TokenEmbeddings] = [
+embedding_types: list[TokenEmbeddings] = [
WordEmbeddings('glove'),
]
embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)
diff --git a/resources/docs/TUTORIAL_12_CLUSTERING.md b/resources/docs/TUTORIAL_12_CLUSTERING.md
deleted file mode 100644
index 376e5d5639..0000000000
--- a/resources/docs/TUTORIAL_12_CLUSTERING.md
+++ /dev/null
@@ -1,180 +0,0 @@
-Text Clustering in flair
-----------
-
-In this package text clustering is implemented. This module has the following
-clustering algorithms implemented:
-- k-Means
-- BIRCH
-- Expectation Maximization
-
-Each of the implemented algorithm needs to have an instanced DocumentEmbedding. This embedding will
-transform each text/document to a vector. With these vectors the clustering algorithm can be performed.
-
----------------------------
-
-k-Means
-------
-k-Means is a classical and well known clustering algorithm. k-Means is a partitioning-based Clustering algorithm.
-The user defines with the parameter *k* how many clusters the given data has.
-So the choice of *k* is very important.
-More about k-Means can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html).
-
-
-```python
-from flair.models import ClusteringModel
-from flair.datasets import TREC_6
-from flair.embeddings import SentenceTransformerDocumentEmbeddings
-from sklearn.cluster import KMeans
-
-embeddings = SentenceTransformerDocumentEmbeddings()
-
-# store all embeddings in memory which is required to perform clustering
-corpus = TREC_6(memory_mode='full').downsample(0.05)
-
-model = KMeans(n_clusters=6)
-
-clustering_model = ClusteringModel(
- model=model,
- embeddings=embeddings
-)
-
-# fit the model on a corpus
-clustering_model.fit(corpus)
-
-# evaluate the model on a corpus with the given label
-clustering_model.evaluate(corpus, label_type="question_class")
-```
-
-BIRCH
----------
-BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm.
-BIRCH is specialized to handle large amounts of data. BIRCH scans the data a single time and builds an internal data
-structure. This data structure contains the data but in a compressed way.
-More about BIRCH can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html).
-
-```python
-from sklearn.cluster import Birch
-from flair.datasets import TREC_6
-from flair.embeddings import SentenceTransformerDocumentEmbeddings
-from flair.models import ClusteringModel
-
-embeddings = SentenceTransformerDocumentEmbeddings()
-
-# store all embeddings in memory which is required to perform clustering
-corpus = TREC_6(memory_mode='full').downsample(0.05)
-
-model = Birch(n_clusters=6)
-
-clustering_model = ClusteringModel(
- model=model,
- embeddings=embeddings
-)
-
-# fit the model on a corpus
-clustering_model.fit(corpus)
-
-# evaluate the model on a corpus with the given label
-clustering_model.evaluate(corpus, label_type="question_class")
-```
-
-
-Expectation Maximization
---------------------------
-Expectation Maximization (EM) is a different class of clustering algorithms called soft clustering algorithms.
-Here each point isn't directly assigned to a cluster by a hard decision.
-Each data point has a probability to which cluster the data point belongs. The Expectation Maximization (EM)
-algorithm is a soft clustering algorithm.
-More about EM can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html).
-
-
-```python
-from sklearn.mixture import GaussianMixture
-from flair.datasets import TREC_6
-from flair.embeddings import SentenceTransformerDocumentEmbeddings
-from flair.models import ClusteringModel
-
-embeddings = SentenceTransformerDocumentEmbeddings()
-
-# store all embeddings in memory which is required to perform clustering
-corpus = TREC_6(memory_mode='full').downsample(0.05)
-
-model = GaussianMixture(n_components=6)
-
-clustering_model = ClusteringModel(
- model=model,
- embeddings=embeddings
-)
-
-# fit the model on a corpus
-clustering_model.fit(corpus)
-
-# evaluate the model on a corpus with the given label
-clustering_model.evaluate(corpus, label_type="question_class")
-```
-
----------------------------
-
-Loading/Saving the model
------------
-
-The model can be saved and loaded. The code below shows how to save a model.
-```python
-from flair.models import ClusteringModel
-from flair.datasets import TREC_6
-from flair.embeddings import SentenceTransformerDocumentEmbeddings
-from sklearn.cluster import KMeans
-
-embeddings = SentenceTransformerDocumentEmbeddings()
-
-# store all embeddings in memory which is required to perform clustering
-corpus = TREC_6(memory_mode='full').downsample(0.05)
-
-model = KMeans(n_clusters=6)
-
-clustering_model = ClusteringModel(
- model=model,
- embeddings=embeddings
-)
-
-# fit the model on a corpus
-clustering_model.fit(corpus)
-
-# save the model
-clustering_model.save(model_file="clustering_model.pt")
-```
-
-The code for loading a model.
-
-````python
-# load saved clustering model
-model = ClusteringModel.load(model_file="clustering_model.pt")
-
-# load a corpus
-corpus = TREC_6(memory_mode='full').downsample(0.05)
-
-# predict the corpus
-model.predict(corpus)
-````
-
----------------------
-
-Evaluation
----------
-The result of the clustering can be evaluated. For this we will use the
-[NMI](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html).
-(Normalized Mutual Info) score.
-
-````python
-# need to fit() the model first
-# evaluate the model on a corpus with the given label
-clustering_model.evaluate(corpus, label_type="question_class")
-````
-
-The result of the evaluation can be seen below with the SentenceTransformerDocumentEmbeddings:
-
-
-| Clustering Algorithm | Dataset | NMI |
-|--------------------------|:-------------:|--------:|
-| k Means | StackOverflow | ~0.2122 |
-| BIRCH | StackOverflow | ~0,2424 |
-| Expectation Maximization | 20News group | ~0,2222 |
diff --git a/setup.py b/setup.py
index 3c1cc06018..0573896c19 100644
--- a/setup.py
+++ b/setup.py
@@ -20,5 +20,5 @@
"word-embeddings": ["gensim>=4.2.0", "bpemb>=0.3.5"],
},
include_package_data=True,
- python_requires=">=3.8",
+ python_requires=">=3.9",
)
diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py
index 554ef32777..c1a0b1a791 100644
--- a/tests/embedding_test_utils.py
+++ b/tests/embedding_test_utils.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
import pytest
import torch
@@ -9,15 +9,15 @@
class BaseEmbeddingsTest:
- embedding_cls: Type[Embeddings[Sentence]]
+ embedding_cls: type[Embeddings[Sentence]]
is_token_embedding: bool
is_document_embedding: bool
- default_args: Dict[str, Any]
- valid_args: List[Dict[str, Any]] = []
- invalid_args: List[Dict[str, Any]] = []
- invalid_names: List[str] = []
+ default_args: dict[str, Any]
+ valid_args: list[dict[str, Any]] = []
+ invalid_args: list[dict[str, Any]] = []
+ invalid_names: list[str] = []
name_field: Optional[str] = None
- weired_texts: List[str] = [
+ weired_texts: list[str] = [
"Hybrid mesons , qq ̄ states with an admixture",
"typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .",
"🤟 🤟 🤟 hüllo",
@@ -33,7 +33,7 @@ def create_embedding_from_name(self, name: str):
kwargs.pop(self.name_field)
return self.embedding_cls(name, **kwargs) # type: ignore[call-arg]
- def create_embedding_with_args(self, args: Dict[str, Any]):
+ def create_embedding_with_args(self, args: dict[str, Any]):
kwargs = dict(self.default_args)
for k, v in args.items():
kwargs[k] = v
diff --git a/tests/embeddings/test_document_transform_word_embeddings.py b/tests/embeddings/test_document_transform_word_embeddings.py
index 6a06372723..73567ffbeb 100644
--- a/tests/embeddings/test_document_transform_word_embeddings.py
+++ b/tests/embeddings/test_document_transform_word_embeddings.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List
+from typing import Any
from flair.embeddings import (
DocumentCNNEmbeddings,
@@ -19,7 +19,7 @@
class BaseDocumentsViaWordEmbeddingsTest(BaseEmbeddingsTest):
is_document_embedding = True
is_token_embedding = False
- base_embeddings: List[TokenEmbeddings] = [word, flair_embedding]
+ base_embeddings: list[TokenEmbeddings] = [word, flair_embedding]
def create_embedding_from_name(self, name: str):
"""Overwrite this method if it is more complex to load an embedding by name."""
@@ -28,7 +28,7 @@ def create_embedding_from_name(self, name: str):
kwargs.pop(self.name_field)
return self.embedding_cls(name, **kwargs) # type: ignore[call-arg]
- def create_embedding_with_args(self, args: Dict[str, Any]):
+ def create_embedding_with_args(self, args: dict[str, Any]):
kwargs = dict(self.default_args)
for k, v in args.items():
kwargs[k] = v
@@ -63,4 +63,4 @@ class TestDocumentCNNEmbeddings(BaseDocumentsViaWordEmbeddingsTest):
class TestDocumentLMEmbeddings(BaseDocumentsViaWordEmbeddingsTest):
embedding_cls = DocumentLMEmbeddings
base_embeddings = [flair_embedding, flair_embedding_back]
- default_args: Dict[str, Any] = {}
+ default_args: dict[str, Any] = {}
diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py
index 1a65a96fdb..f0f6389b7d 100644
--- a/tests/embeddings/test_transformer_document_embeddings.py
+++ b/tests/embeddings/test_transformer_document_embeddings.py
@@ -1,4 +1,6 @@
-from flair.data import Dictionary
+import pytest
+
+from flair.data import Dictionary, Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.nn import Classifier
@@ -37,3 +39,16 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path):
# check that context_length and use_context_separator is the same for both
assert model.embeddings.context_length == loaded_single_task.embeddings.context_length
assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator
+
+
+@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"])
+def test_cls_pooling(cls_pooling):
+ embeddings = TransformerDocumentEmbeddings(
+ model="distilbert-base-uncased",
+ layers="-1",
+ cls_pooling=cls_pooling,
+ allow_long_sentences=True,
+ )
+ sentence = Sentence("Today is a good day.")
+ embeddings.embed(sentence)
+ assert sentence.embedding is not None
diff --git a/tests/embeddings/test_word_embeddings.py b/tests/embeddings/test_word_embeddings.py
index 34d0b3b9f7..87f56fec4f 100644
--- a/tests/embeddings/test_word_embeddings.py
+++ b/tests/embeddings/test_word_embeddings.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict
+from typing import Any
from flair.embeddings import MuseCrosslingualEmbeddings, NILCEmbeddings, WordEmbeddings
from tests.embedding_test_utils import BaseEmbeddingsTest
@@ -18,7 +18,7 @@ class TestMuseCrosslingualEmbeddings(BaseEmbeddingsTest):
embedding_cls = MuseCrosslingualEmbeddings
is_token_embedding = True
is_document_embedding = False
- default_args: Dict[str, Any] = {}
+ default_args: dict[str, Any] = {}
class TestNILCEmbeddings(BaseEmbeddingsTest):
diff --git a/tests/model_test_utils.py b/tests/model_test_utils.py
index 10aab0831f..b5afd81bfe 100644
--- a/tests/model_test_utils.py
+++ b/tests/model_test_utils.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
import pytest
@@ -11,13 +11,13 @@
class BaseModelTest:
- model_cls: Type[Model]
+ model_cls: type[Model]
pretrained_model: Optional[str] = None
empty_sentence = Sentence(" ")
train_label_type: str
- multiclass_prediction_labels: List[str]
- model_args: Dict[str, Any] = {}
- training_args: Dict[str, Any] = {}
+ multiclass_prediction_labels: list[str]
+ model_args: dict[str, Any] = {}
+ training_args: dict[str, Any] = {}
finetune_instead_of_train: bool = False
@pytest.fixture()
diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py
index c0ca34bce5..da4de52bfc 100644
--- a/tests/models/test_relation_classifier.py
+++ b/tests/models/test_relation_classifier.py
@@ -1,5 +1,5 @@
from operator import itemgetter
-from typing import Dict, List, Optional, Set, Tuple
+from typing import Optional
import pytest
from torch.utils.data import Dataset
@@ -20,7 +20,7 @@
)
from tests.model_test_utils import BaseModelTest
-encoding_strategies: Dict[EncodingStrategy, List[Tuple[str, str]]] = {
+encoding_strategies: dict[EncodingStrategy, list[tuple[str, str]]] = {
EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)],
TypedEntityMask(): [
("[HEAD-ORG]", "[TAIL-PER]"),
@@ -140,7 +140,7 @@ def train_test_sentence(self):
return sentence
def assert_training_example(self, predicted_training_example):
- relations: List[Relation] = predicted_training_example.get_relations("relation")
+ relations: list[Relation] = predicted_training_example.get_relations("relation")
assert len(relations) == 2
# Intel ----founded_by---> Gordon Moore
@@ -164,7 +164,7 @@ def assert_training_example(self, predicted_training_example):
@staticmethod
def check_transformation_correctness(
split: Optional[Dataset],
- ground_truth: Set[Tuple[str, Tuple[str, ...]]],
+ ground_truth: set[tuple[str, tuple[str, ...]]],
) -> None:
# Ground truth is a set of tuples of (, )
assert split is not None
@@ -190,7 +190,7 @@ def test_transform_corpus(
embeddings: TransformerDocumentEmbeddings,
cross_augmentation: bool,
encoding_strategy: EncodingStrategy,
- encoded_entity_pairs: List[Tuple[str, str]],
+ encoded_entity_pairs: list[tuple[str, str]],
) -> None:
label_dictionary = corpus.make_label_dictionary("relation")
model: RelationClassifier = self.build_model(
@@ -200,7 +200,7 @@ def test_transform_corpus(
# Check sentence masking and relation label annotation on
# training, validation and test dataset (in this test the splits are the same)
- ground_truth: Set[Tuple[str, Tuple[str, ...]]] = {
+ ground_truth: set[tuple[str, tuple[str, ...]]] = {
# Entity pair permutations of: "Larry Page and Sergey Brin founded Google ."
(f"{encoded_entity_pairs[0][1]} and Sergey Brin founded {encoded_entity_pairs[0][0]} .", ("founded_by",)),
(f"Larry Page and {encoded_entity_pairs[1][1]} founded {encoded_entity_pairs[1][0]} .", ("founded_by",)),
diff --git a/tests/models/test_sequence_tagger.py b/tests/models/test_sequence_tagger.py
index 67bf976899..13cc15ddc1 100644
--- a/tests/models/test_sequence_tagger.py
+++ b/tests/models/test_sequence_tagger.py
@@ -1,4 +1,6 @@
import pytest
+import torch
+import torch.nn.functional as F
import flair
from flair.embeddings import FlairEmbeddings, WordEmbeddings
@@ -121,3 +123,36 @@ def test_train_load_use_tagger_disjunct_tags(
loaded_model.predict([example_sentence, self.empty_sentence])
loaded_model.predict([self.empty_sentence])
del loaded_model
+
+ @pytest.mark.integration()
+ def test_all_token_prob_distribution(self, embeddings, corpus):
+ tag_dictionary = corpus.make_label_dictionary("ner", add_unk=False)
+ model = self.build_model(embeddings, tag_dictionary)
+
+ # get features from forward propagation
+ sentences = [corpus.train[i] for i in range(len(corpus.train))]
+
+ # reverse sort all sequences by their length
+ sentences = sorted(sentences, key=len, reverse=True)
+
+ with torch.no_grad():
+ sentence_tensor, lengths = model._prepare_tensors(sentences)
+ features = model.forward(sentence_tensor, lengths)
+
+ # remove previously predicted labels of this type
+ for sentence in sentences:
+ sentence.remove_labels(model.label_type)
+
+ softmax_batch = F.softmax(features, dim=1).cpu()
+ lengths = [len(sentence) for sentence in sentences]
+ all_tokens_prob_distrib = model._all_scores_for_token(sentences, softmax_batch, lengths)
+
+ for i, sen_tokens_prob_distribution in enumerate(all_tokens_prob_distrib):
+ assert len(sen_tokens_prob_distribution) == lengths[i]
+ for token_prob_distrib, token in zip(sen_tokens_prob_distribution, sentences[i]):
+ assert len(token_prob_distrib) == len(model.label_dictionary)
+ score_sum = 0.0
+ for token_label in token_prob_distrib:
+ assert token_label.data_point == token
+ score_sum += token_label.score
+ assert abs(score_sum - 1.0) < 1.0e-5
diff --git a/tests/test_corpus_dictionary.py b/tests/test_corpus_dictionary.py
index cc65019b67..d107c0b62e 100644
--- a/tests/test_corpus_dictionary.py
+++ b/tests/test_corpus_dictionary.py
@@ -110,9 +110,9 @@ def test_tagged_corpus_make_vocab_dictionary():
vocab = corpus.make_vocab_dictionary(max_tokens=2, min_freq=-1)
assert len(vocab) == 3
- assert "" in vocab.get_items()
- assert "training" in vocab.get_items()
- assert "." in vocab.get_items()
+ assert vocab.has_item("")
+ assert vocab.has_item("training")
+ assert vocab.has_item(".")
vocab = corpus.make_vocab_dictionary(max_tokens=-1, min_freq=-1)
@@ -121,9 +121,9 @@ def test_tagged_corpus_make_vocab_dictionary():
vocab = corpus.make_vocab_dictionary(max_tokens=-1, min_freq=2)
assert len(vocab) == 3
- assert "" in vocab.get_items()
- assert "training" in vocab.get_items()
- assert "." in vocab.get_items()
+ assert vocab.has_item("")
+ assert vocab.has_item("training")
+ assert vocab.has_item(".")
def test_label_set_confidence():
@@ -153,9 +153,9 @@ def test_tagged_corpus_make_label_dictionary():
label_dict = corpus.make_label_dictionary("label", add_unk=True)
assert len(label_dict) == 3
- assert "" in label_dict.get_items()
- assert "class_1" in label_dict.get_items()
- assert "class_2" in label_dict.get_items()
+ assert label_dict.has_item("")
+ assert label_dict.has_item("class_1")
+ assert label_dict.has_item("class_2")
with pytest.warns(DeprecationWarning): # test to make sure the warning comes, but function works
corpus.make_tag_dictionary("label")
diff --git a/tests/test_datasets_biomedical.py b/tests/test_datasets_biomedical.py
index 0264b08394..c15674eb6b 100644
--- a/tests/test_datasets_biomedical.py
+++ b/tests/test_datasets_biomedical.py
@@ -2,7 +2,7 @@
import os
import tempfile
from pathlib import Path
-from typing import List, Optional
+from typing import Optional
from flair.datasets.biomedical import (
CoNLLWriter,
@@ -84,7 +84,7 @@ def test_conll_writer_one_token_multiple_entities2():
def assert_conll_writer_output(
dataset: InternalBioNerDataset,
- expected_output: List[str],
+ expected_output: list[str],
sentence_splitter: Optional[SentenceSplitter] = None,
):
fd, outfile_path = tempfile.mkstemp()
diff --git a/tests/test_labels.py b/tests/test_labels.py
index 210a215889..099484162c 100644
--- a/tests/test_labels.py
+++ b/tests/test_labels.py
@@ -1,5 +1,3 @@
-from typing import List
-
from flair.data import Label, Relation, Sentence, Span
@@ -14,7 +12,7 @@ def test_token_tags():
sentence[0].add_label("pos", "pronoun")
# check if there are three POS labels with correct text and values
- labels: List[Label] = sentence.get_labels("pos")
+ labels: list[Label] = sentence.get_labels("pos")
assert len(labels) == 3
assert labels[0].data_point.text == "I"
assert labels[0].value == "pronoun"
@@ -24,7 +22,7 @@ def test_token_tags():
assert labels[2].value == "proper noun"
# check if there are is one SENTIMENT label with correct text and values
- labels: List[Label] = sentence.get_labels("sentiment")
+ labels: list[Label] = sentence.get_labels("sentiment")
assert len(labels) == 1
assert labels[0].data_point.text == "love"
assert labels[0].value == "positive"
@@ -45,7 +43,7 @@ def test_token_tags():
# remove the pos label from the last word
sentence[2].remove_labels("pos")
# there should be 2 POS labels left
- labels: List[Label] = sentence.get_labels("pos")
+ labels: list[Label] = sentence.get_labels("pos")
assert len(labels) == 2
assert len(sentence[0].get_labels("pos")) == 1
assert len(sentence[1].get_labels("pos")) == 1
@@ -72,7 +70,7 @@ def test_span_tags():
sentence[7:8].add_label("ner", "City")
# check if there are three labels with correct text and values
- labels: List[Label] = sentence.get_labels("ner")
+ labels: list[Label] = sentence.get_labels("ner")
assert len(labels) == 3
assert labels[0].data_point.text == "Humboldt Universität zu Berlin"
assert labels[0].value == "Organization"
@@ -82,7 +80,7 @@ def test_span_tags():
assert labels[2].value == "City"
# check if there are two spans with correct text and values
- spans: List[Span] = sentence.get_spans("ner")
+ spans: list[Span] = sentence.get_spans("ner")
assert len(spans) == 2
assert spans[0].text == "Humboldt Universität zu Berlin"
assert len(spans[0].get_labels("ner")) == 2
@@ -92,12 +90,12 @@ def test_span_tags():
# now delete the NER tags of "Humboldt-Universität zu Berlin"
sentence[0:4].remove_labels("ner")
# should be only one NER label left
- labels: List[Label] = sentence.get_labels("ner")
+ labels: list[Label] = sentence.get_labels("ner")
assert len(labels) == 1
assert labels[0].data_point.text == "Berlin"
assert labels[0].value == "City"
# and only one NER span
- spans: List[Span] = sentence.get_spans("ner")
+ spans: list[Span] = sentence.get_spans("ner")
assert len(spans) == 1
assert spans[0].text == "Berlin"
assert spans[0].get_label("ner").value == "City"
@@ -111,7 +109,7 @@ def test_different_span_tags():
sentence[7:8].add_label("ner", "City")
# check if there are three labels with correct text and values
- labels: List[Label] = sentence.get_labels("ner")
+ labels: list[Label] = sentence.get_labels("ner")
assert len(labels) == 2
assert labels[0].data_point.text == "Humboldt Universität zu Berlin"
assert labels[0].value == "Organization"
@@ -119,7 +117,7 @@ def test_different_span_tags():
assert labels[1].value == "City"
# check if there are two spans with correct text and values
- spans: List[Span] = sentence.get_spans("ner")
+ spans: list[Span] = sentence.get_spans("ner")
assert len(spans) == 2
assert spans[0].text == "Humboldt Universität zu Berlin"
assert spans[0].get_label("ner").value == "Organization"
@@ -131,22 +129,22 @@ def test_different_span_tags():
# now delete the NER tags of "Humboldt-Universität zu Berlin"
sentence[0:4].remove_labels("ner")
# should be only one NER label left
- labels: List[Label] = sentence.get_labels("ner")
+ labels: list[Label] = sentence.get_labels("ner")
assert len(labels) == 1
assert labels[0].data_point.text == "Berlin"
assert labels[0].value == "City"
# and only one NER span
- spans: List[Span] = sentence.get_spans("ner")
+ spans: list[Span] = sentence.get_spans("ner")
assert len(spans) == 1
assert spans[0].text == "Berlin"
assert spans[0].get_label("ner").value == "City"
# but there is also one orgtype span and label
- labels: List[Label] = sentence.get_labels("orgtype")
+ labels: list[Label] = sentence.get_labels("orgtype")
assert len(labels) == 1
assert labels[0].data_point.text == "Humboldt Universität zu Berlin"
assert labels[0].value == "University"
# and only one NER span
- spans: List[Span] = sentence.get_spans("orgtype")
+ spans: list[Span] = sentence.get_spans("orgtype")
assert len(spans) == 1
assert spans[0].text == "Humboldt Universität zu Berlin"
assert spans[0].get_label("orgtype").value == "University"
@@ -154,7 +152,7 @@ def test_different_span_tags():
# let's add the NER tag back
sentence[0:4].add_label("ner", "Organization")
# check if there are three labels with correct text and values
- labels: List[Label] = sentence.get_labels("ner")
+ labels: list[Label] = sentence.get_labels("ner")
print(labels)
assert len(labels) == 2
assert labels[0].data_point.text == "Humboldt Universität zu Berlin"
@@ -163,7 +161,7 @@ def test_different_span_tags():
assert labels[1].value == "City"
# check if there are two spans with correct text and values
- spans: List[Span] = sentence.get_spans("ner")
+ spans: list[Span] = sentence.get_spans("ner")
assert len(spans) == 2
assert spans[0].text == "Humboldt Universität zu Berlin"
assert spans[0].get_label("ner").value == "Organization"
@@ -194,17 +192,17 @@ def test_relation_tags():
Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition")
# there should be two relation labels
- labels: List[Label] = sentence.get_labels("rel")
+ labels: list[Label] = sentence.get_labels("rel")
assert len(labels) == 2
assert labels[0].value == "located in"
assert labels[1].value == "university of"
# there should be one syntactic labels
- labels: List[Label] = sentence.get_labels("syntactic")
+ labels: list[Label] = sentence.get_labels("syntactic")
assert len(labels) == 1
# there should be two relations, one with two and one with one label
- relations: List[Relation] = sentence.get_relations("rel")
+ relations: list[Relation] = sentence.get_relations("rel")
assert len(relations) == 2
assert len(relations[0].labels) == 1
assert len(relations[1].labels) == 2
diff --git a/tests/test_tokenize_sentence.py b/tests/test_tokenize_sentence.py
index fd049b642e..7fd03ac6ba 100644
--- a/tests/test_tokenize_sentence.py
+++ b/tests/test_tokenize_sentence.py
@@ -1,5 +1,3 @@
-from typing import List
-
import pytest
import flair
@@ -492,5 +490,5 @@ def test_line_separator_is_ignored():
assert Sentence(with_separator).to_original_text() == Sentence(without_separator).to_original_text()
-def no_op_tokenizer(text: str) -> List[str]:
+def no_op_tokenizer(text: str) -> list[str]:
return [text]