Skip to content

Commit

Permalink
Issue one time keys in upload order
Browse files Browse the repository at this point in the history
Currently, one-time-keys are issued in a somewhat random order. (In practice,
they are issued according to the lexicographical order of their key IDs.) That
can lead to a situation where a client gives up hope of a given OTK ever
being used, whilst it is still on the server.

Fixes: element-hq/element-meta#2356
  • Loading branch information
richvdh committed Nov 4, 2024
1 parent c705bee commit 699d893
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 8 deletions.
2 changes: 1 addition & 1 deletion synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def claim_local_one_time_keys(
3. Attempt to fetch fallback keys from the database.
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
always_include_fallback_keys: True to always include fallback keys.
Returns:
Expand Down
15 changes: 13 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def __init__(
unique=True,
)

self.db_pool.updates.register_background_index_update(
update_name="add_otk_ts_added_index",
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
table="e2e_one_time_keys_json",
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
)


class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
def __init__(
Expand Down Expand Up @@ -1122,7 +1129,7 @@ async def claim_e2e_one_time_keys(
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
Returns:
A tuple (results, missing) of:
Expand Down Expand Up @@ -1313,6 +1320,7 @@ def _claim_e2e_one_time_key_simple(
sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
ORDER BY ts_added_ms
LIMIT ?
"""

Expand Down Expand Up @@ -1360,7 +1368,10 @@ def _claim_e2e_one_time_keys_bulk(
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
ROW_NUMBER() OVER (
PARTITION BY (user_id, device_id, algorithm)
ORDER BY ts_added_ms
) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
Expand Down
18 changes: 18 additions & 0 deletions synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.


-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
-- efficiently be issued in the same order they were uploaded.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8803, 'add_otk_ts_added_index', '{}');
79 changes: 74 additions & 5 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,30 @@ def test_change_one_time_keys(self) -> None:
def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}

res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
)

res2 = self.get_success(
# Keys should be returned in the order they were uploaded. To test, advance time
# a little, then upload a second key with an earlier key ID; it should get
# returned second.
self.reactor.advance(1)
res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
)

# now claim both keys back. They should be in the same order
res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
Expand All @@ -171,12 +183,27 @@ def test_claim_one_time_key(self) -> None:
)
)
self.assertEqual(
res2,
res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
},
)

def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
Expand Down Expand Up @@ -336,6 +363,48 @@ def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
)

def test_claim_one_time_key_bulk_ordering(self) -> None:
"""Keys returned by the bulk claim call should be returned in the correct order"""

# Alice has lots of keys, uploaded in a specific order
alice = f"@alice:{self.hs.hostname}"
alice_dev = "alice_dev_1"

self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
)
)
self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
)
)

# Now claim some, and check we get the right ones.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{alice: {alice_dev: {"alg1": 2}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
# We should get the first-uploaded keys, even though they have later key ids
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
"@alice:test": {"alice_dev_1": {"alg1:k20": 20, "alg1:k21": 21}}
},
},
)

def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
Expand Down

0 comments on commit 699d893

Please sign in to comment.