Skip to content

Commit

Permalink
Merge pull request #371 from NexaAI/yifei/fix_nexa_models
Browse files Browse the repository at this point in the history
Fixing NexaVLMInference, NexaOmniVlmInference and NexaAudioLMInference
  • Loading branch information
Davidqian123 authored Feb 5, 2025
2 parents 156d40c + e1e6265 commit 7813e7d
Show file tree
Hide file tree
Showing 14 changed files with 142 additions and 66 deletions.
3 changes: 1 addition & 2 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,10 @@ def run_ggml_inference(args):
tts_engine = "outetts" if "OuteTTS" in local_path else "bark"
inference = NexaTTSInference(model_path=model_path, local_path=local_path, tts_engine=tts_engine, **kwargs)
elif run_type == "AudioLM":
from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
if is_local_path:
from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
inference = NexaAudioLMInference(model_path=model_path, local_path=local_path, projector_local_path=projector_local_path, **kwargs)
else:
from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
inference = NexaAudioLMInference(model_path=model_path, local_path=local_path, **kwargs)
else:
print(f"Unknown task: {run_type}. Skipping inference.")
Expand Down
17 changes: 16 additions & 1 deletion nexa/gguf/llama/audio_lm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _load_shared_library(lib_base_name: str, base_path: Path = None):
f"Shared library with base name '{lib_base_name}' not found"
)


def _get_lib(is_qwen: bool = True):
# Specify the base name of the shared library to load
_lib_base_name = "nexa-qwen2-audio-lib_shared" if is_qwen else "omni_audio_shared"
Expand All @@ -50,6 +51,7 @@ def _get_lib(is_qwen: bool = True):
)
return _load_shared_library(_lib_base_name, base_path)


# Initialize both libraries
_lib_omni = _get_lib(is_qwen=False)
_lib_qwen = _get_lib(is_qwen=True)
Expand All @@ -64,6 +66,8 @@ def _get_lib(is_qwen: bool = True):
# char *prompt;
# int32_t n_gpu_layers;
# };


class omni_context_params(ctypes.Structure):
_fields_ = [
("model", ctypes.c_char_p),
Expand All @@ -73,15 +77,20 @@ class omni_context_params(ctypes.Structure):
("n_gpu_layers", ctypes.c_int32),
]


omni_context_params_p = ctypes.POINTER(omni_context_params)
omni_context_p = ctypes.c_void_p

# OMNI_AUDIO_API omni_context_params omni_context_default_params();


def context_default_params(is_qwen: bool = True) -> omni_context_params:
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.omni_context_default_params()

# OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params &params);


def init_context(params: omni_context_params_p, is_qwen: bool = True) -> omni_context_p: # type: ignore
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.omni_init_context(params)
Expand All @@ -90,6 +99,8 @@ def init_context(params: omni_context_params_p, is_qwen: bool = True) -> omni_co
# struct omni_context *ctx_omni,
# omni_context_params &params
# );


def process_full(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bool = True): # type: ignore
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.omni_process_full(ctx, params)
Expand All @@ -110,10 +121,13 @@ def get_str(omni_streaming: ctypes.c_void_p, is_qwen: bool = True):
return _lib.get_str(omni_streaming)

# OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);


def free(ctx: omni_context_p, is_qwen: bool = True):
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.omni_free(ctx)


for lib in [_lib_omni, _lib_qwen]:
# Configure context_default_params
lib.omni_context_default_params.argtypes = []
Expand All @@ -127,7 +141,8 @@ def free(ctx: omni_context_p, is_qwen: bool = True):
lib.omni_process_full.argtypes = [omni_context_p, omni_context_params_p]
lib.omni_process_full.restype = ctypes.c_char_p

lib.omni_process_streaming.argtypes = [omni_context_p, omni_context_params_p]
lib.omni_process_streaming.argtypes = [
omni_context_p, omni_context_params_p]
lib.omni_process_streaming.restype = ctypes.c_void_p

lib.sample.argtypes = [ctypes.c_void_p]
Expand Down
3 changes: 2 additions & 1 deletion nexa/gguf/llama/llava_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@

# Specify the base name of the shared library to load
_libllava_base_name = "llava_shared"
_lib_subdir_name = "llama"
# Load the library
_libllava = load_library(_libllava_base_name)
_libllava = load_library(_libllava_base_name, _lib_subdir_name)

ctypes_function = ctypes_function_for_shared_library(_libllava)

Expand Down
6 changes: 4 additions & 2 deletions nexa/gguf/llama/omni_vlm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _load_shared_library(lib_base_name: str, base_path: Path = None):
f"Shared library with base name '{lib_base_name}' not found"
)

