You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On Wed, Nov 15, 2023, 17:20 bangxiangyong ***@***.***> wrote:
The NN with erf function output activation can occassionally output way
beyond the boundary [-1,1]:
from jax import random
from neural_tangents import stax
import neural_tangents as nt
import random as rd
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(1),
stax.Relu(),
stax.Dense(1),
stax.Relu(),
stax.Dense(1),
stax.Relu(),
stax.Dense(1),
stax.Erf(),
)
key1, key2 = random.split(random.PRNGKey(777))
x1 = random.normal(key1, (100, 10))
x2 = random.normal(key2, (100, 10))
x_train, x_test = x1, x2
y_train = [rd.choice([-1, 1]) for i in range(100)]
y_train = np.array(y_train)[:, np.newaxis]
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train)
y_test_nngp = predict_fn(x_test=x_test, get="nngp")
print(y_test_nngp.max()) ## 1.6560178
print(y_test_nngp.min()) ## -2.244388
Is this intended or have i missed something?
—
Reply to this email directly, view it on GitHub
<#191>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AKTTJJIBICIDANL4IUVMLIDYET2XJAVCNFSM6AAAAAA7M26XEGVHI2DSMVQWIX3LMV43ASLTON2WKOZRHE4TKMRRGEYDCOI>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
I don't have a good answer yet, but it appears that bad conditioning of kernel_fn(x_train, x_train).nngp matrix (which is inverted to make predictions) is causing the numerical issues. One way to improve it is to have higher-dimensional inputs (e.g. have 1000 features vs 10, now input covariance is rank-10, and it appears to result in a badly conditioned output covariance), and/or pass a diag_reg=1e-3 (or other vaues) when calling gradient_descent_mse_ensemble to add a small diagonal matrix to kernel_fn(x_train, x_train).nngp before inversion.
The NN with erf function output activation can occassionally output way beyond the boundary [-1,1]:
Is this intended or have i missed something?
The text was updated successfully, but these errors were encountered: