diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index ac62eec4..b083e057 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -289,6 +289,8 @@ def __len__(self) -> int: class WebDataLoader(DataLoaderIF): + """WebDataLoader is a custom DataLoader class that wraps the webdataset.WebLoader class.""" + def __init__( self, dataloader_tag: str, @@ -299,6 +301,18 @@ def __init__( pin_memory: bool = False, drop_last: bool = False, ): + """Initializes WebDataLoader, which is a wrapper for webdataset.WebLoader. + + Args: + dataloader_tag (str): The tag for the dataloader. + dataset (Dataset[T_co]): The dataset to load the data from. + batch_size (Optional[int], optional): The batch size. Defaults to 1. + num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 0. + collate_fn (Optional[_collate_fn_t], optional): The function used to collate the data samples. + Defaults to None. + pin_memory (bool, optional): Flag indicating whether to pin the memory. Defaults to False. + drop_last (bool, optional): Flag indicating whether to drop the last incomplete batch. Defaults to False. + """ self.num_batches = len(dataset) // batch_size + int(not drop_last) dataset = dataset.batched(batch_size, collation_fn=collate_fn) self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index e551cb17..3d4aa759 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -309,7 +309,10 @@ def __init__( # Logit scale for contrastive loss self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - def _init_modality(self, encoder_class, encoder_config, n_queries): + def _init_modality( + self, encoder_class: type, encoder_config: VisionTransformerConfig | AudioTransformerConfig, n_queries: int + ) -> tuple[VisionTransformer | AudioTransformer, nn.Parameter, AttentionPooling]: + # initialize modality encoder, returns a tuple containing the encoder, queries and attention pooling layer encoder = encoder_class(**dict(encoder_config)) queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd)) attn_pool = AttentionPooling(