Skip to content

Commit

Permalink
Use a sortedset instead
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed May 8, 2024
1 parent 202a09c commit a2adc9c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 19 deletions.
28 changes: 9 additions & 19 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import attr
from prometheus_client import Counter, Gauge
from sortedcontainers import SortedSet

from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
Expand Down Expand Up @@ -373,25 +374,15 @@ def _get_chain_links(

# We fetch the links in batches. Separate batches will likely fetch the
# same set of links (e.g. they'll always pull in the links to create
# event). To try and minimize the amount of redundant links, we sort the
# chain IDs in reverse, as there will be a correlation between the order
# of chain IDs and links (i.e., higher chain IDs are more likely to
# depend on lower chain IDs than vice versa).
# event). To try and minimize the amount of redundant links, we query
# the chain IDs in reverse order, as there will be a correlation between
# the order of chain IDs and links (i.e., higher chain IDs are more
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 1000
chains_to_fetch_list = list(chains_to_fetch)
chains_to_fetch_list.sort(reverse=True)

seen_chains: Set[int] = set()
while chains_to_fetch_list:
batch2 = [
c for c in chains_to_fetch_list[-BATCH_SIZE:] if c not in seen_chains
]
chains_to_fetch_list = chains_to_fetch_list[:-BATCH_SIZE]
while len(batch2) < BATCH_SIZE and chains_to_fetch_list:
chain_id = chains_to_fetch_list.pop()
if chain_id not in seen_chains:
batch2.append(chain_id)
chains_to_fetch_sorted = SortedSet(chains_to_fetch)

while chains_to_fetch_sorted:
batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
Expand All @@ -409,8 +400,7 @@ def _get_chain_links(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

seen_chains.update(links)
seen_chains.update(batch2)
chains_to_fetch_sorted.difference_update(links)

yield links

Expand Down
74 changes: 74 additions & 0 deletions tests/storage/test_purge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.server import HomeServer
from synapse.util import Clock

from tests.test_utils.event_injection import inject_event
from tests.unittest import HomeserverTestCase


Expand Down Expand Up @@ -128,3 +129,76 @@ def test_purge_room(self) -> None:
self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)

def test_state_groups_state_decreases(self) -> None:
response = self.helper.send(self.room_id, body="first")
first_event_id = response["event_id"]

batches = []

previous_event_id = first_event_id
for i in range(50):
state_event1 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 1},
prev_event_ids=[previous_event_id],
origin_server_ts=1,
)
)

state_event2 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 2},
prev_event_ids=[previous_event_id],
origin_server_ts=2,
)
)

# print(state_event2.origin_server_ts - state_event1.origin_server_ts)

message_event = self.get_success(
inject_event(
self.hs,
type="dummy_event",
sender=self.user_id,
room_id=self.room_id,
content={},
prev_event_ids=[state_event1.event_id, state_event2.event_id],
)
)

token = self.get_success(
self.store.get_topological_token_for_event(state_event1.event_id)
)
batches.append(token)

previous_event_id = message_event.event_id

self.helper.send(self.room_id, body="last event")

def count_state_groups() -> int:
sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
return rows[0][0]

print(count_state_groups())
for token in batches:
token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
self.get_success(
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, False
)
)
print(count_state_groups())

0 comments on commit a2adc9c

Please sign in to comment.