def _get_lib():
def _get_lib(sub_dir:str=""):
# Specify the base name of the shared library to load
_lib_base_name = "omni_vlm_wrapper_shared"
base_path = (
Expand All @@ -47,10 +47,12 @@ def _get_lib():
/ "gguf"
/ "lib"
)
if sub_dir != "":
base_path = base_path / sub_dir
return _load_shared_library(_lib_base_name, base_path)

# Initialize both libraries
_lib = _get_lib()
_lib = _get_lib('llama')

omni_char_p = ctypes.c_char_p

Expand Down
3 changes: 2 additions & 1 deletion nexa/gguf/nexa_inference_audio_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import librosa
import soundfile as sf
from pathlib import Path
from typing import Generator
from streamlit.web import cli as stcli
from nexa.utils import SpinningCursorAnimation, nexa_prompt
from nexa.constants import (
Expand Down Expand Up @@ -225,7 +226,7 @@ def inference(self, audio_path: str, prompt: str = "") -> str:
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")

def inference_streaming(self, audio_path: str, prompt: str = "") -> str:
def inference_streaming(self, audio_path: str, prompt: str = "") -> Generator[str, None, None]:
"""
Perform a single inference with the audio language model.
"""
Expand Down
6 changes: 3 additions & 3 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from urllib.parse import urlparse
import asyncio


from nexa.constants import (
NEXA_MODELS_HUB_OFFICIAL_DIR,
NEXA_OFFICIAL_MODELS_TYPE,
Expand Down Expand Up @@ -49,9 +50,6 @@
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr
from nexa.general import add_model_to_list, default_use_processes, download_file_with_progress, get_model_info, is_model_exists, pull_model
from nexa.gguf.llama.llama import Llama
# temporarily disabled NexaOmniVlmInference and NexaAudioLMInference
# from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference
# from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
from faster_whisper import WhisperModel
import numpy as np
import argparse
Expand Down Expand Up @@ -486,6 +484,7 @@ async def load_model():
# Therefore, model initialization is deferred until the text-to-speech API is called.
model = None
elif model_type == "Multimodal":
from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference
with suppress_stdout_stderr():
if 'omni' in model_path.lower():
try:
Expand Down Expand Up @@ -520,6 +519,7 @@ async def load_model():
)
logging.info(f"Model loaded as {model}")
elif model_type == "AudioLM":
from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
with suppress_stdout_stderr():
try:
model = NexaAudioLMInference(
Expand Down
43 changes: 38 additions & 5 deletions tests/test_image_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
from nexa.gguf import NexaImageInference
from tempfile import TemporaryDirectory
from tests.utils import download_model
import os
import subprocess


def download_model(url, output_dir):
"""
Download a file from a given URL using curl, if it doesn't already exist.
Args:
- url: str, the URL of the file to download.
- output_dir: str, the directory where the file should be saved.
Returns:
- str: The path to the downloaded file.
"""
file_name = url.split("/")[-1]
output_path = os.path.join(output_dir, file_name)

if os.path.exists(output_path):
print(
f"File {file_name} already exists in {output_dir}. Skipping download.")
return output_path

try:
subprocess.run(["curl", url, "--output", output_path], check=True)
print(f"Downloaded {file_name} to {output_dir}")
except subprocess.CalledProcessError as e:
print(f"Failed to download {file_name}: {e}")
raise

return output_path


sd = NexaImageInference(
model_path="sd1-4",
Expand All @@ -15,22 +46,24 @@ def test_txt_to_img():
output = sd.txt2img("a lovely cat", width=128, height=128, sample_steps=2)
output[0].save("output_txt_to_img.png")


# Test image-to-image generation
def test_img_to_img():

global sd
img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
with TemporaryDirectory() as temp_dir:
img_path = download_model(img_url, temp_dir)
output = sd.img2img(
image_path=img_path,
prompt="blue sky",
image_path=img_path,
prompt="blue sky",
width=128,
height=128,
negative_prompt="black soil",
sample_steps=2
)


# Main execution
if __name__ == "__main__":
test_txt_to_img()
Expand Down
18 changes: 12 additions & 6 deletions tests/test_structure_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
from nexa.gguf import NexaTextInference
from pathlib import Path


def test_structural_output():
nexa = NexaTextInference(model_path="Meta-Llama-3.1-8B-Instruct:q4_0")
prompt = "Emily Carter, a 32-year-old owner, drives a 2023 Audi Q5 available in White, Black, and Gray, equipped with Bang & Olufsen audio (19 speakers, Bluetooth), 10 airbags, 8 parking sensors, lane assist, and a turbocharged inline-4 engine delivering 261 horsepower and a top speed of 130 mph."
schema_abs_path = (Path(__file__).parent / "structure_decoding_resources/schema.json").resolve()
response = nexa.structure_output(prompt=prompt, json_schema_path=schema_abs_path)
schema_abs_path = (Path(__file__).parent /
"structure_decoding_resources/schema.json").resolve()
response = nexa.structure_output(
prompt=prompt, json_schema_path=schema_abs_path)
print(f"response: {json.dumps(response, indent=4)}")


def add_integer(num1, num2):
return num1, num2


def test_function_calling():

system_prompt = (
Expand All @@ -31,7 +36,7 @@ def test_function_calling():
# "type": "object",
# "properties": {
# "num1": {"type": "integer", "description": "An integer to add."},
# "num2": {"type": "integer", "description": "An integer to add."}
# "num2": {"type": "integer", "description": "An integer to add."}
# },
# "required": ["number"],
# "additionalProperties": False
Expand All @@ -55,7 +60,7 @@ def test_function_calling():
# "type": "object",
# "properties": {
# "num1": {"type": "integer", "description": "An input integer."},
# "num2": {"type": "integer", "description": "An input integer."}
# "num2": {"type": "integer", "description": "An input integer."}
# },
# "required": ["num1", "num2"],
# "additionalProperties": False
Expand Down Expand Up @@ -91,8 +96,9 @@ def test_function_calling():
{"role": "system", "content": system_prompt},
{"role": "user", "content": "What's the weather like in Paris today?"}
]

nexa = NexaTextInference(model_path="Meta-Llama-3.1-8B-Instruct:q4_0", function_calling=True)

nexa = NexaTextInference(
model_path="Meta-Llama-3.1-8B-Instruct:q4_0", function_calling=True)
response = nexa.function_calling(messages=messages, tools=tools)
print(response)

Expand Down
9 changes: 7 additions & 2 deletions tests/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
chat_format="llama-2",
)


# Test text generation from a prompt
def test_text_generation():
global model
Expand All @@ -22,6 +23,7 @@ def test_text_generation():
# print(output)
# TODO: add assertions here


# Test chat completion in streaming mode
def test_streaming():
global model
Expand All @@ -36,6 +38,7 @@ def test_streaming():
print(chunk["choices"][0]["text"], end="", flush=True)
# TODO: add assertions here


# Test conversation mode with chat format
def test_create_chat_completion():
global model
Expand All @@ -53,17 +56,19 @@ def test_create_chat_completion():
elif "content" in delta:
print(delta["content"], end="", flush=True)


def test_create_embedding():
model = NexaTextInference(
model_path="gemma",
verbose=False,
n_gpu_layers=-1 if is_gpu_available() else 0,
chat_format="llama-2",
embedding=True,
)
)
embeddings = model.create_embedding("Hello, world!")
print("Embeddings:\n", embeddings)


# Main execution
if __name__ == "__main__":
print("=== Testing 1 ===")
Expand All @@ -73,4 +78,4 @@ def test_create_embedding():
print("=== Testing 3 ===")
test_create_chat_completion()
print("=== Testing 4 ===")
test_create_embedding()
test_create_embedding()
13 changes: 7 additions & 6 deletions tests/test_tts_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from nexa.gguf import NexaTTSInference


def test_tts_generation_barkcpp():
tts = NexaTTSInference(
model_path="bark-small",
Expand All @@ -10,11 +11,11 @@ def test_tts_generation_barkcpp():
sampling_rate=24000,
verbosity=2
)

# Generate audio from prompt
prompt = "Hello, this is a test of the Bark text to speech system."
audio_data = tts.audio_generation(prompt)

# Save the generated audio
tts._save_audio(audio_data, tts.sampling_rate, "tts_output/barkcpp")

Expand All @@ -28,16 +29,16 @@ def test_tts_generation_outetts():
sampling_rate=24000,
verbosity=2
)

# Generate audio from prompt
prompt = "Hello, this is a test of the OuteTTS text to speech system."
audio_data = tts.audio_generation(prompt)

# Save the generated audio
tts._save_audio(audio_data, tts.sampling_rate, "tts_output/outetts")


if __name__ == "__main__":
test_tts_generation_barkcpp()
test_tts_generation_outetts()
print("TTS generation test completed successfully!")
print("TTS generation test completed successfully!")
Loading

0 comments on commit 7813e7d

Please sign in to comment.