Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement recursion #1603

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/chain/chain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use lychee_lib::{chain::RequestChain, ChainResult, ClientBuilder, Handler, Result, Status};
use lychee_lib::{chain::RequestChain, ChainResult, ClientBuilder, Handler, Result, Status, Uri};
use reqwest::{Method, Request};

#[derive(Debug)]
Expand All @@ -16,6 +16,10 @@ impl Handler<Request, Status> for MyHandler {

ChainResult::Next(request)
}

fn subsequent_uris(&self) -> Vec<Uri> {
vec![]
}
}

#[tokio::main]
Expand Down
7 changes: 7 additions & 0 deletions lychee-bin/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc<CookieStoreMutex>>) -
.cookie_jar(cookie_jar.cloned())
.include_fragments(cfg.include_fragments)
.fallback_extensions(cfg.fallback_extensions.clone())
.recursive_domains(match (cfg.recursive, cfg.recursed_domains.clone()) {
(true, domains) if domains.is_empty() => {
todo!("please specify --recursed-domains for now")
}
(true, domains) => domains,
_ => vec![],
})
.build()
.client()
.context("Failed to create request client")
Expand Down
148 changes: 129 additions & 19 deletions lychee-bin/src/commands/check.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::collections::HashSet;
use std::io::{self, Write};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;

use futures::StreamExt;
use futures::{task, StreamExt};
use indicatif::ProgressBar;
use indicatif::ProgressStyle;
use reqwest::Url;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

use lychee_lib::{Client, ErrorKind, Request, Response, Uri};
use lychee_lib::{BasicAuthCredentials, Client, ErrorKind, Request, Response, Uri};
use lychee_lib::{InputSource, Result};
use lychee_lib::{ResponseBody, Status};

Expand All @@ -30,9 +31,10 @@ where
S: futures::Stream<Item = Result<Request>>,
{
// Setup
let (send_req, recv_req) = mpsc::channel(params.cfg.max_concurrency);
let (send_resp, recv_resp) = mpsc::channel(params.cfg.max_concurrency);
let max_concurrency = params.cfg.max_concurrency;
let (send_req, recv_req) = mpsc::channel(max_concurrency);
let (send_resp, recv_resp) = mpsc::channel(max_concurrency);
let remaining_requests = Arc::new(AtomicUsize::new(0));

// Measure check time
let start = std::time::Instant::now();
Expand All @@ -56,28 +58,34 @@ where
};

