-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialization of lambda incorrect #5
Comments
Hi @ozppupbg, nice catch!
That chart was generated by this notebook: https://gist.github.com/proger/cd6ee302661034b7b8d4685dcad8cc3d |
The paper describes "We initialize Λ such that We then solve for given that we can perform I develop in jax/flax so apologies if this cant be checked. That being said couldnt we define the cell like the following: class RGLRUCell(nn.Module):
feature: int
c: float = 8.0
gate_fn: callable = nn.sigmoid
def setup(self):
self.k = self.param(
"k",
lambda rng, s: jax.random.uniform(rng, s, minval=0.9, maxval=0.999),
(self.feature,),
)
self.ri = nn.Dense(self.feature * 2)
def __call__(self, carry, x):
(h_prev, _) = carry
r, i = jnp.split(self.ri(x), 2, axis=-1)
r = self.gate_fn(r)
i = self.gate_fn(i)
a = jnp.exp(jnp.log((1 / (self.k ** (1/(self.c * r)))) - 1))
h_new = a * h_prev + jnp.sqrt(1 - a**2) * (i * x)
return (h_new, h_new), h_new Im trying to make sense where this line and its values are coming from. Line 45 in 92c94f5
|
Hello,
I noticed a deviation from the Griffin paper in your code.
The Griffin paper states in the second part of chapter 2.4:
and a = sigmoid(Λ).
So actually, the initialization for Lambda should be calculated as
Λ = -log((1 / a^(1/c)) - 1)
with a uniformly between 0.9 and 0.999.Here, Lambda is initialized as:
hippogriff/hippogriff.py
Line 45 in 7bf5732
which is neither random nor uniform, as linspace is not random and the sigmoid operation is not linear.
The text was updated successfully, but these errors were encountered: