Skip to content

Commit

Permalink
Conditionally use prologue in test_vjp_correctness (#1438)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored Nov 18, 2024
1 parent b21378c commit d60f85c
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _dot(x, y):
return sum([_tensor_dot(a, b) for a, b in zip(x, y)])


def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = False):
def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = False, prologue_required: bool = False):
"""Check that the vector-Jacobian product of a function is correct.
Args:
Expand Down Expand Up @@ -296,15 +296,21 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals

u = tree_map(make, primals)

comp_f = thunder.jit(f, disable_torch_autograd=True)
# dirty little trick for speed: skip the prologue, however, the prologue is required when
# there are non-differentiable kwargs
jf = executor.make_callable(f, disable_torch_autograd=True)
if prologue_required:
comp_f = thunder.jit(f, disable_torch_autograd=True)
else:
comp_f = thunder.compile_data(jf).get_computation_and_inputs(*primals)[0].computation_fn

outs_p, J_u = numerical_jvp(comp_f)(primals, u)

multiple_results = isinstance(outs_p, Sequence)

v = tree_map(make, outs_p)
if set_compile_data:
with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(comp_f), None):
with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(jf), None):
initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v)
else:
initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v)
Expand Down Expand Up @@ -364,8 +370,15 @@ def wrapper(*differentiable_args):
return wrapper, filtered_args


def snippet_vjp_correctness(func, args, comp, executor, set_compile_data):
check_vjp(func, *args, comp=comp, executor=executor, set_compile_data=set_compile_data)
def snippet_vjp_correctness(func, args, comp, executor, set_compile_data, prologue_required):
check_vjp(
func,
*args,
comp=comp,
executor=executor,
set_compile_data=set_compile_data,
prologue_required=prologue_required,
)


# TODO Use the given comparator
Expand Down Expand Up @@ -408,6 +421,7 @@ def test_vjp_correctness(op, device, dtype, executor, comp):
comp,
executor,
"adaptive_avg_pool2d" in op.name,
len(sample.kwargs) != 0,
)
if result is not None:
return result
Expand Down

0 comments on commit d60f85c

Please sign in to comment.