Skip to content

Commit

Permalink
Added HF MM LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
dhuynh-lovable committed May 18, 2024
1 parent 0b67c8a commit fec8e9e
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions lavague-core/lavague/core/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,62 @@ def get_instruction(self, state: str, objective: str) -> str:
Thought:
"""
)


from text_generation import Client
import os
import cloudinary
import cloudinary.uploader
from llama_index.core.base.llms.types import CompletionResponse

BASE_URL = "https://api-inference.huggingface.co/models/"
BASE_MODEL= "HuggingFaceM4/idefics2-8b"

class HuggingFaceMMLLM:
def __init__(self, api_key=None, cloudinary_config=None, model=BASE_MODEL, base_url = BASE_URL):
if api_key is None:
api_key = os.getenv("HF_TOKEN")
if api_key is None:
raise ValueError("HF_TOKEN is not set")

if cloudinary_config is None:
cloudinary_config = {
"cloud_name": os.getenv("CLOUDINARY_CLOUD_NAME"),
"api_key": os.getenv("CLOUDINARY_API_KEY"),
"api_secret": os.getenv("CLOUDINARY_API_SECRET"),
}
if None in cloudinary_config.values():
raise ValueError("CLOUDINARY_CLOUD_NAME, CLOUDINARY_API_KEY, or CLOUDINARY_API_SECRET is not set")

cloudinary.config(**cloudinary_config)
api_url = base_url + model

self.client = Client(
base_url=api_url,
headers={"x-use-cache": "0", "Authorization": f"Bearer {api_key}"},
)

def upload_image(self, image_documents):
n_documents = len(image_documents)
if n_documents != 1:
raise ValueError(f"Expected 1 image document, but got {n_documents}")
file_path = image_documents[0].metadata["file_path"]

img_url = cloudinary.uploader.upload(file_path)["url"]
return img_url

def complete(self, prompt, image_documents):

generation_args = {
"max_new_tokens": 512,
"repetition_penalty": 1.1,
"do_sample": False,
}

img_url = self.upload_image(image_documents)
prompt_with_image = f"User:![]({img_url})" + prompt
print(prompt_with_image)
generated_text = self.client.generate(prompt=prompt_with_image, **generation_args).generated_text

output = CompletionResponse(text=generated_text)
return output

0 comments on commit fec8e9e

Please sign in to comment.