Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pkarr): async_compat in tokio runtim + cleanup relay inflight request #127

Merged
merged 8 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkarr/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pkarr"
version = "3.3.1"
version = "3.3.3"
authors = ["Nuh <[email protected]>"]
edition = "2021"
description = "Public-Key Addressable Resource Records (Pkarr); publish and resolve DNS records over Mainline DHT"
Expand Down
41 changes: 27 additions & 14 deletions pkarr/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod tests_web;

use futures_lite::{Stream, StreamExt};
use pubky_timestamp::Timestamp;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::{hash::Hash, num::NonZeroUsize};
Expand Down Expand Up @@ -225,12 +226,7 @@ impl Client {
signed_packet: &SignedPacket,
cas: Option<Timestamp>,
) -> Result<(), PublishError> {
#[cfg(not(wasm_browser))]
{
async_compat::Compat::new(self.publish_inner(signed_packet, cas)).await
}
#[cfg(wasm_browser)]
self.publish_inner(signed_packet, cas).await
async_compat_if_necessary(self.publish_inner(signed_packet, cas)).await
}

// === Resolve ===
Expand All @@ -242,12 +238,7 @@ impl Client {
/// If you want to get the most recent version of a [SignedPacket],
/// you should use [Self::resolve_most_recent].
pub async fn resolve(&self, public_key: &PublicKey) -> Option<SignedPacket> {
#[cfg(not(wasm_browser))]
{
async_compat::Compat::new(self.resolve_inner(public_key)).await
}
#[cfg(wasm_browser)]
self.resolve_inner(public_key).await
async_compat_if_necessary(self.resolve_inner(public_key)).await
}

/// Returns the most recent [SignedPacket] found after querying all
Expand Down Expand Up @@ -342,11 +333,19 @@ impl Client {

#[cfg(all(dht, relays))]
return if dht_future.is_some() && relays_future.is_some() {
futures_lite::future::or(
let result = futures_lite::future::or(
dht_future.expect("infallible"),
relays_future.expect("infallible"),
)
.await
.await;

self.0
.relays
.as_ref()
.expect("infallible")
.cancel_publish(&signed_packet.public_key());

result
} else if dht_future.is_some() {
dht_future.expect("infallible").await
} else {
Expand Down Expand Up @@ -662,3 +661,17 @@ impl From<PutMutableError> for PublishError {
}
}
}

async fn async_compat_if_necessary<T, O>(fut: T) -> O
where
T: Future<Output = O>,
{
#[cfg(not(wasm_browser))]
{
if tokio::runtime::Handle::try_current().is_err() {
return async_compat::Compat::new(fut).await;
}
}

fut.await
}
102 changes: 61 additions & 41 deletions pkarr/src/client/relays.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ impl RelaysClient {
.expect("infallible")
}

/// Cancel an inflight publish request.
pub fn cancel_publish(&self, public_key: &PublicKey) {
self.inflight_publish.cancel_request(public_key);
}

#[cfg(not(wasm_browser))]
pub fn resolve(
&self,
Expand Down Expand Up @@ -208,6 +213,15 @@ impl InflightPublishRequests {
Ok(())
}

pub fn cancel_request(&self, public_key: &PublicKey) {
let mut inflight = self
.requests
.write()
.expect("InflightPublishRequests write lock");

inflight.remove(public_key);
}

pub fn add_result(
&mut self,
public_key: &PublicKey,
Expand All @@ -219,28 +233,31 @@ impl InflightPublishRequests {
}
}

/// Returns true if request is done.
fn add_success(&self, public_key: &PublicKey) -> Result<bool, PublishError> {
let mut inflight = self
.requests
.write()
.expect("InflightPublishRequests write lock");

let request = inflight.get_mut(public_key).expect("infallible");
let majority = (self.relays_count / 2) + 1;
if let Some(request) = inflight.get_mut(public_key) {
let majority = (self.relays_count / 2) + 1;

request.success_count += 1;
request.success_count += 1;

if self.done(request) {
return Ok(true);
} else if request.success_count >= majority {
inflight.remove(public_key);
if self.done(request) || request.success_count >= majority {
inflight.remove(public_key);

return Ok(true);
Ok(true)
} else {
Ok(false)
}
} else {
Ok(true)
}

Ok(false)
}

/// Returns true if request is done.
fn add_error(
&mut self,
public_key: &PublicKey,
Expand All @@ -251,47 +268,50 @@ impl InflightPublishRequests {
.write()
.expect("InflightPublishRequests write lock");

let request = inflight.get_mut(public_key).expect("infallible");
let majority = (self.relays_count / 2) + 1;
if let Some(request) = inflight.get_mut(public_key) {
let majority = (self.relays_count / 2) + 1;

// Add error, and return early error if necessary.
{
let count = request.errors.get(&error).unwrap_or(&0) + 1;

if count >= majority
&& matches!(
error,
PublishError::Concurrency(ConcurrencyError::NotMostRecent)
) | matches!(
error,
PublishError::Concurrency(ConcurrencyError::CasFailed)
)
// Add error, and return early error if necessary.
{
inflight.remove(public_key);
let count = request.errors.get(&error).unwrap_or(&0) + 1;

if count >= majority
&& matches!(
error,
PublishError::Concurrency(ConcurrencyError::NotMostRecent)
) | matches!(
error,
PublishError::Concurrency(ConcurrencyError::CasFailed)
)
{
inflight.remove(public_key);

return Err(error);
}

return Err(error);
request.errors.insert(error, count);
}

request.errors.insert(error, count);
}
if self.done(request) {
let request = inflight.remove(public_key).expect("infallible");

if self.done(request) {
let request = inflight.remove(public_key).expect("infallible");
if request.success_count >= majority {
Ok(true)
} else {
let most_common_error = request
.errors
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(error, _)| error)
.expect("infallible");

if request.success_count >= majority {
Ok(true)
Err(most_common_error)
}
} else {
let most_common_error = request
.errors
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(error, _)| error)
.expect("infallible");

Err(most_common_error)
Ok(false)
}
} else {
Ok(false)
Ok(true)
}
}

Expand Down
29 changes: 29 additions & 0 deletions pkarr/src/client/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,32 @@ async fn zero_cache_size(#[case] networks: Networks) {
let resolved = b.resolve(&keypair.public_key()).await.unwrap();
assert_eq!(resolved.as_bytes(), signed_packet.as_bytes());
}

#[rstest]
#[case::both_networks(Networks::Both)]
#[cfg_attr(feature = "relays", case::relays(Networks::Relays))]
#[tokio::test]
async fn clear_inflight_requests(#[case] networks: Networks) {
let testnet = mainline::Testnet::new(10).unwrap();
let relay = Relay::run_test(&testnet).await.unwrap();

let client = builder(&relay, &testnet, networks).build().unwrap();

let keypair = Keypair::random();

let signed_packet_builder =
SignedPacket::builder().txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 30);

client
.publish(&signed_packet_builder.clone().sign(&keypair).unwrap(), None)
.await
.unwrap();

tokio::time::sleep(Duration::from_millis(200)).await;

// If there was a memory leak, we would get a `ConflictRisk` error instead.
client
.publish(&signed_packet_builder.sign(&keypair).unwrap(), None)
.await
.unwrap();
}
2 changes: 1 addition & 1 deletion pkarr/src/extra/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl ServerCertVerifier for CertVerifier {
// This won't be necessary if Reqwest enabled us to create a rustls configuration
// per connection.
//
// TODO: update this Reqwest enabled this.
// TODO: update after Reqwest enables this.
let stream = self.0.resolve_https_endpoints(&qname);
pin!(stream);
for endpoint in block_on(stream) {
Expand Down
Loading
Loading