diff --git a/libcst/codemod/_cli.py b/libcst/codemod/_cli.py index 2481bf9d..aa482c30 100644 --- a/libcst/codemod/_cli.py +++ b/libcst/codemod/_cli.py @@ -10,20 +10,22 @@ import difflib import os.path import re +import functools import subprocess import sys import time import traceback +from concurrent.futures import as_completed, Executor from copy import deepcopy from dataclasses import dataclass, replace -from multiprocessing import cpu_count, Pool +from multiprocessing import cpu_count from pathlib import Path -from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union +from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union, Callable from libcst import parse_module, PartialParserConfig from libcst.codemod._codemod import Codemod from libcst.codemod._context import CodemodContext -from libcst.codemod._dummy_pool import DummyPool +from libcst.codemod._dummy_pool import DummyExecutor from libcst.codemod._runner import ( SkipFile, SkipReason, @@ -607,13 +609,20 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901 python_version=python_version, ) + pool_impl: Callable[[], Executor] if total == 1 or jobs == 1: # Simple case, we should not pay for process overhead. - # Let's just use a dummy synchronous pool. + # Let's just use a dummy synchronous executor. jobs = 1 - pool_impl = DummyPool + pool_impl = DummyExecutor + elif getattr(sys, "_is_gil_enabled", lambda: False)(): + from concurrent.futures import ThreadPoolExecutor + + pool_impl = functools.partial(ThreadPoolExecutor, max_workers=jobs) else: - pool_impl = Pool + from concurrent.futures import ProcessPoolExecutor + + pool_impl = functools.partial(ProcessPoolExecutor, max_workers=jobs) # Warm the parser, pre-fork. parse_module( "", @@ -629,7 +638,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901 warnings: int = 0 skips: int = 0 - with pool_impl(processes=jobs) as p: # type: ignore + with pool_impl() as executor: # type: ignore args = [ { "transformer": transform, @@ -640,9 +649,9 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901 for filename in files ] try: - for result in p.imap_unordered( - _execute_transform_wrap, args, chunksize=chunksize - ): + futures = [executor.submit(_execute_transform_wrap, arg) for arg in args] + for future in as_completed(futures): + result = future.result() # Print an execution result, keep track of failures _print_parallel_result( result, diff --git a/libcst/codemod/_dummy_pool.py b/libcst/codemod/_dummy_pool.py index c4a24932..f256e405 100644 --- a/libcst/codemod/_dummy_pool.py +++ b/libcst/codemod/_dummy_pool.py @@ -3,37 +3,50 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import sys +from concurrent.futures import Executor, Future from types import TracebackType -from typing import Callable, Generator, Iterable, Optional, Type, TypeVar +from typing import Callable, Optional, Type, TypeVar -RetT = TypeVar("RetT") -ArgT = TypeVar("ArgT") +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec +Return = TypeVar("Return") +ParamSpec = ParamSpec("ParamSpec") -class DummyPool: + +class DummyExecutor(Executor): """ - Synchronous dummy `multiprocessing.Pool` analogue. + Synchronous dummy `concurrent.futures.Executor` analogue. """ - def __init__(self, processes: Optional[int] = None) -> None: - pass - - def imap_unordered( + def submit( self, - func: Callable[[ArgT], RetT], - iterable: Iterable[ArgT], - chunksize: Optional[int] = None, - ) -> Generator[RetT, None, None]: - for args in iterable: - yield func(args) - - def __enter__(self) -> "DummyPool": + # pyre-ignore + fn: Callable[ParamSpec, Return], + # pyre-ignore + *args: ParamSpec.args, + # pyre-ignore + **kwargs: ParamSpec.kwargs, + # pyre-ignore + ) -> Future[Return]: + future: Future[Return] = Future() + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception as exc: + future.set_exception(exc) + return future + + def __enter__(self) -> "DummyExecutor": return self def __exit__( self, - exc_type: Optional[Type[Exception]], - exc: Optional[Exception], - tb: Optional[TracebackType], + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: pass diff --git a/native/libcst/src/py.rs b/native/libcst/src/py.rs index bd7dfe6d..57da11e7 100644 --- a/native/libcst/src/py.rs +++ b/native/libcst/src/py.rs @@ -6,7 +6,7 @@ use crate::nodes::traits::py::TryIntoPy; use pyo3::prelude::*; -#[pymodule] +#[pymodule(gil_used = false)] #[pyo3(name = "native")] pub fn libcst_native(_py: Python, m: &Bound) -> PyResult<()> { #[pyfn(m)]