From 2c9b5679a842cfcadaba7791e665afd31fbb4f0c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 10:56:30 +0800 Subject: [PATCH 01/21] refactor Signed-off-by: youkaichao --- vllm/compilation/backends.py | 366 +++++++++++------------------------ vllm/config.py | 9 - 2 files changed, 109 insertions(+), 266 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b9f96c00284b9..bb124e05379af 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -6,7 +6,8 @@ import time from collections import defaultdict from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, + Type) from unittest.mock import patch import torch @@ -17,6 +18,8 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor) from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -31,39 +34,28 @@ class InductorArtifact: file_path: str = "" -class InductorHashCache: +class CompilerManager: """ - Disk format: a Python list of tuples, each tuple is - (runtime_shape, graph_index, hash_str, file_path) - We use list of tuple for readability. - - In-memory format: a defaultdict of dict, where the key is - runtime_shape, and the value is a dict of graph_index to hash_str. - - The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`, - we don't use json here because json doesn't support int as key. - - TODO: better off-the-shelf solution to serialize the data? + A dict mapping `(runtime_shape, graph_index, backend_name)` + to `any_data`. + When serializing the cache, we save it to a Python file + for readability. We don't use json here because json doesn't + support int as key. """ - def __init__(self, cache_dir: str, disabled: bool = False): - self.cache: Dict[Optional[int], - Dict[int, InductorArtifact]] = defaultdict(dict) + def __init__(self, cache_dir: str, compilers: List[Type[CompilerInterface]], disabled: bool = False): + self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() + self.compilers = compilers self.disabled = disabled self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, - "inductor_hash_cache.py") + "compiler_manager.py") if disabled: return - # set flags so that Inductor and Triton store their cache - # in the cache_dir, then users only need to copy the cache_dir - # to another machine to reuse the cache. - inductor_cache = os.path.join(cache_dir, "inductor_cache") - os.makedirs(inductor_cache, exist_ok=True) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache - triton_cache = os.path.join(cache_dir, "triton_cache") - os.makedirs(triton_cache, exist_ok=True) - os.environ["TRITON_CACHE_DIR"] = triton_cache + + for compiler in self.compilers: + compiler.init_with_cache_dir(cache_dir) + if os.path.exists(self.cache_file_path): with open(self.cache_file_path) as f: self.deserialize(f.read()) @@ -72,28 +64,10 @@ def deserialize(self, data: str): # we use ast.literal_eval to parse the data # because it is a safe way to parse Python literals. # do not use eval(), it is unsafe. - list_data = ast.literal_eval(data) - for item in list_data: - runtime_shape = item[0] - graph_index = item[1] - hash_str = item[2] - # for compatibility of old version, - # where we don't have file_path. - # NOTE: after running the new code, the file_path - # will be updated. - file_path = "" if len(item) == 3 else item[3] - self.cache[runtime_shape][graph_index] = InductorArtifact( - hash_str=hash_str, file_path=file_path) + self.cache = ast.literal_eval(data) def serialize(self) -> str: - data = [] - for runtime_shape, value in self.cache.items(): - for graph_index, inductor_artifact in value.items(): - data.append( - (runtime_shape, graph_index, inductor_artifact.hash_str, - inductor_artifact.file_path)) - printer = pprint.PrettyPrinter(indent=4) - return printer.pformat(data) + return printer.pformat(self.cache) def save_to_file(self): if self.disabled: @@ -104,11 +78,12 @@ def save_to_file(self): def __contains__(self, key: Tuple[Optional[int], int]) -> bool: if self.disabled: return False - runtime_shape, graph_index = key - return runtime_shape in self.cache and graph_index in self.cache[ - runtime_shape] + for compiler in self.compilers: + if (key[0], key[1], compiler.name) in self.cache: + return True + return False - def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact: + def __getitem__(self, key: Tuple[Optional[int], int]) -> Any: if self.disabled: raise KeyError("cannot read from disabled cache") runtime_shape, graph_index = key @@ -121,210 +96,86 @@ def __setitem__(self, key: Tuple[Optional[int], int], runtime_shape, graph_index = key self.cache[runtime_shape][graph_index] = value - -class AlwaysHitShapeEnv: - """ - Why do we need this class: - - For normal `torch.compile` usage, every compilation will have - one Dynamo bytecode compilation and one Inductor compilation. - The Inductor compilation happens under the context of the - Dynamo bytecode compilation, and that context is used to - determine the dynamic shape information, etc. - - For our use case, we only run Dynamo bytecode compilation once, - and run Inductor compilation multiple times with different shapes - plus a general shape. The compilation for specific shapes happens - outside of the context of the Dynamo bytecode compilation. At that - time, we don't have shape environment to provide to Inductor, and - it will fail the Inductor code cache lookup. - - By providing a dummy shape environment that always hits, we can - make the Inductor code cache lookup always hit, and we can - compile the graph for different shapes as needed. - - The following dummy methods are obtained by trial-and-error - until it works. - """ - - def __init__(self) -> None: - self.guards: List[Any] = [] - - def evaluate_guards_expression(self, *args, **kwargs): - return True - - def get_pruned_guards(self, *args, **kwargs): - return [] - - def produce_guards_expression(self, *args, **kwargs): - return "" - - -def wrap_inductor(graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - vllm_backend: "VllmBackend", - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None, - use_inductor: bool = True) -> Any: - if graph_index == 0: - # before compiling the first graph, record the start time - global compilation_start_time - compilation_start_time = time.time() - - if not use_inductor: - return graph - - compilation_counter.num_inductor_compilations += 1 - - from torch._inductor import config - current_config = config.get_config_copy() - from torch._inductor.compile_fx import compile_fx - - if additional_inductor_config is not None: - current_config.update(additional_inductor_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True - - # inductor can inplace modify the graph, so we need to copy it - # see https://github.com/pytorch/pytorch/issues/138980 - graph = copy.deepcopy(graph) - - cache_data = vllm_backend.inductor_hash_cache - if (runtime_shape, graph_index) in cache_data: - # we compiled this graph before - # so we can directly lookup the compiled graph via hash - inductor_artifact = cache_data[(runtime_shape, graph_index)] - hash_str = inductor_artifact.hash_str + def compile(self, graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + vllm_backend: "VllmBackend", + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None, + use_inductor: bool = True) -> Any: if graph_index == 0: - # adds some info logging for the first graph - logger.info( - "Directly lookup the graph for shape %s from the cache", - str(runtime_shape)) # noqa - logger.debug( - "directly lookup the %s-th graph for shape %s via hash %s", - graph_index, str(runtime_shape), hash_str) - from torch._inductor.codecache import FxGraphCache - with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv()): - inductor_compiled_graph = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, False) - assert inductor_compiled_graph is not None, ( - "Inductor cache lookup failed. Please remove" - f"the cache file {cache_data.cache_file_path} and try again." # noqa - ) - inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - - # Inductor calling convention (function signature): - # f(list) -> tuple - # Dynamo calling convention (function signature): - # f(*args) -> Any - - # need to know if the graph returns a tuple - from torch._inductor.compile_fx import graph_returns_tuple - returns_tuple = graph_returns_tuple(graph) - - # this is the callable we return to Dynamo to run - def compiled_graph(*args): - # convert args to list - list_args = list(args) - graph_output = inductor_compiled_graph(list_args) - # unpack the tuple if needed - if returns_tuple: - return graph_output + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + if not use_inductor: + return graph + + compilation_counter.num_inductor_compilations += 1 + + compiled_graph = None + + if not self.disabled: + for compiler in self.compilers: + if compiled_graph is not None: + break + if (runtime_shape, graph_index, compiler.name) in self.cache: + try: + handle = self.cache[(runtime_shape, graph_index, compiler.name)] + compiled_graph = compiler.load(handle) + if graph_index == 0: + # adds some info logging for the first graph + logger.info( + "Directly load the compiled graph for shape %s from the cache", + str(runtime_shape)) # noqa + logger.debug( + "Directly load the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), compiler.name, handle) + except Exception as e: + logger.warning( + "Failed to load the compiled graph from the cache. " + "Error: %s", str(e) + ) + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + for compiler in self.compilers: + if compiled_graph is not None: + break + try: + compiled_graph, handle = compiler.compile(graph, example_inputs, + additional_inductor_config, + runtime_shape) + # store the inductor_artifact in the cache + self.cache[(runtime_shape, graph_index, compiler.name)] = handle + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), compiler.name, handle) + except Exception as e: + logger.warning( + "Failed to compile the graph. Error: %s", str(e)) + + assert compiled_graph is not None, "Failed to compile the graph" + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) else: - return graph_output[0] - else: - # it's the first time we compile this graph - # the assumption is that we don't have nested Inductor compilation. - # compiled_fx_graph_hash will only be called once, and we can hook - # it to get the hash of the compiled graph directly. - - inductor_artifact = InductorArtifact() - from torch._inductor.codecache import (FxGraphCache, - compiled_fx_graph_hash) - original_load = FxGraphCache.load - - def hijack_load(*args, **kwargs): - inductor_compiled_graph = original_load(*args, **kwargs) - inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - return inductor_compiled_graph - - def hijack_compiled_fx_graph_hash(*args, **kwargs): - out = compiled_fx_graph_hash(*args, **kwargs) - inductor_artifact.hash_str = out[0] - return out - - def _check_can_cache(*args, **kwargs): - # no error means it can be cached. - # Inductor refuses to cache the graph outside of Dynamo - # tracing context, and also disables caching for graphs - # with high-order ops. - # For vLLM, in either case, we want to cache the graph. - # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa - return - - def _get_shape_env() -> AlwaysHitShapeEnv: - return AlwaysHitShapeEnv() - - with ExitStack() as stack: - if not cache_data.disabled: - # compilation cache is enabled, patch several functions - - # hijack to get the compiled graph itself - stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache.load", - hijack_load)) - - # for hijacking the hash of the compiled graph - stack.enter_context( - patch("torch._inductor.codecache.compiled_fx_graph_hash", - hijack_compiled_fx_graph_hash)) - - # for providing a dummy shape environment - stack.enter_context( - patch( - "torch._inductor.codecache.FxGraphCache._get_shape_env", - _get_shape_env)) - - # for forcing the graph to be cached - stack.enter_context( - patch( - "torch._inductor.codecache.FxGraphCache._check_can_cache", - _check_can_cache)) - - compiled_graph = compile_fx(graph, - example_inputs, - config_patches=current_config) - # store the inductor_artifact in the cache - cache_data[(runtime_shape, graph_index)] = inductor_artifact - if graph_index == 0: - # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug( - "store the %s-th graph for shape %s via hash %s from file %s", - graph_index, str(runtime_shape), inductor_artifact.hash_str, - inductor_artifact.file_path) - # after compiling the last graph, record the end time - if graph_index == num_graphs - 1: - now = time.time() - elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info("Compiling a graph for general shape takes %.2f s", - elapsed) - else: - logger.info("Compiling a graph for shape %s takes %.2f s", - runtime_shape, elapsed) + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) - return compiled_graph + return compiled_graph @dataclasses.dataclass @@ -434,7 +285,7 @@ def call_module(self, target: torch.fx.node.Target, i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_general_shape = wrap_inductor( + compiled_graph_for_general_shape = self.vllm_backend.compiler_manager.compile( submod, args, self.compilation_config.inductor_compile_config, @@ -481,7 +332,7 @@ class VllmBackend: post_grad_passes: Sequence[Callable] sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - inductor_hash_cache: InductorHashCache + compiler_manager: CompilerManager def __init__( self, @@ -569,8 +420,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.local_cache_dir = local_cache_dir disabled = envs.VLLM_DISABLE_COMPILE_CACHE - self.inductor_hash_cache: InductorHashCache = InductorHashCache( - local_cache_dir, disabled=disabled) + compilers = [InductorAdaptor] if self.compilation_config.use_inductor else [EagerAdaptor] + self.compiler_manager: CompilerManager = CompilerManager( + local_cache_dir, compilers, disabled=disabled) if disabled: logger.info("vLLM's torch.compile cache is disabled.") else: @@ -757,7 +609,7 @@ def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run - self.vllm_backend.inductor_hash_cache.save_to_file() + self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: @@ -780,7 +632,7 @@ def __call__(self, *args) -> Any: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = wrap_inductor( + entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, diff --git a/vllm/config.py b/vllm/config.py index f4548c4466e48..2d61b2fb8cfc6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3008,15 +3008,6 @@ def compute_hash(self) -> str: the final hidden states. """ factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) # summarize vllm config vllm_factors: List[Any] = [] From a10d23e4b8a169b578d3089f3928622f060f9b9b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 16:24:14 +0800 Subject: [PATCH 02/21] refactor Signed-off-by: youkaichao --- vllm/compilation/backends.py | 152 +++++-------- vllm/compilation/compiler_interface.py | 302 +++++++++++++++++++++++++ 2 files changed, 364 insertions(+), 90 deletions(-) create mode 100644 vllm/compilation/compiler_interface.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e65640408557b..d87bae2008f66 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,13 +1,10 @@ import ast -import copy import dataclasses import os import pprint import time -from collections import defaultdict from contextlib import ExitStack -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, - Type) +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch import torch @@ -18,8 +15,7 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors -from .compiler_interface import (CompilerInterface, EagerAdaptor, - InductorAdaptor) +from .compiler_interface import EagerAdaptor, InductorAdaptor from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -28,12 +24,6 @@ logger = init_logger(__name__) -@dataclasses.dataclass -class InductorArtifact: - hash_str: str = "" - file_path: str = "" - - class CompilerManager: """ A dict mapping `(runtime_shape, graph_index, backend_name)` @@ -43,18 +33,19 @@ class CompilerManager: support int as key. """ - def __init__(self, cache_dir: str, compilers: List[Type[CompilerInterface]], disabled: bool = False): + def __init__(self, + cache_dir: str, + use_inductor: bool, + disabled: bool = False): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - self.compilers = compilers + self.compiler = InductorAdaptor() if use_inductor else EagerAdaptor() self.disabled = disabled self.cache_dir = cache_dir - self.cache_file_path = os.path.join(cache_dir, - "compiler_manager.py") + self.cache_file_path = os.path.join(cache_dir, "compiler_manager.py") if disabled: return - for compiler in self.compilers: - compiler.init_with_cache_dir(cache_dir) + self.compiler.init_with_cache_dir(cache_dir) if os.path.exists(self.cache_file_path): with open(self.cache_file_path) as f: @@ -67,6 +58,7 @@ def deserialize(self, data: str): self.cache = ast.literal_eval(data) def serialize(self) -> str: + printer = pprint.PrettyPrinter(indent=4) return printer.pformat(self.cache) def save_to_file(self): @@ -75,36 +67,32 @@ def save_to_file(self): with open(self.cache_file_path, "w") as f: f.write(self.serialize()) - def __contains__(self, key: Tuple[Optional[int], int]) -> bool: - if self.disabled: - return False - for compiler in self.compilers: - if (key[0], key[1], compiler.name) in self.cache: - return True - return False + def load(self, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + return None + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, runtime_shape) + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), self.compiler.name, + handle) + return compiled_graph - def __getitem__(self, key: Tuple[Optional[int], int]) -> Any: - if self.disabled: - raise KeyError("cannot read from disabled cache") - runtime_shape, graph_index = key - return self.cache[runtime_shape][graph_index] - - def __setitem__(self, key: Tuple[Optional[int], int], - value: InductorArtifact): - # setitem for disabled cache is fine, because we - # don't actually write to the disk - runtime_shape, graph_index = key - self.cache[runtime_shape][graph_index] = value - - def compile(self, graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - vllm_backend: "VllmBackend", - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None, - use_inductor: bool = True) -> Any: + def compile(self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + vllm_backend: "VllmBackend", + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None, + use_inductor: bool = True) -> Any: if graph_index == 0: # before compiling the first graph, record the start time global compilation_start_time @@ -118,51 +106,33 @@ def compile(self, graph: fx.GraphModule, compiled_graph = None if not self.disabled: - for compiler in self.compilers: - if compiled_graph is not None: - break - if (runtime_shape, graph_index, compiler.name) in self.cache: - try: - handle = self.cache[(runtime_shape, graph_index, compiler.name)] - compiled_graph = compiler.load(handle) - if graph_index == 0: - # adds some info logging for the first graph - logger.info( - "Directly load the compiled graph for shape %s from the cache", - str(runtime_shape)) # noqa - logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), compiler.name, handle) - except Exception as e: - logger.warning( - "Failed to load the compiled graph from the cache. " - "Error: %s", str(e) - ) - - # no compiler cached the graph, or the cache is disabled, - # we need to compile it - for compiler in self.compilers: + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) if compiled_graph is not None: - break - try: - compiled_graph, handle = compiler.compile(graph, example_inputs, - additional_inductor_config, - runtime_shape) - # store the inductor_artifact in the cache - self.cache[(runtime_shape, graph_index, compiler.name)] = handle if graph_index == 0: # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug( - "store the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), compiler.name, handle) - except Exception as e: - logger.warning( - "Failed to compile the graph. Error: %s", str(e)) - + logger.info( + "Directly load the compiled graph for shape %s " + "from the cache", str(runtime_shape)) # noqa + return compiled_graph + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, runtime_shape) + assert compiled_graph is not None, "Failed to compile the graph" + # store the artifact in the cache + self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) + # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() @@ -285,7 +255,8 @@ def call_module(self, target: torch.fx.node.Target, i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_general_shape = self.vllm_backend.compiler_manager.compile( + compiled_graph_for_general_shape = self.vllm_backend.\ + compiler_manager.compile( submod, args, self.compilation_config.inductor_compile_config, @@ -420,9 +391,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.local_cache_dir = local_cache_dir disabled = envs.VLLM_DISABLE_COMPILE_CACHE - compilers = [InductorAdaptor] if self.compilation_config.use_inductor else [EagerAdaptor] self.compiler_manager: CompilerManager = CompilerManager( - local_cache_dir, compilers, disabled=disabled) + local_cache_dir, + self.compilation_config.use_inductor, + disabled=disabled) if disabled: logger.info("vLLM's torch.compile cache is disabled.") else: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py new file mode 100644 index 0000000000000..911c86614f2cb --- /dev/null +++ b/vllm/compilation/compiler_interface.py @@ -0,0 +1,302 @@ +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import patch + +import torch.fx as fx + +from vllm.config import VllmConfig + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + Gather all the relevant information from the VLLM config, + to compute a hash so that we can cache the compiled model. + + See :meth:`VllmConfig.compute_hash` to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + pass + + def init_with_cache_dir(self, cache_dir: str) -> None: + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + """ + pass + + def compile(self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None) -> Tuple[Callable, Any]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + """ + pass + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + pass + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: List[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +class InductorAdaptor(CompilerInterface): + name = "inductor" + dynamic_shape = True + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] + return hash_str + + def init_with_cache_dir(self, cache_dir: str) -> None: + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile(self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None) -> Tuple[Callable, Any]: + from torch._inductor import config + current_config = config.get_config_copy() + from torch._inductor.compile_fx import compile_fx + + if compiler_config is not None: + current_config.update(compiler_config) + + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + current_config["max_autotune"] = True + current_config["coordinate_descent_tuning"] = True + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + original_load = FxGraphCache.load + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + return inductor_compiled_graph + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache.load", + hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + + compiled_graph = compile_fx(graph, + example_inputs, + config_patches=current_config) + + assert hash_str is not None, ( + "failed to get the hash of the compiled graph") + return compiled_graph, (hash_str, file_path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._inductor.codecache import FxGraphCache + with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv()): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + +class EagerAdaptor(CompilerInterface): + name = "eager" + dynamic_shape = True + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + We don't need to cache the compiled model for the eager compiler, + which just runs the graph directly.""" + return "" + + def compile(self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None) -> Tuple[Callable, Any]: + # we don't need to compile the graph, just return the graph itself + return graph, None + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + # handle is None, we don't need to load anything + raise NotImplementedError( + "Eager compiler doesn't need compilation cache") From 5559d416450b4eb5d6c206a253ba3cf22c873138 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 16:28:59 +0800 Subject: [PATCH 03/21] rename variable Signed-off-by: youkaichao --- vllm/compilation/backends.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d87bae2008f66..f4dc594ed7c96 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -36,13 +36,13 @@ class CompilerManager: def __init__(self, cache_dir: str, use_inductor: bool, - disabled: bool = False): + disable_cache: bool = False): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() self.compiler = InductorAdaptor() if use_inductor else EagerAdaptor() - self.disabled = disabled + self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "compiler_manager.py") - if disabled: + if disable_cache: return self.compiler.init_with_cache_dir(cache_dir) @@ -62,7 +62,7 @@ def serialize(self) -> str: return printer.pformat(self.cache) def save_to_file(self): - if self.disabled: + if self.disable_cache: return with open(self.cache_file_path, "w") as f: f.write(self.serialize()) @@ -105,7 +105,7 @@ def compile(self, compiled_graph = None - if not self.disabled: + if not self.disable_cache: compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) if compiled_graph is not None: @@ -390,12 +390,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: cache_dir, f"rank_{vllm_config.parallel_config.rank}") self.compilation_config.local_cache_dir = local_cache_dir - disabled = envs.VLLM_DISABLE_COMPILE_CACHE + disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE self.compiler_manager: CompilerManager = CompilerManager( local_cache_dir, self.compilation_config.use_inductor, - disabled=disabled) - if disabled: + disable_cache=disable_cache) + if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: logger.info("Using cache directory: %s for vLLM's torch.compile", From ecff8c793cc4e88fcbadb5ab26ff4ad678ea0942 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 16:34:02 +0800 Subject: [PATCH 04/21] simplify functions Signed-off-by: youkaichao --- vllm/compilation/backends.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f4dc594ed7c96..90f77a8e1f024 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -88,19 +88,15 @@ def compile(self, example_inputs, additional_inductor_config, compilation_config: CompilationConfig, - vllm_backend: "VllmBackend", graph_index: int = 0, num_graphs: int = 1, - runtime_shape: Optional[int] = None, - use_inductor: bool = True) -> Any: + runtime_shape: Optional[int] = None) -> Any: if graph_index == 0: # before compiling the first graph, record the start time global compilation_start_time compilation_start_time = time.time() - if not use_inductor: - return graph - + # TODO: rename to num_backend_compilations compilation_counter.num_inductor_compilations += 1 compiled_graph = None @@ -261,11 +257,9 @@ def call_module(self, target: torch.fx.node.Target, args, self.compilation_config.inductor_compile_config, self.compilation_config, - self.vllm_backend, graph_index=index, num_graphs=len(self.compile_submod_names), - runtime_shape=None, - use_inductor=self.compilation_config.use_inductor) + runtime_shape=None) self.module.__dict__[target] = PiecewiseBackend( submod, self.vllm_config, self.graph_pool, index, @@ -609,11 +603,9 @@ def __call__(self, *args) -> Any: args, self.compilation_config.inductor_compile_config, self.compilation_config, - self.vllm_backend, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, - use_inductor=self.compilation_config.use_inductor) + runtime_shape=runtime_shape) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: From 30dc83fdf234df1e01af3d2b3550c39758fdcfe2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 16:47:50 +0800 Subject: [PATCH 05/21] simplify Signed-off-by: youkaichao --- vllm/compilation/backends.py | 77 +++++++++++++------------- vllm/compilation/compiler_interface.py | 52 +++++++++++------ 2 files changed, 74 insertions(+), 55 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 90f77a8e1f024..d48fced7027c8 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -26,8 +26,14 @@ class CompilerManager: """ - A dict mapping `(runtime_shape, graph_index, backend_name)` - to `any_data`. + A manager to manage the compilation process, including + caching the compiled graph, loading the compiled graph, + and compiling the graph. + + The cache is a dict mapping + `(runtime_shape, graph_index, backend_name)` + to `any_data` returned from the compiler. + When serializing the cache, we save it to a Python file for readability. We don't use json here because json doesn't support int as key. @@ -38,40 +44,36 @@ def __init__(self, use_inductor: bool, disable_cache: bool = False): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - self.compiler = InductorAdaptor() if use_inductor else EagerAdaptor() + self.compiler = InductorAdaptor( + cache_dir=cache_dir, + disable_cache=disable_cache) if use_inductor else EagerAdaptor( + cache_dir=cache_dir, disable_cache=disable_cache) self.disable_cache = disable_cache self.cache_dir = cache_dir - self.cache_file_path = os.path.join(cache_dir, "compiler_manager.py") + self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") if disable_cache: return - self.compiler.init_with_cache_dir(cache_dir) - if os.path.exists(self.cache_file_path): with open(self.cache_file_path) as f: - self.deserialize(f.read()) - - def deserialize(self, data: str): - # we use ast.literal_eval to parse the data - # because it is a safe way to parse Python literals. - # do not use eval(), it is unsafe. - self.cache = ast.literal_eval(data) - - def serialize(self) -> str: - printer = pprint.PrettyPrinter(indent=4) - return printer.pformat(self.cache) + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + self.cache = ast.literal_eval(f.read()) def save_to_file(self): if self.disable_cache: return with open(self.cache_file_path, "w") as f: - f.write(self.serialize()) + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + f.write(data) def load(self, graph: fx.GraphModule, example_inputs: List[Any], graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + runtime_shape: Optional[int] = None) -> Optional[Callable]: if (runtime_shape, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] @@ -101,16 +103,15 @@ def compile(self, compiled_graph = None - if not self.disable_cache: - compiled_graph = self.load(graph, example_inputs, graph_index, - runtime_shape) - if compiled_graph is not None: - if graph_index == 0: - # adds some info logging for the first graph - logger.info( - "Directly load the compiled graph for shape %s " - "from the cache", str(runtime_shape)) # noqa - return compiled_graph + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) + if compiled_graph is not None: + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Directly load the compiled graph for shape %s " + "from the cache", str(runtime_shape)) # noqa + return compiled_graph # no compiler cached the graph, or the cache is disabled, # we need to compile it @@ -120,14 +121,16 @@ def compile(self, assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle - if graph_index == 0: - # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug( - "store the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), self.compiler.name, handle) + if handle is not None: + self.cache[(runtime_shape, graph_index, + self.compiler.name)] = handle + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 911c86614f2cb..3de0ef7946731 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -18,6 +18,9 @@ class CompilerInterface: # This is a class-level attribute. name: str + def __init__(self, cache_dir: str, disable_cache: bool = False): + pass + def compute_hash(self, vllm_config: VllmConfig) -> str: """ Gather all the relevant information from the VLLM config, @@ -37,11 +40,13 @@ def init_with_cache_dir(self, cache_dir: str) -> None: """ pass - def compile(self, - graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None) -> Tuple[Callable, Any]: + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, with a runtime shape. If the `runtime_shape` is None, it means @@ -57,6 +62,10 @@ def compile(self, The handle should be a plain Python object, preferably a string or a file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. """ pass @@ -131,7 +140,9 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] return hash_str - def init_with_cache_dir(self, cache_dir: str) -> None: + def __init__(self, cache_dir: str, disable_cache: bool = False): + if disable_cache: + return # redirect the cache directory to a sub-directory # set flags so that Inductor and Triton store their cache # in the cache_dir, then users only need to copy the cache_dir @@ -143,11 +154,13 @@ def init_with_cache_dir(self, cache_dir: str) -> None: os.makedirs(triton_cache, exist_ok=True) os.environ["TRITON_CACHE_DIR"] = triton_cache - def compile(self, - graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None) -> Tuple[Callable, Any]: + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: from torch._inductor import config current_config = config.get_config_copy() from torch._inductor.compile_fx import compile_fx @@ -283,12 +296,15 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: which just runs the graph directly.""" return "" - def compile(self, - graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None) -> Tuple[Callable, Any]: - # we don't need to compile the graph, just return the graph itself + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + # we don't need to compile the graph, just return the graph itself. + # It does not support caching, return None for the handle. return graph, None def load(self, @@ -299,4 +315,4 @@ def load(self, runtime_shape: Optional[int] = None) -> Callable: # handle is None, we don't need to load anything raise NotImplementedError( - "Eager compiler doesn't need compilation cache") + "Eager compiler doesn't support compilation cache") From 71655166299ad2d1855b78a1baadf75f04c53b8f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 16:49:25 +0800 Subject: [PATCH 06/21] simplify Signed-off-by: youkaichao --- vllm/compilation/backends.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d48fced7027c8..8c3e1e3b304ce 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -44,10 +44,8 @@ def __init__(self, use_inductor: bool, disable_cache: bool = False): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - self.compiler = InductorAdaptor( - cache_dir=cache_dir, - disable_cache=disable_cache) if use_inductor else EagerAdaptor( - cache_dir=cache_dir, disable_cache=disable_cache) + cls = InductorAdaptor if use_inductor else EagerAdaptor + self.compiler = cls(cache_dir=cache_dir, disable_cache=disable_cache) self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") From 5f010a70d3b23549942fe1fcb571cdf757c0088d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 16:50:09 +0800 Subject: [PATCH 07/21] simplify Signed-off-by: youkaichao --- vllm/compilation/backends.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8c3e1e3b304ce..354f184609468 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -49,10 +49,9 @@ def __init__(self, self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") - if disable_cache: - return - if os.path.exists(self.cache_file_path): + if not disable_cache and os.path.exists(self.cache_file_path): + # load the cache from the file with open(self.cache_file_path) as f: # we use ast.literal_eval to parse the data # because it is a safe way to parse Python literals. From a76b6fbaad1c2a3a512ff15c9e8227cafdcde969 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 17:05:45 +0800 Subject: [PATCH 08/21] use hash from compiler Signed-off-by: youkaichao --- vllm/compilation/backends.py | 38 ++++++++++++++++++-------- vllm/compilation/compiler_interface.py | 4 +-- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 354f184609468..576b5b1783ac1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -39,13 +39,15 @@ class CompilerManager: support int as key. """ - def __init__(self, - cache_dir: str, - use_inductor: bool, - disable_cache: bool = False): + def __init__(self, use_inductor: bool): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() cls = InductorAdaptor if use_inductor else EagerAdaptor - self.compiler = cls(cache_dir=cache_dir, disable_cache=disable_cache) + self.compiler = cls() + + def compute_hash(self, vllm_config: VllmConfig) -> str: + return self.compiler.compute_hash(vllm_config) + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") @@ -58,6 +60,9 @@ def __init__(self, # do not use eval(), it is unsafe. self.cache = ast.literal_eval(f.read()) + self.compiler.initialize_cache(cache_dir=cache_dir, + disable_cache=disable_cache) + def save_to_file(self): if self.disable_cache: return @@ -321,6 +326,9 @@ def __init__( self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config.use_inductor) + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -347,9 +355,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # the cache dir will be the same so that we can reuse the compiled # graph. + factors = [] # 1. factors come from the vllm_config (it mainly summarizes how the # model is created) config_hash = vllm_config.compute_hash() + factors.append(config_hash) # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) @@ -367,10 +377,15 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: import hashlib code_hash = hashlib.md5( "\n".join(hash_content).encode()).hexdigest() + factors.append(code_hash) + + # 3. compiler hash + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + factors.append(compiler_hash) + + # combine all factors to generate the cache dir + hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10] - # combine the two hashes to generate the cache dir - hash_key = hashlib.md5( - f"{config_hash}_{code_hash}".encode()).hexdigest()[:10] cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, "torch_compile_cache", @@ -385,16 +400,15 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.local_cache_dir = local_cache_dir disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE - self.compiler_manager: CompilerManager = CompilerManager( - local_cache_dir, - self.compilation_config.use_inductor, - disable_cache=disable_cache) + if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: logger.info("Using cache directory: %s for vLLM's torch.compile", local_cache_dir) + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) + # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3de0ef7946731..6871ce62ae59d 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -18,7 +18,7 @@ class CompilerInterface: # This is a class-level attribute. name: str - def __init__(self, cache_dir: str, disable_cache: bool = False): + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): pass def compute_hash(self, vllm_config: VllmConfig) -> str: @@ -140,7 +140,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] return hash_str - def __init__(self, cache_dir: str, disable_cache: bool = False): + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): if disable_cache: return # redirect the cache directory to a sub-directory From fcd4128de61fce6ce5b0c78bfdae5194c290c055 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 17:07:40 +0800 Subject: [PATCH 09/21] rename Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 6 +++--- vllm/compilation/backends.py | 3 +-- vllm/compilation/counter.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index aa11524812cdd..d2d8a6fe7304f 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -91,7 +91,7 @@ def test_simple_piecewise_compile(): num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=5, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=3, # 1 + num_layers - num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen + num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index d4ede4d2320a7..92da74a171e7d 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -321,7 +321,7 @@ def test_toy_llama(): num_graphs_seen=0, num_piecewise_graphs_seen=0, num_piecewise_capturable_graphs_seen=0, - num_inductor_compilations=0, + num_backend_compilations=0, num_cudagraph_caputured=0, ): outputs.append(run_model(llama_config, use_compile=False)) @@ -331,7 +331,7 @@ def test_toy_llama(): num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=1, num_piecewise_capturable_graphs_seen=1, - num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen + num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): @@ -344,7 +344,7 @@ def test_toy_llama(): 1, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=1 + llama_config.num_layers, # 1 + num_layers - num_inductor_compilations=1 + + num_backend_compilations=1 + llama_config.num_layers, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured=2 * (1 + llama_config.num_layers diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 576b5b1783ac1..2f18d10c682ef 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -100,8 +100,7 @@ def compile(self, global compilation_start_time compilation_start_time = time.time() - # TODO: rename to num_backend_compilations - compilation_counter.num_inductor_compilations += 1 + compilation_counter.num_backend_compilations += 1 compiled_graph = None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 6385f1c5dbf81..932795c86477f 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -11,7 +11,7 @@ class CompilationCounter: num_piecewise_graphs_seen: int = 0 # not including the splitting ops num_piecewise_capturable_graphs_seen: int = 0 - num_inductor_compilations: int = 0 + num_backend_compilations: int = 0 num_cudagraph_caputured: int = 0 def clone(self) -> "CompilationCounter": From 788014ef174bc54d5b66c1d0da6bdcb6e159a522 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 17:17:34 +0800 Subject: [PATCH 10/21] clean up Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 43 +++++++++----------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 6871ce62ae59d..33d0a550bfec6 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -19,6 +19,11 @@ class CompilerInterface: name: str def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + """ pass def compute_hash(self, vllm_config: VllmConfig) -> str: @@ -30,15 +35,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: is already considered by default. This function should only consider the information that is specific to the compiler. """ - pass - - def init_with_cache_dir(self, cache_dir: str) -> None: - """ - when the vLLM process uses `cache_dir` as the cache directory, - the compiler should initialize itself with the cache directory, - e.g. by re-directing its own cache directory to a sub-directory. - """ - pass + return "" def compile( self, @@ -67,7 +64,7 @@ def compile( handle. If the compiler fails to compile the graph, it should return None for the compiled function as well. """ - pass + return None, None def load(self, handle: Any, @@ -81,7 +78,7 @@ def load(self, The handle is the second return value of the `compile` function. """ - pass + raise NotImplementedError("caching is not supported") class AlwaysHitShapeEnv: @@ -122,9 +119,11 @@ def produce_guards_expression(self, *args, **kwargs): return "" -class InductorAdaptor(CompilerInterface): +class Inductor25Adaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5. + """ name = "inductor" - dynamic_shape = True def compute_hash(self, vllm_config: VllmConfig) -> str: factors: List[Any] = [] @@ -288,13 +287,6 @@ def compiled_graph(*args): class EagerAdaptor(CompilerInterface): name = "eager" - dynamic_shape = True - - def compute_hash(self, vllm_config: VllmConfig) -> str: - """ - We don't need to cache the compiled model for the eager compiler, - which just runs the graph directly.""" - return "" def compile( self, @@ -307,12 +299,5 @@ def compile( # It does not support caching, return None for the handle. return graph, None - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: List[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: - # handle is None, we don't need to load anything - raise NotImplementedError( - "Eager compiler doesn't support compilation cache") + +InductorAdaptor = Inductor25Adaptor From 303276270660c18047caff8c22a79735ce7e9eda Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 20:34:47 +0800 Subject: [PATCH 11/21] fix load Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 32 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 33d0a550bfec6..9c6b704ece463 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch +import torch import torch.fx as fx from vllm.config import VllmConfig @@ -185,13 +186,28 @@ def compile( hash_str, file_path = None, None from torch._inductor.codecache import (FxGraphCache, compiled_fx_graph_hash) - original_load = FxGraphCache.load - def hijack_load(*args, **kwargs): - inductor_compiled_graph = original_load(*args, **kwargs) - nonlocal file_path - file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - return inductor_compiled_graph + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + return inductor_compiled_graph + elif torch.__version__.startswith("2.6"): + # function renamed in 2.6 + original_load = FxGraphCache.load_with_key + original_load_name = ("torch._inductor.codecache" + ".FxGraphCache.load_with_key") + + def hijack_load(*args, **kwargs): + # it returns a tuple, we only need the first element + inductor_compiled_graph, _ = original_load(*args, **kwargs) + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + return inductor_compiled_graph, _ def hijack_compiled_fx_graph_hash(*args, **kwargs): out = compiled_fx_graph_hash(*args, **kwargs) @@ -213,9 +229,7 @@ def _get_shape_env() -> AlwaysHitShapeEnv: with ExitStack() as stack: # hijack to get the compiled graph itself - stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache.load", - hijack_load)) + stack.enter_context(patch(original_load_name, hijack_load)) # for hijacking the hash of the compiled graph stack.enter_context( From 2dde996ffb15aee40fd514cef2ccbd25f58f4965 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 20:35:31 +0800 Subject: [PATCH 12/21] fix load Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 9c6b704ece463..9031aa28de2de 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -196,7 +196,7 @@ def hijack_load(*args, **kwargs): nonlocal file_path file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa return inductor_compiled_graph - elif torch.__version__.startswith("2.6"): + elif torch.__version__ >= "2.6": # function renamed in 2.6 original_load = FxGraphCache.load_with_key original_load_name = ("torch._inductor.codecache" From d3139d8ba772597b7f28409e0b4f4a4e51acc91d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 20:39:17 +0800 Subject: [PATCH 13/21] add parent class Signed-off-by: youkaichao --- vllm/compilation/inductor_pass.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index f6846c08ac841..b0ba9e579ee5c 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -8,10 +8,19 @@ from torch import fx -class InductorPass(ABC): +class PlaceHolder: + pass + + +if torch.__version__ < "2.6": + parent = PlaceHolder +else: + parent = torch._inductor.custom_graph_pass.CustomGraphPass + + +class InductorPass(ABC, parent): """ General custom inductor pass interface. - TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass """ @abstractmethod From 0e945702539167703eac7e2b8e1b1f31ea4716b2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:00:34 +0800 Subject: [PATCH 14/21] fix compiler interface Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 9031aa28de2de..2e2d563da5ce7 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -198,16 +198,18 @@ def hijack_load(*args, **kwargs): return inductor_compiled_graph elif torch.__version__ >= "2.6": # function renamed in 2.6 - original_load = FxGraphCache.load_with_key + original_load = FxGraphCache._save_graph original_load_name = ("torch._inductor.codecache" - ".FxGraphCache.load_with_key") + ".FxGraphCache._save_graph") def hijack_load(*args, **kwargs): - # it returns a tuple, we only need the first element - inductor_compiled_graph, _ = original_load(*args, **kwargs) + output = original_load(*args, **kwargs) nonlocal file_path + nonlocal hash_str + inductor_compiled_graph = args[1] + hash_str = args[0] file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - return inductor_compiled_graph, _ + return output def hijack_compiled_fx_graph_hash(*args, **kwargs): out = compiled_fx_graph_hash(*args, **kwargs) @@ -253,6 +255,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv: assert hash_str is not None, ( "failed to get the hash of the compiled graph") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") return compiled_graph, (hash_str, file_path) def load(self, From 8a03aa921d81ceb75b496995b601820bb17f0234 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:02:14 +0800 Subject: [PATCH 15/21] pass manager Signed-off-by: youkaichao --- vllm/compilation/pass_manager.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 34f5f355798b2..e79fce5bb9a9c 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List +import torch from torch import fx as fx from vllm.config import CompilationConfig @@ -13,7 +14,17 @@ logger = init_logger(__name__) -class PostGradPassManager: +class PlaceHolder: + pass + + +if torch.__version__ < "2.6": + Parent = PlaceHolder +else: + Parent = torch._inductor.custom_graph_pass.CustomGraphPass + + +class PostGradPassManager(Parent): """ The pass manager for post-grad passes. It handles configuration, adding custom passes, and running passes. @@ -53,6 +64,9 @@ def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) self.passes.append(pass_) + def uuid(self): + return self.__getstate__() + def __getstate__(self) -> Dict[str, List[Any]]: """ Custom pickling for the pass manager, as some passes cannot be pickled. From babedb126d121565bf2ab4af199d62c0ed84183f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:03:16 +0800 Subject: [PATCH 16/21] revert inductor pass Signed-off-by: youkaichao --- vllm/compilation/inductor_pass.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b0ba9e579ee5c..f6846c08ac841 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -8,19 +8,10 @@ from torch import fx -class PlaceHolder: - pass - - -if torch.__version__ < "2.6": - parent = PlaceHolder -else: - parent = torch._inductor.custom_graph_pass.CustomGraphPass - - -class InductorPass(ABC, parent): +class InductorPass(ABC): """ General custom inductor pass interface. + TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass """ @abstractmethod From eed15a17d5b993f1d18ff95d4ee65f91af00baa0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:03:27 +0800 Subject: [PATCH 17/21] revert inductor pass Signed-off-by: youkaichao --- vllm/compilation/inductor_pass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index f6846c08ac841..a548348ba2ae8 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -11,7 +11,6 @@ class InductorPass(ABC): """ General custom inductor pass interface. - TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass """ @abstractmethod From 276090b954fc52e82c6a97fe24b15e76ce790037 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:11:51 +0800 Subject: [PATCH 18/21] load fix Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 2e2d563da5ce7..324dda3151231 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -273,12 +273,23 @@ def load(self, from torch._inductor.codecache import FxGraphCache with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", lambda *args, **kwargs: AlwaysHitShapeEnv()): - inductor_compiled_graph = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, False) - assert inductor_compiled_graph is not None, ( - "Inductor cache lookup failed. Please remove" - f"the cache directory and try again." # noqa - ) + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import ( + CompiledFxGraphConstantsWithGm) + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) # Inductor calling convention (function signature): # f(list) -> tuple From 9128b54fdb96b7f8dd4933a503d214bb037700d3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:12:25 +0800 Subject: [PATCH 19/21] rename Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 324dda3151231..910b1ff775a2c 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -120,9 +120,9 @@ def produce_guards_expression(self, *args, **kwargs): return "" -class Inductor25Adaptor(CompilerInterface): +class InductorAdaptor(CompilerInterface): """ - The adaptor for the Inductor compiler, version 2.5. + The adaptor for the Inductor compiler, version 2.5 and 2.6. """ name = "inductor" @@ -327,6 +327,3 @@ def compile( # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. return graph, None - - -InductorAdaptor = Inductor25Adaptor From fa3114468f1edab55033431d65043073a34ac2db Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:32:50 +0800 Subject: [PATCH 20/21] load Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 910b1ff775a2c..aa9ee2e61b6d7 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -6,6 +6,7 @@ from unittest.mock import patch import torch +import torch._inductor.compile_fx import torch.fx as fx from vllm.config import VllmConfig @@ -196,19 +197,21 @@ def hijack_load(*args, **kwargs): nonlocal file_path file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa return inductor_compiled_graph + + hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa elif torch.__version__ >= "2.6": # function renamed in 2.6 - original_load = FxGraphCache._save_graph - original_load_name = ("torch._inductor.codecache" - ".FxGraphCache._save_graph") + original_load_name = None - def hijack_load(*args, **kwargs): - output = original_load(*args, **kwargs) - nonlocal file_path + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner( + *args, **kwargs) nonlocal hash_str - inductor_compiled_graph = args[1] - hash_str = args[0] - file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + hash_str = inductor_compiled_graph._fx_graph_cache_key return output def hijack_compiled_fx_graph_hash(*args, **kwargs): @@ -231,7 +234,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv: with ExitStack() as stack: # hijack to get the compiled graph itself - stack.enter_context(patch(original_load_name, hijack_load)) + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) # for hijacking the hash of the compiled graph stack.enter_context( @@ -249,9 +253,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv: "torch._inductor.codecache.FxGraphCache._check_can_cache", _check_can_cache)) - compiled_graph = compile_fx(graph, - example_inputs, - config_patches=current_config) + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) assert hash_str is not None, ( "failed to get the hash of the compiled graph") From d636fbc5ff41828f664b39844b30031d1b0596bf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 21:45:58 +0800 Subject: [PATCH 21/21] fix types Signed-off-by: youkaichao --- vllm/compilation/pass_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e79fce5bb9a9c..7c83d1619f2f6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -19,9 +19,9 @@ class PlaceHolder: if torch.__version__ < "2.6": - Parent = PlaceHolder + Parent = PlaceHolder # type: ignore else: - Parent = torch._inductor.custom_graph_pass.CustomGraphPass + Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore class PostGradPassManager(Parent):