Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow deleting rules in plugins to override behavior #21318

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/docs/writing-plugins/the-rules-api/delete-rules-advanced.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
---
title: Delete rules (advanced)
sidebar_position: 8
---

Delete builtin rules and define your own to override behavior.

---

Sometimes you might want to change the behavior of some existing backend in a way that is not supported by the backend. `DeleteRule` allows you to remove a single rule from the backend, then you can implement your own version of the same rule.

For example, you might want to change the behavior of the `list` goal, here is how you do it:


```python title="pants-plugins/custom_list/register.py"
from pants.backend.project_info.list_targets import List, ListSubsystem
from pants.backend.project_info.list_targets import list_targets as original_rule
from pants.engine.addresses import Addresses
from pants.engine.console import Console
from pants.engine.rules import DeleteRule, collect_rules, goal_rule


@goal_rule
async def list_targets(addresses: Addresses, list_subsystem: ListSubsystem, console: Console) -> List:
with list_subsystem.line_oriented(console) as print_stdout:
print_stdout("ha cool")
return List(exit_code=0)


def target_types():
return []


def rules():
return (
*collect_rules(),
DeleteRule.create(original_rule),
)
```

```python title="pants.toml"
[GLOBAL]
...
backend_packages = [
...
"custom_list",
]
```

```python title="pants-plugins/custom_list/__init__.py"
```
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: Logging and dynamic output
sidebar_position: 8
sidebar_position: 9
---

How to add logging and influence the dynamic UI.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: Testing plugins
sidebar_position: 9
sidebar_position: 10
---

How to verify your plugin works.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: Tips and debugging
sidebar_position: 10
sidebar_position: 11
---

---
Expand Down
1 change: 1 addition & 0 deletions docs/notes/2.24.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Pants will now warn if any errors are encountered while fingerprinting candidate

### Plugin API changes

Plugins may now use `DeleteRule` to delete rules from other backends to override behavior.


## Full Changelog
Expand Down
19 changes: 13 additions & 6 deletions src/python/pants/build_graph/build_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, DefaultDict
from typing import Any, Callable, DefaultDict, FrozenSet

from pants.backend.project_info.filter_targets import FilterSubsystem
from pants.build_graph.build_file_aliases import BuildFileAliases
from pants.core.util_rules.environments import EnvironmentsSubsystem
from pants.engine.goal import GoalSubsystem
from pants.engine.rules import Rule, RuleIndex
from pants.engine.rules import DeleteRule, Rule, RuleIndex
from pants.engine.target import Target
from pants.engine.unions import UnionRule
from pants.option.alias import CliOptions
Expand Down Expand Up @@ -60,6 +60,7 @@ class BuildConfiguration:
subsystem_to_providers: FrozenDict[type[Subsystem], tuple[str, ...]]
target_type_to_providers: FrozenDict[type[Target], tuple[str, ...]]
rule_to_providers: FrozenDict[Rule, tuple[str, ...]]
delete_rules: FrozenSet[DeleteRule]
union_rule_to_providers: FrozenDict[UnionRule, tuple[str, ...]]
allow_unknown_options: bool
remote_auth_plugin_func: Callable | None
Expand Down Expand Up @@ -134,6 +135,7 @@ class Builder:
default_factory=lambda: defaultdict(list)
)
_rule_to_providers: dict[Rule, list[str]] = field(default_factory=lambda: defaultdict(list))
_delete_rules: set[DeleteRule] = field(default_factory=set)
_union_rule_to_providers: dict[UnionRule, list[str]] = field(
default_factory=lambda: defaultdict(list)
)
Expand Down Expand Up @@ -189,9 +191,9 @@ def _register_exposed_context_aware_object_factory(
"Overwriting!".format(alias)
)

self._exposed_context_aware_object_factory_by_alias[
alias
] = context_aware_object_factory
self._exposed_context_aware_object_factory_by_alias[alias] = (
context_aware_object_factory
)

