From 5fd4a9152a198faea67f0f2d5607a59a8be9b73c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 30 Aug 2022 11:45:37 -0400 Subject: [PATCH] add within_gradient --- Project.toml | 2 +- docs/src/reference.md | 1 + src/utils.jl | 15 +++++++++++++++ test/utils.jl | 5 +++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2b8688353..ab97e761b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.9" +version = "0.8.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/src/reference.md b/docs/src/reference.md index 69dc7ed5c..ae849b47f 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -130,4 +130,5 @@ ctc_loss ```@docs logsumexp NNlib.glu +NNlib.within_gradient ``` diff --git a/src/utils.jl b/src/utils.jl index cd1b9f03b..1bdd5e56a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,18 @@ +""" + within_gradient(x) --> Bool + +Returns `false` except when used inside a `gradient` call, when it returns `true`. +Useful for Flux regularisation layers which behave differently during training and inference. + +Works with any ChainRules-based differentiation package, in which case `x` is ignored. +But Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should +pass it an array whose gradient is of interest. +""" +within_gradient(x) = false + +ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent()) + + """ safe_div(x, y) diff --git a/test/utils.jl b/test/utils.jl index a5264dc5a..dcc882c3b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,3 +1,8 @@ +@testset "within_gradient" begin + @test NNlib.within_gradient([1.0]) === false + @test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,) +end + @testset "maximum_dims" begin ind1 = [1,2,3,4,5,6] @test NNlib.maximum_dims(ind1) == (6,)