Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Apr 18, 2020
1 parent 327e3a6 commit 0819179
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
8 changes: 5 additions & 3 deletions src/autodiff/instructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ end
end

@i @inline function (-)(out!::GVar, x::GVar, y::GVar)
value(out!) -= value(x) + value(y)
value(out!) -= value(x) - value(y)
grad(x) += identity(grad(out!))
grad(y) -= identity(grad(out!))
end

@i @inline function (-)(out!::GVar, x::Real, y::GVar)
value(out!) -= value(x) + value(y)
value(out!) -= value(x) - value(y)
grad(y) -= identity(grad(out!))
end

@i @inline function (-)(out!::GVar, x::GVar, y::Real)
value(out!) -= value(x) + value(y)
value(out!) -= value(x) - value(y)
grad(x) += identity(grad(out!))
end

Expand Down Expand Up @@ -200,6 +200,8 @@ end
grad(x) += grad(out!) * x2
~@routine
end
@nograd (abs2)(a!::GVar, b::Real)
@nograd (abs2)(a!::Real, b::GVar)

for op in [:*, :/, :^, :+, :-]
@eval @nograd ($op)(out!::GVar, x::Real, y::Real)
Expand Down
6 changes: 3 additions & 3 deletions src/autodiff/vars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ Attach a gradient field to `x`.
end
@i function GVar(x::Function)
end
@i function GVar(x::AbstractArray)
GVar.(x)
end
@i function GVar(x::Tuple)
GVar.(x)
end
Expand All @@ -47,6 +44,9 @@ end
GVar(x::Complex) = Complex(GVar(x.re), GVar(x.im))
GVar(x::Complex, y::Complex) = Complex(GVar(x.re, y.re), GVar(x.im, y.im))
(_::Inv{GVar})(x::Complex) = Complex((~GVar)(x.re), (~GVar)(x.im))
GVar(x::AbstractArray) = GVar.(x)
GVar(x::AbstractArray, y::AbstractArray) = GVar.(x, y)
(_::Inv{GVar})(x::AbstractArray) = (~GVar).(x)

Base.copy(b::GVar) = GVar(b.x, copy(b.g))
Base.zero(x::GVar) = GVar(Base.zero(x.x), Base.zero(x.g))
Expand Down
2 changes: 1 addition & 1 deletion test/autodiff/vars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
@testset "assign tuple" begin
x = 0.3
z = Loss(0.3)
@instr GVar.((x,))
@instr for i=1:length(x) GVar(x) end
@test x === GVar(0.3)
end

Expand Down

0 comments on commit 0819179

Please sign in to comment.