// Start receiving requests
tokio::spawn(request_channel_task(
tokio::spawn(request_to_response(
recv_req,
send_resp,
max_concurrency,
client,
cache,
cache.clone(),
cache_exclude_status,
accept,
params.cfg.recursive,
));

let formatter = get_response_formatter(&params.cfg.mode);

let show_results_task = tokio::spawn(progress_bar_task(
let show_results_task = tokio::spawn(receive_responses(
recv_resp,
send_req.clone(),
remaining_requests.clone(),
params.cfg.max_depth,
cache,
params.cfg.recursive,
params.cfg.verbose,
pb.clone(),
formatter,
stats,
));

// Wait until all messages are sent
send_inputs_loop(params.requests, send_req, pb).await?;
// Fill the request channel with the initial requests
send_requests(params.requests, send_req, pb, remaining_requests).await?;

// Wait until all responses are received
let result = show_results_task.await?;
Expand Down Expand Up @@ -158,47 +166,117 @@ async fn suggest_archived_links(
// drops the `send_req` channel on exit
// required for the receiver task to end, which closes send_resp, which allows
// the show_results_task to finish
async fn send_inputs_loop<S>(
async fn send_requests<S>(
requests: S,
send_req: mpsc::Sender<Result<Request>>,
bar: Option<ProgressBar>,
remaining_requests: Arc<AtomicUsize>,
) -> Result<()>
where
S: futures::Stream<Item = Result<Request>>,
{
tokio::pin!(requests);
println!("--- INITIAL REQUESTS ---");
let mut i = 0;
while let Some(request) = requests.next().await {
// println!("#{} starting request", i);
i += 1;
let request = request?;

if let Some(pb) = &bar {
pb.inc_length(1);
pb.set_message(request.to_string());
};
remaining_requests.fetch_add(1, Ordering::Relaxed);
let uri = request.uri.clone();
// println!("sending request to queue for {}", uri);
send_req
.send(Ok(request))
.await
.expect("Cannot send request");
// println!("sent request to queue for {}", uri);
}
println!("--- END OF INITIAL REQUESTS ---");
Ok(())
}

/// Reads from the request channel and updates the progress bar status
async fn progress_bar_task(
/// Reads from the response channel, updates the progress bar status and (if recursing) sends new requests.
Copy link

@nobkd nobkd Jan 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: recursing -> recurring / recursive

async fn receive_responses(
mut recv_resp: mpsc::Receiver<Response>,
req_send: mpsc::Sender<Result<Request>>,
remaining_requests: Arc<AtomicUsize>,
max_recursion_depth: Option<usize>,
cache: Arc<Cache>,
recurse: bool,
verbose: Verbosity,
pb: Option<ProgressBar>,
formatter: Box<dyn ResponseFormatter>,
mut stats: ResponseStats,
) -> Result<(Option<ProgressBar>, ResponseStats)> {
let mut i = 0;
while let Some(response) = recv_resp.recv().await {
// println!(
// "starting response #{} out of {}",
// i,
// remaining_requests.load(Ordering::Relaxed),
// );
// println!("#{} received response from queue for {}", i, response.1.uri);
i += 1;
show_progress(
&mut io::stderr(),
pb.as_ref(),
&response,
formatter.as_ref(),
&verbose,
)?;
stats.add(response);

if recurse
&& max_recursion_depth
.map(|limit| response.1.recursion_level < limit)
.unwrap_or(true)
{
println!(
"recursing: {} has depth {} < {}",
response.1.uri,
response.1.recursion_level,
max_recursion_depth.unwrap()
);
tokio::spawn((|requests: Vec<Request>,
req_send: mpsc::Sender<Result<Request>>,
remaining_requests: Arc<AtomicUsize>,
pb: Option<ProgressBar>| async move {
for request in requests {
let uri = request.uri.clone().to_string();
req_send
.send(Ok(request))
.await
.expect("Cannot send request");
remaining_requests.fetch_add(1, Ordering::Relaxed);

if let Some(bar) = &pb {
bar.inc_length(1);
bar.set_message(uri);
}
}
})(
response.subsequent_requests(|uri| cache.contains_key(uri), None),
req_send.clone(),
remaining_requests.clone(),
pb.clone(),
));
}

remaining_requests.fetch_sub(1, Ordering::Relaxed);
let remaining_now = remaining_requests.load(Ordering::Relaxed);
// println!("remaining requests: {}", remaining_now);
if remaining_now == 0 {
break;
}

stats.insert(response);
// println!("finished response #{}", i);
}
// println!("Processed {} responses", i);
Ok((pb, stats))
}

Expand All @@ -215,20 +293,29 @@ fn init_progress_bar(initial_message: &'static str) -> ProgressBar {
bar
}

async fn request_channel_task(
async fn request_to_response(
recv_req: mpsc::Receiver<Result<Request>>,
send_resp: mpsc::Sender<Response>,
max_concurrency: usize,
client: Client,
cache: Arc<Cache>,
cache_exclude_status: HashSet<u16>,
accept: HashSet<u16>,
recursive: bool,
) {
// while let Some(request) = recv_req.recv().await {
StreamExt::for_each_concurrent(
ReceiverStream::new(recv_req),
max_concurrency,
|request: Result<Request>| async {
let request = request.expect("cannot read request");
if recursive && cache.contains_key(&request.uri) {
return;
}
// let uri = request.uri.clone();
// println!("handling request {}", uri);
// let uri = request.uri.clone();
// println!("received request for {}", uri);
let response = handle(
&client,
cache.clone(),
Expand All @@ -238,10 +325,18 @@ async fn request_channel_task(
)
.await;

// println!("sending response to queue for {}", uri);
send_resp
.send(response)
.await
.expect("cannot send response to queue");
.expect("Cannot send response");
// if let Err(_) = timeout(Duration::from_millis(500), send_resp.send(response)).await {
// println!(
// "Timeout occurred while sending response to queue for {}",
// uri
// );
// }
// println!("sent response to queue for {}", uri);
},
)
.await;
Expand All @@ -256,12 +351,15 @@ async fn check_url(client: &Client, request: Request) -> Response {
// Request was not cached; run a normal check
let uri = request.uri.clone();
let source = request.source.clone();
let depth = request.recursion_level;
client.check(request).await.unwrap_or_else(|e| {
log::error!("Error checking URL {}: Cannot parse URL to URI: {}", uri, e);
Response::new(
uri.clone(),
Status::Error(ErrorKind::InvalidURI(uri.clone())),
source,
vec![],
depth,
)
})
}
Expand All @@ -288,7 +386,15 @@ async fn handle(
// code.
Status::from_cache_status(v.value().status, &accept)
};
return Response::new(uri.clone(), status, request.source);
// TODO: not too sure about it, we never recurse on cached requests
// println!("Found cached response for {}", uri);
return Response::new(
uri.clone(),
status,
request.source,
vec![],
request.recursion_level,
);
}

// Request was not cached; run a normal check
Expand Down Expand Up @@ -366,7 +472,7 @@ fn get_failed_urls(stats: &mut ResponseStats) -> Vec<(InputSource, Url)> {
.iter()
.flat_map(|(source, set)| {
set.iter()
.map(move |ResponseBody { uri, status: _ }| (source, uri))
.map(move |ResponseBody { uri, .. }| (source, uri))
})
.filter_map(|(source, uri)| {
if uri.is_data() || uri.is_mail() || uri.is_file() {
Expand Down Expand Up @@ -397,6 +503,8 @@ mod tests {
Uri::try_from("http://127.0.0.1").unwrap(),
Status::Cached(CacheStatus::Ok(200)),
InputSource::Stdin,
vec![],
0,
);
let formatter = get_response_formatter(&options::OutputMode::Plain);
show_progress(
Expand All @@ -419,6 +527,8 @@ mod tests {
Uri::try_from("http://127.0.0.1").unwrap(),
Status::Cached(CacheStatus::Ok(200)),
InputSource::Stdin,
vec![],
0,
);
let formatter = get_response_formatter(&options::OutputMode::Plain);
show_progress(
Expand All @@ -439,9 +549,9 @@ mod tests {
async fn test_invalid_url() {
let client = ClientBuilder::builder().build().client().unwrap();
let uri = Uri::try_from("http://\"").unwrap();
let response = client.check_website(&uri, None).await.unwrap();
let (status, _) = client.check_website(&uri, None).await.unwrap();
assert!(matches!(
response,
status,
Status::Unsupported(ErrorKind::BuildRequestClient(_))
));
}
Expand Down
2 changes: 2 additions & 0 deletions lychee-bin/src/formatters/response/color.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ mod tests {
ResponseBody {
uri: Uri::try_from(uri).unwrap(),
status,
subsequent_uris: vec![],
recursion_level: 0,
}
}

Expand Down
2 changes: 2 additions & 0 deletions lychee-bin/src/formatters/response/emoji.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ mod emoji_tests {
ResponseBody {
uri: Uri::try_from(uri).unwrap(),
status,
subsequent_uris: vec![],
recursion_level: 0,
}
}

Expand Down
2 changes: 2 additions & 0 deletions lychee-bin/src/formatters/response/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ mod plain_tests {
ResponseBody {
uri: Uri::try_from(uri).unwrap(),
status,
subsequent_uris: vec![],
recursion_level: 0,
}
}

Expand Down
Loading
Loading