Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable keep-alive on unidentified websocket #175

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions presage-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use core::fmt;
use std::convert::TryInto;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
use std::time::UNIX_EPOCH;

use anyhow::{anyhow, bail, Context as _};
Expand All @@ -12,6 +11,7 @@ use directories::ProjectDirs;
use env_logger::Env;
use futures::StreamExt;
use futures::{channel::oneshot, future, pin_mut};
use log::warn;
use log::{debug, error, info};
use notify_rust::Notification;
use presage::libsignal_service::content::Reaction;
Expand All @@ -30,8 +30,6 @@ use presage::{
use presage_store_sled::MigrationConflictStrategy;
use presage_store_sled::SledStore;
use tempfile::Builder;
use tokio::task;
use tokio::time::sleep;
use tokio::{
fs,
io::{self, AsyncBufReadExt, BufReader},
Expand Down Expand Up @@ -218,27 +216,9 @@ async fn send<C: Store + 'static>(
..Default::default()
});

let local = task::LocalSet::new();

local
.run_until(async move {
let mut receiving_manager = manager.clone();
task::spawn_local(async move {
if let Err(e) = receive(&mut receiving_manager, false).await {
error!("error while receiving stuff: {e}");
}
});

sleep(Duration::from_secs(5)).await;

manager
.send_message(*uuid, message, timestamp)
.await
.unwrap();

sleep(Duration::from_secs(5)).await;
})
.await;
if let Err(error) = manager.send_message(*uuid, message, timestamp).await {
warn!("possible failure when sending message: {error}");
}

Ok(())
}
Expand Down
142 changes: 82 additions & 60 deletions presage/src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::{
cell::RefCell,
fmt,
ops::RangeBounds,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};

