From 3277488f93795389d2d969e689e4b17e237feca5 Mon Sep 17 00:00:00 2001 From: HR Wu <5631010+heiruwu@users.noreply.github.com> Date: Tue, 23 Jul 2024 02:31:44 +0800 Subject: [PATCH] fix(ray): use dict instead of protobuf message (#178) Because - The generated field names for protobuf message does not aligned between golang and python - Task protobuf definitions are going to be retired This commit - use `MessageToDict` instead to parse task inputs --- instill/helpers/ray_io.py | 420 +++++++++++++++----------------------- 1 file changed, 163 insertions(+), 257 deletions(-) diff --git a/instill/helpers/ray_io.py b/instill/helpers/ray_io.py index 02c56e1..7e35dd2 100644 --- a/instill/helpers/ray_io.py +++ b/instill/helpers/ray_io.py @@ -63,86 +63,32 @@ def protobuf_to_struct(pb_msg): """Convert Protobuf message to Struct""" dict_data = json_format.MessageToDict(pb_msg) - lower_camel_dict: Dict[str, Dict[str, Any]] = {} - # task layer - for k, v in dict_data.items(): - lower_camel_dict[k[0].lower() + k[1:]] = {} - # field layer - for kk, vv in v.items(): - lower_camel_dict[k[0].lower() + k[1:]][snake_to_lower_camel(kk)] = vv - # Convert dictionary to struct_pb2.Struct struct_pb = struct_pb2.Struct() - json_format.ParseDict(lower_camel_dict, struct_pb) + json_format.ParseDict(dict_data, struct_pb) return struct_pb -def struct_to_protobuf(struct_pb, pb_message_type): - """Convert Struct to Protobuf message""" - dict_data = json_format.MessageToDict(struct_pb) - - lower_camel_dict: Dict[str, Dict[str, Any]] = {} - # task layer - for k, v in dict_data.items(): - lower_camel_dict[k[0].lower() + k[1:]] = {} - # field layer - for kk, vv in v.items(): - lower_camel_dict[k[0].lower() + k[1:]][snake_to_lower_camel(kk)] = vv - - # Parse dictionary to Protobuf message - pb_msg = pb_message_type() - json_format.ParseDict(lower_camel_dict, pb_msg) - - return pb_msg - - -def struct_to_dict(struct_obj): - """Convert Protobuf Struct to dictionary""" - if isinstance(struct_obj, struct_pb2.Struct): - return {k: struct_to_dict(v) for k, v in struct_obj.fields.items()} - if isinstance(struct_obj, struct_pb2.ListValue): - return [struct_to_dict(v) for v in struct_obj.values] - if isinstance(struct_obj, struct_pb2.Value): - kind = struct_obj.WhichOneof("kind") - if kind == "null_value": - return None - if kind == "number_value": - return struct_obj.number_value - if kind == "string_value": - return struct_obj.string_value - if kind == "bool_value": - return struct_obj.bool_value - if kind == "struct_value": - return struct_to_dict(struct_obj.struct_value) - if kind == "list_value": - return struct_to_dict(struct_obj.list_value) - else: - return struct_obj - - def parse_task_classification_to_vision_input( request: TriggerRequest, ) -> List[VisionInput]: input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - classification_pb: classificationpb.ClassificationInput = ( - task_input_pb.classification - ) - + task_input_dict = json_format.MessageToDict(task_input)["Classification"][ + "Type" + ] inp = VisionInput() - if ( - classification_pb.image_base64 != "" and classification_pb.image_url != "" - ) or ( - classification_pb.image_base64 == "" and classification_pb.image_url == "" + if ("ImageBase64" in task_input_dict and "ImageUrl" in task_input_dict) or ( + not "ImageBase64" in task_input_dict and not "ImageUrl" in task_input_dict ): raise InvalidInputException - if classification_pb.image_base64 != "": - inp.image = base64_to_pil_image(classification_pb.image_base64) - elif classification_pb.image_url != "": - inp.image = url_to_pil_image(classification_pb.image_url) + if "ImageBase64" in task_input_dict and task_input_dict["ImageBase64"] != "": + inp.image = base64_to_pil_image(task_input_dict["ImageBase64"]) + elif "ImageUrl" in task_input_dict and task_input_dict["ImageUrl"] != "": + inp.image = url_to_pil_image(task_input_dict["ImageUrl"]) + else: + raise InvalidInputException input_list.append(inp) @@ -177,19 +123,18 @@ def parse_task_detection_to_vision_input( ) -> List[VisionInput]: input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - detection_pb: detectionpb.DetectionInput = task_input_pb.detection - + task_input_dict = json_format.MessageToDict(task_input)["Detection"]["Type"] inp = VisionInput() - if (detection_pb.image_base64 != "" and detection_pb.image_url != "") or ( - detection_pb.image_base64 == "" and detection_pb.image_url == "" + if ("ImageBase64" in task_input_dict and "ImageUrl" in task_input_dict) or ( + not "ImageBase64" in task_input_dict and not "ImageUrl" in task_input_dict ): raise InvalidInputException - if detection_pb.image_base64 != "": - inp.image = base64_to_pil_image(detection_pb.image_base64) - elif detection_pb.image_url != "": - inp.image = url_to_pil_image(detection_pb.image_url) + if "ImageBase64" in task_input_dict and task_input_dict["ImageBase64"] != "": + inp.image = base64_to_pil_image(task_input_dict["ImageBase64"]) + elif "ImageUrl" in task_input_dict and task_input_dict["ImageUrl"] != "": + inp.image = url_to_pil_image(task_input_dict["ImageUrl"]) + else: + raise InvalidInputException input_list.append(inp) @@ -244,19 +189,18 @@ def parse_task_ocr_to_vision_input( ) -> List[VisionInput]: input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - ocr_pb: ocrpb.OcrInput = task_input_pb.ocr - + task_input_dict = json_format.MessageToDict(task_input)["Ocr"]["Type"] inp = VisionInput() - if (ocr_pb.image_base64 != "" and ocr_pb.image_url != "") or ( - ocr_pb.image_base64 == "" and ocr_pb.image_url == "" + if ("ImageBase64" in task_input_dict and "ImageUrl" in task_input_dict) or ( + not "ImageBase64" in task_input_dict and not "ImageUrl" in task_input_dict ): raise InvalidInputException - if ocr_pb.image_base64 != "": - inp.image = base64_to_pil_image(ocr_pb.image_base64) - elif ocr_pb.image_url != "": - inp.image = url_to_pil_image(ocr_pb.image_url) + if "ImageBase64" in task_input_dict and task_input_dict["ImageBase64"] != "": + inp.image = base64_to_pil_image(task_input_dict["ImageBase64"]) + elif "ImageUrl" in task_input_dict and task_input_dict["ImageUrl"] != "": + inp.image = url_to_pil_image(task_input_dict["ImageUrl"]) + else: + raise InvalidInputException input_list.append(inp) @@ -307,27 +251,20 @@ def parse_task_instance_segmentation_to_vision_input( ) -> List[VisionInput]: input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - instance_segmentation_pb: instancesegmentationpb.InstanceSegmentationInput = ( - task_input_pb.instance_segmentation - ) - + task_input_dict = json_format.MessageToDict(task_input)["InstanceSegmentation"][ + "Type" + ] inp = VisionInput() - if ( - instance_segmentation_pb.image_base64 != "" - and instance_segmentation_pb.image_url != "" - ) or ( - instance_segmentation_pb.image_base64 == "" - and instance_segmentation_pb.image_url == "" + if ("ImageBase64" in task_input_dict and "ImageUrl" in task_input_dict) or ( + not "ImageBase64" in task_input_dict and not "ImageUrl" in task_input_dict ): raise InvalidInputException - if instance_segmentation_pb.image_base64 != "": - inp.image = base64_to_pil_image( - instance_segmentation_pb.image_base64, - ) - elif instance_segmentation_pb.image_url != "": - inp.image = url_to_pil_image(instance_segmentation_pb.image_url) + if "ImageBase64" in task_input_dict and task_input_dict["ImageBase64"] != "": + inp.image = base64_to_pil_image(task_input_dict["ImageBase64"]) + elif "ImageUrl" in task_input_dict and task_input_dict["ImageUrl"] != "": + inp.image = url_to_pil_image(task_input_dict["ImageUrl"]) + else: + raise InvalidInputException input_list.append(inp) @@ -387,27 +324,20 @@ def parse_task_semantic_segmentation_to_vision_input( ) -> List[VisionInput]: input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - semantic_segmentation_pb: semanticsegmentationpb.SemanticSegmentationInput = ( - task_input_pb.semantic_segmentation - ) - + task_input_dict = json_format.MessageToDict(task_input)["SemanticSegmentation"][ + "Type" + ] inp = VisionInput() - if ( - semantic_segmentation_pb.image_base64 != "" - and semantic_segmentation_pb.image_url != "" - ) or ( - semantic_segmentation_pb.image_base64 == "" - and semantic_segmentation_pb.image_url == "" + if ("ImageBase64" in task_input_dict and "ImageUrl" in task_input_dict) or ( + not "ImageBase64" in task_input_dict and not "ImageUrl" in task_input_dict ): raise InvalidInputException - if semantic_segmentation_pb.image_base64 != "": - inp.image = base64_to_pil_image( - semantic_segmentation_pb.image_base64, - ) - elif semantic_segmentation_pb.image_url != "": - inp.image = url_to_pil_image(semantic_segmentation_pb.image_url) + if "ImageBase64" in task_input_dict and task_input_dict["ImageBase64"] != "": + inp.image = base64_to_pil_image(task_input_dict["ImageBase64"]) + elif "ImageUrl" in task_input_dict and task_input_dict["ImageUrl"] != "": + inp.image = url_to_pil_image(task_input_dict["ImageUrl"]) + else: + raise InvalidInputException input_list.append(inp) @@ -456,19 +386,18 @@ def parse_task_keypoint_to_vision_input( ) -> List[VisionInput]: input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - keypoint_pb: keypointpb.KeypointInput = task_input_pb.keypoint - + task_input_dict = json_format.MessageToDict(task_input)["Keypoint"]["Type"] inp = VisionInput() - if (keypoint_pb.image_base64 != "" and keypoint_pb.image_url != "") or ( - keypoint_pb.image_base64 == "" and keypoint_pb.image_url == "" + if ("ImageBase64" in task_input_dict and "ImageUrl" in task_input_dict) or ( + not "ImageBase64" in task_input_dict and not "ImageUrl" in task_input_dict ): raise InvalidInputException - if keypoint_pb.image_base64 != "": - inp.image = base64_to_pil_image(keypoint_pb.image_base64) - elif keypoint_pb.image_url != "": - inp.image = url_to_pil_image(keypoint_pb.image_url) + if "ImageBase64" in task_input_dict and task_input_dict["ImageBase64"] != "": + inp.image = base64_to_pil_image(task_input_dict["ImageBase64"]) + elif "ImageUrl" in task_input_dict and task_input_dict["ImageUrl"] != "": + inp.image = url_to_pil_image(task_input_dict["ImageUrl"]) + else: + raise InvalidInputException input_list.append(inp) @@ -533,11 +462,7 @@ def parse_task_text_generation_to_conversation_input( input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - text_generation_pb: textgenerationpb.TextGenerationInput = ( - task_input_pb.text_generation - ) + task_input_dict = json_format.MessageToDict(task_input)["TextGeneration"] inp = ConversationInput() @@ -545,19 +470,19 @@ def parse_task_text_generation_to_conversation_input( # system message if ( - text_generation_pb.system_message is not None - and len(text_generation_pb.system_message) > 0 + "system_message" in task_input_dict + and len(task_input_dict["system_message"]) > 0 ): conversation.append( - {"role": "system", "content": text_generation_pb.system_message} + {"role": "system", "content": task_input_dict["system_message"]} ) # conversation history if ( - text_generation_pb.chat_history is not None - and len(text_generation_pb.chat_history) > 0 + "chat_history" in task_input_dict + and len(task_input_dict["chat_history"]) > 0 ): - for chat_entity in text_generation_pb.chat_history: + for chat_entity in task_input_dict["chat_history"]: chat_message = None if len(chat_entity["content"]) > 1: raise ValueError( @@ -574,7 +499,7 @@ def parse_task_text_generation_to_conversation_input( chat_message = chat_entity["content"][0]["text"] else: raise ValueError( - f"Unknown structure of chat_history: {text_generation_pb.chat_history}" + f"Unknown structure of chat_history: {task_input_dict['chat_history']}" ) else: raise ValueError( @@ -590,8 +515,8 @@ def parse_task_text_generation_to_conversation_input( ) if ( chat_entity["role"] == PROMPT_ROLES[-1] - and text_generation_pb.system_message is not None - and len(text_generation_pb.system_message) > 0 + and task_input_dict["system_message"] is not None + and len(task_input_dict["system_message"]) > 0 ): continue if chat_message is None: @@ -614,7 +539,7 @@ def parse_task_text_generation_to_conversation_input( ) # conversation - prompt = text_generation_pb.prompt + prompt = task_input_dict["prompt"] if len(conversation) > 0 and conversation[-1]["role"] == PROMPT_ROLES[0]: last_conversation = conversation.pop() prompt = f"{last_conversation['content']}\n\n{prompt}" @@ -624,20 +549,20 @@ def parse_task_text_generation_to_conversation_input( inp.conversation = conversation # max new tokens - if text_generation_pb.max_new_tokens is not None: - inp.max_new_tokens = text_generation_pb.max_new_tokens + if "max_new_tokens" in task_input_dict: + inp.max_new_tokens = task_input_dict["max_new_tokens"] # temperature - if text_generation_pb.temperature is not None: - inp.temperature = text_generation_pb.temperature + if "temperature" in task_input_dict: + inp.temperature = task_input_dict["temperature"] # top k - if text_generation_pb.top_k is not None: - inp.top_k = text_generation_pb.top_k + if "top_k" in task_input_dict: + inp.top_k = task_input_dict["top_k"] # seed - if text_generation_pb.seed is not None: - inp.seed = text_generation_pb.seed + if "seed" in task_input_dict: + inp.seed = task_input_dict["seed"] input_list.append(inp) @@ -666,34 +591,27 @@ def parse_task_text_generation_chat_to_conversation_input( input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - text_generation_chat_pb: textgenerationchatpb.TextGenerationChatInput = ( - task_input_pb.text_generation_chat - ) + task_input_dict = json_format.MessageToDict(task_input)["TextGenerationChat"] inp = ConversationInput() - conversation = [] + conversation: List[Dict[str, str]] = [] # system message if ( - text_generation_chat_pb.system_message is not None - and len(text_generation_chat_pb.system_message) > 0 + "system_message" in task_input_dict + and len(task_input_dict["system_message"]) > 0 ): conversation.append( - { - "role": "system", - "content": text_generation_chat_pb.system_message, - } + {"role": "system", "content": task_input_dict["system_message"]} ) # conversation history if ( - text_generation_chat_pb.chat_history is not None - and len(text_generation_chat_pb.chat_history) > 0 + "chat_history" in task_input_dict + and len(task_input_dict["chat_history"]) > 0 ): - for chat_entity in text_generation_chat_pb.chat_history: + for chat_entity in task_input_dict["chat_history"]: chat_message = None if len(chat_entity["content"]) > 1: raise ValueError( @@ -710,7 +628,7 @@ def parse_task_text_generation_chat_to_conversation_input( chat_message = chat_entity["content"][0]["text"] else: raise ValueError( - f"Unknown structure of chat_history: {text_generation_chat_pb.chat_history}" + f"Unknown structure of chat_history: {task_input_dict['chat_history']}" ) else: raise ValueError( @@ -726,8 +644,8 @@ def parse_task_text_generation_chat_to_conversation_input( ) if ( chat_entity["role"] == PROMPT_ROLES[-1] - and text_generation_chat_pb.system_message is not None - and len(text_generation_chat_pb.system_message) > 0 + and task_input_dict["system_message"] is not None + and len(task_input_dict["system_message"]) > 0 ): continue if chat_message is None: @@ -750,7 +668,7 @@ def parse_task_text_generation_chat_to_conversation_input( ) # conversation - prompt = text_generation_chat_pb.prompt + prompt = task_input_dict["prompt"] if len(conversation) > 0 and conversation[-1]["role"] == PROMPT_ROLES[0]: last_conversation = conversation.pop() prompt = f"{last_conversation['content']}\n\n{prompt}" @@ -760,20 +678,20 @@ def parse_task_text_generation_chat_to_conversation_input( inp.conversation = conversation # max new tokens - if text_generation_chat_pb.max_new_tokens is not None: - inp.max_new_tokens = text_generation_chat_pb.max_new_tokens + if "max_new_tokens" in task_input_dict: + inp.max_new_tokens = task_input_dict["max_new_tokens"] # temperature - if text_generation_chat_pb.temperature is not None: - inp.temperature = text_generation_chat_pb.temperature + if "temperature" in task_input_dict: + inp.temperature = task_input_dict["temperature"] # top k - if text_generation_chat_pb.top_k is not None: - inp.top_k = text_generation_chat_pb.top_k + if "top_k" in task_input_dict: + inp.top_k = task_input_dict["top_k"] # seed - if text_generation_chat_pb.seed is not None: - inp.seed = text_generation_chat_pb.seed + if "seed" in task_input_dict: + inp.seed = task_input_dict["seed"] input_list.append(inp) @@ -804,11 +722,9 @@ def parse_task_visual_question_answering_to_conversation_multimodal_input( input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - visual_question_answering_pb: ( - visualquestionansweringpb.VisualQuestionAnsweringInput - ) = task_input_pb.visual_question_answering + task_input_dict = json_format.MessageToDict(task_input)[ + "VisualQuestionAnswering" + ] inp = ConversationMultiModelInput() @@ -816,30 +732,24 @@ def parse_task_visual_question_answering_to_conversation_multimodal_input( # system message if ( - visual_question_answering_pb.system_message is not None - and len(visual_question_answering_pb.system_message) > 0 + "system_message" in task_input_dict + and len(task_input_dict["system_message"]) > 0 ): conversation.append( - { - "role": "system", - "content": { - "type": "text", - "content": visual_question_answering_pb.system_message, - }, - } + {"role": "system", "content": task_input_dict["system_message"]} ) # conversation history if ( - visual_question_answering_pb.chat_history is not None - and len(visual_question_answering_pb.chat_history) > 0 + "chat_history" in task_input_dict + and len(task_input_dict["chat_history"]) > 0 ): - for chat_entity in visual_question_answering_pb.chat_history: + for chat_entity in task_input_dict["chat_history"]: chat_dict = json_format.MessageToDict(chat_entity) conversation.append(chat_dict) # conversation - prompt = visual_question_answering_pb.prompt + prompt = task_input_dict["prompt"] if len(conversation) > 0 and conversation[-1]["role"] == PROMPT_ROLES[0]: last_conversation = conversation.pop() prompt = f"{last_conversation['content']['content']}\n\n{prompt}" # type: ignore @@ -859,43 +769,43 @@ def parse_task_visual_question_answering_to_conversation_multimodal_input( # prompt images prompt_image_list = [] if ( - visual_question_answering_pb.prompt_images is not None - and len(visual_question_answering_pb.prompt_images) > 0 + "prompt_images" in task_input_dict + and len(task_input_dict["prompt_images"]) > 0 ): - for prompt_image in visual_question_answering_pb.prompt_images: + for prompt_image in task_input_dict["prompt_images"]: if ( - prompt_image.prompt_image_base64 != "" - and prompt_image.prompt_image_url != "" + "PromptImageUrl" in prompt_image["Type"] + and "PromptImageBase64" in prompt_image["Type"] ) or ( - prompt_image.prompt_image_base64 == "" - and prompt_image.prompt_image_url == "" + "PromptImageUrl" not in prompt_image["Type"] + and "PromptImageBase64" not in prompt_image["Type"] ): raise InvalidInputException - if prompt_image.prompt_image_base64 != "": + if "PromptImageUrl" in prompt_image["Type"]: prompt_image_list.append( - base64_to_pil_image(prompt_image.prompt_image_base64) + url_to_pil_image(prompt_image["Type"]["PromptImageUrl"]) ) - elif prompt_image.prompt_image_url != "": + elif "PromptImageBase64" in prompt_image["Type"]: prompt_image_list.append( - url_to_pil_image(prompt_image.prompt_image_url) + base64_to_pil_image(prompt_image["Type"]["PromptImageBase64"]) ) inp.prompt_images = prompt_image_list # max new tokens - if visual_question_answering_pb.max_new_tokens is not None: - inp.max_new_tokens = visual_question_answering_pb.max_new_tokens + if "max_new_tokens" in task_input_dict: + inp.max_new_tokens = task_input_dict["max_new_tokens"] # temperature - if visual_question_answering_pb.temperature is not None: - inp.temperature = visual_question_answering_pb.temperature + if "temperature" in task_input_dict: + inp.temperature = task_input_dict["temperature"] # top k - if visual_question_answering_pb.top_k is not None: - inp.top_k = visual_question_answering_pb.top_k + if "top_k" in task_input_dict: + inp.top_k = task_input_dict["top_k"] # seed - if visual_question_answering_pb.seed is not None: - inp.seed = visual_question_answering_pb.seed + if "seed" in task_input_dict: + inp.seed = task_input_dict["seed"] input_list.append(inp) @@ -928,30 +838,28 @@ def parse_task_text_to_image_input( input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - text_to_image_pb: texttoimagepb.TextToImageInput = task_input_pb.text_to_image + task_input_dict = json_format.MessageToDict(task_input)["TextToImage"] inp = TextToImageInput() # prompt - inp.prompt = text_to_image_pb.prompt + inp.prompt = task_input_dict["prompt"] # steps - if text_to_image_pb.steps is not None: - inp.steps = text_to_image_pb.steps + if "steps" in task_input_dict: + inp.steps = task_input_dict["steps"] - # temperature - if text_to_image_pb.cfg_scale is not None: - inp.cfg_scale = text_to_image_pb.cfg_scale + # cfg_scale + if "cfg_scale" in task_input_dict: + inp.cfg_scale = task_input_dict["cfg_scale"] - # top k - if text_to_image_pb.samples is not None: - inp.samples = text_to_image_pb.samples + # samples + if "samples" in task_input_dict: + inp.samples = task_input_dict["samples"] # seed - if text_to_image_pb.seed is not None: - inp.seed = text_to_image_pb.seed + if "seed" in task_input_dict: + inp.seed = task_input_dict["seed"] input_list.append(inp) @@ -987,48 +895,46 @@ def parse_task_image_to_image_input( input_list = [] for task_input in request.task_inputs: - task_input_pb = struct_to_protobuf(task_input, modelpb.TaskInput) - - image_to_image_pb: imagetoimagepb.ImageToImageInput = ( - task_input_pb.text_to_image - ) + task_input_dict = json_format.MessageToDict(task_input)["ImageToImage"] inp = ImageToImageInput() # prompt - inp.prompt = image_to_image_pb.prompt + inp.prompt = task_input_dict["prompt"] # prompt images if ( - image_to_image_pb.prompt_image_base64 != "" - and image_to_image_pb.prompt_image_url != "" + "PromptImageUrl" in task_input_dict["Type"] + and "PromptImageBase64" in task_input_dict["Type"] ) or ( - image_to_image_pb.prompt_image_base64 == "" - and image_to_image_pb.prompt_image_url == "" + "PromptImageUrl" not in task_input_dict["Type"] + and "PromptImageBase64" not in task_input_dict["Type"] ): raise InvalidInputException - if image_to_image_pb.prompt_image_base64 != "": + if "PromptImageUrl" in task_input_dict["Type"]: + inp.prompt_image = url_to_pil_image( + task_input_dict["Type"]["PromptImageUrl"] + ) + elif "PromptImageBase64" in task_input_dict["Type"]: inp.prompt_image = base64_to_pil_image( - image_to_image_pb.prompt_image_base64 + task_input_dict["Type"]["PromptImageBase64"] ) - elif image_to_image_pb.prompt_image_url != "": - inp.prompt_image = url_to_pil_image(image_to_image_pb.prompt_image_url) # steps - if image_to_image_pb.steps is not None: - inp.steps = image_to_image_pb.steps + if "steps" in task_input_dict: + inp.steps = task_input_dict["steps"] - # temperature - if image_to_image_pb.cfg_scale is not None: - inp.cfg_scale = image_to_image_pb.cfg_scale + # cfg_scale + if "cfg_scale" in task_input_dict: + inp.cfg_scale = task_input_dict["cfg_scale"] - # top k - if image_to_image_pb.samples is not None: - inp.samples = image_to_image_pb.samples + # samples + if "samples" in task_input_dict: + inp.samples = task_input_dict["samples"] # seed - if image_to_image_pb.seed is not None: - inp.seed = image_to_image_pb.seed + if "seed" in task_input_dict: + inp.seed = task_input_dict["seed"] input_list.append(inp)