Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New model interface #516

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open

New model interface #516

wants to merge 30 commits into from

Conversation

luv-bansal
Copy link
Contributor

Why

How

Tests

Notes


class BaseDataHandler:
def to_proto(self, output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it has to have 'output' here? I think it should turn whatever it has to output proto?

@deigen
Copy link
Contributor

deigen commented Feb 12, 2025

Nice --- really on the right track. Things from our discussion for next changes:

  • Named outputs, e.g. def predict(x: Image) -> Output(y: str, z: str)
  • This is always using parts, which is different from using data.image directly (now, we'll use data.parts[argname].data.image). That needs to be discussed more broadly as it can make things around monitoring, UI, etc different.
    • A possible heuristic that might address this most of the way, is serialize the first parameter in the top-level data, then all subsequent params in the parts.
  • Need to add name/id field to the Parts proto in the protos repo.
  • Provide some example models (either in examples or here) for testing
  • We might be better off not supporting unlimited dict nesting levels for now. However, a json dict might be good to support for params and other definitions. That may not need to be in this PR though.
  • We should support the following python types as "atomics" --- some of these are not yet there, e.g. bytes or int:
    • str, bytes, int, float, bool, np.ndarray, PIL.Image.Image
    • for types without corresponding fields in the proto now, we need to either add fields or make correspondences to existing ones (e.g. int might use ndarray, though I don't particularly like that, it would be simpler to use int64 in protobuf. or possibly json.)
  • Make sure to test with invalid client calls with the wrong types --- What is the error provided to the user in these cases? It should be along the lines of what they would get calling a function with the wrong args.
  • Test with the case where the server defines def predict(x: str) -> str: return x and the client calls with model.predict(x=Image.open('test.jpg'). What happens? Right now I think it will return the empty string (there is nothing in parts[x].text.raw) but it would be better to error on mismatched types. However, if the user actually calls with an empty string, it should still return the empty string for model.predict(x='')

outputs = []
inputs = self._convert_proto_to_python(request.inputs)
if len(inputs) == 1:
inputs = inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this will be nice as in this wrapper we can handle things that the user should not have to deal with like including in each output the input.id too which some of our APIs / code

Copy link
Contributor

@deigen deigen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is making progress. I'm still reading over this revision, but have some comments so far on generate and stream:

  • There are a few issues with batched_predict() and batched_generate(), which I mention in my other comments. Because of these, I think it would be best to remove these functions from this PR and implement them in a subsequent one.
  • stream() should be called once with an iterator, not once per input in the stream (see inline comment on this)

"""Batch generate method for multiple inputs."""
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.generate, **input) for input in inputs]
return [future.result() for future in futures]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will return a list of generators, not one generator that zips all the outputs. Most simply, this should use zip_longest. Also, this won't run the generators in different threads -- only the generate call that produces the generator will run in another thread. All the calls to next() will be in the zip call in the main thread. That's inconsistent with the multithreaded batch_predict() behavior, which does run them in different threads.

Copy link
Contributor Author

@luv-bansal luv-bansal Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I'm not implementing batch_generate function as we discussed in call, we will deal with it in separate PR

"""Batch predict method for multiple inputs."""
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.predict, **input) for input in inputs]
return [future.result() for future in futures]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether to use multiple threads by default; multithreading for a batched call should be optional/configurable, but I don't know what the default number of threads should be. My inclination is to use the safest route, which is no multithreading unless enabled. There can be examples in the examples repo that have it enabled, so if you start by copying an example it will be enabled to start with for those.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I've update this to simple for loop implementation as per your suggestion, we can deal optimising batching in separate PR

