Skip to content

Commit

Permalink
Reuse pushservice and websockets (#203)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Féron <[email protected]>
  • Loading branch information
boxdot and gferon authored Nov 16, 2023
1 parent bdc195e commit f291581
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 102 deletions.
21 changes: 5 additions & 16 deletions presage/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,27 @@ impl<T: Clone> Default for CacheCell<T> {
}

impl<T: Clone> CacheCell<T> {
pub fn get<E>(&self, factory: impl FnOnce() -> Result<T, E>) -> Result<T, E> {
pub fn get(&self, factory: impl FnOnce() -> T) -> T {
let value = match self.cell.replace(None) {
Some(value) => value,
None => factory()?,
None => factory(),
};
self.cell.set(Some(value.clone()));
Ok(value)
value
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::convert::Infallible;

#[test]
fn test_cache_cell() {
let cache: CacheCell<String> = Default::default();

let value = cache
.get(|| Ok::<_, Infallible>("Hello, World!".to_string()))
.unwrap();
assert_eq!(value, "Hello, World!");
let value = cache
.get(|| -> Result<String, Infallible> { panic!("I should not run") })
.unwrap();
let value = cache.get(|| ("Hello, World!".to_string()));
assert_eq!(value, "Hello, World!");

let value = cache
.get(|| -> Result<String, Infallible> { panic!("I should not run") })
.unwrap();
let value = cache.get(|| panic!("I should not run"));
assert_eq!(value, "Hello, World!");
}
}
43 changes: 18 additions & 25 deletions presage/src/manager/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use log::trace;
use rand::rngs::StdRng;
use rand::{RngCore, SeedableRng};

use crate::cache::CacheCell;
use crate::manager::registered::RegistrationData;
use crate::store::Store;
use crate::{Error, Manager};
Expand Down Expand Up @@ -122,31 +121,25 @@ impl<S: Store> Manager<S, Confirmation> {
let mut manager = Manager {
rng,
store: self.store,
state: Registered {
push_service_cache: CacheCell::default(),
identified_websocket: Default::default(),
unidentified_websocket: Default::default(),
unidentified_sender_certificate: Default::default(),
data: RegistrationData {
signal_servers: self.state.signal_servers,
device_name: None,
phone_number,
service_ids: ServiceIds {
aci: registered.uuid,
pni: registered.pni,
},
password,
signaling_key,
device_id: None,
registration_id,
pni_registration_id: Some(pni_registration_id),
aci_private_key: aci_identity_key_pair.private_key,
aci_public_key: aci_identity_key_pair.public_key,
pni_private_key: Some(pni_identity_key_pair.private_key),
pni_public_key: Some(pni_identity_key_pair.public_key),
profile_key,
state: Registered::with_data(RegistrationData {
signal_servers: self.state.signal_servers,
device_name: None,
phone_number,
service_ids: ServiceIds {
aci: registered.uuid,
pni: registered.pni,
},
},
password,
signaling_key,
device_id: None,
registration_id,
pni_registration_id: Some(pni_registration_id),
aci_private_key: aci_identity_key_pair.private_key,
aci_public_key: aci_identity_key_pair.public_key,
pni_private_key: Some(pni_identity_key_pair.private_key),
pni_public_key: Some(pni_identity_key_pair.public_key),
profile_key,
}),
};

manager.store.save_registration_data(&manager.state.data)?;
Expand Down
154 changes: 93 additions & 61 deletions presage/src/manager/registered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use libsignal_service::attachment_cipher::decrypt_in_place;
use libsignal_service::configuration::{ServiceConfiguration, SignalServers, SignalingKey};
use libsignal_service::content::{Content, ContentBody, DataMessageFlags, Metadata};
use libsignal_service::groups_v2::{decrypt_group, Group, GroupsManager, InMemoryCredentialsCache};
use libsignal_service::messagepipe::{Incoming, ServiceCredentials};
use libsignal_service::messagepipe::{Incoming, MessagePipe, ServiceCredentials};
use libsignal_service::models::Contact;
use libsignal_service::prelude::phonenumber::PhoneNumber;
use libsignal_service::prelude::Uuid;
Expand Down Expand Up @@ -53,7 +53,8 @@ type MessageSender<S> = libsignal_service::prelude::MessageSender<HyperPushServi
/// Manager state when the client is registered and can send and receive messages from Signal
#[derive(Clone)]
pub struct Registered {
pub(crate) push_service_cache: CacheCell<HyperPushService>,
pub(crate) identified_push_service: CacheCell<HyperPushService>,
pub(crate) unidentified_push_service: CacheCell<HyperPushService>,
pub(crate) identified_websocket: Arc<Mutex<Option<SignalWebSocket>>>,
pub(crate) unidentified_websocket: Arc<Mutex<Option<SignalWebSocket>>>,
pub(crate) unidentified_sender_certificate: Option<SenderCertificate>,
Expand All @@ -70,7 +71,8 @@ impl fmt::Debug for Registered {
impl Registered {
pub(crate) fn with_data(data: RegistrationData) -> Self {
Self {
push_service_cache: CacheCell::default(),
identified_push_service: CacheCell::default(),
unidentified_push_service: CacheCell::default(),
identified_websocket: Default::default(),
unidentified_websocket: Default::default(),
unidentified_sender_certificate: Default::default(),
Expand Down Expand Up @@ -172,10 +174,76 @@ impl<S: Store> Manager<S, Registered> {
&self.state.data
}

/// Returns a clone of a cached push service (with credentials).
///
/// If no service is yet cached, it will create and cache one.
fn identified_push_service(&self) -> HyperPushService {
self.state.identified_push_service.get(|| {
HyperPushService::new(
self.state.service_configuration(),
self.credentials(),
crate::USER_AGENT.to_string(),
)
})
}

/// Returns a clone of a cached push service (without credentials).
///
/// If no service is yet cached, it will create and cache one.
fn unidentified_push_service(&self) -> HyperPushService {
self.state.unidentified_push_service.get(|| {
HyperPushService::new(
self.state.service_configuration(),
None,
crate::USER_AGENT.to_string(),
)
})
}

/// Returns the current identified websocket, or creates a new one
async fn identified_websocket(&self) -> Result<SignalWebSocket, Error<S::Error>> {
let mut identified_ws = self.state.identified_websocket.lock();

Check warning on line 205 in presage/src/manager/registered.rs

View workflow job for this annotation

GitHub Actions / clippy

this `MutexGuard` is held across an `await` point
match identified_ws.clone() {
Some(ws) => Ok(ws),
None => {
let keep_alive = true;
let headers = &[("X-Signal-Receive-Stories", "false")];
let ws = self
.identified_push_service()
.ws("/v1/websocket/", headers, self.credentials(), keep_alive)
.await?;
identified_ws.replace(ws.clone());
debug!("initialized identified websocket");

Ok(ws)
}
}
}

async fn unidentified_websocket(&self) -> Result<SignalWebSocket, Error<S::Error>> {
let mut unidentified_ws = self.state.unidentified_websocket.lock();

Check warning on line 224 in presage/src/manager/registered.rs

View workflow job for this annotation

GitHub Actions / clippy

this `MutexGuard` is held across an `await` point
match unidentified_ws.clone() {
Some(ws) => Ok(ws),
None => {
let keep_alive = true;
let ws = self
.unidentified_push_service()
.ws("/v1/websocket/", &[], None, keep_alive)
.await?;
unidentified_ws.replace(ws.clone());
debug!("initialized unidentified websocket");

Ok(ws)
}
}
}

pub(crate) async fn register_pre_keys(&mut self) -> Result<(), Error<S::Error>> {
trace!("registering pre keys");
let mut account_manager =
AccountManager::new(self.push_service()?, Some(self.state.data.profile_key));
let mut account_manager = AccountManager::new(
self.identified_push_service(),
Some(self.state.data.profile_key),
);

let (pre_keys_offset_id, next_signed_pre_key_id, next_pq_pre_key_id) = account_manager
.update_pre_key_bundle(
Expand All @@ -199,8 +267,10 @@ impl<S: Store> Manager<S, Registered> {

pub(crate) async fn set_account_attributes(&mut self) -> Result<(), Error<S::Error>> {
trace!("setting account attributes");
let mut account_manager =
AccountManager::new(self.push_service()?, Some(self.state.data.profile_key));
let mut account_manager = AccountManager::new(
self.identified_push_service(),
Some(self.state.data.profile_key),
);

let pni_registration_id =
if let Some(pni_registration_id) = self.state.data.pni_registration_id {
Expand Down Expand Up @@ -251,7 +321,7 @@ impl<S: Store> Manager<S, Registered> {
&mut self,
mut messages: impl Stream<Item = Content> + Unpin,
) -> Result<(), Error<S::Error>> {
let mut message_receiver = MessageReceiver::new(self.push_service()?);
let mut message_receiver = MessageReceiver::new(self.identified_push_service());
while let Some(Content { body, .. }) = messages.next().await {
if let ContentBody::SynchronizeMessage(SyncMessage {
contacts: Some(contacts),
Expand Down Expand Up @@ -333,7 +403,7 @@ impl<S: Store> Manager<S, Registered> {

if needs_renewal(self.state.unidentified_sender_certificate.as_ref()) {
let sender_certificate = self
.push_service()?
.identified_push_service()
.get_uuid_only_sender_certificate()
.await?;

Expand All @@ -354,7 +424,7 @@ impl<S: Store> Manager<S, Registered> {
token: &str,
captcha: &str,
) -> Result<(), Error<S::Error>> {
let mut account_manager = AccountManager::new(self.push_service()?, None);
let mut account_manager = AccountManager::new(self.identified_push_service(), None);
account_manager
.submit_recaptcha_challenge(token, captcha)
.await?;
Expand All @@ -363,7 +433,7 @@ impl<S: Store> Manager<S, Registered> {

/// Fetches basic information on the registered device.
pub async fn whoami(&self) -> Result<WhoAmIResponse, Error<S::Error>> {
Ok(self.push_service()?.whoami().await?)
Ok(self.identified_push_service().whoami().await?)
}

/// Fetches the profile (name, about, status emoji) of the registered user.
Expand All @@ -383,7 +453,8 @@ impl<S: Store> Manager<S, Registered> {
return Ok(profile);
}

let mut account_manager = AccountManager::new(self.push_service()?, Some(profile_key));
let mut account_manager =
AccountManager::new(self.identified_push_service(), Some(profile_key));

let profile = account_manager.retrieve_profile(uuid.into()).await?;

Expand All @@ -404,23 +475,8 @@ impl<S: Store> Manager<S, Registered> {
&mut self,
) -> Result<impl Stream<Item = Result<Incoming, ServiceError>>, Error<S::Error>> {
let credentials = self.credentials().ok_or(Error::NotYetRegisteredError)?;
let allow_stories = false;
let pipe = MessageReceiver::new(self.push_service()?)
.create_message_pipe(credentials, allow_stories)
.await?;

let service_configuration = self.state.service_configuration();
let mut unidentified_push_service =
HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string());
let unidentified_ws = unidentified_push_service
.ws("/v1/websocket/", &[], None, false)
.await?;
self.state.identified_websocket.lock().replace(pipe.ws());
self.state
.unidentified_websocket
.lock()
.replace(unidentified_ws);

let ws = self.identified_websocket().await?;
let pipe = MessagePipe::from_socket(ws, credentials);
Ok(pipe.stream())
}

Expand Down Expand Up @@ -449,7 +505,7 @@ impl<S: Store> Manager<S, Registered> {
let groups_credentials_cache = InMemoryCredentialsCache::default();
let groups_manager = GroupsManager::new(
self.state.data.service_ids.clone(),
self.push_service()?,
self.identified_push_service(),
groups_credentials_cache,
server_public_params,
);
Expand All @@ -472,13 +528,15 @@ impl<S: Store> Manager<S, Registered> {

let init = StreamState {
encrypted_messages: Box::pin(self.receive_messages_encrypted().await?),
message_receiver: MessageReceiver::new(self.push_service()?),
message_receiver: MessageReceiver::new(self.identified_push_service()),
service_cipher: self.new_service_cipher()?,
store: self.store.clone(),
groups_manager: self.groups_manager()?,
mode,
};

debug!("starting to consume incoming message stream");

Ok(futures::stream::unfold(init, |mut state| async move {
loop {
match state.encrypted_messages.next().await {
Expand Down Expand Up @@ -763,7 +821,7 @@ impl<S: Store> Manager<S, Registered> {
&self,
attachment_pointer: &AttachmentPointer,
) -> Result<Vec<u8>, Error<S::Error>> {
let mut service = self.push_service()?;
let mut service = self.identified_push_service();
let mut attachment_stream = service.get_attachment(attachment_pointer).await?;

// We need the whole file for the crypto to check out
Expand Down Expand Up @@ -804,45 +862,19 @@ impl<S: Store> Manager<S, Registered> {
})
}

/// Returns a clone of a cached push service.
///
/// If no service is yet cached, it will create and cache one.
fn push_service(&self) -> Result<HyperPushService, Error<S::Error>> {
self.state.push_service_cache.get(|| {
Ok(HyperPushService::new(
self.state.service_configuration(),
self.credentials(),
crate::USER_AGENT.to_string(),
))
})
}

/// Creates a new message sender.
async fn new_message_sender(&self) -> Result<MessageSender<S>, Error<S::Error>> {
let local_addr = ServiceAddress {
uuid: self.state.data.service_ids.aci,
};

let identified_websocket = self
.state
.identified_websocket
.lock()
.clone()
.ok_or(Error::MessagePipeNotStarted)?;

let mut unidentified_push_service = HyperPushService::new(
self.state.service_configuration(),
None,
crate::USER_AGENT.to_string(),
);
let unidentified_websocket = unidentified_push_service
.ws("/v1/websocket/", &[], None, false)
.await?;
let identified_websocket = self.identified_websocket().await?;
let unidentified_websocket = self.unidentified_websocket().await?;

Ok(MessageSender::new(
identified_websocket,
unidentified_websocket,
self.push_service()?,
self.identified_push_service(),
self.new_service_cipher()?,
self.rng.clone(),
self.store.clone(),
Expand Down

0 comments on commit f291581

Please sign in to comment.