Skip to content

Commit

Permalink
Update to use vertexai package
Browse files Browse the repository at this point in the history
  • Loading branch information
sydp committed Dec 19, 2024
1 parent 1abb28c commit 326d6d7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import backoff
from google.api_core import exceptions
import google.generativeai as genai
from google.cloud import aiplatform
import ratelimit
import vertexai
from vertexai import generative_models

from dftimewolf.lib.processors.llmproviders import interface
from dftimewolf.lib.processors.llmproviders import manager
Expand Down Expand Up @@ -37,29 +37,43 @@ class VertexAILLMProvider(interface.LLMProvider):
def __init__(self) -> None:
"""Initializes the VertexAILLMProvider."""
super().__init__()
self.chat_session: genai.ChatSession | None = None
self.chat_session: generative_models.ChatSession | None = None
self._configure()

def _configure(self) -> None:
"""Configures the genai client."""
if 'api_key' in self.options:
aiplatform.init(api_key=self.options['api_key'])
vertexai.init(api_key=self.options['api_key'])
elif 'project_id' in self.options or 'region' in self.options:
aiplatform.init(
vertexai.init(
project=self.options['project_id'],
location=self.options['region']
)
elif os.environ.get('GOOGLE_API_KEY'):
aiplatform.init(api_key=os.environ.get('GOOGLE_API_KEY'))
vertexai.init(api_key=os.environ.get('GOOGLE_API_KEY'))
else:
raise RuntimeError('API key or project_id/region must be set.')

def _get_model(self, model: str) -> genai.GenerativeModel:
def _get_model(
self,
model: str
) -> generative_models.GenerativeModel:
"""Returns the generative model."""
model_name = f"models/{model}"
generation_config = self.models[model]['options'].get('generative_config')
safety_settings = self.models[model]['options'].get('safety_settings')
return genai.GenerativeModel(
safety_settings = [
generative_models.SafetySetting(
category=generative_models.HarmCategory[
safety_setting['category']
],
threshold=generative_models.HarmBlockThreshold[
safety_setting['threshold']
]
) for safety_setting in (
self.models[model]['options'].get('safety_settings')
)
]
return generative_models.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings
Expand Down Expand Up @@ -101,8 +115,6 @@ def Generate(self, prompt: str, model: str, **kwargs: str) -> str:
genai_model = self._get_model(model)
try:
response = genai_model.generate_content(contents=prompt, **kwargs)
except genai.types.generation_types.StopCandidateException as e:
return f"VertexAI LLM response was stopped because of: {e}"
except Exception as e:
log.warning("Exception while calling VertexAI: %s", e)
raise
Expand Down Expand Up @@ -134,8 +146,6 @@ def GenerateWithHistory(self, prompt: str, model: str, **kwargs: str) -> str:
self.chat_session = self._get_model(model).start_chat()
try:
response = self.chat_session.send_message(prompt, **kwargs)
except genai.types.generation_types.StopCandidateException as e:
return f"VertexAI LLM response was stopped because of: {e}"
except Exception as e:
log.warning("Exception while calling VertexAI: %s", e)
raise
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ grr-api-client = "^3.4.7"
libcloudforensics = {git = "https://github.com/google/cloud-forensics-utils.git"}
docker = "^7.0.0"
setuptools = "^70.0.0" # needed by docker
google-cloud-aiplatform = "^1.74.0"
google-generativeai = "^0.8.3"
ratelimit = "^2.2.1"
backoff = "^2.2.1"
vertexai = "^1.71.1"

[tool.poetry.extras]
turbinia_legacy = ["turbinia"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from unittest import mock

from dftimewolf import config
from dftimewolf.lib.processors.llmproviders import vertexai
from dftimewolf.lib.processors.llmproviders import vertex_ai


GENAI_CONFIG = {
VERTEX_AI_CONFIG = {
'llm_providers': {
'vertexai': {
'options': {
Expand All @@ -22,12 +22,7 @@
'temperature': 0.2,
'max_output_tokens': 8192,
},
'safety_settings': [
{
'category': 'example_category',
'threshold': 'BLOCK_NONE',
}
]
'safety_settings': []
},
'tasks': [
'generate'
Expand All @@ -42,7 +37,7 @@
class VertexAILLMProviderTest(unittest.TestCase):
"""Test for the VertexAILLMProvider."""

@mock.patch('google.cloud.aiplatform.init')
@mock.patch('vertexai.init')
def test_configure_api_key(self, mock_gen_config):
"""Tests the configuration with an API key."""
config.Config.LoadExtraData(json.dumps(
Expand All @@ -58,12 +53,12 @@ def test_configure_api_key(self, mock_gen_config):
}
}
).encode('utf-8'))
provider = vertexai.VertexAILLMProvider()
provider = vertex_ai.VertexAILLMProvider()

self.assertEqual(provider.options['api_key'], 'test_api_key')
mock_gen_config.assert_called_with(api_key='test_api_key')

@mock.patch('google.cloud.aiplatform.init')
@mock.patch('vertexai.init')
def test_configure_project_id_region(self, mock_gen_config):
"""Tests the configuration with a project ID/region"""
config.Config.LoadExtraData(json.dumps(
Expand All @@ -80,7 +75,7 @@ def test_configure_project_id_region(self, mock_gen_config):
}
}
).encode('utf-8'))
provider = vertexai.VertexAILLMProvider()
provider = vertex_ai.VertexAILLMProvider()

self.assertEqual(provider.options['project_id'], 'myproject')
self.assertEqual(provider.options['region'], 'australia-southeast2')
Expand All @@ -89,16 +84,16 @@ def test_configure_project_id_region(self, mock_gen_config):
)

@mock.patch.dict(
vertexai.os.environ,
vertex_ai.os.environ,
values={'GOOGLE_API_KEY': 'fake_env_key'},
clear=True)
@mock.patch('google.cloud.aiplatform.init')
@mock.patch('vertexai.init')
def test_configure_env(self, mock_gen_config):
"""Tests the configuration with a environment variable."""
config.Config.LoadExtraData(json.dumps(
{'llm_providers': {'vertexai': {'options': {},'models': {}}}}
).encode('utf-8'))
provider = vertexai.VertexAILLMProvider()
provider = vertex_ai.VertexAILLMProvider()
self.assertIsNotNone(provider)
mock_gen_config.assert_called_with(api_key='fake_env_key')

Expand All @@ -109,19 +104,19 @@ def test_configure_empty(self):
).encode('utf-8'))
with self.assertRaisesRegex(
RuntimeError, 'API key or project_id/region must be set'):
_ = vertexai.VertexAILLMProvider()
_ = vertex_ai.VertexAILLMProvider()


