Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AutoGluon to 1.0 API #604

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions frameworks/AutoGluon/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def run(dataset, config):

is_classification = config.type == 'classification'
training_params = {k: v for k, v in config.framework_params.items() if not k.startswith('_')}
time_limit = config.max_runtime_seconds
presets = training_params.get("presets", [])
presets = presets if isinstance(presets, list) else [presets]
if preset_with_refit_full := (set(presets) & {"good_quality", "high_quality"}):
if (preset_with_refit_full := (set(presets) & {"good_quality", "high_quality"})) and (time_limit is not None):
preserve = 0.9
preset = next(iter(preset_with_refit_full))
msg = (
Expand All @@ -61,7 +62,7 @@ def run(dataset, config):
"See https://auto.gluon.ai/stable/api/autogluon.tabular.TabularPredictor.refit_full.html"
)
log.info(msg)
config.max_runtime_seconds = preserve * config.max_runtime_seconds
time_limit = preserve * config.max_runtime_seconds

train_path, test_path = dataset.train.path, dataset.test.path
label = dataset.target.name
Expand All @@ -77,15 +78,17 @@ def run(dataset, config):
problem_type=problem_type,
).fit(
train_data=train_path,
time_limit=config.max_runtime_seconds,
time_limit=time_limit,
**training_params
)

log.info(f"Finished fit in {training.duration}s.")

# Persist model in memory that is going to be predicting to get correct inference latency
# max_memory=0.4 will be future default: https://github.com/autogluon/autogluon/pull/3338
predictor.persist_models('best', max_memory=0.4)
if hasattr(predictor, 'persist'): # autogluon>=1.0
predictor.persist('best')
else:
predictor.persist_models('best')

def inference_time_classification(data: Union[str, pd.DataFrame]):
return None, predictor.predict_proba(data, as_multiclass=True)
Expand All @@ -108,14 +111,17 @@ def inference_time_regression(data: Union[str, pd.DataFrame]):
with Timer() as predict:
predictions, probabilities = infer(test_data)
if is_classification:
predictions = probabilities.idxmax(axis=1).to_numpy()
if hasattr(predictor, 'predict_from_proba'): # autogluon>=1.0
predictions = predictor.predict_from_proba(probabilities).to_numpy()
else:
predictions = probabilities.idxmax(axis=1).to_numpy()

prob_labels = probabilities.columns.values.astype(str).tolist() if probabilities is not None else None
log.info(f"Finished predict in {predict.duration}s.")

_leaderboard_extra_info = config.framework_params.get('_leaderboard_extra_info', False) # whether to get extra model info (very verbose)
_leaderboard_test = config.framework_params.get('_leaderboard_test', False) # whether to compute test scores in leaderboard (expensive)
leaderboard_kwargs = dict(silent=True, extra_info=_leaderboard_extra_info)
leaderboard_kwargs = dict(extra_info=_leaderboard_extra_info)
# Disabled leaderboard test data input by default to avoid long running computation, remove 7200s timeout limitation to re-enable
if _leaderboard_test:
leaderboard_kwargs['data'] = test_data
Expand Down
5 changes: 4 additions & 1 deletion frameworks/AutoGluon/setup.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#!/usr/bin/env bash

# exit when any command fails
set -e

HERE=$(dirname "$0")
VERSION=${1:-"stable"}
REPO=${2:-"https://github.com/awslabs/autogluon.git"}
REPO=${2:-"https://github.com/autogluon/autogluon.git"}
PKG=${3:-"autogluon"}
if [[ "$VERSION" == "latest" ]]; then
VERSION="master"
Expand Down
Loading