Skip to content

Commit

Permalink
docs: add docstring to WebDataloader, typehints to _init_modality
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkaczer committed Oct 7, 2024
1 parent 8312ed6 commit 152ebf2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
14 changes: 14 additions & 0 deletions src/modalities/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def __len__(self) -> int:


class WebDataLoader(DataLoaderIF):
"""WebDataLoader is a custom DataLoader class that wraps the webdataset.WebLoader class."""

def __init__(
self,
dataloader_tag: str,
Expand All @@ -299,6 +301,18 @@ def __init__(
pin_memory: bool = False,
drop_last: bool = False,
):
"""Initializes WebDataLoader, which is a wrapper for webdataset.WebLoader.
Args:
dataloader_tag (str): The tag for the dataloader.
dataset (Dataset[T_co]): The dataset to load the data from.
batch_size (Optional[int], optional): The batch size. Defaults to 1.
num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 0.
collate_fn (Optional[_collate_fn_t], optional): The function used to collate the data samples.
Defaults to None.
pin_memory (bool, optional): Flag indicating whether to pin the memory. Defaults to False.
drop_last (bool, optional): Flag indicating whether to drop the last incomplete batch. Defaults to False.
"""
self.num_batches = len(dataset) // batch_size + int(not drop_last)
dataset = dataset.batched(batch_size, collation_fn=collate_fn)
self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory)
Expand Down
5 changes: 4 additions & 1 deletion src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ def __init__(
# Logit scale for contrastive loss
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

def _init_modality(self, encoder_class, encoder_config, n_queries):
def _init_modality(
self, encoder_class: type, encoder_config: VisionTransformerConfig | AudioTransformerConfig, n_queries: int
) -> tuple[VisionTransformer | AudioTransformer, nn.Parameter, AttentionPooling]:
# initialize modality encoder, returns a tuple containing the encoder, queries and attention pooling layer
encoder = encoder_class(**dict(encoder_config))
queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd))
attn_pool = AttentionPooling(
Expand Down

0 comments on commit 152ebf2

Please sign in to comment.