diff --git a/lavague-core/lavague/core/world_model.py b/lavague-core/lavague/core/world_model.py index 783fbf57..907e64a1 100644 --- a/lavague-core/lavague/core/world_model.py +++ b/lavague-core/lavague/core/world_model.py @@ -92,3 +92,62 @@ def get_instruction(self, state: str, objective: str) -> str: Thought: """ ) + + +from text_generation import Client +import os +import cloudinary +import cloudinary.uploader +from llama_index.core.base.llms.types import CompletionResponse + +BASE_URL = "https://api-inference.huggingface.co/models/" +BASE_MODEL= "HuggingFaceM4/idefics2-8b" + +class HuggingFaceMMLLM: + def __init__(self, api_key=None, cloudinary_config=None, model=BASE_MODEL, base_url = BASE_URL): + if api_key is None: + api_key = os.getenv("HF_TOKEN") + if api_key is None: + raise ValueError("HF_TOKEN is not set") + + if cloudinary_config is None: + cloudinary_config = { + "cloud_name": os.getenv("CLOUDINARY_CLOUD_NAME"), + "api_key": os.getenv("CLOUDINARY_API_KEY"), + "api_secret": os.getenv("CLOUDINARY_API_SECRET"), + } + if None in cloudinary_config.values(): + raise ValueError("CLOUDINARY_CLOUD_NAME, CLOUDINARY_API_KEY, or CLOUDINARY_API_SECRET is not set") + + cloudinary.config(**cloudinary_config) + api_url = base_url + model + + self.client = Client( + base_url=api_url, + headers={"x-use-cache": "0", "Authorization": f"Bearer {api_key}"}, + ) + + def upload_image(self, image_documents): + n_documents = len(image_documents) + if n_documents != 1: + raise ValueError(f"Expected 1 image document, but got {n_documents}") + file_path = image_documents[0].metadata["file_path"] + + img_url = cloudinary.uploader.upload(file_path)["url"] + return img_url + + def complete(self, prompt, image_documents): + + generation_args = { + "max_new_tokens": 512, + "repetition_penalty": 1.1, + "do_sample": False, + } + + img_url = self.upload_image(image_documents) + prompt_with_image = f"User:![]({img_url})" + prompt + print(prompt_with_image) + generated_text = self.client.generate(prompt=prompt_with_image, **generation_args).generated_text + + output = CompletionResponse(text=generated_text) + return output \ No newline at end of file