Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization of lambda incorrect #5

Open
ozppupbg opened this issue Mar 21, 2024 · 2 comments
Open

Initialization of lambda incorrect #5

ozppupbg opened this issue Mar 21, 2024 · 2 comments

Comments

@ozppupbg
Copy link

Hello,

I noticed a deviation from the Griffin paper in your code.

The Griffin paper states in the second part of chapter 2.4:

We initialize Λ such that a^c is uniformly distributed between 0.9 and 0.999 at the start of training,

and a = sigmoid(Λ).

So actually, the initialization for Lambda should be calculated as Λ = -log((1 / a^(1/c)) - 1) with a uniformly between 0.9 and 0.999.

Here, Lambda is initialized as:

self.forget_base = nn.Parameter(torch.linspace(-4.323, -9, hidden))

which is neither random nor uniform, as linspace is not random and the sigmoid operation is not linear.

@proger
Copy link
Owner

proger commented Mar 22, 2024

Hi @ozppupbg, nice catch!

forget_base is initialized so that values of (-alpha_log_scale.exp() * softplus(forget_base)).exp() are in the range of 0.9...0.999 but are exponentially biased towards 1. This makes the activation distribution more similar to Mamba's forget gates. I'll be adding sweeps to ablate this decision later.

image

That chart was generated by this notebook: https://gist.github.com/proger/cd6ee302661034b7b8d4685dcad8cc3d

@BeeGass
Copy link

BeeGass commented Mar 31, 2024

The paper describes "We initialize Λ such that $a^{c}$ is uniformly distributed between 0.9 and 0.999". We assume there is $a^{c} = k$ where $k$ is uniformly distributed between 0.9 and 0.999. Given this we solve for $a$:

$$ \begin{align} a^{c} &= k \\ \log(a^{c}) &= \log(k) \\ c \cdot \log(a) &= \log(k) \\ \frac{c \cdot \log(a)}{c} &= \frac{\log(k)}{c} \\ \log(a) &= \frac{\log(k)}{c} \\ \log(a) &= \frac{1}{c} \log(k) \\ \log(a) &= \log(k^{\frac{1}{c}}) \\ a &= k^{\frac{1}{c}} \end{align} $$

We then solve for $\log(\sigma(\Lambda)^{c \cdot r})$ with this redefined $a$.

$$ \begin{align} a &= \sigma (-\Lambda) \\ a &= \frac{1}{(1 + e^{-(-\Lambda)})} \\ a &= \frac{1}{(1 + e^{(\Lambda)})} \\ a (1 + e^{-(-\Lambda)}) &= \frac{(1 + e^{(\Lambda)})}{(1 + e^{(\Lambda)})} \\ a (1 + e^{(\Lambda)}) &= 1 \\ \frac{a (1 + e^{(\Lambda)})}{a} &= \frac{1}{a} \\ 1 + e^{(\Lambda)} &= \frac{1}{a} \\ 1 + e^{(\Lambda)} - 1 &= \frac{1}{a} - 1 \\ e^{(\Lambda)} &= \frac{1}{a} - 1 \\ \ln{e^{(\Lambda)}} &= \ln{(\frac{1}{a} - 1)} \\ \Lambda &= \ln{(\frac{1}{a} - 1)} \\ \Lambda &= \ln{(\frac{1}{k^{\frac{1}{c}}} - 1)} \\ \end{align} $$

given that we can perform $a^{c \cdot r} \rightarrow k^{\frac{1}{c \cdot r}}$, we can say:

$$ \Lambda = \ln{(\frac{1}{k^{\frac{1}{c \dot r}}} - 1)} $$

I develop in jax/flax so apologies if this cant be checked. That being said couldnt we define the cell like the following:

class RGLRUCell(nn.Module):

    feature: int
    c: float = 8.0
    gate_fn: callable = nn.sigmoid

    def setup(self):
        self.k = self.param(
            "k",
            lambda rng, s: jax.random.uniform(rng, s, minval=0.9, maxval=0.999),
            (self.feature,),
        )
        self.ri = nn.Dense(self.feature * 2)

    def __call__(self, carry, x):
        (h_prev, _) = carry
        r, i = jnp.split(self.ri(x), 2, axis=-1)
        r = self.gate_fn(r)
        i = self.gate_fn(i)
        
        a = jnp.exp(jnp.log((1 / (self.k ** (1/(self.c * r)))) - 1))

        h_new = a * h_prev + jnp.sqrt(1 - a**2) * (i * x)

        return (h_new, h_new), h_new

Im trying to make sense where this line and its values are coming from.

self.forget_base = nn.Parameter(torch.linspace(-4.323, -9, hidden))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants