Skip to content

Simple ML package that computes gradients without backpropagation

Notifications You must be signed in to change notification settings

ForceBru/Matches.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Matches.jl

The most barebones machine learning "framework" that can compute (unbiased estimates of) gradients without backpropagation. Based on "Gradients without backpropagation" [1].

NOTE: I'm not one of the authors and have no affiliation with the paper. I just found it interesting and tried to see whether it works.

The interface mimicks PyTorch, but since this package is just a toy, it's not a torch, but a bunch of matches :P

Basic usage

using Matches

# 1. Get data.
n_observations, n_features = 100, 5
X, Y = cumsum(randn(n_observations, n_features), dims=1), cumsum(randn(n_observations, 1), dims=1)

# 2. Build model, PyTorch-like.
model = Sequential(
    Linear(n_features, 7), Activations.sigmoid,
    Linear(7, 5), Activations.tanh,
    Linear(5, 1)
)

# 3. Set up gradient descent optimizer.
# It can compute the "forward gradient" from the paper.
optim = Descent(params(model), 1e-4)

# 4. Train.
for epoch in 1:50_000
    # Randomize the dual parts of model parameters
    random_dual!(optim)
    loss = Losses.mse(model(X), Y)
    # Estimate the gradient and use it in gradient descent
    step!(optim, loss)

    if epoch % 1000 == 0
        @show (epoch, real(loss))
    end
end

# 5. Predict!
Y_hat = real(model(X))

Also see example.jl. Basic usage:

$ wget https://raw.githubusercontent.com/ForceBru/Matches.jl/master/example.jl
$ julia example.jl

References

  1. Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without Backpropagation (Version 1). arXiv. https://doi.org/10.48550/ARXIV.2202.08587

About

Simple ML package that computes gradients without backpropagation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Languages