Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose progress bar class control #1110

Merged
merged 15 commits into from
May 20, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Updated `TermSetWrapper` to support validating a single field within a compound array. @mavaylon1 [#1061](https://github.com/hdmf-dev/hdmf/pull/1061)
- Updated testing to not install in editable mode and not run `coverage` by default. @rly [#1107](https://github.com/hdmf-dev/hdmf/pull/1107)
- Add `post_init_method` parameter when generating classes to perform post-init functionality, i.e., validation. @mavaylon1 [#1089](https://github.com/hdmf-dev/hdmf/pull/1089)
- Exposed `progress_bar_class` to the `GenericDataChunkIterator` for more custom control over display of progress while iterating. @codycbakerphd [#1110](https://github.com/hdmf-dev/hdmf/pull/1110)
- Updated loading, unloading, and getting the `TypeConfigurator` to support a `TypeMap` parameter. @mavaylon1 [#1117](https://github.com/hdmf-dev/hdmf/pull/1117)

### Bug Fixes
Expand Down
35 changes: 29 additions & 6 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import copy
import math
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from collections.abc import Iterable, Callable
from warnings import warn
from typing import Tuple, Callable
from typing import Tuple
from itertools import product, chain

import h5py
Expand Down Expand Up @@ -179,9 +179,15 @@
doc="Display a progress bar with iteration rate and estimated completion time.",
default=False,
),
dict(
name="progress_bar_class",
type=Callable,
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
doc="The progress bar class to use. Defaults to tqdm.tqdm if the TQDM package is installed.",
default=None,
),
dict(
name="progress_bar_options",
type=None,
type=dict,
doc="Dictionary of keyword arguments to be passed directly to tqdm.",
default=None,
),
Expand All @@ -199,8 +205,23 @@
HDF5 recommends chunk size in the range of 2 to 16 MB for optimal cloud performance.
https://youtu.be/rcS5vt-mKok?t=621
"""
buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, progress_bar_options = getargs(
"buffer_gb", "buffer_shape", "chunk_mb", "chunk_shape", "display_progress", "progress_bar_options", kwargs
(
buffer_gb,
buffer_shape,
chunk_mb,
chunk_shape,
self.display_progress,
progress_bar_class,
progress_bar_options,
) = getargs(
"buffer_gb",
"buffer_shape",
"chunk_mb",
"chunk_shape",
"display_progress",
"progress_bar_class",
"progress_bar_options",
kwargs,
)
self.progress_bar_options = progress_bar_options or dict()

Expand Down Expand Up @@ -277,11 +298,13 @@
try:
from tqdm import tqdm

progress_bar_class = progress_bar_class or tqdm

Check warning on line 301 in src/hdmf/data_utils.py

View check run for this annotation

Codecov / codecov/patch

src/hdmf/data_utils.py#L301

Added line #L301 was not covered by tests

if "total" in self.progress_bar_options:
warn("Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring.")
self.progress_bar_options.pop("total")

self.progress_bar = tqdm(total=self.num_buffers, **self.progress_bar_options)
self.progress_bar = progress_bar_class(total=self.num_buffers, **self.progress_bar_options)

Check warning on line 307 in src/hdmf/data_utils.py

View check run for this annotation

Codecov / codecov/patch

src/hdmf/data_utils.py#L307

Added line #L307 was not covered by tests
except ImportError:
warn(
"You must install tqdm to use the progress bar feature (pip install tqdm)! "
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/utils_test/test_core_GenericDataChunkIterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from tempfile import mkdtemp
from shutil import rmtree
from typing import Tuple, Iterable, Callable
from typing import Tuple, Iterable, Callable, Union
from sys import version_info

import h5py
Expand Down Expand Up @@ -408,6 +408,33 @@ def test_progress_bar(self):
first_line = file.read()
self.assertIn(member=desc, container=first_line)

@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed")
def test_progress_bar_class(self):
import tqdm

class MyCustomProgressBar(tqdm.tqdm):
def update(self, n: int = 1) -> Union[bool, None]:
displayed = super().update(n)
print(f"Custom injection on step {n}") # noqa: T201

return displayed

out_text_file = self.test_dir / "test_progress_bar_class.txt"
desc = "Testing progress bar..."
with open(file=out_text_file, mode="w") as file:
iterator = self.TestNumpyArrayDataChunkIterator(
array=self.test_array,
display_progress=True,
progress_bar_class=MyCustomProgressBar,
progress_bar_options=dict(file=file, desc=desc),
)
j = 0
for buffer in iterator:
j += 1 # dummy operation; must be silent for proper updating of bar
with open(file=out_text_file, mode="r") as file:
first_line = file.read()
self.assertIn(member=desc, container=first_line)

@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is installed")
def test_progress_bar_no_options(self):
dci = self.TestNumpyArrayDataChunkIterator(array=self.test_array, display_progress=True)
Expand Down
Loading