This is an experimental project where I've replaced the exponential time decay in the RWKV language model by polynomial decay. The goal of the project is to improve the model's long-term memory.
RWKV is an RNN language model with Transformer-level performance. Recently, there has been a large focus on training the model with longer context lengths, as the model is not very effective at generalizing to context lengths beyond what was used during training. However, training with very long context lengths is difficult and expensive, so there is a need to also find other ways to improve the model's ability to utilize long-term memory.
RWKV's memory consists of several channels, each with a specific decay factor which specifies how much emphasis the channel should put on recent data compared to older data. For a channel with decay factor
I believe that the main reason why RWKV struggles with generalizing to larger context lengths is that the attention decays exponentially, because this means that unless
Bo Peng, the main author of RWKV, has previously suggested modifying the decay schedule by adding a constant term:
My proposal is to replace the exponential decay,
Let us take a look at how the attention weight is distributed as the context length tends to infinity, when using a polynomial decay instead of an exponential one. The behavior depends on the value of
-
$w \geq -1$ : In this case the attention given to the most recent, say,$1000$ data points will tend to$0$ as the context width tends to infinity. This is because$\int_{0}^{\infty} x^w \mathrm{d} x = \infty$ , so the bulk of the weight will be on old data points. -
$w < -1$ : In this case the attention given to the most recent$1000$ data points will not tend to 0, because$\int_{0}^{\infty} x^w \mathrm{d} x$ converges. This means that new data will still be able to meaningfully affect the channel's value regardless of how large the context length is. However, the contribution from old data also doesn't become negligible because the weight only decays polynomially, so if there is some important piece of information that the model saw long ago it will still be capable of remembering it.
In essence, I think polynomial decay strikes a good balance between attending to data points of different ages. With polynomial decay, the total attention given to data of ages in the range
- the preceding words in the sentence,
- the preceding sentences in the chapter, and
- the preceding chapters in the book,
and the total attention given to each of these should be in the same ballpark.
The problem with polynomial decay is that it is more challenging to compute, as it doesn't suit the recursive structure of RWKV as well as the exponential decay does. The exponential attention that is currently used in RWKV can be computed easily in a recursive manner. Suppose we want to compute
However, we can still get the benefits of polynomial decay even if we don't compute it exactly! And it is possible to approximate
For any
Define
Now that we know how to compute
Here is a comparison of
So far I have only been able to validate that this works on a small scale. I think a large scale training is the only way to know for sure what effects the proposed change would have, because small models trained on standard datasets aren't sophisticated enough to make full use of long context lengths, in which case it doesn't matter if the decay is exponential or polynomial.
However, I have seen promising results when training small models on relatively simple memory-focused tasks. For my test bed I'm using generalized fizz buzz, which is a version of the well-known fizz buzz game but with more divisors beyond just
Example of what the generated fizz buzz data can look like
1
Zyzz
Lazz
ZyzzWyzz
Qizz
ZyzzLazzZyzz
Wezz
ZyzzWyzzSezz
LazzLuzz
ZyzzQizzXezz
Pizz
ZyzzLazzWyzzZyzz
13
ZyzzWezz
LazzQizzQazz
ZyzzWyzzSezz
17
ZyzzLazzZyzzLuzzCuzz
19
ZyzzWyzzQizzXezzJizz
LazzWezzVezz
ZyzzPizz
23
...
Training two ctx256 L6 D512
models, one with exponential decay and one with polynomial decay, there is a clear difference in how well the models can handle this problem:
Training commands
The training commands I used for the above runs are as follows.
Exponential decay
cd RWKV-v4neo & python train.py --load_model "" --wandb "" --proj_dir "out" \
--data_file "" --data_type "fizzbuzz" --vocab_size 0 \
--ctx_len 256 --epoch_steps 100 --epoch_count 500 --epoch_begin 0 --epoch_save 10 \
--micro_bsz 16 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
--lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision tf32 --strategy ddp_find_unused_parameters_false --grad_cp 0 \
--decay "exponential"
Polynomial decay
cd RWKV-v4neo & python train.py --load_model "" --wandb "" --proj_dir "out" \
--data_file "" --data_type "fizzbuzz" --vocab_size 0 \
--ctx_len 256 --epoch_steps 100 --epoch_count 500 --epoch_begin 0 --epoch_save 10 \
--micro_bsz 16 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
--lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision tf32 --strategy ddp_find_unused_parameters_false --grad_cp 0 \
--decay "polynomial"
If I use larger models, the model with the exponential decay is also able to achieve a lower loss, but it appears that the model with polynomial decay makes more effective use of the channels it has at its disposal.
Of course, this is just a toy problem, and there is no guarantee that polynomial decay's advantage generalizes to larger models and real world datasets. To find out, we'll need to try it! 🙂️