Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rectalogic committed Oct 2, 2024
1 parent c13787b commit 22221bf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions llm_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from transformers.pipelines import Pipeline, check_task, get_supported_tasks
from transformers.utils import get_available_devices

log = logging.getLogger(__name__)

TASK_BLACKLIST = (
"feature-extraction",
"image-feature-extraction",
Expand Down Expand Up @@ -50,7 +52,6 @@ def save_audio(audio: numpy.ndarray, sample_rate: int, output: pathlib.Path | No
def save(f: ta.BinaryIO) -> None:
# musicgen is shape (batch_size, num_channels, sequence_length)
# https://huggingface.co/docs/transformers/v4.45.1/en/model_doc/musicgen#unconditional-generation
# XXX check shape of other audio pipelines
sf.write(f, audio[0].T, sample_rate)

if output is None:
Expand Down Expand Up @@ -380,9 +381,8 @@ def handle_result(
}:
response.response_json = {task: result}
yield "\n".join(f"{label} ({score})" for label, score in zip(labels, scores, strict=True))
case _, _:
breakpoint() # XXX log an error and try json
print("DEFAULT CASE") # XXX
case str(task), _:
log.error("Unhandled pipeline task '%s'. Attempting to show results as JSON.", task)
yield json.dumps(result, indent=4)

def execute(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ def validate(out: str):
path = out.strip()
actual_sample_rate = sf.read(path)[1]
pathlib.Path(path).unlink(missing_ok=True)
assert actual_sample_rate == sample_rate
assert sample_rate == actual_sample_rate

return validate


def equals_validator(value):
def validate(out):
assert value == out
assert out == value

return validate


def json_validator(value: dict):
def validate(out):
assert value == json.loads(out)
assert json.loads(out) == value

return validate

Expand Down

0 comments on commit 22221bf

Please sign in to comment.