Skip to content

Commit

Permalink
fix mp in DataLoader (#2506) (#2507)
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao authored May 2, 2024
1 parent 675c037 commit 5576e6f
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
9 changes: 6 additions & 3 deletions test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.utils.data import datapipes
from torch.utils.data.datapipes.iter import IterableWrapper
import torch.multiprocessing as mp
from functools import partial

from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe,
Expand Down Expand Up @@ -108,9 +109,11 @@ def test_dynamic_batch_datapipe(data_list):
window_class=DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding)

dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=2)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
num_workers=2,
multiprocessing_context=mp.get_context("spawn"))
for d in dataloader:
assert d['feats'].size(1) <= max_frames_in_batch

Expand Down
11 changes: 7 additions & 4 deletions test/wenet/dataset/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
import torch.multiprocessing as mp
from wenet.dataset.dataset import Dataset
from wenet.text.char_tokenizer import CharTokenizer

Expand Down Expand Up @@ -54,9 +55,11 @@ def test_dataset(params):
data_list,
tokenizer=tokenizer,
conf=dataset_conf)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=4,
persistent_workers=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
num_workers=4,
persistent_workers=True,
multiprocessing_context=mp.get_context("spawn"))
for d in dataloader:
pass
5 changes: 4 additions & 1 deletion tools/compute_cmvn_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp


class CollateFunc(object):
Expand Down Expand Up @@ -107,12 +108,14 @@ def __getitem__(self, idx):
collate_func = CollateFunc(feat_dim, resample_rate)
dataset = AudioDataset(args.in_scp)
batch_size = 20
mp_context = mp.get_context("spawn") if args.num_workers > 0 else None
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
num_workers=args.num_workers,
collate_fn=collate_func)
collate_fn=collate_func,
multiprocessing_context=mp_context)

with torch.no_grad():
all_number = 0
Expand Down
5 changes: 4 additions & 1 deletion wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import yaml
from torch.utils.data import DataLoader
import torch.multiprocessing as mp

from wenet.dataset.dataset import Dataset
from wenet.utils.config import override_config
Expand Down Expand Up @@ -222,9 +223,11 @@ def main():
test_conf,
partition=False)

mp_context = mp.get_context("spawn") if args.num_workers > 0 else None
test_data_loader = DataLoader(test_dataset,
batch_size=None,
num_workers=args.num_workers)
num_workers=args.num_workers,
multiprocessing_context=mp_context)

# Init asr model from configs
args.jit = False
Expand Down
8 changes: 6 additions & 2 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -345,20 +346,23 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):

# NOTE(xcsong): Why we prefer persistent_workers=True ?
# https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
mp_context = mp.get_context("spawn") if args.num_workers > 0 else None
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
prefetch_factor=args.prefetch,
multiprocessing_context=mp_context)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
prefetch_factor=args.prefetch,
multiprocessing_context=mp_context)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader


Expand Down

0 comments on commit 5576e6f

Please sign in to comment.