Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extension installer using threadpoolexecutor #16548

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="does not do anything")
parser.add_argument("--max-install-thread", type=int, default=1, help="Maximum Thread number for for asynchronously install extensions. 1 = normal install. ⚠ Enabling this feature may cause unintended issues with some extensions."),
parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
Expand Down
24 changes: 21 additions & 3 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import shlex
from functools import lru_cache

from concurrent.futures import ThreadPoolExecutor, as_completed
from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
from modules.timer import startup_timer
Expand Down Expand Up @@ -228,7 +229,10 @@ def version_check(commit):
def run_extension_installer(extension_dir):
path_installer = os.path.join(extension_dir, "install.py")
if not os.path.isfile(path_installer):
return
return False

dirname = os.path.basename(extension_dir)
logging.debug(f"Installing {dirname}")

try:
env = os.environ.copy()
Expand All @@ -240,6 +244,8 @@ def run_extension_installer(extension_dir):
except Exception as e:
errors.report(str(e))

return True


def list_extensions(settings_file):
settings = {}
Expand Down Expand Up @@ -267,14 +273,26 @@ def run_extensions_installers(settings_file):
return

with startup_timer.subcategory("run extensions installers"):
paths = {}
for dirname_extension in list_extensions(settings_file):
logging.debug(f"Installing {dirname_extension}")

path = os.path.join(extensions_dir, dirname_extension)

if os.path.isdir(path):
paths[dirname_extension] = path

max_workers = args.max_install_thread
if max_workers == 1:
for dirname_extension, path in paths.items():
run_extension_installer(path)
startup_timer.record(dirname_extension)
else:
max_workers = min(max_workers, 4)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(run_extension_installer, path): dirname_extension for dirname_extension, path in paths.items()}
for future in as_completed(futures):
dirname_extension = futures[future]
if future.result():
startup_timer.record(dirname_extension)


re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
Expand Down