Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utilities for video streaming #503

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a860ea9
make modelinfo for inference thread-safe
deigen Jan 16, 2025
477f2f0
download in parallel to current stream
deigen Jan 21, 2025
33bac1b
update requirements
deigen Jan 28, 2025
5e96b5c
maintain venv in user cache dir and add pdb flags
deigen Jan 28, 2025
74cd874
stream and video utils files
deigen Jan 30, 2025
f49b1f9
stream and video utils files
deigen Jan 30, 2025
01398b7
stream video file function
deigen Jan 30, 2025
917f075
create error model whem model fails to load
deigen Jan 31, 2025
6fb0396
write ffmpeg command in pyav
deigen Feb 3, 2025
91f5aca
remove pip install clarifai from model_run_locally
deigen Feb 4, 2025
9f18f9e
rename function
deigen Feb 4, 2025
7fbd78c
stream_util changes
deigen Feb 5, 2025
00ebae0
revert running-related changes
deigen Feb 5, 2025
8bac797
move ensure_urls_downloaded into model class with enable flag
deigen Feb 5, 2025
55e3f78
revert unrelated change
deigen Feb 6, 2025
0389776
fix import
deigen Feb 6, 2025
e002eb2
add test
deigen Feb 6, 2025
9ac9293
fix name
deigen Feb 6, 2025
2eecf55
Merge branch 'master' into video-inference
deigen Feb 6, 2025
6568c04
fix model param updates
deigen Feb 6, 2025
87c60c7
fix unused line
deigen Feb 6, 2025
ca6ab9c
fixes
deigen Feb 6, 2025
1c5a334
fix order of ensure_urls_downloaded and parse_input_request
deigen Feb 6, 2025
083f14d
rename function
deigen Feb 7, 2025
50a8551
Merge branch 'master' into video-inference
deigen Feb 11, 2025
c397e5d
use optional import
deigen Feb 11, 2025
df603bd
add aiohttp to reqs
deigen Feb 13, 2025
7bc9cbb
Merge branch 'master' of github.com:Clarifai/clarifai-python
deigen Feb 18, 2025
2480acc
Merge branch 'master' into video-inference
deigen Feb 18, 2025
f063319
move modules to utils
deigen Feb 18, 2025
071592e
Merge branch 'master' into video-inference
deigen Feb 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MODEL_EXPORT_TIMEOUT, RANGE_SIZE, TRAINABLE_MODEL_TYPES)
from clarifai.errors import UserError
from clarifai.urls.helper import ClarifaiUrlHelper
from clarifai.utils import video_utils
from clarifai.utils.logging import logger
from clarifai.utils.misc import BackoffIterator, status_is_retryable
from clarifai.utils.model_train import (find_and_replace_key, params_parser,
Expand Down Expand Up @@ -424,14 +425,14 @@ def predict(self,
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
) # TODO Use Chunker for inputs len > 128

self._override_model_version(inference_params, output_config)
model_info = self._get_model_info_for_inference(inference_params, output_config)
request = service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
runner_selector=runner_selector,
model=self.model_info)
model=model_info)

start_time = time.time()
backoff_iterator = BackoffIterator(10)
Expand Down Expand Up @@ -704,14 +705,14 @@ def generate(self,
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
) # TODO Use Chunker for inputs len > 128

self._override_model_version(inference_params, output_config)
model_info = self._get_model_info_for_inference(inference_params, output_config)
request = service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
runner_selector=runner_selector,
model=self.model_info)
model=model_info)

start_time = time.time()
backoff_iterator = BackoffIterator(10)
Expand Down Expand Up @@ -922,15 +923,16 @@ def generate_by_url(self,
inference_params=inference_params,
output_config=output_config)

def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector):
def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector,
model_info: resources_pb2.Model):
for inputs in input_iterator:
yield service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
runner_selector=runner_selector,
model=self.model_info)
model=model_info)

