diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py index f269788c9..f8a530559 100644 --- a/test/wenet/dataset/test_datapipes.py +++ b/test/wenet/dataset/test_datapipes.py @@ -4,8 +4,8 @@ from torch.utils.data.datapipes.iter import IterableWrapper from functools import partial -from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe, - WenetRawDatasetSource, +from wenet.dataset.datapipes import (InterlaveDataPipe, RepeatDatapipe, + SortDataPipe, WenetRawDatasetSource, WenetTarShardDatasetSource) from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding, parse_json, compute_fbank, @@ -224,3 +224,25 @@ def test_repeat(): assert len(result) == len(expected) all(h == r for h, r in zip(result, expected)) + + +def test_interleave(): + dataset_1 = IterableWrapper(range(10)) + dataset_2 = IterableWrapper(range(20, 30, 2)) + + dataset = InterlaveDataPipe([dataset_1, dataset_2]) + dataset = dataset.batch(4) + generator = torch.Generator() + generator.manual_seed(100) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=None, + num_workers=0, + generator=generator, + persistent_workers=False) + expected = [[0, 1, 2, 3], [4, 20, 5, 22], [24, 6, 7, 8], [26, 9, 28]] + + result = [] + for batch in dataloader: + result.append(batch) + + assert expected == result diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 6d89ab552..54127a821 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -18,7 +18,8 @@ import sys import tarfile import logging -from typing import List +from typing import List, Optional +import numpy as np import torch from torch.utils.data import IterDataPipe, functional_datapipe from torch.utils.data import datapipes @@ -302,6 +303,47 @@ def apply_sharding(self, num_of_instances: int, instance_id: int, self.instance_id = info.id +@functional_datapipe("interleave") +class InterlaveDataPipe(IterDataPipe): + + def __init__( + self, + source_datapipes: List[IterDataPipe], + weights: Optional[List[float]] = None, + seed=2027, + ): + super().__init__() + self.rng = np.random.default_rng(seed) + self.source_datapipes = source_datapipes + self.weights = weights + if weights is None: + self.weights = [1 / len(self.source_datapipes)] * len( + self.source_datapipes) + else: + self.weights = [weight / sum(weights) for weight in weights] + self.iters = None + + def __iter__(self): + weights = copy.deepcopy(self.weights) + exhausted = len(self.source_datapipes) * [False] + if self.iters is None: + self.iters = [(i, iter(d)) + for i, d in enumerate(self.source_datapipes)] + while True: + # TODO(Mddct): rng + index_iter = self.rng.choice(self.iters, p=weights) + i, ite = index_iter + try: + elem = next(ite) + yield elem + except StopIteration: + weights[i] = 0. + exhausted[i] = True + if all(exhausted): + return + weights = [weight / sum(weights) for weight in weights] + + class TextLineDataPipe(IterDataPipe): """ Streamming Text line """