diff --git a/trankit/adapter_transformers/tokenization_utils.py b/trankit/adapter_transformers/tokenization_utils.py index 3bb91c0..663ddde 100644 --- a/trankit/adapter_transformers/tokenization_utils.py +++ b/trankit/adapter_transformers/tokenization_utils.py @@ -924,7 +924,15 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs) if pretrained_model_name_or_path in s3_models: # Get the vocabulary from AWS S3 bucket for file_id, map_list in cls.pretrained_vocab_files_map.items(): - vocab_files[file_id] = map_list[pretrained_model_name_or_path] + basename_from_s3 = os.path.basename(map_list[pretrained_model_name_or_path]) + model_prefix = slice(len(pretrained_model_name_or_path) + 1, None) + fname = basename_from_s3 [model_prefix] # filename as mentioned on hf.co/models + model_folder_exists = os.path.isdir(pretrained_model_name_or_path) + fname_exists = os.path.exists(os.path.join(pretrained_model_name_or_path, fname)) + if model_folder_exists and fname_exists: + vocab_files[file_id] = os.path.join(pretrained_model_name_or_path, fname) + else: + vocab_files[file_id] = map_list[pretrained_model_name_or_path] if ( cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration