Skip to content

Commit

Permalink
Change default behavior for TrainDatasets overwrite (#2121)
Browse files Browse the repository at this point in the history
Co-authored-by: nklingen <[email protected]>
  • Loading branch information
nok-halfspace and nklingen authored Jun 30, 2022
1 parent 4e5ee8a commit f529423
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def save(
self,
path_str: str,
writer: DatasetWriter,
overwrite=True,
overwrite=False,
) -> None:
"""
Saves an TrainDatasets object to a JSON Lines file.
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/dataset/repository/_airpassengers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def generate_airpassengers_dataset(
meta = MetaData(freq="1M", prediction_length=12)

dataset = TrainDatasets(metadata=meta, train=train, test=test)
dataset.save(str(dataset_path), writer=dataset_writer)
dataset.save(str(dataset_path), writer=dataset_writer, overwrite=True)
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_artificial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ def generate_artificial_dataset(
ds.metadata.prediction_length = prediction_length

dataset = TrainDatasets(metadata=ds.metadata, train=ds.train, test=ds.test)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_gp_copula_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def generate_gp_copula_dataset(
)

dataset = TrainDatasets(metadata=meta, train=train_data, test=test_data)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)
clean_up_dataset(dataset_path, ds_info)


Expand Down
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,6 @@ def generate_lstnet_dataset(
)

dataset = TrainDatasets(metadata=meta, train=train_ts, test=test_ts)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def normalize_category(c: str):
)

dataset = TrainDatasets(metadata=meta, train=train_data, test=test_data)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)

check_dataset(dataset_path, len(df), subset.sheet_name)
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,6 @@ def generate_m4_dataset(
)

dataset = TrainDatasets(metadata=meta, train=train_data, test=test_data)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,6 @@ def generate_m5_dataset(
)

dataset = TrainDatasets(metadata=meta, train=train_ds, test=test_ds)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_tsf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def generate_forecasting_dataset(
)

dataset = TrainDatasets(metadata=meta, train=train_data, test=test_data)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)


def default_prediction_length_from_frequency(freq: str) -> int:
Expand Down
4 changes: 3 additions & 1 deletion src/gluonts/dataset/repository/_uber_tlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,6 @@ def generate_uber_dataset(
)

dataset = TrainDatasets(metadata=meta, train=train_data, test=test_data)
dataset.save(path_str=str(dataset_path), writer=dataset_writer)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)

0 comments on commit f529423

Please sign in to comment.