-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchorny.py
167 lines (145 loc) · 5.32 KB
/
chorny.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import dis
from importlib.machinery import ExtensionFileLoader, SourceFileLoader
from pathlib import Path
import sys
def _patch_function(func, new_code, new_names):
func.__code__ = func.__code__.replace(co_code=new_code, co_names=new_names)
def _no_mypyc_for_black(cls, fullname, path):
if fullname.startswith("black.") or fullname == "black":
path = Path(path)
name = path.name
path = path.with_name(f"{name[:name.find('.')]}.py")
return SourceFileLoader(fullname, str(path))
return object.__new__(cls)
def patch_black():
ExtensionFileLoader.__new__ = _no_mypyc_for_black # there is no way back
import black
from black import (
_format_str_once,
format_file_contents,
Line,
patched_main,
syms,
token,
transform_line,
)
from black.nodes import is_one_sequence_between
from blib2to3.pytree import Leaf
def transform_line_patch(line, mode, features=()):
for transformed_line in transform_line(line, mode, features=features):
removed = []
leaves = transformed_line.leaves
last_index = len(leaves) - 1
for i, leaf in enumerate(leaves):
if leaf.type == token.ERRORTOKEN and leaf.value == ",":
if (i > 0 and leaves[i - 1].type == token.COMMA) or (
i < last_index and leaves[i + 1].type == token.COMMA
):
removed.append(i)
for i in removed[::-1]:
del leaves[i]
yield transformed_line
def format_file_contents_patch(src_contents, *, fast, mode):
try:
return format_file_contents(src_contents, fast=fast, mode=mode)
except AssertionError:
return format_file_contents(
_format_str_once(src_contents, mode=mode), fast=fast, mode=mode,
)
def has_magic_trailing_comma_patch(self, closing: Leaf) -> bool:
if (
(self.is_def or self.is_decorator)
and closing.opening_bracket is not None
and not is_one_sequence_between(closing.opening_bracket, closing, self.leaves)
):
return True
if closing.opening_bracket is None:
return False
# if already multiline, set multiline
parent = self.leaves[-1].parent
if parent.type not in (syms.arglist, syms.typedargslist):
return False
line = 0
distinct_lines = 0
for child in parent.children:
if child.type == token.COMMA:
if child.lineno != line:
line = child.lineno
distinct_lines += 1
if distinct_lines <= 1:
first_leaf = next(iter((parent.leaves())))
comma_leaf = self.leaves[-1]
assert comma_leaf.type == token.COMMA
if first_leaf.lineno < line or comma_leaf.column in (0, self.mode.line_length):
return False
if line == closing.lineno:
self.remove_trailing_comma()
else: # first_leaf.lineno == line
comma_leaf.type = token.ERRORTOKEN
return False
return True
black.format_file_contents = format_file_contents_patch
black.transform_line = transform_line_patch
black.linegen.transform_line = transform_line_patch
Line.has_magic_trailing_comma_patch = has_magic_trailing_comma_patch
func = Line.has_magic_trailing_comma
code = func.__code__
ops = bytearray(code.co_code)
if sys.version_info >= (3, 11):
pattern = [
dis.opmap["LOAD_FAST"],
code.co_varnames.index("self"),
dis.opmap["LOAD_ATTR"],
code.co_names.index("is_import"),
*([0] * 8),
dis.opmap["POP_JUMP_FORWARD_IF_FALSE"],
]
else:
pattern = [
dis.opmap["LOAD_FAST"],
code.co_varnames.index("self"),
dis.opmap["LOAD_ATTR"],
code.co_names.index("is_import"),
dis.opmap["POP_JUMP_IF_FALSE"],
]
pos = ops.find(bytes(pattern))
assert pos >= 0, "patch failed for has_magic_trailing_comma"
pos += len(pattern) + 5
if sys.version_info >= (3, 11):
patch = [
dis.opmap["LOAD_FAST"],
code.co_varnames.index("self"),
dis.opmap["LOAD_METHOD"],
len(code.co_names),
*([0] * 20),
dis.opmap["LOAD_FAST"],
code.co_varnames.index("closing"),
dis.opmap["PRECALL"],
1,
*([0] * 2),
dis.opmap["CALL"],
1,
*([0] * 8),
dis.opmap["RETURN_VALUE"],
0,
]
else:
patch = [
dis.opmap["LOAD_FAST"],
code.co_varnames.index("self"),
dis.opmap["LOAD_METHOD"],
len(code.co_names),
dis.opmap["LOAD_FAST"],
code.co_varnames.index("closing"),
dis.opmap["CALL_METHOD"],
1,
dis.opmap["RETURN_VALUE"],
]
ops[pos : pos + len(patch)] = patch # noqa: E203
ops = ops[: pos + len(patch)]
_patch_function(func, bytes(ops), code.co_names + ("has_magic_trailing_comma_patch",))
return patched_main
def main():
patch_black()()
if __name__ == "__main__":
main()