Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow silencing of predict_tile stdout based on verbose flag #878

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# entry point for deepforest model
import importlib
import os
import contextlib
import io
import typing
import warnings

Expand Down Expand Up @@ -486,7 +488,8 @@ def predict_tile(self,
thickness=1,
crop_model=None,
crop_transform=None,
crop_augment=False):
crop_augment=False,
verbose=True):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Expand All @@ -507,19 +510,27 @@ def predict_tile(self,
return_plot: return a plot of the image with predictions overlaid (deprecated)
color: color of the bounding box as a tuple of BGR color (deprecated)
thickness: thickness of the rectangle border line in px (deprecated)

verbose: whether to show progress bar
Returns:
pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple
"""
self.model.eval()
self.model.nms_thresh = self.config["nms_thresh"]

# if more than one GPU present, use only a the first available gpu
# if more than one GPU present, use only the first available gpu
if torch.cuda.device_count() > 1:
# Get available gpus and regenerate trainer
warnings.warn(
"More than one GPU detected. Using only the first GPU for predict_tile.")
self.config["devices"] = 1

# Configure trainer based on verbose setting
if not verbose:
callbacks = [
cb for cb in self.trainer.callbacks
if not isinstance(cb, pl.callbacks.ProgressBar)
]
self.create_trainer(enable_progress_bar=False, callbacks=callbacks)
else:
self.create_trainer()

if (raster_path is None) and (image is None):
Expand Down Expand Up @@ -551,6 +562,7 @@ def predict_tile(self,
patch_overlap=patch_overlap,
patch_size=patch_size)

# Predict using trainer
batched_results = self.trainer.predict(self, self.predict_dataloader(ds))

# Flatten list from batched prediction
Expand All @@ -560,11 +572,21 @@ def predict_tile(self,
results.append(boxes)

if mosaic:
results = predict.mosiac(results,
ds.windows,
sigma=sigma,
thresh=thresh,
iou_threshold=iou_threshold)
# Suppress output if not verbose
if not verbose:
f = io.StringIO()
with contextlib.redirect_stdout(f):
results = predict.mosiac(results,
ds.windows,
sigma=sigma,
thresh=thresh,
iou_threshold=iou_threshold)
else:
results = predict.mosiac(results,
ds.windows,
sigma=sigma,
thresh=thresh,
iou_threshold=iou_threshold)
results["label"] = results.label.apply(
lambda x: self.numeric_to_label_dict[x])
if raster_path:
Expand Down
22 changes: 21 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,4 +906,24 @@ def test_evaluate_on_epoch_interval(m):
m.create_trainer()
m.trainer.fit(m)
assert m.trainer.logged_metrics["box_precision"]
assert m.trainer.logged_metrics["box_recall"]
assert m.trainer.logged_metrics["box_recall"]

@pytest.mark.parametrize("verbose", [True, False])
def test_predict_tile_verbose(m, raster_path, capsys, verbose):
"""Test that verbose output can be controlled in predict_tile"""
m.config["train"]["fast_dev_run"] = False
m.create_trainer()

m.predict_tile(
raster_path=raster_path,
patch_size=300,
patch_overlap=0,
mosaic=True,
verbose=verbose
)

captured = capsys.readouterr()
if verbose:
assert captured.out.strip()
else:
assert not captured.out.strip()
Loading