Optimizations for ComfyUI (taking advantage of multiprocessing?) #6766
avachon100510
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
With some help from Copilot, and from There Is An AI For That, I took the liberty to apply some optimizations to the execution.py file using ProcessPoolExecutor by importing it and as_completed from concurrent.futures, as well as JIT compilers from numba. I wonder how useful this will be in the generation by using CPU or GPU.
`import sys
import copy
import logging
import threading
from collections import deque
import copy
import heapq
import time
from numba import jit
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
from concurrent.futures import ProcessPoolExecutor, as_completed
import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
PENDING = 2
class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def init(self, dynprompt, outputs_cache):
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
class CacheSet:
def init(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
mark_missing()
continue
if output_index >= len(cached_output):
mark_missing()
continue
obj = cached_output[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
def format_value(x):
if x is None:
return None
elif isinstance(x, (int, float, bool, str)):
return x
else:
return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None)
from concurrent.futures import ProcessPoolExecutor, as_completed
class PromptExecutor:
def init(self, server, lru_size=None):
self.lru_size = lru_size
self.server = server
self.reset()
def get_input_info_cached(obj_class, x, cache):
cache_key = (obj_class, x)
if cache_key not in cache:
cache[cache_key] = get_input_info(obj_class, x)
return cache[cache_key]
def get_input_data_cached(inputs, obj_class, unique_id, cache):
cache_key = (obj_class, unique_id)
if cache_key not in cache:
cache[cache_key] = get_input_data(inputs, obj_class, unique_id)
return cache[cache_key]
@jit(nopython=True)
def validate_inputs(prompt, item, validated, info_cache={}, data_cache={}):
unique_id = item
if unique_id in validated:
return validated[unique_id]
def full_type_name(klass):
module = klass.module
if module == 'builtins':
return klass.qualname
return module + '.' + klass.qualname
def validate_prompt(prompt):
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
error = {
"type": "invalid_prompt",
"message": f"Cannot execute because a node is missing the class_type property.",
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
MAXIMUM_HISTORY_SIZE = 10000
class ExecutionStatus:
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
class PromptQueue:
def init(self, server):
self.server = server
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = deque()
self.currently_running = {}
self.history = {}
self.flags = {}
self.execution_status = ExecutionStatus()
server.prompt_queue = self
Beta Was this translation helpful? Give feedback.
All reactions