Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Associative Recall task #10

Open
tristandeleu opened this issue Oct 13, 2015 · 8 comments
Open

Associative Recall task #10

tristandeleu opened this issue Oct 13, 2015 · 8 comments
Labels

Comments

@tristandeleu
Copy link
Collaborator

Associative Recall task

I will gather all the progress on the Associative Recall task in this issue. I will likely update this issue regularly (hopefully), so you may want to unsubscribe from this issue if you don't want to get all the spam.

@tristandeleu
Copy link
Collaborator Author

Training the NTM with items of length one

Like in the previous tasks, I begin the tests on the Associative Recall tasks on small examples to check if everything goes well. This time though the NTM seems to not perform well even on this small data. The examples for this task have some variability (length of one item & number of items in the sequence), hence I decided to start by looking at sequences of length-one items.

The training procedure seems to get stuck in some local minimum within the first few iterations. This local minimum, like in the Copy task, may be one where NTM figures out that it has to write on the second half of the output, but does not know what to write just yet. The NTM was trained on sequences of 2 to 6 items of length 1. Here is the result given by the NTM on some test data:
assoc-recall-fail-01
Something that is concerning is that the NTM seems to write every items on the same address (besides the first item) which inevitably results in some loss in information. This overwriting appears to be item-wise, and not vector-wise (it writes every item on the same addresses). This is confirmed by this test on items of length 3:
assoc-recall-fail-02

Learning curve

assoc-recall-learning-curve

@tristandeleu
Copy link
Collaborator Author

Training the NTM (partial training)

In this experiment, I trained the NTM on sequences of 2 to 6 items of length from 1 to 3. Earlier experiments on such data showed that the training procedure often gets stuck in some local minimum (which I suspect is when the NTM figures out where to write its output). At test time on this experiment, it shows that the NTM actually learned to write its inputs sequentially in memory this time, but it is still not able to consistently perform well on test data (even for similar data it was trained on).
assoc-recall-01
However most of the time, the NTM is not able to retrieve the correct item.
assoc-recall-04
It seems that the model is able to correctly retrieve the item when the query is the 2nd item (and correctly outputs the 3rd item). When reading the query, we can see that the NTM reads the memory addresses corresponding to the 2nd item. I assume that the NTM has learned a function that does something like:

for every input vector
    write its representation in memory
    go to the next address in memory
if the query is equal to the 2nd item
    output the 3rd item
else
    output a mixture of all the items (except 1st & 3rd)

Moreover, when the query corresponds to the 2nd item, the NTM starts reading the addresses corresponding to the 2nd item in memory ; after the first few vectors of the query (maybe if the the first 1 or 2 vectors of the query match the ones of the 2nd item), it reads the 127th address. We can notice the same behavior when the NTM first sees the 2nd item, so maybe that's how it remembers it really corresponds to the 2nd item. This seems to be confirmed by the generalization example (see below).
assoc-recall-06
In terms of generalization, the model is already not able to get correct outputs on similar data it was trained on. When testing with longer items (of length 5), we can see that the NTM is not able to retrieve the correct item, even in the successful case of querying the 2nd item. An analysis of the read weights shows that this must be due to information leakage over multiple addresses while reading (because it indeed tries to read the memory where the representation of the 3rd item is stored).
assoc-recall-07

Learning curve

learning-curve

Parameters of the experiment

I made a lot of changes to #10 (comment) for this experiment:

  • The activation functions: add and key have the activation tanh, erase and gate have the activation sigmoid, beta has activation softplus and gamma has activation 1 + softplus.
  • I switched the updates from Graves' RMSprop to Adam (with learning_rate=1e-3)
  • I did not learn the parameters for key, beta and gate for the writing head (not needed for this task), so that the write head directly uses w_g = w_tm1.
  • I clipped the gradients to [-10, 10]

@tristandeleu
Copy link
Collaborator Author

In [1], they suggest that the Associative Recall task is significantly harder to train than the Copy task, and does not always converge.

For associative recall, we can see that outliers are produced much more frequently when loss significantly reduces, and we rarely observe convergence of original NTM

[1] Wei Zhang, Yang Yu, Structured Memory for Neural Turing Machines [arXiv]

@tristandeleu
Copy link
Collaborator Author

Training the NTM with more iterations (partial training)

Initializing the model with the weights obtained in experiment #10 (comment), I decided to keep the same setup and let the training procedure run longer than just 500k iterations. I left the training procedure run for 2M+ iterations to see if the model is at the moment inherently unable to solve this task or if it is simply a matter of running time. The NTM was again trained on sequences of 2 to 6 items, each of length 1 to 3. Compared to the performances observed in the previous experiment, the performances for this experiment are far better (with an improvement in the cross-entropy loss of only 1 order in magnitude, from ~1e-2 down to ~1e-3). Here are some examples
assoc-recall_02-06
assoc-recall_02-15
assoc-recall_02-14
Contrary to DeepMind's results

This implies the following memory-access algorithm: when each item delimiter is presented, the controller writes a compressed representation of the previous three time slices of the item.

the NTM does not seem to write any such compressed representations of the items for each delimiter, but instead overwrites the last vector of the item when the delimiter is presented. In that case it may rely heavily on the read weights -- that are activated, even when the items are presented -- and may not scale as well as the behavior from DeepMind.

