Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repeating training data beyond epoch limit if batch_size > len(dataset) #30

Open
jelmervdl opened this issue Aug 14, 2023 · 1 comment
Labels
enhancement New feature or request minor

Comments

@jelmervdl
Copy link
Contributor

In this bit of code, we use islice to read batch_size elements from the dataset if the dataset has not yet read past its epoch (for this stage):

while self.stage.until_epoch is None or self.epoch_tracker.epoch < self.stage.until_epoch:
batch: List[str] = []
# Read from each dataset according to its weight in this stage
# (They will reshuffle and repeat if necessary)
for dataset, weight in self.stage.datasets:
batch.extend(islice(self.readers[dataset.name], 0, int(batch_size * weight)))

The problem is that if batch_size is anywhere near the size of the dataset, we could well read beyond that epoch. The output becomes more than you'd expect. If no errors occur with the modifiers, the output will always be a multiple of batch_size.

One way to solve this is to put the Epoch tracker around the iterator directly, so it will stop generating once it has reached its epoch limit. That batch then might be smaller, but that's okay.

Another thing I'd like to take a look at is that we generate batch_size * dataset_weight items per dataset. If batch_size is small, this doesn't work very well.

@XapaJIaMnu XapaJIaMnu added minor enhancement New feature or request labels Aug 15, 2023
@graemenail
Copy link
Contributor

I'm also saw int(batch_size * weight) and came to open an issue. At least one test calls for a batch size of 1. Any dataset with a weight $w&lt;1$ is does not contribute to the batch. The only reason this test passes is because it is the only dataset in the test, and has a weight of 1. Reducing this sole weight below 1 causes the test to hang, because the dataset never advances and triggers the until epoch condition.

In a more real-world scenario, with multiple datasets, we may miss this if the dataset is never part of an until condition because we do not impose any checks on batch-size.

Since this is ultimately a sampling issue, I'll also add that we should be normalising weights where possible. This would also allow for dataset ratios of 2:1 to behave properly, but also make our sampling more robust.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request minor
Projects
None yet
Development

No branches or pull requests

3 participants