Skip to content

Commit

Permalink
Added suppotr for Llama-3.2, and images for all olllama models (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
openvmp authored Oct 18, 2024
1 parent 49cdbc1 commit b4e64a1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion partcad/src/partcad/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"gemini-pro-vision",
"gemini-1.5-pro",
"gemini-1.5-flash",
"llama3.1*",
"llama3.*",
"codellama*",
"codegemma*",
"gemma*",
Expand Down
20 changes: 18 additions & 2 deletions partcad/src/partcad/ai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
# Licensed under Apache License, Version 2.0.
#

import base64
import importlib
import httpx
from path import Path
import re
import threading
import time
from typing import Any
Expand Down Expand Up @@ -70,8 +73,18 @@ def generate_ollama(
if not ollama_once():
return None

if "INSERT_IMAGE_HERE" in prompt:
raise NotImplementedError("Images are not supported by Ollama")
image_content = []

def insert_image(match):
filename = match.group(1)
image_index = len(image_content)
image_content.append(
Path(filename).read_bytes(),
# base64.b64encode(Path(filename).read_bytes()).decode(),
)
return f"The attached image number {image_index}.\n"

prompt = re.sub(r"INSERT_IMAGE_HERE\(([^)]*)\)", insert_image, prompt)

if "tokens" in config:
tokens = config["tokens"]
Expand Down Expand Up @@ -106,10 +119,13 @@ def generate_ollama(
top_k=top_k,
temperature=temperature,
)
pc_logging.debug("Prompt: %s" % prompt)
pc_logging.debug("Images: %d" % len(image_content))
response = ollama.generate(
model=model,
context=[], # do not accumulate context uncontrollably
prompt=prompt,
images=image_content,
options=options,
)
except httpx.ConnectError as e:
Expand Down
38 changes: 26 additions & 12 deletions partcad/src/partcad/part_factory_feature_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def on_init_ai(self):
constructor to finalize the AI initialization. At the time of the call
self.part and self.instantiate must be already defined."""
self.part.generate = lambda path: self._create_file(path)
self.part.change = lambda path, change=None: self._change_file(path, change)
self.part.change = lambda path, change=None: self._change_file(
path, change
)

# If uncommented out, this makes the package initialization
# unaccceptably slow
Expand Down Expand Up @@ -199,10 +201,8 @@ def _create_file(self, path):

# Validate the image by rendering it,
# attempt to correct the script if rendering doesn't work
image_filename, changed_script = (
self._validate_and_fix(
changed_script, candidate_id
)
image_filename, changed_script = self._validate_and_fix(
changed_script, candidate_id
)
# Check if the model was valid
if image_filename is not None:
Expand Down Expand Up @@ -239,6 +239,11 @@ def _change_file(self, path, change=None):

script = open(path, "r").read()

if change is not None:
change = change.strip()
if change == "":
change = None

image_filename, error_text = self._render_image(script, 0)
if image_filename is None or error_text:
pc_logging.error(
Expand All @@ -251,7 +256,9 @@ def _change_file(self, path, change=None):

# Attempt to change the script once more by comparing the result with
# the original request
changed_scripts = self._change_script(None, script, image_filename, change)
changed_scripts = self._change_script(
None, script, image_filename, change
)
for changed_script in changed_scripts:
pc_logging.debug(
"Generated the changed script: %s" % changed_script
Expand All @@ -278,7 +285,9 @@ def _change_file(self, path, change=None):
new_script = script_candidates[0][1]
else:
# Compare the images and select the best one
new_script = self.select_best_image(script_candidates, change=change)
new_script = self.select_best_image(
script_candidates, change=change
)

if new_script == script:
pc_logging.info("The script was not changed")
Expand Down Expand Up @@ -365,7 +374,7 @@ def _generate_script(self, csg_instructions):
"""This method generates a script given specific CSG description."""

prompt = """You are an AI assistant in an engineering department.
You are helping engineers to create programmatic scripts that produce CAD geometry data
You are helping engineers by writing scripts that produce CAD geometry data
for parts, mechanisms, buildings or anything else.
The scripts you create are fully functional and can be used right away, as is, in automated workflows.
Assume that the scripts you produce are used automatically to render 3D models and to validate them.
Expand Down Expand Up @@ -415,7 +424,9 @@ def _generate_script(self, csg_instructions):

return scripts

def _change_script(self, csg_instructions, script, rendered_image, change=None):
def _change_script(
self, csg_instructions, script, rendered_image, change=None
):
"""This method changes the script given the original request and the produced script."""

config = copy.copy(self.ai_config)
Expand Down Expand Up @@ -469,7 +480,7 @@ def _change_script(self, csg_instructions, script, rendered_image, change=None):
if change is not None:
prompt += f"\n\n{change}\n"
else:
prompt += """
prompt += """
Please, analyze whether the produced script and image match the original request
(where the original image and description take precedence
Expand Down Expand Up @@ -753,12 +764,15 @@ def select_best_image(self, script_candidates, change=None):
prompt += "INSERT_IMAGE_HERE(%s)\n" % image_filename

if change is not None:
prompt += """
prompt += (
"""
Subsequently, the following changes were requested (until "CHANGE END"):
%s
CHANGE END
""" % change
"""
% change
)

prompt += """
Expand Down

0 comments on commit b4e64a1

Please sign in to comment.