Skip to content

Commit

Permalink
replacing tqdm_leave argument with more general tqdm_kwargs to allow …
Browse files Browse the repository at this point in the history
…passing a dict of kwargs to the tqdm progress bar in parallel.py (by default the dict is empty and no kwargs are passed)
  • Loading branch information
ejhigson committed Apr 5, 2018
1 parent 28af123 commit 9d2438b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 20 deletions.
14 changes: 7 additions & 7 deletions fgivenx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def compute_samples(f, x, samples, logZ=None, **kwargs):
parallel:
see docstring for fgivenx.parallel.parallel_apply.
tqdm_leave: bool
tqdm_kwargs: dict
see docstring for fgivenx.parallel.parallel_apply.
Returns
Expand All @@ -88,7 +88,7 @@ def compute_samples(f, x, samples, logZ=None, **kwargs):
parallel = kwargs.pop('parallel', False)
ntrim = kwargs.pop('ntrim', None)
cache = kwargs.pop('cache', None)
tqdm_leave = kwargs.pop('tqdm_leave', True)
tqdm_kwargs = kwargs.pop('tqdm_kwargs', {})
if kwargs:
raise TypeError('Unexpected **kwargs: %r' % kwargs)

Expand All @@ -101,7 +101,7 @@ def compute_samples(f, x, samples, logZ=None, **kwargs):

return fgivenx.samples.compute_samples(f, x, samples,
parallel=parallel, cache=cache,
tqdm_leave=tqdm_leave)
tqdm_kwargs=tqdm_kwargs)


def compute_pmf(f, x, samples, logZ=None, **kwargs):
Expand Down Expand Up @@ -135,7 +135,7 @@ def compute_pmf(f, x, samples, logZ=None, **kwargs):
Keywords
--------
tqdm_leave: bool
tqdm_kwargs: dict
see docstring for fgivenx.parallel.parallel_apply.
Returns
Expand All @@ -153,7 +153,7 @@ def compute_pmf(f, x, samples, logZ=None, **kwargs):
ny = kwargs.pop('ny', 100)
y = kwargs.pop('y', None)
cache = kwargs.pop('cache', None)
tqdm_leave = kwargs.pop('tqdm_leave', True)
tqdm_kwargs = kwargs.pop('tqdm_kwargs', {})
if kwargs:
raise TypeError('Unexpected **kwargs: %r' % kwargs)

Expand All @@ -166,15 +166,15 @@ def compute_pmf(f, x, samples, logZ=None, **kwargs):
fsamps = compute_samples(f, x, samples, logZ=logZ,
weights=weights, ntrim=ntrim,
parallel=parallel, cache=cache,
tqdm_leave=tqdm_leave)
tqdm_kwargs=tqdm_kwargs)

if y is None:
ymin = fsamps[~numpy.isnan(fsamps)].min(axis=None)
ymax = fsamps[~numpy.isnan(fsamps)].max(axis=None)
y = numpy.linspace(ymin, ymax, ny)

return y, fgivenx.mass.compute_pmf(fsamps, y, parallel=parallel,
cache=cache, tqdm_leave=tqdm_leave)
cache=cache, tqdm_kwargs=tqdm_kwargs)


def compute_dkl(f, x, samples, prior_samples, logZ=None, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions fgivenx/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,18 @@ def compute_pmf(fsamps, y, **kwargs):
Keywords
--------
parallel:
parallel: bool
see docstring for fgivenx.parallel.parallel_apply.
tqdm_leave:
tqdm_kwargs: dict
see docstring for fgivenx.parallel.parallel_apply.
Returns
-------
"""
parallel = kwargs.pop('parallel', False)
cache = kwargs.pop('cache', None)
tqdm_leave = kwargs.pop('tqdm_leave', True)
tqdm_kwargs = kwargs.pop('tqdm_kwargs', {})
if kwargs:
raise TypeError('Unexpected **kwargs: %r' % kwargs)

Expand All @@ -151,7 +151,7 @@ def compute_pmf(fsamps, y, **kwargs):
print(e)

masses = parallel_apply(PMF, fsamps, postcurry=(y,), parallel=parallel,
tqdm_leave=tqdm_leave)
tqdm_kwargs=tqdm_kwargs)
masses = numpy.array(masses).transpose().copy()

if cache is not None:
Expand Down
11 changes: 5 additions & 6 deletions fgivenx/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def parallel_apply(f, array, **kwargs):
int < 0 or bool=True: use OMP_NUM_THREADS to choose parallelisation
bool=False or int=0: do not parallelise
tqdm_leave: bool
tqdm progress bars' 'leave' setting - set to False to have progress
bars disappear when finished.
tqdm_kwargs: dict
additional kwargs for tqdm progress bars.
precurry: tuple
immutable arguments to pass to f before x,
Expand All @@ -42,7 +41,7 @@ def parallel_apply(f, array, **kwargs):
precurry = tuple(kwargs.pop('precurry', ()))
postcurry = tuple(kwargs.pop('postcurry', ()))
parallel = kwargs.pop('parallel', False)
tqdm_leave = kwargs.pop('tqdm_leave', True)
tqdm_kwargs = kwargs.pop('tqdm_kwargs', {})
if kwargs:
raise TypeError('Unexpected **kwargs: %r' % kwargs)
# If running in a jupyter notebook then use tqdm_notebook. Otherwise use
Expand All @@ -55,7 +54,7 @@ def parallel_apply(f, array, **kwargs):
progress = tqdm.tqdm
if not parallel:
return [f(*(precurry + (x,) + postcurry)) for x in
progress(array, leave=tqdm_leave)]
progress(array, **tqdm_kwargs)]
elif parallel is True:
nprocs = cpu_count()
elif isinstance(parallel, int):
Expand All @@ -67,4 +66,4 @@ def parallel_apply(f, array, **kwargs):
raise ValueError("parallel keyword must be an integer or bool")

return Parallel(n_jobs=nprocs)(delayed(f)(*(precurry + (x,) + postcurry))
for x in progress(array, leave=tqdm_leave))
for x in progress(array, **tqdm_kwargs))
6 changes: 3 additions & 3 deletions fgivenx/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def compute_samples(f, x, samples, **kwargs):
parallel:
see docstring for fgivenx.parallel.parallel_apply.
tqdm_leave: bool
tqdm_kwargs: dict
see docstring for fgivenx.parallel.parallel_apply.
Expand All @@ -72,7 +72,7 @@ def compute_samples(f, x, samples, **kwargs):

parallel = kwargs.pop('parallel', False)
cache = kwargs.pop('cache', None)
tqdm_leave = kwargs.pop('tqdm_leave', True)
tqdm_kwargs = kwargs.pop('tqdm_kwargs', {})
if kwargs:
raise TypeError('Unexpected **kwargs: %r' % kwargs)

Expand All @@ -87,7 +87,7 @@ def compute_samples(f, x, samples, **kwargs):
for fi, s in zip(f, samples):
if len(s) > 0:
fsamps = parallel_apply(fi, s, precurry=(x,), parallel=parallel,
tqdm_leave=tqdm_leave)
tqdm_kwargs=tqdm_kwargs)
fsamps = numpy.array(fsamps).transpose().copy()
fsamples.append(fsamps)
fsamples = numpy.concatenate(fsamples, axis=1)
Expand Down

0 comments on commit 9d2438b

Please sign in to comment.