diff --git a/presage/src/cache.rs b/presage/src/cache.rs index e60c3c65b..89c485fed 100644 --- a/presage/src/cache.rs +++ b/presage/src/cache.rs @@ -23,13 +23,13 @@ impl Default for CacheCell { } impl CacheCell { - pub fn get(&self, factory: impl FnOnce() -> Result) -> Result { + pub fn get(&self, factory: impl FnOnce() -> T) -> T { let value = match self.cell.replace(None) { Some(value) => value, - None => factory()?, + None => factory(), }; self.cell.set(Some(value.clone())); - Ok(value) + value } } @@ -37,24 +37,13 @@ impl CacheCell { mod tests { use super::*; - use std::convert::Infallible; - #[test] fn test_cache_cell() { let cache: CacheCell = Default::default(); - let value = cache - .get(|| Ok::<_, Infallible>("Hello, World!".to_string())) - .unwrap(); - assert_eq!(value, "Hello, World!"); - let value = cache - .get(|| -> Result { panic!("I should not run") }) - .unwrap(); + let value = cache.get(|| ("Hello, World!".to_string())); assert_eq!(value, "Hello, World!"); - - let value = cache - .get(|| -> Result { panic!("I should not run") }) - .unwrap(); + let value = cache.get(|| panic!("I should not run")); assert_eq!(value, "Hello, World!"); } } diff --git a/presage/src/manager/confirmation.rs b/presage/src/manager/confirmation.rs index d5eb27848..2085da30d 100644 --- a/presage/src/manager/confirmation.rs +++ b/presage/src/manager/confirmation.rs @@ -12,7 +12,6 @@ use log::trace; use rand::rngs::StdRng; use rand::{RngCore, SeedableRng}; -use crate::cache::CacheCell; use crate::manager::registered::RegistrationData; use crate::store::Store; use crate::{Error, Manager}; diff --git a/presage/src/manager/registered.rs b/presage/src/manager/registered.rs index 460674ced..9f3695242 100644 --- a/presage/src/manager/registered.rs +++ b/presage/src/manager/registered.rs @@ -9,7 +9,7 @@ use libsignal_service::attachment_cipher::decrypt_in_place; use libsignal_service::configuration::{ServiceConfiguration, SignalServers, SignalingKey}; use libsignal_service::content::{Content, ContentBody, DataMessageFlags, Metadata}; use libsignal_service::groups_v2::{decrypt_group, Group, GroupsManager, InMemoryCredentialsCache}; -use libsignal_service::messagepipe::{Incoming, ServiceCredentials}; +use libsignal_service::messagepipe::{Incoming, MessagePipe, ServiceCredentials}; use libsignal_service::models::Contact; use libsignal_service::prelude::phonenumber::PhoneNumber; use libsignal_service::prelude::Uuid; @@ -53,7 +53,8 @@ type MessageSender = libsignal_service::prelude::MessageSender, + pub(crate) identified_push_service: CacheCell, + pub(crate) unidentified_push_service: CacheCell, pub(crate) identified_websocket: Arc>>, pub(crate) unidentified_websocket: Arc>>, pub(crate) unidentified_sender_certificate: Option, @@ -70,7 +71,8 @@ impl fmt::Debug for Registered { impl Registered { pub(crate) fn with_data(data: RegistrationData) -> Self { Self { - push_service_cache: CacheCell::default(), + identified_push_service: CacheCell::default(), + unidentified_push_service: CacheCell::default(), identified_websocket: Default::default(), unidentified_websocket: Default::default(), unidentified_sender_certificate: Default::default(), @@ -172,10 +174,76 @@ impl Manager { &self.state.data } + /// Returns a clone of a cached push service (with credentials). + /// + /// If no service is yet cached, it will create and cache one. + fn identified_push_service(&self) -> HyperPushService { + self.state.identified_push_service.get(|| { + HyperPushService::new( + self.state.service_configuration(), + self.credentials(), + crate::USER_AGENT.to_string(), + ) + }) + } + + /// Returns a clone of a cached push service (without credentials). + /// + /// If no service is yet cached, it will create and cache one. + fn unidentified_push_service(&self) -> HyperPushService { + self.state.unidentified_push_service.get(|| { + HyperPushService::new( + self.state.service_configuration(), + None, + crate::USER_AGENT.to_string(), + ) + }) + } + + /// Returns the current identified websocket, or creates a new one + async fn identified_websocket(&self) -> Result> { + let mut identified_ws = self.state.identified_websocket.lock(); + match identified_ws.clone() { + Some(ws) => Ok(ws), + None => { + let keep_alive = true; + let headers = &[("X-Signal-Receive-Stories", "false")]; + let ws = self + .identified_push_service() + .ws("/v1/websocket/", headers, self.credentials(), keep_alive) + .await?; + identified_ws.replace(ws.clone()); + debug!("initialized identified websocket"); + + Ok(ws) + } + } + } + + async fn unidentified_websocket(&self) -> Result> { + let mut unidentified_ws = self.state.unidentified_websocket.lock(); + match unidentified_ws.clone() { + Some(ws) => Ok(ws), + None => { + let keep_alive = true; + let ws = self + .unidentified_push_service() + .ws("/v1/websocket/", &[], None, keep_alive) + .await?; + unidentified_ws.replace(ws.clone()); + debug!("initialized unidentified websocket"); + + Ok(ws) + } + } + } + pub(crate) async fn register_pre_keys(&mut self) -> Result<(), Error> { trace!("registering pre keys"); - let mut account_manager = - AccountManager::new(self.push_service()?, Some(self.state.data.profile_key)); + let mut account_manager = AccountManager::new( + self.identified_push_service(), + Some(self.state.data.profile_key), + ); let (pre_keys_offset_id, next_signed_pre_key_id, next_pq_pre_key_id) = account_manager .update_pre_key_bundle( @@ -199,8 +267,10 @@ impl Manager { pub(crate) async fn set_account_attributes(&mut self) -> Result<(), Error> { trace!("setting account attributes"); - let mut account_manager = - AccountManager::new(self.push_service()?, Some(self.state.data.profile_key)); + let mut account_manager = AccountManager::new( + self.identified_push_service(), + Some(self.state.data.profile_key), + ); let pni_registration_id = if let Some(pni_registration_id) = self.state.data.pni_registration_id { @@ -333,7 +403,7 @@ impl Manager { if needs_renewal(self.state.unidentified_sender_certificate.as_ref()) { let sender_certificate = self - .push_service()? + .identified_push_service() .get_uuid_only_sender_certificate() .await?; @@ -354,7 +424,7 @@ impl Manager { token: &str, captcha: &str, ) -> Result<(), Error> { - let mut account_manager = AccountManager::new(self.push_service()?, None); + let mut account_manager = AccountManager::new(self.identified_push_service(), None); account_manager .submit_recaptcha_challenge(token, captcha) .await?; @@ -363,7 +433,7 @@ impl Manager { /// Fetches basic information on the registered device. pub async fn whoami(&self) -> Result> { - Ok(self.push_service()?.whoami().await?) + Ok(self.identified_push_service().whoami().await?) } /// Fetches the profile (name, about, status emoji) of the registered user. @@ -383,7 +453,8 @@ impl Manager { return Ok(profile); } - let mut account_manager = AccountManager::new(self.push_service()?, Some(profile_key)); + let mut account_manager = + AccountManager::new(self.identified_push_service(), Some(profile_key)); let profile = account_manager.retrieve_profile(uuid.into()).await?; @@ -404,23 +475,8 @@ impl Manager { &mut self, ) -> Result>, Error> { let credentials = self.credentials().ok_or(Error::NotYetRegisteredError)?; - let allow_stories = false; - let pipe = MessageReceiver::new(self.push_service()?) - .create_message_pipe(credentials, allow_stories) - .await?; - - let service_configuration = self.state.service_configuration(); - let mut unidentified_push_service = - HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string()); - let unidentified_ws = unidentified_push_service - .ws("/v1/websocket/", &[], None, false) - .await?; - self.state.identified_websocket.lock().replace(pipe.ws()); - self.state - .unidentified_websocket - .lock() - .replace(unidentified_ws); - + let ws = self.identified_websocket().await?; + let pipe = MessagePipe::from_socket(ws, credentials); Ok(pipe.stream()) } @@ -449,7 +505,7 @@ impl Manager { let groups_credentials_cache = InMemoryCredentialsCache::default(); let groups_manager = GroupsManager::new( self.state.data.service_ids.clone(), - self.push_service()?, + self.identified_push_service(), groups_credentials_cache, server_public_params, ); @@ -472,13 +528,15 @@ impl Manager { let init = StreamState { encrypted_messages: Box::pin(self.receive_messages_encrypted().await?), - message_receiver: MessageReceiver::new(self.push_service()?), + message_receiver: MessageReceiver::new(self.identified_push_service()), service_cipher: self.new_service_cipher()?, store: self.store.clone(), groups_manager: self.groups_manager()?, mode, }; + debug!("starting to consume incoming message stream"); + Ok(futures::stream::unfold(init, |mut state| async move { loop { match state.encrypted_messages.next().await { @@ -763,7 +821,7 @@ impl Manager { &self, attachment_pointer: &AttachmentPointer, ) -> Result, Error> { - let mut service = self.push_service()?; + let mut service = self.identified_push_service(); let mut attachment_stream = service.get_attachment(attachment_pointer).await?; // We need the whole file for the crypto to check out @@ -804,45 +862,19 @@ impl Manager { }) } - /// Returns a clone of a cached push service. - /// - /// If no service is yet cached, it will create and cache one. - fn push_service(&self) -> Result> { - self.state.push_service_cache.get(|| { - Ok(HyperPushService::new( - self.state.service_configuration(), - self.credentials(), - crate::USER_AGENT.to_string(), - )) - }) - } - /// Creates a new message sender. async fn new_message_sender(&self) -> Result, Error> { let local_addr = ServiceAddress { uuid: self.state.data.service_ids.aci, }; - let identified_websocket = self - .state - .identified_websocket - .lock() - .clone() - .ok_or(Error::MessagePipeNotStarted)?; - - let mut unidentified_push_service = HyperPushService::new( - self.state.service_configuration(), - None, - crate::USER_AGENT.to_string(), - ); - let unidentified_websocket = unidentified_push_service - .ws("/v1/websocket/", &[], None, false) - .await?; + let identified_websocket = self.identified_websocket().await?; + let unidentified_websocket = self.unidentified_websocket().await?; Ok(MessageSender::new( identified_websocket, unidentified_websocket, - self.push_service()?, + self.identified_push_service(), self.new_service_cipher()?, self.rng.clone(), self.store.clone(),