Skip to content

Commit

Permalink
[dataset] support bucket by seq length (#2333)
Browse files Browse the repository at this point in the history
* [dataset] support bucket by seq length

* support bucket in dataset.py

* add it
  • Loading branch information
Mddct authored Feb 4, 2024
1 parent f605684 commit b115913
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,38 @@ def test_dynamic_batch_datapipe(data_list):
assert d['feats'].size(1) <= max_frames_in_batch


def test_bucket_batch_datapipe():
dataset = datapipes.iter.IterableWrapper(range(10))

def _seq_len_fn(elem):
if elem < 5:
return 2
elif elem >= 5 and elem < 7:
return 4
else:
return 8

dataset = dataset.bucket_by_sequence_length(
_seq_len_fn,
bucket_boundaries=[3, 5],
bucket_batch_sizes=[3, 2, 2],
)
expected = [
[0, 1, 2],
[5, 6],
[7, 8],
[3, 4],
[9],
]
result = []
for d in dataset:
result.append(d)
assert len(result) == len(expected)
for (r, h) in zip(expected, result):
assert len(r) == len(h)
assert all(rr == hh for (rr, hh) in zip(r, h))


def test_shuffle_deterministic():
dataset = datapipes.iter.IterableWrapper(range(10))
dataset = dataset.shuffle()
Expand Down
90 changes: 90 additions & 0 deletions wenet/dataset/datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import collections
from collections.abc import Callable
import sys
import tarfile
import logging
from typing import List
import torch
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import datapipes
Expand Down Expand Up @@ -56,6 +58,94 @@ def __iter__(self):
logging.warning(str(ex))


@functional_datapipe('bucket_by_sequence_length')
class BucketBySequenceLengthDataPipe(IterDataPipe):

def __init__(
self,
dataset: IterDataPipe,
elem_length_func,
bucket_boundaries: List[int],
bucket_batch_sizes: List[int],
wrapper_class=None,
) -> None:
super().__init__()
_check_unpickable_fn(elem_length_func)
assert len(bucket_batch_sizes) == len(bucket_boundaries) + 1
self.bucket_batch_sizes = bucket_batch_sizes
self.bucket_boundaries = bucket_boundaries + [sys.maxsize]
self.elem_length_func = elem_length_func

self._group_dp = GroupByWindowDataPipe(dataset,
self._element_to_bucket_id,
self._window_size_func,
wrapper_class=wrapper_class)

def __iter__(self):
yield from self._group_dp

def _element_to_bucket_id(self, elem):
seq_len = self.elem_length_func(elem)
bucket_id = 0
for (i, b) in enumerate(self.bucket_boundaries):
if seq_len < b:
bucket_id = i
break
return bucket_id

def _window_size_func(self, bucket_id):
return self.bucket_batch_sizes[bucket_id]


@functional_datapipe("group_by_window")
class GroupByWindowDataPipe(datapipes.iter.Grouper):

def __init__(
self,
dataset: IterDataPipe,
key_func,
window_size_func,
wrapper_class=None,
):
super().__init__(dataset,
key_func,
keep_key=False,
group_size=None,
drop_remaining=False)
_check_unpickable_fn(window_size_func)
self.dp = dataset
self.window_size_func = window_size_func
if wrapper_class is not None:
_check_unpickable_fn(wrapper_class)
del self.wrapper_class
self.wrapper_class = wrapper_class

def __iter__(self):
for x in self.datapipe:
key = self.group_key_fn(x)

self.buffer_elements[key].append(x)
self.curr_buffer_size += 1

group_size = self.window_size_func(key)
if group_size == len(self.buffer_elements[key]):
result = self.wrapper_class(self.buffer_elements[key])
yield result
self.curr_buffer_size -= len(self.buffer_elements[key])
del self.buffer_elements[key]

if self.curr_buffer_size == self.max_buffer_size:
result_to_yield = self._remove_biggest_key()
if result_to_yield is not None:
result = self.wrapper_class(result_to_yield)
yield result

for key in tuple(self.buffer_elements.keys()):
result = self.wrapper_class(self.buffer_elements.pop(key))
self.curr_buffer_size -= len(result)
yield result


@functional_datapipe("sort")
class SortDataPipe(IterDataPipe):

Expand Down
8 changes: 8 additions & 0 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def Dataset(data_type,
assert 'batch_size' in batch_conf
batch_size = batch_conf.get('batch_size', 16)
dataset = dataset.batch(batch_size, wrapper_class=processor.padding)
elif batch_type == 'bucket':
assert 'bucket_boundaries' in batch_conf
assert 'bucket_batch_sizes' in batch_conf
dataset = dataset.bucket_by_sequence_length(
processor.feats_length_fn,
batch_conf['bucket_boundaries'],
batch_conf['bucket_batch_sizes'],
wrapper_class=processor.padding)
else:
max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000)
dataset = dataset.dynamic_batch(
Expand Down
5 changes: 5 additions & 0 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def sort_by_feats(sample):
return sample['feat'].size(0)


def feats_length_fn(sample) -> int:
assert 'feat' in sample
return sample['feat'].size(0)


def compute_mfcc(sample,
num_mel_bins=23,
frame_length=25,
Expand Down

0 comments on commit b115913

Please sign in to comment.