-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New model interface #516
base: master
Are you sure you want to change the base?
New model interface #516
Conversation
|
||
class BaseDataHandler: | ||
def to_proto(self, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why it has to have 'output' here? I think it should turn whatever it has to output proto?
…into new-model-interface
Nice --- really on the right track. Things from our discussion for next changes:
|
outputs = [] | ||
inputs = self._convert_proto_to_python(request.inputs) | ||
if len(inputs) == 1: | ||
inputs = inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh this will be nice as in this wrapper we can handle things that the user should not have to deal with like including in each output the input.id too which some of our APIs / code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is making progress. I'm still reading over this revision, but have some comments so far on generate and stream:
- There are a few issues with batched_predict() and batched_generate(), which I mention in my other comments. Because of these, I think it would be best to remove these functions from this PR and implement them in a subsequent one.
- stream() should be called once with an iterator, not once per input in the stream (see inline comment on this)
"""Batch generate method for multiple inputs.""" | ||
with ThreadPoolExecutor() as executor: | ||
futures = [executor.submit(self.generate, **input) for input in inputs] | ||
return [future.result() for future in futures] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will return a list of generators, not one generator that zips all the outputs. Most simply, this should use zip_longest. Also, this won't run the generators in different threads -- only the generate call that produces the generator will run in another thread. All the calls to next() will be in the zip call in the main thread. That's inconsistent with the multithreaded batch_predict() behavior, which does run them in different threads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I'm not implementing batch_generate
function as we discussed in call, we will deal with it in separate PR
"""Batch predict method for multiple inputs.""" | ||
with ThreadPoolExecutor() as executor: | ||
futures = [executor.submit(self.predict, **input) for input in inputs] | ||
return [future.result() for future in futures] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether to use multiple threads by default; multithreading for a batched call should be optional/configurable, but I don't know what the default number of threads should be. My inclination is to use the safest route, which is no multithreading unless enabled. There can be examples in the examples repo that have it enabled, so if you start by copying an example it will be enabled to start with for those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I've update this to simple for loop implementation as per your suggestion, we can deal optimising batching in separate PR
inputs = self._convert_proto_to_python(request.inputs) | ||
if len(inputs) == 1: | ||
inputs = inputs[0] | ||
for output in self.stream(**inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A call to stream() should call the function once, with an iterator of inputs that we get from the request stream, so the stream() function impl can take inputs by reading off the stream iterator. This will call stream() once for each input in the stream instead. That will make it more difficult for the user to maintain state (and currently they can't distinguish which stream source is which).
The python input types make this a little difficult, as we'd like to be able to pass in a stream of inputs while still using the converters in this PR. One decent option would be to use an InputStream type or Stream[Input(...)] type, mirroring the Output type. All parts names that we put into the stream are streamed in. Parts names in function kwargs typed direclty (not as a stream) just get the first value and subsequent values passed in that are nonempty but different result in a warning log.
For example:
def stream(self, stream: InputStream(img=Image, text=str), drop: bool = False, param1: str = "value1", param2: int = 2):
input_q = queue.Queue(2)
output_q = queue.Queue(2)
def _read():
try:
for input in stream:
try:
input_q.put(input, block=not drop) # user passes in drop for whether to drop or block input reading when not keeping up, as a function param (in this example)
except queue.Full:
pass # drop when not keeping up
finally:
input_q.put(None)
def _work():
try:
for input in iter(input_q.get, None):
# process the stream input --- param1, param2 are values from the first request, input is values for the current input stream request
output = _process(input.img, input.text, param1, param2)
output_q.put(output) # blocking
finally:
output_q.put(None)
threading.Thread(target=_read).start()
threading.Thread(target=_work).start()
yield from iter(output_q.get, None)
At some point (not this PR) we should also add util functions for this streaming pattern.
Might be good to add this discussion back to the design doc RFC as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah Stream_wrapper
function is currently wrong implemented. I'm not sure how to correctly implement it
part.data.image.CopyFrom(image.to_proto()) | ||
elif isinstance(part_value, list): | ||
if len(part_value) == 0: | ||
raise ValueError("List must have at least one element") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be OK with empty list.
…into new-model-interface
@@ -116,7 +117,7 @@ def _convert_proto_to_python(self, inputs: List[resources_pb2.Input]) -> List[Di | |||
|
|||
def _convert_part_data(self, data: resources_pb2.Data, param_type: type) -> Any: | |||
if param_type == str: | |||
return data.text.value | |||
return data.text.raw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should use the new string_value field
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use data.string_value or data.text.raw? It should be consistent everywhere. For now, I have used data.text.raw in all places.
list_output.append(Audio(part.data.audio)) | ||
elif part.data.HasField("video"): | ||
list_output.append(Video(part.data.video)) | ||
elif part.data.HasField("bytes_value"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think HasField works on the built-in types, only message fields
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you have to check for zero values I think is the right approach which is unfortunately not great as you don't know if it's 0 or not provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may end up biting us and we might need to wrap each of the new fields in a message so we can do like Bytes.value as David was saying. but let's see as the zero convention in protobufs is well known and the MessageToDict type of methods understand what to do properly there. We may have to validate that an arg can only have zero values as it's default. Like users shouldn't be allowed to set defaults in python (though that kind of sucks too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're correct—HasField won't work on built-in types. And when checking for zero values, we can't determine whether a field was not provided or if the user explicitly set it to 0. This is a main concern
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now I'm checking for zero, but I think we can wrap each of the new fields in a message for proper validation
…into new-model-interface
@deigen, I've addressed all the points you mentioned in your comment and everything we discussed on the call. you specially asked me focus on these two points:
For both cases, the model now throws errors with clear messages: Case 1 (Invalid parameter in predict method):
Case 2 (Incorrect data type)
|
Why
How
Tests
Notes