def register_subsystems(
self, plugin_or_backend: str, subsystems: Iterable[type[Subsystem]]
Expand All @@ -215,7 +217,9 @@ def register_subsystems(
for subsystem in subsystems:
self._subsystem_to_providers[subsystem].append(plugin_or_backend)

def register_rules(self, plugin_or_backend: str, rules: Iterable[Rule | UnionRule]):
def register_rules(
self, plugin_or_backend: str, rules: Iterable[Rule | UnionRule | DeleteRule]
):
"""Registers the given rules."""
if not isinstance(rules, Iterable):
raise TypeError(f"The rules must be an iterable, given {rules!r}")
Expand All @@ -225,6 +229,8 @@ def register_rules(self, plugin_or_backend: str, rules: Iterable[Rule | UnionRul
rules_and_queries: tuple[Rule, ...] = (*rule_index.rules, *rule_index.queries)
for rule in rules_and_queries:
self._rule_to_providers[rule].append(plugin_or_backend)
for delete_rule in rule_index.delete_rules:
self._delete_rules.add(delete_rule)
for union_rule in rule_index.union_rules:
self._union_rule_to_providers[union_rule].append(plugin_or_backend)
self.register_subsystems(
Expand Down Expand Up @@ -311,6 +317,7 @@ def create(self) -> BuildConfiguration:
rule_to_providers=FrozenDict(
(k, tuple(v)) for k, v in self._rule_to_providers.items()
),
delete_rules=frozenset(self._delete_rules),
union_rule_to_providers=FrozenDict(
(k, tuple(v)) for k, v in self._union_rule_to_providers.items()
),
Expand Down
117 changes: 117 additions & 0 deletions src/python/pants/engine/delete_rule_integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2024 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from dataclasses import dataclass

from pants.engine.rules import DeleteRule, collect_rules, rule
from pants.testutil.rule_runner import QueryRule, RuleRunner
from pants.engine.rules import Get
import pytest


@dataclass(frozen=True)
class IntRequest:
pass


@rule
async def original_rule(request: IntRequest) -> int:
return 0


@rule
def new_rule(request: IntRequest) -> int:
return 42


@dataclass(frozen=True)
class WrapperUsingCallByTypeRequest:
pass


@rule
async def wrapper_using_call_by_type(request: WrapperUsingCallByTypeRequest) -> int:
return await Get(int, IntRequest())


@dataclass(frozen=True)
class WrapperUsingCallByNameRequest:
pass


@rule
async def wrapper_using_call_by_name(request: WrapperUsingCallByNameRequest) -> int:
return await original_rule(IntRequest())


def test_delete_call_by_type() -> None:
rule_runner = RuleRunner(
target_types=[],
rules=[
*collect_rules(
{
"original_rule": original_rule,
"wrapper_using_call_by_type": wrapper_using_call_by_type,
}
),
QueryRule(int, [WrapperUsingCallByTypeRequest]),
],
)

result = rule_runner.request(int, [WrapperUsingCallByTypeRequest()])
assert result == 0

rule_runner = RuleRunner(
target_types=[],
rules=[
*collect_rules(
{
"original_rule": original_rule,
"wrapper_using_call_by_type": wrapper_using_call_by_type,
"new_rule": new_rule,
}
),
DeleteRule.create(original_rule),
QueryRule(int, [WrapperUsingCallByTypeRequest]),
],
)

result = rule_runner.request(int, [WrapperUsingCallByTypeRequest()])
assert result == 42

assert 0


def test_delete_call_by_name() -> None:
# rule_runner = RuleRunner(
# target_types=[],
# rules=[
# *collect_rules(
# {
# "original_rule": original_rule,
# "wrapper_using_call_by_name": wrapper_using_call_by_name,
# }
# ),
# QueryRule(int, [WrapperUsingCallByNameRequest]),
# ],
# )

# result = rule_runner.request(int, [WrapperUsingCallByNameRequest()])
# assert result == 0

rule_runner = RuleRunner(
target_types=[],
rules=[
*collect_rules(
{
"original_rule": original_rule,
"wrapper_using_call_by_name": wrapper_using_call_by_name,
"new_rule": new_rule,
}
),
DeleteRule.create(original_rule),
QueryRule(int, [WrapperUsingCallByNameRequest]),
],
)

result = rule_runner.request(int, [WrapperUsingCallByNameRequest()])
assert result == 42
11 changes: 11 additions & 0 deletions src/python/pants/engine/internals/rule_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ def _get_legacy_awaitable(self, call_node: ast.Call, is_effect: bool) -> Awaitab
is_effect,
)

def _get_input_types(self, input_nodes: Sequence[Any]) -> tuple[type, ...]:
input_nodes, input_type_nodes = self._get_inputs(input_nodes)
return tuple(
self._check_constraint_arg_type(input_type, input_node)
for input_type, input_node in zip(input_type_nodes, input_nodes)
)

def _get_byname_awaitable(
self, rule_id: str, rule_func: Callable, call_node: ast.Call
) -> AwaitableConstraints:
Expand All @@ -265,6 +272,9 @@ def _get_byname_awaitable(
# argument names of kwargs. But positional-only callsites can avoid those allocations.
explicit_args_arity = len(call_node.args)

# if "delete_rule_integration_test" in rule_id.split("."):
sys.stderr.write(f"{call_node=}\n")

input_types: tuple[type, ...]
if not call_node.keywords:
input_types = ()
Expand Down Expand Up @@ -295,6 +305,7 @@ def _get_byname_awaitable(
# TODO: Extract this from the callee? Currently only intrinsics can be Effects, so need
# to figure out their new syntax first.
is_effect=False,
rule_func=rule_func,
)

def visit_Call(self, call_node: ast.Call) -> None:
Expand Down
39 changes: 34 additions & 5 deletions src/python/pants/engine/internals/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from __future__ import annotations

import logging
import sys
import os
import time
from dataclasses import dataclass
from pathlib import PurePath
from types import CoroutineType
from typing import Any, Callable, Dict, Iterable, NoReturn, Sequence, cast
from typing import Any, Callable, Dict, Iterable, NoReturn, Sequence, cast, get_type_hints

from typing_extensions import TypedDict

Expand Down Expand Up @@ -66,7 +67,7 @@
Process,
ProcessResultMetadata,
)
from pants.engine.rules import Rule, RuleIndex, TaskRule
from pants.engine.rules import DeleteRule, RuleIndex, TaskRule
from pants.engine.unions import UnionMembership, is_union, union_in_scope_types
from pants.option.global_options import (
LOCAL_STORE_LEASE_TIME_SECS,
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
local_execution_root_dir: str,
named_caches_dir: str,
ca_certs_path: str | None,
rules: Iterable[Rule],
rule_index: RuleIndex,
union_membership: UnionMembership,
execution_options: ExecutionOptions,
local_store_options: LocalStoreOptions,
Expand Down Expand Up @@ -150,7 +151,6 @@ def __init__(
self._visualize_to_dir = visualize_to_dir
self._visualize_run_count = 0
# Validate and register all provided and intrinsic tasks.
rule_index = RuleIndex.create(rules)
tasks = register_rules(rule_index, union_membership)

# Create the native Scheduler and Session.
Expand Down Expand Up @@ -685,6 +685,12 @@ def register_task(rule: TaskRule) -> None:
)

for awaitable in rule.awaitables:
if "delete_rule_integration_test" in task_rule.canonical_name.split("."):
sys.stderr.write(f"{awaitable=}\n")
sys.stderr.write(f" {awaitable.input_types=}\n")
sys.stderr.write(f" {awaitable.output_type=}\n")
sys.stderr.write(f" {awaitable.explicit_args_arity=}\n")

unions = [t for t in awaitable.input_types if is_union(t)]
if len(unions) == 1:
# Register the union by recording a copy of the Get for each union member.
Expand All @@ -702,6 +708,20 @@ def register_task(rule: TaskRule) -> None:
raise TypeError(
"Only one @union may be used in a Get, but {awaitable} used: {unions}."
)
elif (
awaitable.rule_id is not None
and DeleteRule(awaitable.rule_id) in rule_index.delete_rules
):
# This is a call to a known rule, but some plugin has deleted
# it, so it wants to override it with some other rule. We have
# to call it by type to make it possible.
awaitable.rule_id = None
native_engine.tasks_add_get(
tasks,
awaitable.output_type,
[v for k, v in get_type_hints(awaitable.rule_func).items() if k != "return"],
)
sys.stderr.write(f"add_get for deleted {awaitable=}\n")
elif awaitable.rule_id is not None:
# Is a call to a known rule.
native_engine.tasks_add_call(
Expand All @@ -718,7 +738,16 @@ def register_task(rule: TaskRule) -> None:
native_engine.tasks_task_end(tasks)

for task_rule in rule_index.rules:
register_task(task_rule)
do_register = DeleteRule(rule_id=task_rule.canonical_name) not in rule_index.delete_rules
if "delete_rule_integration_test" in task_rule.canonical_name.split("."):
sys.stderr.write(
f"register {task_rule.canonical_name=}, {task_rule.func=}: {do_register}\n"
)
for awaitable in task_rule.awaitables:
sys.stderr.write(f" {awaitable=}\n")
sys.stderr.write(f" {awaitable.rule_id=}\n")
if do_register:
register_task(task_rule)
for query in rule_index.queries:
native_engine.tasks_add_query(
tasks,
Expand Down
Loading