use futures::{channel::mpsc, channel::oneshot, future, pin_mut, AsyncReadExt, Stream, StreamExt};
use log::{debug, error, info, trace, warn};
use parking_lot::Mutex;
use rand::{
distributions::{Alphanumeric, DistString},
rngs::StdRng,
Expand All @@ -16,14 +15,13 @@ use rand::{
use serde::{Deserialize, Serialize};
use url::Url;

use libsignal_service::push_service::{RegistrationMethod, VerificationTransport};
use libsignal_service::{
attachment_cipher::decrypt_in_place,
cipher,
configuration::{ServiceConfiguration, SignalServers, SignalingKey},
content::{ContentBody, DataMessage, DataMessageFlags, Metadata, SyncMessage},
groups_v2::{decrypt_group, Group, GroupsManager, InMemoryCredentialsCache},
messagepipe::ServiceCredentials,
groups_v2::{Group, GroupsManager, InMemoryCredentialsCache},
messagepipe::{MessagePipe, ServiceCredentials},
models::Contact,
prelude::{phonenumber::PhoneNumber, Content, ProfileKey, PushService, Uuid},
proto::{
Expand All @@ -46,6 +44,10 @@ use libsignal_service::{
websocket::SignalWebSocket,
AccountManager, Profile, ServiceAddress,
};
use libsignal_service::{
groups_v2::decrypt_group,
push_service::{RegistrationMethod, VerificationTransport},
};
use libsignal_service_hyper::push_service::HyperPushService;

use crate::cache::CacheCell;
Expand Down Expand Up @@ -95,11 +97,13 @@ pub struct Confirmation {
#[derive(Clone, Serialize, Deserialize)]
pub struct Registered {
#[serde(skip)]
push_service_cache: CacheCell<HyperPushService>,
identified_push_service: CacheCell<HyperPushService>,
#[serde(skip)]
unidentified_push_service: CacheCell<HyperPushService>,
#[serde(skip)]
identified_websocket: Arc<Mutex<Option<SignalWebSocket>>>,
identified_websocket: RefCell<Option<SignalWebSocket>>,
#[serde(skip)]
unidentified_websocket: Arc<Mutex<Option<SignalWebSocket>>>,
unidentified_websocket: RefCell<Option<SignalWebSocket>>,
#[serde(skip)]
unidentified_sender_certificate: Option<SenderCertificate>,

Expand Down Expand Up @@ -130,7 +134,8 @@ pub struct Registered {
impl fmt::Debug for Registered {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Registered")
.field("websocket", &self.identified_websocket.lock().is_some())
.field("signal_servers", &self.signal_servers)
.field("phone_number", &self.phone_number)
.finish_non_exhaustive()
}
}
Expand Down Expand Up @@ -344,7 +349,8 @@ impl<C: Store> Manager<C, Linking> {
{
log::info!("successfully registered device {}", &service_ids);
Ok(Registered {
push_service_cache: CacheCell::default(),
identified_push_service: CacheCell::default(),
unidentified_push_service: CacheCell::default(),
identified_websocket: Default::default(),
unidentified_websocket: Default::default(),
unidentified_sender_certificate: Default::default(),
Expand Down Expand Up @@ -495,7 +501,8 @@ impl<C: Store> Manager<C, Confirmation> {
rng,
config_store: self.config_store,
state: Registered {
push_service_cache: CacheCell::default(),
identified_push_service: CacheCell::default(),
unidentified_push_service: CacheCell::default(),
identified_websocket: Default::default(),
unidentified_websocket: Default::default(),
unidentified_sender_certificate: Default::default(),
Expand Down Expand Up @@ -555,8 +562,10 @@ impl<C: Store> Manager<C, Registered> {

async fn register_pre_keys(&mut self) -> Result<(), Error<C::Error>> {
trace!("registering pre keys");
let mut account_manager =
AccountManager::new(self.push_service()?, Some(self.state.profile_key));
let mut account_manager = AccountManager::new(
self.identified_push_service()?,
Some(self.state.profile_key),
);

let (pre_keys_offset_id, next_signed_pre_key_id, next_pq_pre_key_id) = account_manager
.update_pre_key_bundle(
Expand All @@ -582,8 +591,10 @@ impl<C: Store> Manager<C, Registered> {

async fn set_account_attributes(&mut self) -> Result<(), Error<C::Error>> {
trace!("setting account attributes");
let mut account_manager =
AccountManager::new(self.push_service()?, Some(self.state.profile_key));
let mut account_manager = AccountManager::new(
self.identified_push_service()?,
Some(self.state.profile_key),
);

let pni_registration_id = if let Some(pni_registration_id) = self.state.pni_registration_id
{
Expand Down Expand Up @@ -633,7 +644,7 @@ impl<C: Store> Manager<C, Registered> {
&mut self,
mut messages: impl Stream<Item = Content> + Unpin,
) -> Result<(), Error<C::Error>> {
let mut message_receiver = MessageReceiver::new(self.push_service()?);
let mut message_receiver = MessageReceiver::new(self.identified_push_service()?);
while let Some(Content { body, .. }) = messages.next().await {
if let ContentBody::SynchronizeMessage(SyncMessage {
contacts: Some(contacts),
Expand Down Expand Up @@ -715,7 +726,7 @@ impl<C: Store> Manager<C, Registered> {

if needs_renewal(self.state.unidentified_sender_certificate.as_ref()) {
let sender_certificate = self
.push_service()?
.identified_push_service()?
.get_uuid_only_sender_certificate()
.await?;

Expand All @@ -736,7 +747,7 @@ impl<C: Store> Manager<C, Registered> {
token: &str,
captcha: &str,
) -> Result<(), Error<C::Error>> {
let mut account_manager = AccountManager::new(self.push_service()?, None);
let mut account_manager = AccountManager::new(self.identified_push_service()?, None);
account_manager
.submit_recaptcha_challenge(token, captcha)
.await?;
Expand All @@ -750,7 +761,7 @@ impl<C: Store> Manager<C, Registered> {

/// Fetches basic information on the registered device.
pub async fn whoami(&self) -> Result<WhoAmIResponse, Error<C::Error>> {
Ok(self.push_service()?.whoami().await?)
Ok(self.identified_push_service()?.whoami().await?)
}

/// Fetches the profile (name, about, status emoji) of the registered user.
Expand All @@ -770,7 +781,8 @@ impl<C: Store> Manager<C, Registered> {
return Ok(profile);
}

let mut account_manager = AccountManager::new(self.push_service()?, Some(profile_key));
let mut account_manager =
AccountManager::new(self.identified_push_service()?, Some(profile_key));

let profile = account_manager.retrieve_profile(uuid.into()).await?;

Expand Down Expand Up @@ -827,23 +839,8 @@ impl<C: Store> Manager<C, Registered> {
&mut self,
) -> Result<impl Stream<Item = Result<Envelope, ServiceError>>, Error<C::Error>> {
let credentials = self.credentials()?.ok_or(Error::NotYetRegisteredError)?;
let allow_stories = false;
let pipe = MessageReceiver::new(self.push_service()?)
.create_message_pipe(credentials, allow_stories)
.await?;

let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
let mut unidentified_push_service =
HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string());
let unidentified_ws = unidentified_push_service
.ws("/v1/websocket/", &[], None, false)
.await?;
self.state.identified_websocket.lock().replace(pipe.ws());
self.state
.unidentified_websocket
.lock()
.replace(unidentified_ws);

let ws = self.identified_websocket().await?;
let pipe = MessagePipe::from_socket(ws, credentials);
Ok(pipe.stream())
}

Expand All @@ -865,7 +862,7 @@ impl<C: Store> Manager<C, Registered> {
let groups_credentials_cache = InMemoryCredentialsCache::default();
let groups_manager = GroupsManager::new(
self.state.service_ids.clone(),
self.push_service()?,
self.identified_push_service()?,
groups_credentials_cache,
server_public_params,
);
Expand Down Expand Up @@ -1161,7 +1158,7 @@ impl<C: Store> Manager<C, Registered> {
&self,
attachment_pointer: &AttachmentPointer,
) -> Result<Vec<u8>, Error<C::Error>> {
let mut service = self.push_service()?;
let mut service = self.identified_push_service()?;
let mut attachment_stream = service.get_attachment(attachment_pointer).await?;

// We need the whole file for the crypto to check out
Expand Down Expand Up @@ -1202,14 +1199,13 @@ impl<C: Store> Manager<C, Registered> {
}))
}

/// Returns a clone of a cached push service.
/// Return a clone of a cached push service.
///
/// If no service is yet cached, it will create and cache one.
fn push_service(&self) -> Result<HyperPushService, Error<C::Error>> {
self.state.push_service_cache.get(|| {
fn identified_push_service(&self) -> Result<HyperPushService, Error<C::Error>> {
self.state.identified_push_service.get(|| {
let credentials = self.credentials()?;
let service_configuration: ServiceConfiguration = self.state.signal_servers.into();

Ok(HyperPushService::new(
service_configuration,
credentials,
Expand All @@ -1218,30 +1214,56 @@ impl<C: Store> Manager<C, Registered> {
})
}

/// Return a clone of a cached _unidentified_ push service.
fn unidentified_push_service(&self) -> Result<HyperPushService, Error<C::Error>> {
self.state.unidentified_push_service.get(|| {
let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
Ok(HyperPushService::new(
service_configuration,
None,
crate::USER_AGENT.to_string(),
))
})
}

async fn identified_websocket(&self) -> Result<SignalWebSocket, Error<C::Error>> {
let socket = self.state.identified_websocket.borrow().clone();
if let Some(ws) = socket {
return Ok(ws);
}

let ws = self
.identified_push_service()?
.ws("/v1/websocket/", &[], self.credentials()?, true)
.await?;
self.state.identified_websocket.replace(Some(ws.clone()));
Ok(ws)
}

async fn unidentified_websocket(&self) -> Result<SignalWebSocket, Error<C::Error>> {
let socket = self.state.unidentified_websocket.borrow().clone();
if let Some(ws) = socket {
Ok(ws)
} else {
let ws = self
.unidentified_push_service()?
.ws("/v1/websocket/", &[], None, true)
.await?;
self.state.unidentified_websocket.replace(Some(ws.clone()));
Ok(ws)
}
}

/// Creates a new message sender.
async fn new_message_sender(&self) -> Result<MessageSender<C>, Error<C::Error>> {
let local_addr = ServiceAddress {
uuid: self.state.service_ids.aci,
};

let identified_websocket = self
.state
.identified_websocket
.lock()
.clone()
.ok_or(Error::MessagePipeNotStarted)?;

let service_configuration: ServiceConfiguration = self.state.signal_servers.into();
let mut unidentified_push_service =
HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string());
let unidentified_websocket = unidentified_push_service
.ws("/v1/websocket/", &[], None, false)
.await?;

Ok(MessageSender::new(
identified_websocket,
unidentified_websocket,
self.push_service()?,
self.identified_websocket().await?,
self.unidentified_websocket().await?,
self.identified_push_service()?,
self.new_service_cipher()?,
self.rng.clone(),
self.config_store.clone(),
Expand Down