Skip to content

Commit

Permalink
test: add test for webdataset dataset and dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
sthoduka committed Oct 1, 2024
1 parent b9bfaea commit d6da583
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 0 deletions.
140 changes: 140 additions & 0 deletions tests/dataloader/test_webdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import io
import tarfile
from pathlib import Path

import numpy as np
import pytest
import torch
import torchaudio
import webdataset as wds
from pydantic import BaseModel

from modalities.__main__ import load_app_config_dict
from modalities.config.component_factory import ComponentFactory
from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from tests.conftest import _ROOT_DIR


def create_image_sample():
img = np.random.randint(0, 255, size=(224, 224, 3)).astype(np.uint8)
img = wds.writer.imageencoder(img, format="JPG")
text = {"text0": "this is an image caption %d" % np.random.randint(10)}
return img, text


@pytest.fixture(scope="session")
def image_tar_path(tmp_path_factory):
data_path = str(tmp_path_factory.mktemp("data") / "images.tar")
dataset_sink = wds.TarWriter(data_path)
# 10 image samples
for idx in range(10):
img, text = create_image_sample()
dataset_sink.write(
{
"__key__": "%02d" % idx,
"jpg": img,
"json": text,
}
)
dataset_sink.close()
return data_path


def create_audio_sample():
sample_rate = 16000
audio = torch.from_numpy(np.random.uniform(-1, 1, sample_rate)).unsqueeze(0)
audio_buf = io.BytesIO()
torchaudio.save(audio_buf, audio, sample_rate, format="wav")
audio_buf.seek(0)
text = "this is an audio caption %d" % np.random.randint(10)
text_f = io.BytesIO()
text_f.write(text.encode("utf-8"))
text_f.seek(0)
return audio_buf, text_f


@pytest.fixture(scope="session")
def audio_tar_path(tmp_path_factory):
data_path = str(tmp_path_factory.mktemp("data") / "audio.tar")
with tarfile.open(data_path, "w") as fp:
# 25 audio samples
for idx in range(25):
key = "%02d" % idx
wav, text = create_audio_sample()
info = tarfile.TarInfo(key + ".wav")
info.size = wav.getbuffer().nbytes
fp.addfile(info, wav)
info = tarfile.TarInfo(key + ".transcript.txt")
info.size = text.getbuffer().nbytes
fp.addfile(info, text)
return data_path


@pytest.mark.parametrize(
"mixing_ratios,resample,batch_size",
[
([0.9, 0.1], False, 10), # we run out of image samples after the second batch
([0.9, 0.1], True, 10), # since we resample, there are enough samples for >2 batches
([0.7, 0.3], False, 20), # the first batch won't have 0.7*20 samples
([0.3, 0.6], False, 10), # ratios don't add up to 1
([0.8, 0.2], True, 100),
],
)
def test_web_dataloader(image_tar_path, audio_tar_path, mixing_ratios, resample, batch_size):
class DataloaderTestModel(BaseModel):
train_dataloader: PydanticDataLoaderIFType

config_file_path = _ROOT_DIR / Path("tests/dataloader/yaml_configs/web_dataloader.yaml")
config_dict = load_app_config_dict(config_file_path=config_file_path)
config_dict["image_dataset"]["config"]["urls"] = image_tar_path
config_dict["audio_dataset"]["config"]["urls"] = audio_tar_path
config_dict["train_dataset"]["config"]["mixing_ratios"] = mixing_ratios
config_dict["train_dataset"]["config"]["resample"] = resample
config_dict["train_dataset"]["config"]["batch_size"] = batch_size
config_dict["train_dataloader"]["config"]["batch_size"] = batch_size
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
components = component_factory.build_components(config_dict=config_dict, components_model_type=DataloaderTestModel)

expected_images = int(mixing_ratios[0] * batch_size)
expected_audio = int(mixing_ratios[1] * batch_size)
# if ratios don't add up to 1, extra samples are added to first modality
remaining = batch_size - (expected_audio + expected_images)
expected_images += remaining

loader = iter(components.train_dataloader)

# image, audio
total_samples = [10, 25]
seen_samples = [0, 0]

for idx in range(5):
batch_expected_images = expected_images
batch_expected_audio = expected_audio
try:
batch = next(loader)
except StopIteration:
break

if not resample:
# if resample is False, the last batch may have less
# samples than expected if one of the modalities
# runs out of samples
if total_samples[0] - seen_samples[0] < expected_images:
expected_images - (total_samples[0] - seen_samples[0])
batch_expected_images = total_samples[0] - seen_samples[0]
if total_samples[1] - seen_samples[1] < expected_audio:
expected_audio - (total_samples[1] - seen_samples[1])
batch_expected_audio = total_samples[1] - seen_samples[1]

assert batch.samples["images"].shape[0] == batch_expected_images
seen_samples[0] += batch.samples["images"].shape[0]
assert batch.samples["audio"].shape[0] == batch_expected_audio
seen_samples[1] += batch.samples["audio"].shape[0]
assert batch.samples["input_ids"].shape[0] == batch_expected_audio + batch_expected_images
for idx in range(2):
# reset if the complete dataset has been seen already
if seen_samples[idx] == total_samples[idx]:
seen_samples[idx] = 0
111 changes: 111 additions & 0 deletions tests/dataloader/yaml_configs/web_dataloader.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
tokenizer:
component_key: tokenizer
variant_key: pretrained_hf_tokenizer
config:
pretrained_model_name_or_path: openai/clip-vit-base-patch32
padding: true
max_length: 50

train_image_transform:
component_key: transform
variant_key: image_transform
config:
is_training: True
input_size: 224

train_audio_transform:
component_key: transform
variant_key: audio_transform
config:
is_training: True
block_size_audio_encoder: 500
freq_domain_mask_length: 30
time_domain_mask_length: 100

text_transform:
component_key: transform
variant_key: text_transform
config:
tokenizer:
instance_key: tokenizer
pass_type: BY_REFERENCE

collate_fn:
component_key: collate_fn
variant_key: coca_collator
config:
sample_keys:
- images
- audio
- audio_len
- input_ids
target_keys: []
text_sample_key: input_ids
text_target_key: logits

image_dataset:
component_key: dataset
variant_key: web_dataset_builder
config:
urls: None
modality_key_mapping:
TEXT: ["json_text0", "input_ids"]
IMAGE: ["jpg", "images"]
modality_transforms:
IMAGE:
instance_key: train_image_transform
pass_type: BY_REFERENCE
TEXT:
instance_key: text_transform
pass_type: BY_REFERENCE
num_samples: 10

audio_dataset:
component_key: dataset
variant_key: web_dataset_builder
config:
urls: None
modality_key_mapping:
TEXT: ["transcript.txt", "input_ids"] # source and target keys
AUDIO: ["wav", "audio"]
modality_transforms:
AUDIO:
instance_key: train_audio_transform
pass_type: BY_REFERENCE
TEXT:
instance_key: text_transform
pass_type: BY_REFERENCE
num_samples: 10


train_dataset:
component_key: dataset
variant_key: web_dataset
config:
builders:
- instance_key: image_dataset
pass_type: BY_REFERENCE
- instance_key: audio_dataset
pass_type: BY_REFERENCE
mixing_ratios: [0.9, 0.1]
batch_size: 10
shardshuffle: 100
repeat: false
resample: false
shuffle_buffer: 10_000

train_dataloader:
component_key: data_loader
variant_key: web_dataloader
config:
num_workers: 0
pin_memory: true
drop_last: true
dataloader_tag: "train"
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
batch_size: 10
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

0 comments on commit d6da583

Please sign in to comment.