From d2049acb5569ad9d5ba244939bcce50d536002c7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= <g@leirbag.net>
Date: Thu, 26 Oct 2023 23:53:20 +0200
Subject: [PATCH] More refactoring to avoid small mistakes

---
 Cargo.toml             |   6 +-
 presage/src/manager.rs | 182 ++++++++++++++++++++---------------------
 2 files changed, 90 insertions(+), 98 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 19c49fa58..8e49fdad3 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -5,6 +5,6 @@ resolver = "2"
 [patch.crates-io]
 curve25519-dalek = { git = 'https://github.com/signalapp/curve25519-dalek', tag = 'signal-curve25519-4.0.0' }
 
-# [patch."https://github.com/whisperfish/libsignal-service-rs.git"]
-# libsignal-service = { path = "../libsignal-service-rs/libsignal-service" }
-# libsignal-service-hyper = { path = "../libsignal-service-rs/libsignal-service-hyper" }
+[patch."https://github.com/whisperfish/libsignal-service-rs.git"]
+libsignal-service = { path = "../libsignal-service-rs/libsignal-service" }
+libsignal-service-hyper = { path = "../libsignal-service-rs/libsignal-service-hyper" }
diff --git a/presage/src/manager.rs b/presage/src/manager.rs
index 8d19cbbc7..81e6c52a9 100644
--- a/presage/src/manager.rs
+++ b/presage/src/manager.rs
@@ -16,7 +16,6 @@ use rand::{
 use serde::{Deserialize, Serialize};
 use url::Url;
 
-use libsignal_service::push_service::{RegistrationMethod, VerificationTransport};
 use libsignal_service::{
     attachment_cipher::decrypt_in_place,
     cipher,
@@ -44,6 +43,10 @@ use libsignal_service::{
     AccountManager, Profile, ServiceAddress,
 };
 use libsignal_service::{messagepipe::Incoming, proto::EditMessage};
+use libsignal_service::{
+    messagepipe::MessagePipe,
+    push_service::{RegistrationMethod, VerificationTransport},
+};
 use libsignal_service_hyper::push_service::HyperPushService;
 
 use crate::cache::CacheCell;
@@ -391,9 +394,9 @@ impl<C: Store> Manager<C, Linking> {
         manager.config_store.save_state(&manager.state)?;
 
         match (
-            manager.register_pre_keys().await,
             manager.set_account_attributes().await,
-            manager.sync_contacts().await,
+            manager.register_pre_keys().await,
+            manager.request_initial_sync().await,
         ) {
             (Err(e), _, _) | (_, Err(e), _) => {
                 // clear the entire store on any error, there's no possible recovery here
@@ -558,13 +561,23 @@ impl<C: Store> Manager<C, Registered> {
             manager.set_account_attributes().await?;
         }
 
+        let credentials = manager.credentials()?;
+        manager.state.identified_websocket.lock().replace(
+            manager
+                .identified_push_service()?
+                .ws("/v1/websocket/", &[], Some(credentials), true)
+                .await?,
+        );
+
         Ok(manager)
     }
 
     async fn register_pre_keys(&mut self) -> Result<(), Error<C::Error>> {
         trace!("registering pre keys");
-        let mut account_manager =
-            AccountManager::new(self.push_service()?, Some(self.state.profile_key));
+        let mut account_manager = AccountManager::new(
+            self.identified_push_service()?,
+            Some(self.state.profile_key),
+        );
 
         let (pre_keys_offset_id, next_signed_pre_key_id, next_pq_pre_key_id) = account_manager
             .update_pre_key_bundle(
@@ -590,8 +603,10 @@ impl<C: Store> Manager<C, Registered> {
 
     async fn set_account_attributes(&mut self) -> Result<(), Error<C::Error>> {
         trace!("setting account attributes");
-        let mut account_manager =
-            AccountManager::new(self.push_service()?, Some(self.state.profile_key));
+        let mut account_manager = AccountManager::new(
+            self.identified_push_service()?,
+            Some(self.state.profile_key),
+        );
 
         let pni_registration_id = if let Some(pni_registration_id) = self.state.pni_registration_id
         {
@@ -637,14 +652,8 @@ impl<C: Store> Manager<C, Registered> {
         Ok(())
     }
 
-    async fn sync_contacts(&mut self) -> Result<(), Error<C::Error>> {
-        let messages = self
-            .receive_messages_stream(ReceivingMode::InitialSync)
-            .await?;
-        pin_mut!(messages);
-        while let Some(_msg) = messages.next().await {}
-
-        self.request_configuration_sync().await?;
+    async fn request_initial_sync(&mut self) -> Result<(), Error<C::Error>> {
+        self.request_keys_sync().await?;
         self.request_block_sync().await?;
         self.request_contacts_sync().await?;
 
@@ -666,44 +675,23 @@ impl<C: Store> Manager<C, Registered> {
     /// processed when they're received using the `MessageReceiver`.
     pub async fn request_contacts_sync(&mut self) -> Result<(), Error<C::Error>> {
         trace!("requesting contacts sync");
-        let sync_message = SyncMessage {
-            request: Some(sync_message::Request {
-                r#type: Some(sync_message::request::Type::Contacts as i32),
-            }),
-            ..Default::default()
-        };
-
-        self.send_message(self.state.service_ids.aci, sync_message)
+        self.send_message(self.state.service_ids.aci, SyncMessage::request_contacts())
             .await?;
 
         Ok(())
     }
 
-    async fn request_block_sync(&mut self) -> Result<(), Error<C::Error>> {
-        trace!("requesting blocked sync");
-        let sync_message = SyncMessage {
-            request: Some(sync_message::Request {
-                r#type: Some(sync_message::request::Type::Blocked as i32),
-            }),
-            ..Default::default()
-        };
-
-        self.send_message(self.state.service_ids.aci, sync_message)
+    async fn request_keys_sync(&mut self) -> Result<(), Error<C::Error>> {
+        trace!("requesting keys sync");
+        self.send_message(self.state.service_ids.aci, SyncMessage::request_keys())
             .await?;
 
         Ok(())
     }
 
-    async fn request_configuration_sync(&mut self) -> Result<(), Error<C::Error>> {
-        trace!("requesting configuration sync");
-        let sync_message = SyncMessage {
-            request: Some(sync_message::Request {
-                r#type: Some(sync_message::request::Type::Configuration as i32),
-            }),
-            ..Default::default()
-        };
-
-        self.send_message(self.state.service_ids.aci, sync_message)
+    async fn request_block_sync(&mut self) -> Result<(), Error<C::Error>> {
+        trace!("requesting blocked sync");
+        self.send_message(self.state.service_ids.aci, SyncMessage::request_blocked())
             .await?;
 
         Ok(())
@@ -729,7 +717,7 @@ impl<C: Store> Manager<C, Registered> {
 
         if needs_renewal(self.state.unidentified_sender_certificate.as_ref()) {
             let sender_certificate = self
-                .push_service()?
+                .identified_push_service()?
                 .get_uuid_only_sender_certificate()
                 .await?;
 
@@ -750,7 +738,7 @@ impl<C: Store> Manager<C, Registered> {
         token: &str,
         captcha: &str,
     ) -> Result<(), Error<C::Error>> {
-        let mut account_manager = AccountManager::new(self.push_service()?, None);
+        let mut account_manager = AccountManager::new(self.identified_push_service()?, None);
         account_manager
             .submit_recaptcha_challenge(token, captcha)
             .await?;
@@ -764,7 +752,7 @@ impl<C: Store> Manager<C, Registered> {
 
     /// Fetches basic information on the registered device.
     pub async fn whoami(&self) -> Result<WhoAmIResponse, Error<C::Error>> {
-        Ok(self.push_service()?.whoami().await?)
+        Ok(self.identified_push_service()?.whoami().await?)
     }
 
     /// Fetches the profile (name, about, status emoji) of the registered user.
@@ -784,7 +772,8 @@ impl<C: Store> Manager<C, Registered> {
             return Ok(profile);
         }
 
-        let mut account_manager = AccountManager::new(self.push_service()?, Some(profile_key));
+        let mut account_manager =
+            AccountManager::new(self.identified_push_service()?, Some(profile_key));
 
         let profile = account_manager.retrieve_profile(uuid.into()).await?;
 
@@ -840,30 +829,14 @@ impl<C: Store> Manager<C, Registered> {
     async fn receive_messages_encrypted(
         &mut self,
     ) -> Result<impl Stream<Item = Result<Incoming, ServiceError>>, Error<C::Error>> {
-        let credentials = self.credentials()?.ok_or(Error::NotYetRegisteredError)?;
-        let allow_stories = false;
-        let pipe = MessageReceiver::new(self.push_service()?)
-            .create_message_pipe(credentials, allow_stories)
-            .await?;
-
-        let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
-        let mut unidentified_push_service =
-            HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string());
-        let unidentified_ws = unidentified_push_service
-            .ws("/v1/websocket/", &[], None, false)
-            .await?;
-        self.state.identified_websocket.lock().replace(pipe.ws());
-        self.state
-            .unidentified_websocket
-            .lock()
-            .replace(unidentified_ws);
-
-        Ok(pipe.stream())
+        let credentials: ServiceCredentials = self.credentials()?;
+        let identified_ws = self.identified_websocket().await?;
+        Ok(MessagePipe::from_socket(identified_ws, credentials).stream())
     }
 
     /// Starts receiving and storing messages.
     ///
-    /// * `stop_on_initial_sync` [unstable API] - receive messages until the initial sync is over, or forever.
+    /// * `stop_on_initial_sync` [unstable API] - receive messages until the initial sync is over, or forever
     ///    It is essential to synchronize the client once before you try to send, to make sure you have all the updated keys and sessions.
     ///
     /// Returns a [futures::Stream] of messages to consume. Messages will also be stored by the implementation of the [Store].
@@ -883,7 +856,7 @@ impl<C: Store> Manager<C, Registered> {
         let groups_credentials_cache = InMemoryCredentialsCache::default();
         let groups_manager = GroupsManager::new(
             self.state.service_ids.clone(),
-            self.push_service()?,
+            self.identified_push_service()?,
             groups_credentials_cache,
             server_public_params,
         );
@@ -906,7 +879,7 @@ impl<C: Store> Manager<C, Registered> {
 
         let init = StreamState {
             encrypted_messages: Box::pin(self.receive_messages_encrypted().await?),
-            message_receiver: MessageReceiver::new(self.push_service()?),
+            message_receiver: MessageReceiver::new(self.identified_push_service()?),
             service_cipher: self.new_service_cipher()?,
             config_store: self.config_store.clone(),
             groups_manager: self.groups_manager()?,
@@ -1100,7 +1073,7 @@ impl<C: Store> Manager<C, Registered> {
 
     /// Uploads attachments prior to linking them in a message.
     pub async fn upload_attachments(
-        &self,
+        &mut self,
         attachments: Vec<(AttachmentSpec, Vec<u8>)>,
     ) -> Result<Vec<Result<AttachmentPointer, AttachmentUploadError>>, Error<C::Error>> {
         if attachments.is_empty() {
@@ -1210,7 +1183,7 @@ impl<C: Store> Manager<C, Registered> {
         &self,
         attachment_pointer: &AttachmentPointer,
     ) -> Result<Vec<u8>, Error<C::Error>> {
-        let mut service = self.push_service()?;
+        let mut service = self.identified_push_service()?;
         let mut attachment_stream = service.get_attachment(attachment_pointer).await?;
 
         // We need the whole file for the crypto to check out
@@ -1240,56 +1213,75 @@ impl<C: Store> Manager<C, Registered> {
         Ok(())
     }
 
-    fn credentials(&self) -> Result<Option<ServiceCredentials>, Error<C::Error>> {
-        Ok(Some(ServiceCredentials {
+    fn credentials(&self) -> Result<ServiceCredentials, Error<C::Error>> {
+        Ok(ServiceCredentials {
             uuid: Some(self.state.service_ids.aci),
             phonenumber: self.state.phone_number.clone(),
             password: Some(self.state.password.clone()),
             signaling_key: Some(self.state.signaling_key),
             device_id: self.state.device_id,
-        }))
+        })
     }
 
     /// Returns a clone of a cached push service.
     ///
     /// If no service is yet cached, it will create and cache one.
-    fn push_service(&self) -> Result<HyperPushService, Error<C::Error>> {
+    fn identified_push_service(&self) -> Result<HyperPushService, Error<C::Error>> {
         self.state.push_service_cache.get(|| {
             let credentials = self.credentials()?;
             let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
-
             Ok(HyperPushService::new(
                 service_configuration,
-                credentials,
+                Some(credentials),
                 crate::USER_AGENT.to_string(),
             ))
         })
     }
 
+    fn unidentified_push_service(&self) -> HyperPushService {
+        let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
+        HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string())
+    }
+
+    async fn identified_websocket(&mut self) -> Result<SignalWebSocket, Error<C::Error>> {
+        let mut lock = self.state.identified_websocket.lock();
+        if let Some(identified_ws) = lock.as_ref() {
+            Ok(identified_ws.clone())
+        } else {
+            let credentials = self.credentials()?;
+            let ws = self
+                .identified_push_service()?
+                .ws("/v1/websocket/", &[], Some(credentials), true)
+                .await?;
+            lock.replace(ws.clone());
+            Ok(ws)
+        }
+    }
+
+    async fn unidentified_websocket(&mut self) -> Result<SignalWebSocket, Error<C::Error>> {
+        let mut lock = self.state.unidentified_websocket.lock();
+        if let Some(unidentified_ws) = lock.as_ref() {
+            Ok(unidentified_ws.clone())
+        } else {
+            let ws = self
+                .unidentified_push_service()
+                .ws("/v1/websocket/", &[], None, true)
+                .await?;
+            lock.replace(ws.clone());
+            Ok(ws)
+        }
+    }
+
     /// Creates a new message sender.
-    async fn new_message_sender(&self) -> Result<MessageSender<C>, Error<C::Error>> {
+    async fn new_message_sender(&mut self) -> Result<MessageSender<C>, Error<C::Error>> {
         let local_addr = ServiceAddress {
             uuid: self.state.service_ids.aci,
         };
 
-        let identified_websocket = self
-            .state
-            .identified_websocket
-            .lock()
-            .clone()
-            .ok_or(Error::MessagePipeNotStarted)?;
-
-        let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
-        let mut unidentified_push_service =
-            HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string());
-        let unidentified_websocket = unidentified_push_service
-            .ws("/v1/websocket/", &[], None, false)
-            .await?;
-
         Ok(MessageSender::new(
-            identified_websocket,
-            unidentified_websocket,
-            self.push_service()?,
+            self.identified_websocket().await?,
+            self.unidentified_websocket().await?,
+            self.identified_push_service()?,
             self.new_service_cipher()?,
             self.rng.clone(),
             self.config_store.clone(),