Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for free-threading #1295

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions libcst/codemod/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@
import difflib
import os.path
import re
import functools
import subprocess
import sys
import time
import traceback
from concurrent.futures import as_completed, Executor
from copy import deepcopy
from dataclasses import dataclass, replace
from multiprocessing import cpu_count, Pool
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union
from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union, Callable

from libcst import parse_module, PartialParserConfig
from libcst.codemod._codemod import Codemod
from libcst.codemod._context import CodemodContext
from libcst.codemod._dummy_pool import DummyPool
from libcst.codemod._dummy_pool import DummyExecutor
from libcst.codemod._runner import (
SkipFile,
SkipReason,
Expand Down Expand Up @@ -607,13 +609,20 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
python_version=python_version,
)

pool_impl: Callable[[], Executor]
if total == 1 or jobs == 1:
# Simple case, we should not pay for process overhead.
# Let's just use a dummy synchronous pool.
# Let's just use a dummy synchronous executor.
jobs = 1
pool_impl = DummyPool
pool_impl = DummyExecutor
elif getattr(sys, "_is_gil_enabled", lambda: False)():
from concurrent.futures import ThreadPoolExecutor

pool_impl = functools.partial(ThreadPoolExecutor, max_workers=jobs)
else:
pool_impl = Pool
from concurrent.futures import ProcessPoolExecutor

pool_impl = functools.partial(ProcessPoolExecutor, max_workers=jobs)
# Warm the parser, pre-fork.
parse_module(
"",
Expand All @@ -629,7 +638,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
warnings: int = 0
skips: int = 0

with pool_impl(processes=jobs) as p: # type: ignore
with pool_impl() as executor: # type: ignore
args = [
{
"transformer": transform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't transform.context get clobbered here:

# Apart from metadata_manager, every field of context should be reset per file
transformer.context = CodemodContext(
scratch=deepcopy(scratch),
filename=filename,
full_module_name=mod_name,
full_package_name=pkg_name,
metadata_manager=transformer.context.metadata_manager,
)

This is mostly safe with the GIL, so long as python doesn't switch to another thread that also owns a reference to transform, but in free-threading with the sort of parallel thread pool setup you have here it's much more likely to happen.

In general - how do you want to go about adding multithreaded tests? It looks like there's a decent amount of stateful objects that get passed around and updated. How much should we worry about threads simultaneously updating stateful objects?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah absolutely - we'll probably have to change how state is passed around between threads. In fact, it might be better to just make a new transformer instance for each file and only pass immutable data structures between threads/processes.

In general - how do you want to go about adding multithreaded tests?

We could run the test suite with something like unittest-ft, but the main thing I'd look for is the concurrent execution of the same set of visitors/transformers on multiple input files (basically what this PR sketches out), and then any bugs we find we can write targeted tests. How does that sound?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does that sound?

Sounds great, that's more or less what I've been doing with pyca/cryptography this week: pyca/cryptography#12555

Another thing I learned there is moving code form pure python to Rust can also be a way to avoid thread safety issues. If we move state from Python to Rust then we have much more control over concurrent multithreaded access.

I'll look at this next week - thanks for merging my other PR!

Expand All @@ -640,9 +649,9 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
for filename in files
]
try:
for result in p.imap_unordered(
_execute_transform_wrap, args, chunksize=chunksize
):
futures = [executor.submit(_execute_transform_wrap, arg) for arg in args]
for future in as_completed(futures):
result = future.result()
# Print an execution result, keep track of failures
_print_parallel_result(
result,
Expand Down
53 changes: 33 additions & 20 deletions libcst/codemod/_dummy_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,50 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
from concurrent.futures import Executor, Future
from types import TracebackType
from typing import Callable, Generator, Iterable, Optional, Type, TypeVar
from typing import Callable, Optional, Type, TypeVar

RetT = TypeVar("RetT")
ArgT = TypeVar("ArgT")
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

Return = TypeVar("Return")
ParamSpec = ParamSpec("ParamSpec")

class DummyPool:

class DummyExecutor(Executor):
"""
Synchronous dummy `multiprocessing.Pool` analogue.
Synchronous dummy `concurrent.futures.Executor` analogue.
"""

def __init__(self, processes: Optional[int] = None) -> None:
pass

def imap_unordered(
def submit(
self,
func: Callable[[ArgT], RetT],
iterable: Iterable[ArgT],
chunksize: Optional[int] = None,
) -> Generator[RetT, None, None]:
for args in iterable:
yield func(args)

def __enter__(self) -> "DummyPool":
# pyre-ignore
fn: Callable[ParamSpec, Return],
# pyre-ignore
*args: ParamSpec.args,
# pyre-ignore
**kwargs: ParamSpec.kwargs,
# pyre-ignore
) -> Future[Return]:
future: Future[Return] = Future()
try:
result = fn(*args, **kwargs)
future.set_result(result)
except Exception as exc:
future.set_exception(exc)
return future

def __enter__(self) -> "DummyExecutor":
return self

def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc: Optional[Exception],
tb: Optional[TracebackType],
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass
2 changes: 1 addition & 1 deletion native/libcst/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use crate::nodes::traits::py::TryIntoPy;
use pyo3::prelude::*;

#[pymodule]
#[pymodule(gil_used = false)]
#[pyo3(name = "native")]
pub fn libcst_native(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
#[pyfn(m)]
Expand Down
Loading