def stream(self,
inputs: Iterator[List[Input]],
Expand All @@ -954,8 +956,8 @@ def stream(self,
# if not isinstance(inputs, Iterator[List[Input]]):
# raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')

self._override_model_version(inference_params, output_config)
request = self._req_iterator(inputs, runner_selector)
model_info = self._get_model_info_for_inference(inference_params, output_config)
request = self._req_iterator(inputs, runner_selector, model_info)

start_time = time.time()
backoff_iterator = BackoffIterator(10)
Expand Down Expand Up @@ -1168,8 +1170,53 @@ def input_generator():
inference_params=inference_params,
output_config=output_config)

def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
"""Overrides the model version.
def stream_by_video_file(self,
filepath: str,
input_type: str = 'video',
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""
Stream the model output based on the given video file.

Converts the video file to a streamable format, streams as bytes to the model,
and streams back the model outputs.

Args:
filepath (str): The filepath to predict.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
inference_params (dict): The inference params to override.
output_config (dict): The output config to override.
"""

if not os.path.isfile(filepath):
raise UserError('Invalid filepath.')

# TODO check if the file is streamable already

# Convert the video file to a streamable format
# TODO this conversion can offset the start time by a little bit; we should account for this
# by getting the original start time ffprobe and either sending that to the model so it can adjust
# with the ts of the first frame (too fragile to do all of this adjustment in the client input stream)
# or by adjusting the timestamps in the output stream
stream = video_utils.convert_to_streamable(filepath)

# TODO accumulate reads to fill the chunk size
chunk_size = 1024 * 1024 # 1 MB
chunk_iterator = iter(lambda: stream.read(chunk_size), b'')

return self.stream_by_bytes(chunk_iterator, input_type, compute_cluster_id, nodepool_id,
deployment_id, user_id, inference_params, output_config)

def _get_model_info_for_inference(self, inference_params: Dict = {},
output_config: Dict = {}) -> None:
"""Gets the model_info with modified inference params and output config.

Args:
inference_params (dict): The inference params to override.
Expand All @@ -1179,13 +1226,12 @@ def _override_model_version(self, inference_params: Dict = {}, output_config: Di
select_concepts (list[Concept]): The concepts to select.
sample_ms (int): The number of milliseconds to sample.
"""
params = Struct()
if inference_params is not None:
params.update(inference_params)

self.model_info.model_version.output_info.CopyFrom(
resources_pb2.OutputInfo(
output_config=resources_pb2.OutputConfig(**output_config), params=params))
model_info = resources_pb2.Model()
model_info.CopyFrom(self.model_info)
model_info.model_version.output_info.params = inference_params
model_info.model_version.output_info.output_config.CopyFrom(
resources_pb2.OutputConfig(**output_config))
return model_info

def _list_concepts(self) -> List[str]:
"""Lists all the concepts for the model type.
Expand Down
15 changes: 12 additions & 3 deletions clarifai/runners/models/base_typed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from clarifai_grpc.grpc.api.service_pb2 import PostModelOutputsRequest
from google.protobuf import json_format

from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded
from clarifai.utils.stream_utils import readahead

from ..utils.data_handler import InputDataHandler, OutputDataHandler
from .model_class import ModelClass

Expand Down Expand Up @@ -46,12 +49,16 @@ def convert_output_to_proto(self, outputs: list):

def predict_wrapper(
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
if self.download_request_urls:
ensure_urls_downloaded(request)
list_dict_input, inference_params = self.parse_input_request(request)
outputs = self.predict(list_dict_input, inference_parameters=inference_params)
return self.convert_output_to_proto(outputs)

def generate_wrapper(
self, request: PostModelOutputsRequest) -> Iterator[service_pb2.MultiOutputResponse]:
if self.download_request_urls:
ensure_urls_downloaded(request)
list_dict_input, inference_params = self.parse_input_request(request)
outputs = self.generate(list_dict_input, inference_parameters=inference_params)
for output in outputs:
Expand All @@ -64,11 +71,13 @@ def _preprocess_stream(
input_data, _ = self.parse_input_request(req)
yield input_data

def stream_wrapper(self, request: Iterator[PostModelOutputsRequest]
def stream_wrapper(self, request_iterator: Iterator[PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
first_request = next(request)
if self.download_request_urls:
request_iterator = readahead(map(ensure_urls_downloaded, request_iterator))
first_request = next(request_iterator)
_, inference_params = self.parse_input_request(first_request)
request_iterator = itertools.chain([first_request], request)
request_iterator = itertools.chain([first_request], request_iterator)
outputs = self.stream(self._preprocess_stream(request_iterator), inference_params)
for output in outputs:
yield self.convert_output_to_proto(output)
Expand Down
22 changes: 20 additions & 2 deletions clarifai/runners/models/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,41 @@

from clarifai_grpc.grpc.api import service_pb2

from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded
from clarifai.utils.stream_utils import readahead


class ModelClass(ABC):

download_request_urls = True

def predict_wrapper(
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
"""This method is used for input/output proto data conversion"""
# Download any urls that are not already bytes.
if self.download_request_urls:
ensure_urls_downloaded(request)

return self.predict(request)

def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This method is used for input/output proto data conversion and yield outcome"""
# Download any urls that are not already bytes.
if self.download_request_urls:
ensure_urls_downloaded(request)

return self.generate(request)

def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest
def stream_wrapper(self, request_stream: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This method is used for input/output proto data conversion and yield outcome"""
return self.stream(request)

# Download any urls that are not already bytes.
if self.download_request_urls:
request_stream = readahead(map(ensure_urls_downloaded, request_stream))

return self.stream(request_stream)

@abstractmethod
def load_model(self):
Expand Down
4 changes: 0 additions & 4 deletions clarifai/runners/models/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from clarifai_protocol import BaseRunner
from clarifai_protocol.utils.health import HealthProbeRequestHandler
from ..utils.url_fetcher import ensure_urls_downloaded

from .model_class import ModelClass

Expand Down Expand Up @@ -79,7 +78,6 @@ def runner_item_predict(self,
if not runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(runner_item))
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request)

resp = self.model.predict_wrapper(request)
successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
Expand Down Expand Up @@ -109,7 +107,6 @@ def runner_item_generate(
if not runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(runner_item))
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request)

for resp in self.model.generate_wrapper(request):
successes = []
Expand Down Expand Up @@ -169,5 +166,4 @@ def pmo_iterator(runner_item_iterator):
for runner_item in runner_item_iterator:
if not runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(runner_item))
ensure_urls_downloaded(runner_item.post_model_outputs_request)
yield runner_item.post_model_outputs_request
18 changes: 1 addition & 17 deletions clarifai/runners/models/model_servicer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from itertools import tee
from typing import Iterator

from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2

from ..utils.url_fetcher import ensure_urls_downloaded


class ModelServicer(service_pb2_grpc.V2Servicer):
"""
Expand All @@ -27,9 +24,6 @@ def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
returns an output.
"""

# Download any urls that are not already bytes.
ensure_urls_downloaded(request)

try:
return self.model.predict_wrapper(request)
except Exception as e:
Expand All @@ -46,9 +40,6 @@ def GenerateModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
This is the method that will be called when the servicer is run. It takes in an input and
returns an output.
"""
# Download any urls that are not already bytes.
ensure_urls_downloaded(request)

try:
return self.model.generate_wrapper(request)
except Exception as e:
Expand All @@ -66,15 +57,8 @@ def StreamModelOutputs(self,
This is the method that will be called when the servicer is run. It takes in an input and
returns an output.
"""
# Duplicate the iterator
request, request_copy = tee(request)

# Download any urls that are not already bytes.
for req in request:
ensure_urls_downloaded(req)

try:
return self.model.stream_wrapper(request_copy)
return self.model_class.stream_wrapper(request)
except Exception as e:
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
code=status_code_pb2.MODEL_PREDICTION_FAILED,
Expand Down
12 changes: 12 additions & 0 deletions clarifai/runners/utils/url_fetcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import concurrent.futures
from typing import Iterable

import fsspec

from clarifai.utils.logging import logger
from clarifai.utils.stream_utils import MB


def download_input(input):
Expand Down Expand Up @@ -47,3 +49,13 @@ def ensure_urls_downloaded(request, max_threads=128):
future.result()
except Exception as e:
logger.exception(f"Error downloading input: {e}")
return request


def stream_url(url: str, chunk_size: int = 1 * MB) -> Iterable[bytes]:
"""
Opens a stream of byte chunks from a URL.
"""
# block_size=0 means that the file is streamed
with fsspec.open(url, 'rb', block_size=0) as f:
yield from iter(lambda: f.read(chunk_size), b'')
24 changes: 24 additions & 0 deletions clarifai/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import re
import uuid
Expand All @@ -18,6 +19,29 @@ def status_is_retryable(status_code: int) -> bool:
return status_code in RETRYABLE_CODES


def optional_import(module_name: str, pip_package: str = None):
"""Import a module if it exists.
Otherwise, return an object that will raise an error when accessed.
"""
try:
return importlib.import_module(module_name)
except ImportError:
return _MissingModule(module_name, pip_package=pip_package)


class _MissingModule:
"""Object that raises an error when accessed."""

def __init__(self, module_name, pip_package=None):
self.module_name = module_name
self.message = f"Module `{module_name}` is not installed."
if pip_package:
self.message += f" Please add `{pip_package}` to your requirements.txt file."

def __getattr__(self, name):
raise ImportError(self.message)


class Chunker:
"""Split an input sequence into small chunks."""

Expand Down
Loading
Loading