Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewFerr committed Jul 30, 2024
1 parent e0d2d76 commit 6b06bf0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 52 deletions.
8 changes: 2 additions & 6 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC4140: Delayed events
# The maximum allowed duration for delayed events.
try:
self.msc4140_max_delay = int(
experimental["msc4140_max_delay"]
)
self.msc4140_max_delay = int(experimental["msc4140_max_delay"])
if self.msc4140_max_delay < 0:
raise ValueError
except ValueError:
Expand All @@ -457,9 +455,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
("experimental", "msc4140_max_delay"),
)
except KeyError:
self.msc4140_max_delay = (
10 * 365 * 24 * 60 * 60 * 1000
) # 10 years
self.msc4140_max_delay = 10 * 365 * 24 * 60 * 60 * 1000 # 10 years

# MSC4151: Report room API (Client-Server API)
self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False)
Expand Down
81 changes: 56 additions & 25 deletions synapse/handlers/delayed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

import attr

from twisted.internet.interfaces import IDelayedCall

from synapse.api.constants import EventTypes
Expand All @@ -38,14 +39,21 @@
from synapse.logging.opentracing import set_tag
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.delayed_events import (
EventType,
Delay,
DelayID,
EventType,
StateKey,
Timestamp,
UserLocalpart,
)
from synapse.types import JsonDict, Requester, RoomID, StateMap, UserID, create_requester
from synapse.types import (
JsonDict,
Requester,
RoomID,
StateMap,
UserID,
create_requester,
)
from synapse.util.async_helpers import Linearizer, ReadWriteLock
from synapse.util.stringutils import random_string

Expand Down Expand Up @@ -89,7 +97,9 @@ def __init__(self, hs: "HomeServer"):

async def _schedule_db_events() -> None:
# TODO: Sync all state first, so that affected delayed state events will be cancelled
events, remaining_timeout_delays = await self.store.process_all_delays(self._get_current_ts())
events, remaining_timeout_delays = await self.store.process_all_delays(
self._get_current_ts()
)
for args in events:
await self._send_event(*args)

Expand All @@ -104,7 +114,9 @@ async def _schedule_db_events() -> None:
"_schedule_db_events", _schedule_db_events
)