inputs = self._convert_proto_to_python(request.inputs)
if len(inputs) == 1:
inputs = inputs[0]
for output in self.stream(**inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A call to stream() should call the function once, with an iterator of inputs that we get from the request stream, so the stream() function impl can take inputs by reading off the stream iterator. This will call stream() once for each input in the stream instead. That will make it more difficult for the user to maintain state (and currently they can't distinguish which stream source is which).

The python input types make this a little difficult, as we'd like to be able to pass in a stream of inputs while still using the converters in this PR. One decent option would be to use an InputStream type or Stream[Input(...)] type, mirroring the Output type. All parts names that we put into the stream are streamed in. Parts names in function kwargs typed direclty (not as a stream) just get the first value and subsequent values passed in that are nonempty but different result in a warning log.

For example:

  def stream(self, stream: InputStream(img=Image, text=str), drop: bool = False, param1: str = "value1", param2: int = 2):
     input_q = queue.Queue(2)
     output_q = queue.Queue(2)
     def _read():
       try:
         for input in stream:
           try:
             input_q.put(input, block=not drop)  # user passes in drop for whether to drop or block input reading when not keeping up, as a function param (in this example)
           except queue.Full:
             pass   # drop when not keeping up
       finally:
         input_q.put(None)
     def _work():
       try:
         for input in iter(input_q.get, None):
           # process the stream input --- param1, param2 are values from the first request, input is values for the current input stream request
           output = _process(input.img, input.text, param1, param2) 
           output_q.put(output)  # blocking
       finally:
         output_q.put(None)
     threading.Thread(target=_read).start()
     threading.Thread(target=_work).start()
     yield from iter(output_q.get, None)

At some point (not this PR) we should also add util functions for this streaming pattern.

Might be good to add this discussion back to the design doc RFC as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah Stream_wrapper function is currently wrong implemented. I'm not sure how to correctly implement it

part.data.image.CopyFrom(image.to_proto())
elif isinstance(part_value, list):
if len(part_value) == 0:
raise ValueError("List must have at least one element")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be OK with empty list.

@@ -116,7 +117,7 @@ def _convert_proto_to_python(self, inputs: List[resources_pb2.Input]) -> List[Di

def _convert_part_data(self, data: resources_pb2.Data, param_type: type) -> Any:
if param_type == str:
return data.text.value
return data.text.raw
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should use the new string_value field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use data.string_value or data.text.raw? It should be consistent everywhere. For now, I have used data.text.raw in all places.

list_output.append(Audio(part.data.audio))
elif part.data.HasField("video"):
list_output.append(Video(part.data.video))
elif part.data.HasField("bytes_value"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think HasField works on the built-in types, only message fields

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you have to check for zero values I think is the right approach which is unfortunately not great as you don't know if it's 0 or not provided.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may end up biting us and we might need to wrap each of the new fields in a message so we can do like Bytes.value as David was saying. but let's see as the zero convention in protobufs is well known and the MessageToDict type of methods understand what to do properly there. We may have to validate that an arg can only have zero values as it's default. Like users shouldn't be allowed to set defaults in python (though that kind of sucks too).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're correct—HasField won't work on built-in types. And when checking for zero values, we can't determine whether a field was not provided or if the user explicitly set it to 0. This is a main concern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now I'm checking for zero, but I think we can wrap each of the new fields in a message for proper validation

@luv-bansal luv-bansal marked this pull request as ready for review February 21, 2025 11:08
@luv-bansal luv-bansal requested a review from deigen February 21, 2025 11:08
@luv-bansal
Copy link
Contributor Author

@deigen, I've addressed all the points you mentioned in your comment and everything we discussed on the call.

you specially asked me focus on these two points:

  • Make sure to test with invalid client calls with the wrong types --- What is the error provided to the user in these cases? It should be along the lines of what they would get calling a function with the wrong args.
  • Test with the case where the server defines def predict(x: str) -> str: return x and the client calls with model.predict(x=Image.open('test.jpg'). What happens? Right now I think it will return the empty string (there is nothing in parts[x].text.raw) but it would be better to error on mismatched types. However, if the user actually calls with an empty string, it should still return the empty string for model.predict(x='')

For both cases, the model now throws errors with clear messages:

Case 1 (Invalid parameter in predict method):

Exception: Model Predict failed with response code: FAILURE
details: "Unknown parameter: `text3` in predict method, available parameters: odict_keys([\'text1\', \'text2\'])"
req_id: "sdk-python-11.1.5-dec9f9995bac4d53ba30cbadac651ae2"

Case 2 (Incorrect data type)

Exception: Model Predict failed with response code: FAILURE
details: "expected str datatype but the provided input is not a str"
req_id: "sdk-python-11.1.5-f79bd7b9739a4879b9f2abcd774012f8"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

4 participants