From 3302ee7311904c56b4e029fa68e71c2cbba512f1 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 13 Aug 2024 21:08:50 +0800 Subject: [PATCH 1/2] addd interleave dataset --- wenet/dataset/datapipes.py | 41 +++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 6d89ab552..bbf715038 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -18,7 +18,7 @@ import sys import tarfile import logging -from typing import List +from typing import List, Optional import torch from torch.utils.data import IterDataPipe, functional_datapipe from torch.utils.data import datapipes @@ -28,6 +28,7 @@ from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from wenet.dataset.processor import parse_url +import random @functional_datapipe("map_ignore_error") @@ -302,6 +303,44 @@ def apply_sharding(self, num_of_instances: int, instance_id: int, self.instance_id = info.id +@functional_datapipe("interleave") +class InterlaveDataPipe(IterDataPipe): + + def __init__( + self, + source_datapipes: List[IterDataPipe], + weights: Optional[List[float]] = None, + ): + super().__init__() + self.source_datapipes = source_datapipes + self.weights = weights + if weights is None: + self.weights = [1 / len(self.source_datapipes)] * len( + self.source_datapipes) + else: + self.weights = [weight / sum(weights) for weight in weights] + self.iters = None + + def __iter__(self): + exhausted = len(self.source_datapipes) * [False] + if self.iters is None: + self.iters = [(i, iter(d)) + for i, d in enumerate(self.source_datapipes)] + + while True: + # TODO(Mddct): rng + index_iter = random.choices(self.iters, self.weights)[0] + i, ite = index_iter + try: + elem = next(ite) + yield elem + except StopIteration: + self.weights[i] = 0. + exhausted[i] = True + if all(exhausted): + return + + class TextLineDataPipe(IterDataPipe): """ Streamming Text line """ From 60b25bdfd551f68934fe8931bf5f8d4c76ec975d Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 13 Aug 2024 23:20:37 +0800 Subject: [PATCH 2/2] add ut --- test/wenet/dataset/test_datapipes.py | 26 ++++++++++++++++++++++++-- wenet/dataset/datapipes.py | 11 +++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py index f269788c9..f8a530559 100644 --- a/test/wenet/dataset/test_datapipes.py +++ b/test/wenet/dataset/test_datapipes.py @@ -4,8 +4,8 @@ from torch.utils.data.datapipes.iter import IterableWrapper from functools import partial -from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe, - WenetRawDatasetSource, +from wenet.dataset.datapipes import (InterlaveDataPipe, RepeatDatapipe, + SortDataPipe, WenetRawDatasetSource, WenetTarShardDatasetSource) from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding, parse_json, compute_fbank, @@ -224,3 +224,25 @@ def test_repeat(): assert len(result) == len(expected) all(h == r for h, r in zip(result, expected)) + + +def test_interleave(): + dataset_1 = IterableWrapper(range(10)) + dataset_2 = IterableWrapper(range(20, 30, 2)) + + dataset = InterlaveDataPipe([dataset_1, dataset_2]) + dataset = dataset.batch(4) + generator = torch.Generator() + generator.manual_seed(100) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=None, + num_workers=0, + generator=generator, + persistent_workers=False) + expected = [[0, 1, 2, 3], [4, 20, 5, 22], [24, 6, 7, 8], [26, 9, 28]] + + result = [] + for batch in dataloader: + result.append(batch) + + assert expected == result diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index bbf715038..54127a821 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -19,6 +19,7 @@ import tarfile import logging from typing import List, Optional +import numpy as np import torch from torch.utils.data import IterDataPipe, functional_datapipe from torch.utils.data import datapipes @@ -28,7 +29,6 @@ from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from wenet.dataset.processor import parse_url -import random @functional_datapipe("map_ignore_error") @@ -310,8 +310,10 @@ def __init__( self, source_datapipes: List[IterDataPipe], weights: Optional[List[float]] = None, + seed=2027, ): super().__init__() + self.rng = np.random.default_rng(seed) self.source_datapipes = source_datapipes self.weights = weights if weights is None: @@ -322,23 +324,24 @@ def __init__( self.iters = None def __iter__(self): + weights = copy.deepcopy(self.weights) exhausted = len(self.source_datapipes) * [False] if self.iters is None: self.iters = [(i, iter(d)) for i, d in enumerate(self.source_datapipes)] - while True: # TODO(Mddct): rng - index_iter = random.choices(self.iters, self.weights)[0] + index_iter = self.rng.choice(self.iters, p=weights) i, ite = index_iter try: elem = next(ite) yield elem except StopIteration: - self.weights[i] = 0. + weights[i] = 0. exhausted[i] = True if all(exhausted): return + weights = [weight / sum(weights) for weight in weights] class TextLineDataPipe(IterDataPipe):