async def on_new_event(self, event: EventBase, _state_events: StateMap[EventBase]) -> None:
async def on_new_event(
self, event: EventBase, _state_events: StateMap[EventBase]
) -> None:
"""
Checks if a received event is a state event, and if so,
cancels any delayed events that target the same state.
Expand Down Expand Up @@ -209,7 +221,7 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
f"'action' is not one of {', '.join(map(lambda m: m.value, _UpdateDelayedEventAction))}",
f"'action' is not one of {', '.join(m.value for m in _UpdateDelayedEventAction)}",
Codes.INVALID_PARAM,
)

Expand All @@ -220,7 +232,9 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None

async with self._get_delay_context(delay_id, user_localpart):
if enum_action == _UpdateDelayedEventAction.CANCEL:
for removed_timeout_delay_id in await self.store.remove(delay_id, user_localpart):
for removed_timeout_delay_id in await self.store.remove(
delay_id, user_localpart
):
self._unschedule(removed_timeout_delay_id, user_localpart)

elif enum_action == _UpdateDelayedEventAction.RESTART:
Expand All @@ -234,22 +248,29 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None
self._schedule(delay_id, user_localpart, delay)

elif enum_action == _UpdateDelayedEventAction.SEND:
args, removed_timeout_delay_ids = await self.store.pop_event(delay_id, user_localpart)
args, removed_timeout_delay_ids = await self.store.pop_event(
delay_id, user_localpart
)

for timeout_delay_id in removed_timeout_delay_ids:
self._unschedule(timeout_delay_id, user_localpart)
await self._send_event(user_localpart, *args)

async def _send_on_timeout(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None:
async def _send_on_timeout(
self, delay_id: DelayID, user_localpart: UserLocalpart
) -> None:
del self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)]

async with self._get_delay_context(delay_id, user_localpart):
try:
args, removed_timeout_delay_ids = await self.store.pop_event(delay_id, user_localpart)
args, removed_timeout_delay_ids = await self.store.pop_event(
delay_id, user_localpart
)
except NotFoundError:
logger.debug(
"delay_id %s for local user %s was removed after it timed out, but before it was sent on timeout",
delay_id, user_localpart,
delay_id,
user_localpart,
)
return

Expand All @@ -268,27 +289,36 @@ def _schedule(
delay_sec = delay / 1000

logger.info(
"Scheduling delayed event %s for local user %s to be sent in %.3fs", delay_id, user_localpart, delay_sec
)

self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] = self.clock.call_later(
delay_sec,
run_as_background_process,
"_send_on_timeout",
self._send_on_timeout,
"Scheduling delayed event %s for local user %s to be sent in %.3fs",
delay_id,
user_localpart,
delay_sec,
)

self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] = (
self.clock.call_later(
delay_sec,
run_as_background_process,
"_send_on_timeout",
self._send_on_timeout,
delay_id,
user_localpart,
)
)

def _unschedule(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None:
delayed_call = self._delayed_calls.pop(_DelayedCallKey(delay_id, user_localpart))
delayed_call = self._delayed_calls.pop(
_DelayedCallKey(delay_id, user_localpart)
)
self.clock.cancel_call_later(delayed_call)

async def get_all_for_user(self, requester: Requester) -> List[JsonDict]:
"""Return all pending delayed events requested by the given user."""
await self.request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
return await self.store.get_all_for_user(UserLocalpart(requester.user.localpart))
return await self.store.get_all_for_user(
UserLocalpart(requester.user.localpart)
)

async def _send_event(
self,
Expand Down Expand Up @@ -350,13 +380,14 @@ def _get_current_ts(self) -> Timestamp:
return Timestamp(self.clock.time_msec())

@asynccontextmanager
async def _get_delay_context(self, delay_id: DelayID, user_localpart: UserLocalpart) -> AsyncIterator[None]:
async def _get_delay_context(
self, delay_id: DelayID, user_localpart: UserLocalpart
) -> AsyncIterator[None]:
await self._initialized_from_db
# TODO: Use parenthesized context manager once the minimum supported Python version is 3.10
async with\
self._state_lock.read(_STATE_LOCK_KEY),\
self._linearizer.queue(_DelayedCallKey(delay_id, user_localpart))\
:
async with self._state_lock.read(_STATE_LOCK_KEY), self._linearizer.queue(
_DelayedCallKey(delay_id, user_localpart)
):
yield

def _get_state_context(self) -> AsyncContextManager:
Expand Down
53 changes: 32 additions & 21 deletions synapse/storage/databases/main/delayed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@

from binascii import crc32
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
NewType,
Optional,
Set,
Tuple,
from typing import TYPE_CHECKING, Any, Dict, List, NewType, Optional, Set, Tuple

from synapse.api.errors import (
Codes,
InvalidAPICallError,
NotFoundError,
StoreError,
SynapseError,
)

from synapse.api.errors import Codes, InvalidAPICallError, NotFoundError, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -127,8 +124,13 @@ def add_txn(txn: LoggingTransaction) -> DelayID:
txn.execute(
sql,
(
delay_id, user_localpart, current_ts,
room_id.to_string(), event_type, state_key, origin_server_ts,
delay_id,
user_localpart,
current_ts,
room_id.to_string(),
event_type,
state_key,
origin_server_ts,
json_encoder.encode(content),
),
)
Expand All @@ -151,7 +153,10 @@ def add_txn(txn: LoggingTransaction) -> DelayID:
FROM delayed_events
WHERE delay_id = ? AND user_localpart = ?
""",
(delay_id, user_localpart,)
(
delay_id,
user_localpart,
),
)
row = txn.fetchone()
assert row is not None
Expand Down Expand Up @@ -185,7 +190,8 @@ def add_txn(txn: LoggingTransaction) -> DelayID:
""",
(
delay_rowid,
parent_id, user_localpart,
parent_id,
user_localpart,
),
)
except Exception as e:
Expand Down Expand Up @@ -355,14 +361,19 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[
""",
(current_ts,),
)
remaining_timeout_delays = [(
DelayID(row[0]),
UserLocalpart(row[1]),
Delay(row[2]),
) for row in txn]
remaining_timeout_delays = [
(
DelayID(row[0]),
UserLocalpart(row[1]),
Delay(row[2]),
)
for row in txn
]
return events, remaining_timeout_delays

return await self.db_pool.runInteraction("process_all_delays", process_all_delays_txn)
return await self.db_pool.runInteraction(
"process_all_delays", process_all_delays_txn
)

async def pop_event(
self,
Expand Down

0 comments on commit 6b06bf0

Please sign in to comment.