@mock.patch('google.cloud.aiplatform.init')
@mock.patch('google.generativeai.GenerativeModel', autospec=True)
@mock.patch('vertexai.init')
@mock.patch('vertexai.generative_models.GenerativeModel', autospec=True)
def test_generate(self, mock_gen_model, mock_gen_config):
"""Tests the generate method."""
mock_gen_model.return_value.generate_content.return_value.text = (
'test generate'
)

config.Config.LoadExtraData(json.dumps(GENAI_CONFIG).encode('utf-8'))
provider = vertexai.VertexAILLMProvider()
config.Config.LoadExtraData(json.dumps(VERTEX_AI_CONFIG).encode('utf-8'))
provider = vertex_ai.VertexAILLMProvider()
resp = provider.Generate(prompt='123', model='fake-gemini')

self.assertEqual(resp, 'test generate')
Expand All @@ -134,24 +129,19 @@ def test_generate(self, mock_gen_model, mock_gen_config):
'temperature': 0.2,
'max_output_tokens': 8192,
},
safety_settings=[
{
'category': 'example_category',
'threshold': 'BLOCK_NONE',
}
]
safety_settings=[]
)

@mock.patch('google.cloud.aiplatform.init')
@mock.patch('google.generativeai.GenerativeModel', autospec=True)
@mock.patch('vertexai.init')
@mock.patch('vertexai.generative_models.GenerativeModel', autospec=True)
def test_generate_with_history(self, mock_gen_model, mock_gen_config):
"""Tests the GenerateWithHistory method."""
chat_instance = mock.MagicMock()
mock_gen_model.return_value.start_chat.return_value = chat_instance
chat_instance.send_message.return_value.text = 'test generate'
config.Config.LoadExtraData(json.dumps(GENAI_CONFIG).encode('utf-8'))
config.Config.LoadExtraData(json.dumps(VERTEX_AI_CONFIG).encode('utf-8'))

provider = vertexai.VertexAILLMProvider()
provider = vertex_ai.VertexAILLMProvider()
resp = provider.GenerateWithHistory(prompt='123', model='fake-gemini')
self.assertEqual(resp, 'test generate')
mock_gen_config.assert_called_once_with(
Expand All @@ -163,12 +153,7 @@ def test_generate_with_history(self, mock_gen_model, mock_gen_config):
'temperature': 0.2,
'max_output_tokens': 8192,
},
safety_settings=[
{
'category': 'example_category',
'threshold': 'BLOCK_NONE',
}
]
safety_settings=[]
)


Expand Down

0 comments on commit 326d6d7

Please sign in to comment.