diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 1283204cea..6e77b67bb3 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -2897,6 +2897,7 @@ def get_return_object(self, responses, return_meta_data): "watsonx-sdk", "rits", "azure", + "vertex-ai", ] @@ -2911,7 +2912,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): user requests. Current _supported_apis = ["watsonx", "together-ai", "open-ai", "aws", "ollama", - "bam", "watsonx-sdk", "rits"] + "bam", "watsonx-sdk", "rits", "vertex-ai"] Args: provider (Optional): @@ -3020,6 +3021,11 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "gpt-3.5-turbo-16k-0613": "azure/gpt-3.5-turbo-16k-0613", "gpt-4-vision": "azure/gpt-4-vision", }, + "vertex-ai": { + "llama-3-1-8b-instruct": "vertex_ai/meta/llama-3.1-8b-instruct-maas", + "llama-3-1-70b-instruct": "vertex_ai/meta/llama-3.1-70b-instruct-maas", + "llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas", + }, } _provider_to_base_class = { @@ -3032,6 +3038,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "watsonx-sdk": WMLInferenceEngine, "rits": RITSInferenceEngine, "azure": LiteLLMInferenceEngine, + "vertex-ai": LiteLLMInferenceEngine, } _provider_param_renaming = {