However, even though it appears that most of the test examples are almost perfectly predicted, it still happens that the NTM is not able to make correct predictions, even on test samples similar to the training examples (same range of items' number, same range of items' length).
assoc-recall_02-08
In terms of generalization, the NTM performs surprisingly well and is able to make correct predictions. I tested the generalization on both degrees of freedom:

  • When the number of items is between 8 and 12, then the predictions are for the most part not as sharp as it was in the previous tests, with a lot of predictions in (0, 1). Still, if one thresholds at 0.5, this mostly leads to at most 1 bps (bits per sequence) error -- this is a pure qualitative observation, there is at the moment no performance metric to confirm it.
    assoc-recall_02-21
    assoc-recall_02-29
    Here is another example where this time the generalization fails. It appears that the NTM actually retrieved the right item in memory (6th), but suffers from the "information leakage" issue observed in Associative Recall task #10 (comment)
    assoc-recall_02-25_fail
  • When the length of each item is 5, predictions seem to be a lot sharper than for generalization in the number of items.
    assoc-recall_02-34
    assoc-recall_02-38
    Here is an example where the generalization fails
    assoc-recall_02-32_fail
  • When the number of items is between 8 and 12 and the length of each item is 5, then the performances are still surprisingly correct, even though not as good as in the previous tests. We are here hitting the limits of the generalization performances, where the "information leakage" issue appears more frequently.
    assoc-recall_02-43
    assoc-recall_02-41_fail
    assoc-recall_02-42_fail
Learning curve

learning-curve_02

Parameters of the experiment

Same parameters as in #10 (comment), where I initialized the model with the results given in #10 (comment) -- more or less as if I left the training procedure running across the two experiments.

The training stopped here at 2.1M iterations due to #11.

@tristandeleu
Copy link
Collaborator Author

Training the NTM with 'Copy task' initialization

Similarly to #10 (comment), I decided for this experiment to tune the initialization and used a set of parameters learned for the Copy task (#6) as initial weights. The intuition is that in the previous experiment, the NTM was overwriting the last element of every item with the delimiter, instead of creating a compressed representation of the whole item at another location in memory; the goal here was to suggest the NTM to write all its inputs sequentially in memory.

I trained the model on sequences of 2 to 6 items, each of them being of length 1 to 3. It converged quickly compared to previous experiments, even though it sometimes struggled to stay in good minima. The error decreased by around 2 orders of magnitude between the previous experiment and this one. The performances are a lot better than the one observed previously and the behavior of the NTM matches the results from DeepMind. Here are some examples
assoc-recall-12
assoc-recall-14
assoc-recall-15
The green line in the weights plots represents the beginning of the query and the red line represents the beginning of the output/the end of the input.

As it is reading the query, the NTM knows almost instantly where to look that in memory, reads both the query and the answer and returns the answer. A particular behavior here (which seems to be consistent across all the tests) is that when the model first sees the delimiter for the query, it reads the memory at the location 124 and then starts reading the query item in memory. I need to investigate on what is really happening at that location.

The generalization performances are very good on both dimensions:

  • When the number of items is between 8 and 15, the NTM is still able to match almost perfectly its outputs.
    assoc-recall-24
    assoc-recall-28
    However it still occasionally suffers from information leakage as in the previous example (not as bad as it was in experiment Associative Recall task #10 (comment) though).
    jrszj
  • When the length of each item is 5, the NTM is once again able to match the outputs almost perfectly
    assoc-recall-44
    assoc-recall-42
  • Same generalization performances when we increase both dimensions at the same time.
Performance metrics

I now have metrics to evaluate how well the NTM generalizes on this task. I only checked the performance in terms of the number of items for now, with items of length 3. As we can see the model is able to make less than 1-bit per sequence errors on sequences up to length 30. The error then explodes because we are actually reaching the limit in the number of locations inside the memory (128 locations compared to 30 * (3 + 1) = 120 for all the items and delimiters).
generalization-performance
generalization-performance-02
Average generalization error as a function of the number of items in the input sequence, in bits-per-sequence

We are here actually outperforming DeepMind's results here since they got 1-bps error for sequences of 15 items and ~7-bps for sequences of 20 items.

Learning curve

The learning curve suggests that the learning rate is too high (currently at 1e-4). We may have to reduce in later experiments, or decrease it as the training goes.
learning-curve

Parameters of the experiment

I switched back for the most part to the setup I used in #8.

  • I now use rectify and hard_sigmoid activations instead of tanh, sigmoid and softplus.
  • I still don't use and learn the weight matrices and biases of key, beta and gate for the write head.
  • I used the Adam optimizer with learning_rate=1e-4.
  • I used parameters learned for the Copy task as initial parameters.

@tristandeleu
Copy link
Collaborator Author

The Adam optimizer seems to be crucial for this task to make the model converge much faster. I do not have any success with the convergence of the NTM with RMSProp yet. A smaller learning rate for the Adam optimizer still led to a behavior similar to #10 (comment).

@tristandeleu
Copy link
Collaborator Author

Training the NTM with 'Copy task' initialization (2)

I performed the same experiment as in #10 (comment), but this time I trained the whole model with key, beta and gate parameters of the write head. Even though the performances are not as good as in the previous experiment, they are still very good anyway. However the NTM occasionally fails to retrieve the item in generalization. Here is an failure example
assoc-recall-01_fail
In terms of generalization, the performances are still better than DeepMind's, with less than 3-bps error for sequence of length 30 compared to 7-bps error for sequences of length 20. It is worth noting though that the behavior observed in the example above happens consistently when reaching the limits of the memory. It seems that the NTM often writes something at locations above ~105 which could be misleading for item retrieval afterwards -- most of the errors in generalization are likely to happen when the item to retrieve appears by the end of the input sequence.
generalization

Learning curve

learning-curve

@tristandeleu
Copy link
Collaborator Author

The above experiments use only 1 read head & 1 write head whereas the original paper is using 4 heads (of each?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant