Skip to content

Commit

Permalink
Implement --max-depth and fix starvation issue!
Browse files Browse the repository at this point in the history
  • Loading branch information
gwennlbh committed Jan 7, 2025
1 parent 82782cb commit a88f9fa
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 68 deletions.
167 changes: 106 additions & 61 deletions lychee-bin/src/commands/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;

use futures::StreamExt;
use futures::{task, StreamExt};
use indicatif::ProgressBar;
use indicatif::ProgressStyle;
use reqwest::Url;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

use lychee_lib::{Client, ErrorKind, Request, Response, Uri};
use lychee_lib::{BasicAuthCredentials, Client, ErrorKind, Request, Response, Uri};
use lychee_lib::{InputSource, Result};
use lychee_lib::{ResponseBody, Status};

Expand All @@ -34,6 +34,7 @@ where
let max_concurrency = params.cfg.max_concurrency;
let (send_req, recv_req) = mpsc::channel(max_concurrency);
let (send_resp, recv_resp) = mpsc::channel(max_concurrency);
let remaining_requests = Arc::new(AtomicUsize::new(0));

// Measure check time
let start = std::time::Instant::now();
Expand All @@ -56,16 +57,6 @@ where
Some(init_progress_bar("Extracting links"))
};

// Fill the request channel with the initial requests
let remaining_requests = Arc::new(AtomicUsize::new(0));
send_inputs_loop(
params.requests,
send_req.clone(),
pb.clone(),
remaining_requests.clone(),
)
.await?;

// Start receiving requests
tokio::spawn(request_channel_task(
recv_req,
Expand All @@ -81,14 +72,25 @@ where

let show_results_task = tokio::spawn(response_receive_task(
recv_resp,
send_req,
remaining_requests,
send_req.clone(),
remaining_requests.clone(),
params.cfg.max_depth,
params.cfg.verbose,
pb,
pb.clone(),
formatter,
stats,
));

// Fill the request channel with the initial requests
send_inputs_loop(
params.requests,
send_req,
pb,
remaining_requests,
params.cfg.max_depth,
)
.await?;

// Wait until all responses are received
let result = show_results_task.await?;
let (pb, mut stats) = result?;
Expand Down Expand Up @@ -173,22 +175,38 @@ async fn send_inputs_loop<S>(
send_req: mpsc::Sender<Result<Request>>,
bar: Option<ProgressBar>,
remaining_requests: Arc<AtomicUsize>,
max_recursion_depth: Option<usize>,
) -> Result<()>
where
S: futures::Stream<Item = Result<Request>>,
{
tokio::pin!(requests);
println!("--- INITIAL REQUESTS ---");
let mut i = 0;
while let Some(request) = requests.next().await {
// println!("#{} starting request", i);
i += 1;
let request = request?;

if max_recursion_depth
.map(|limit| request.recursion_level > limit)
.unwrap_or(false)
{
continue;
}

if let Some(pb) = &bar {
pb.inc_length(1);
pb.set_message(request.to_string());
};
remaining_requests.fetch_add(1, Ordering::Relaxed);
let uri = request.uri.clone();
// println!("sending request to queue for {}", uri);
send_req
.send(Ok(request))
.await
.expect("Cannot send request");
// println!("sent request to queue for {}", uri);
}
Ok(())
}
Expand All @@ -198,19 +216,22 @@ async fn response_receive_task(
mut recv_resp: mpsc::Receiver<Response>,
req_send: mpsc::Sender<Result<Request>>,
remaining_requests: Arc<AtomicUsize>,
max_recursion_depth: Option<usize>,
verbose: Verbosity,
pb: Option<ProgressBar>,
formatter: Box<dyn ResponseFormatter>,
mut stats: ResponseStats,
) -> Result<(Option<ProgressBar>, ResponseStats)> {
// let mut i = 0;
let mut i = 0;
while let Some(response) = recv_resp.recv().await {
// i += 1;
// println!(
// "starting response #{} out of {}",
// i,
// remaining_requests.load(Ordering::Relaxed),
// );
// println!("#{} received response from queue for {}", i, response.1.uri);
println!("{:?}", max_recursion_depth);
i += 1;
show_progress(
&mut io::stderr(),
pb.as_ref(),
Expand All @@ -219,19 +240,37 @@ async fn response_receive_task(
&verbose,
)?;

for uri in &response.body().subsequent_uris {
let request = Request::try_from(uri.clone())?;
req_send
.send(Ok(request))
.await
.expect("Cannot send request");
remaining_requests.fetch_add(1, Ordering::Relaxed);

if let Some(bar) = &pb {
bar.inc_length(1);
bar.set_message(uri.clone().to_string());
}
println!("rec={:?}", response.1.recursion_level);

if max_recursion_depth
.map(|limit| response.1.recursion_level <= limit)
.unwrap_or(true)
{
tokio::spawn((|requests: Vec<Request>,
req_send: mpsc::Sender<Result<Request>>,
remaining_requests: Arc<AtomicUsize>,
pb: Option<ProgressBar>| async move {
for request in requests {
let uri = request.uri.clone().to_string();
req_send
.send(Ok(request))
.await
.expect("Cannot send request");
remaining_requests.fetch_add(1, Ordering::Relaxed);

if let Some(bar) = &pb {
bar.inc_length(1);
bar.set_message(uri);
}
}
})(
response.subsequent_requests(None),
req_send.clone(),
remaining_requests.clone(),
pb.clone(),
));
}

remaining_requests.fetch_sub(1, Ordering::Relaxed);
let remaining_now = remaining_requests.load(Ordering::Relaxed);
// println!("remaining requests: {}", remaining_now);
Expand Down Expand Up @@ -260,44 +299,47 @@ fn init_progress_bar(initial_message: &'static str) -> ProgressBar {
}

async fn request_channel_task(
recv_req: mpsc::Receiver<Result<Request>>,
mut recv_req: mpsc::Receiver<Result<Request>>,
send_resp: mpsc::Sender<Response>,
max_concurrency: usize,
client: Client,
cache: Arc<Cache>,
cache_exclude_status: HashSet<u16>,
accept: HashSet<u16>,
) {
StreamExt::for_each_concurrent(
ReceiverStream::new(recv_req),
max_concurrency,
|request: Result<Request>| async {
let request = request.expect("cannot read request");
// let uri = request.uri.clone();
// println!("received request for {}", uri);
let response = handle(
&client,
cache.clone(),
cache_exclude_status.clone(),
request,
accept.clone(),
)
.await;

send_resp
.send(response)
.await
.expect("Cannot send response");
// if let Err(_) = timeout(Duration::from_millis(500), send_resp.send(response)).await {
// println!(
// "Timeout occurred while sending response to queue for {}",
// uri
// );
// }
// println!("sent response to queue for {}", uri);
},
)
.await;
while let Some(request) = recv_req.recv().await {
// StreamExt::for_each_concurrent(
// ReceiverStream::new(recv_req),
// max_concurrency,
// |request: Result<Request>| async {

let request = request.expect("cannot read request");
// let uri = request.uri.clone();
// println!("handling request {}", uri);
// let uri = request.uri.clone();
// println!("received request for {}", uri);
let response = handle(
&client,
cache.clone(),
cache_exclude_status.clone(),
request,
accept.clone(),
)
.await;

// println!("sending response to queue for {}", uri);
send_resp
.send(response)
.await
.expect("Cannot send response");
// if let Err(_) = timeout(Duration::from_millis(500), send_resp.send(response)).await {
// println!(
// "Timeout occurred while sending response to queue for {}",
// uri
// );
// }
// println!("sent response to queue for {}", uri);
}
}

/// Check a URL and return a response.
Expand All @@ -316,6 +358,7 @@ async fn check_url(client: &Client, request: Request) -> Response {
Status::Error(ErrorKind::InvalidURI(uri.clone())),
source,
vec![],
0,
)
})
}
Expand Down Expand Up @@ -344,7 +387,7 @@ async fn handle(
};
// TODO: not too sure about it, we never recurse on cached requests
// println!("Found cached response for {}", uri);
return Response::new(uri.clone(), status, request.source, vec![]);
return Response::new(uri.clone(), status, request.source, vec![], 0);
}

// Request was not cached; run a normal check
Expand Down Expand Up @@ -454,6 +497,7 @@ mod tests {
Status::Cached(CacheStatus::Ok(200)),
InputSource::Stdin,
vec![],
0,
);
let formatter = get_response_formatter(&options::OutputMode::Plain);
show_progress(
Expand All @@ -477,6 +521,7 @@ mod tests {
Status::Cached(CacheStatus::Ok(200)),
InputSource::Stdin,
vec![],
0,
);
let formatter = get_response_formatter(&options::OutputMode::Plain);
show_progress(
Expand Down
1 change: 1 addition & 0 deletions lychee-bin/src/formatters/response/color.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ mod tests {
uri: Uri::try_from(uri).unwrap(),
status,
subsequent_uris: vec![],
recursion_level: 0,
}
}

Expand Down
1 change: 1 addition & 0 deletions lychee-bin/src/formatters/response/emoji.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mod emoji_tests {
uri: Uri::try_from(uri).unwrap(),
status,
subsequent_uris: vec![],
recursion_level: 0,
}
}

Expand Down
1 change: 1 addition & 0 deletions lychee-bin/src/formatters/response/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mod plain_tests {
uri: Uri::try_from(uri).unwrap(),
status,
subsequent_uris: vec![],
recursion_level: 0,
}
}

Expand Down
3 changes: 3 additions & 0 deletions lychee-bin/src/formatters/stats/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,22 @@ mod tests {
uri: Uri::from(Url::parse("https://example.com").unwrap()),
status: Status::Ok(StatusCode::OK),
subsequent_uris: vec![],
recursion_level: 0,
}]),
);

let err1 = ResponseBody {
uri: Uri::try_from("https://github.com/mre/idiomatic-rust-doesnt-exist-man").unwrap(),
status: Status::Ok(StatusCode::NOT_FOUND),
subsequent_uris: vec![],
recursion_level: 0,
};

let err2 = ResponseBody {
uri: Uri::try_from("https://github.com/mre/boom").unwrap(),
status: Status::Ok(StatusCode::INTERNAL_SERVER_ERROR),
subsequent_uris: vec![],
recursion_level: 0,
};

let mut error_map: HashMap<InputSource, HashSet<ResponseBody>> = HashMap::new();
Expand Down
2 changes: 2 additions & 0 deletions lychee-bin/src/formatters/stats/detailed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ mod tests {
uri: Uri::try_from("https://github.com/mre/idiomatic-rust-doesnt-exist-man").unwrap(),
status: Status::Ok(StatusCode::NOT_FOUND),
subsequent_uris: vec![],
recursion_level: 0,
};

let err2 = ResponseBody {
uri: Uri::try_from("https://github.com/mre/boom").unwrap(),
status: Status::Ok(StatusCode::INTERNAL_SERVER_ERROR),
subsequent_uris: vec![],
recursion_level: 0,
};

let mut error_map: HashMap<InputSource, HashSet<ResponseBody>> = HashMap::new();
Expand Down
4 changes: 4 additions & 0 deletions lychee-bin/src/formatters/stats/markdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ mod tests {
uri: Uri::try_from("http://example.com").unwrap(),
status: Status::Ok(StatusCode::OK),
subsequent_uris: vec![],
recursion_level: 0,
};
let markdown = markdown_response(&response).unwrap();
assert_eq!(
Expand All @@ -184,6 +185,7 @@ mod tests {
uri: Uri::try_from("http://example.com").unwrap(),
status: Status::Cached(CacheStatus::Ok(200)),
subsequent_uris: vec![],
recursion_level: 0,
};
let markdown = markdown_response(&response).unwrap();
assert_eq!(
Expand All @@ -198,6 +200,7 @@ mod tests {
uri: Uri::try_from("http://example.com").unwrap(),
status: Status::Cached(CacheStatus::Error(Some(400))),
subsequent_uris: vec![],
recursion_level: 0,
};
let markdown = markdown_response(&response).unwrap();
assert_eq!(
Expand Down Expand Up @@ -230,6 +233,7 @@ mod tests {
Status::Cached(CacheStatus::Error(Some(404))),
InputSource::Stdin,
vec![],
0,
);
stats.add(response);
stats
Expand Down
Loading

0 comments on commit a88f9fa

Please sign in to comment.