diff --git a/clip/clip.py b/clip/clip.py index 257511e1d..535219421 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -2,7 +2,7 @@ import os import urllib import warnings -from typing import Any, Union, List +from typing import Any, Union, List, Sequence from pkg_resources import packaging import torch @@ -194,7 +194,7 @@ def patch_float(module): return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: +def tokenize(texts: Union[str, Sequence[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: """ Returns the tokenized representation of given input string(s)