Skip to content

Commit

Permalink
refactor keep-alive timer
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Sep 28, 2018
1 parent e95babf commit 4aac3d6
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 95 deletions.
4 changes: 2 additions & 2 deletions src/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ type SslConnector = Arc<ClientConfig>;
feature = "alpn",
feature = "ssl",
feature = "tls",
feature = "rust-tls",
),))]
feature = "rust-tls"
)))]
type SslConnector = ();

use server::IoStream;
Expand Down
176 changes: 116 additions & 60 deletions src/server/h1.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use std::time::Instant;

use bytes::BytesMut;
use futures::{Async, Future, Poll};
Expand Down Expand Up @@ -49,7 +49,14 @@ pub(crate) struct Http1<T: IoStream, H: HttpHandler + 'static> {
payload: Option<PayloadType>,
buf: BytesMut,
tasks: VecDeque<Entry<H>>,
keepalive_timer: Option<Delay>,
ka_enabled: bool,
ka_expire: Instant,
ka_timer: Option<Delay>,
}

struct Entry<H: HttpHandler> {
pipe: EntryPipe<H>,
flags: EntryFlags,
}

enum EntryPipe<H: HttpHandler> {
Expand Down Expand Up @@ -78,11 +85,6 @@ impl<H: HttpHandler> EntryPipe<H> {
}
}

struct Entry<H: HttpHandler> {
pipe: EntryPipe<H>,
flags: EntryFlags,
}

impl<T, H> Http1<T, H>
where
T: IoStream,
Expand All @@ -92,6 +94,15 @@ where
settings: WorkerSettings<H>, stream: T, addr: Option<SocketAddr>, buf: BytesMut,
is_eof: bool, keepalive_timer: Option<Delay>,
) -> Self {
let ka_enabled = settings.keep_alive_enabled();
let (ka_expire, ka_timer) = if let Some(delay) = keepalive_timer {
(delay.deadline(), Some(delay))
} else if let Some(delay) = settings.keep_alive_timer() {
(delay.deadline(), Some(delay))
} else {
(settings.now(), None)
};

Http1 {
flags: if is_eof {
Flags::READ_DISCONNECTED
Expand All @@ -105,7 +116,9 @@ where
addr,
buf,
settings,
keepalive_timer,
ka_timer,
ka_expire,
ka_enabled,
}
}

Expand Down Expand Up @@ -143,9 +156,6 @@ where
for task in &mut self.tasks {
task.pipe.disconnected();
}

// kill keepalive
self.keepalive_timer.take();
}

fn read_disconnected(&mut self) {
Expand All @@ -163,16 +173,9 @@ where

#[inline]
pub fn poll(&mut self) -> Poll<(), ()> {
// keep-alive timer
if let Some(ref mut timer) = self.keepalive_timer {
match timer.poll() {
Ok(Async::Ready(_)) => {
trace!("Keep-alive timeout, close connection");
self.flags.insert(Flags::SHUTDOWN);
}
Ok(Async::NotReady) => (),
Err(_) => unreachable!(),
}
// check connection keep-alive
if !self.poll_keep_alive() {
return Ok(Async::Ready(()));
}

// shutdown
Expand Down Expand Up @@ -203,11 +206,70 @@ where
self.flags.insert(Flags::SHUTDOWN);
return self.poll();
}
Async::NotReady => return Ok(Async::NotReady),
Async::NotReady => {
// deal with keep-alive and steam eof (client-side write shutdown)
if self.tasks.is_empty() {
// handle stream eof
if self.flags.contains(Flags::READ_DISCONNECTED) {
self.flags.insert(Flags::SHUTDOWN);
return self.poll();
}
// no keep-alive
if self.flags.contains(Flags::ERROR)
|| (!self.flags.contains(Flags::KEEPALIVE)
|| !self.ka_enabled)
&& self.flags.contains(Flags::STARTED)
{
self.flags.insert(Flags::SHUTDOWN);
return self.poll();
}
}
return Ok(Async::NotReady);
}
}
}
}

/// keep-alive timer. returns `true` is keep-alive, otherwise drop
fn poll_keep_alive(&mut self) -> bool {
let timer = if let Some(ref mut timer) = self.ka_timer {
match timer.poll() {
Ok(Async::Ready(_)) => {
if timer.deadline() >= self.ka_expire {
// check for any outstanding request handling
if self.tasks.is_empty() {
// if we get timer during shutdown, just drop connection
if self.flags.contains(Flags::SHUTDOWN) {
return false;
} else {
trace!("Keep-alive timeout, close connection");
self.flags.insert(Flags::SHUTDOWN);
None
}
} else {
self.settings.keep_alive_timer()
}
} else {
Some(Delay::new(self.ka_expire))
}
}
Ok(Async::NotReady) => None,
Err(e) => {
error!("Timer error {:?}", e);
return false;
}
}
} else {
None
};

if let Some(mut timer) = timer {
let _ = timer.poll();
self.ka_timer = Some(timer);
}
true
}

#[inline]
/// read data from stream
pub fn poll_io(&mut self) {
Expand Down Expand Up @@ -283,6 +345,11 @@ where
}
// no more IO for this iteration
Ok(Async::NotReady) => {
// check if we need timer
if self.ka_timer.is_some() && self.stream.upgrade() {
self.ka_timer.take();
}

// check if previously read backpressure was enabled
if self.can_read() && !retry {
return Ok(Async::Ready(true));
Expand Down Expand Up @@ -348,32 +415,6 @@ where
}
}

// deal with keep-alive and steam eof (client-side write shutdown)
if self.tasks.is_empty() {
// handle stream eof
if self.flags.contains(Flags::READ_DISCONNECTED) {
return Ok(Async::Ready(false));
}
// no keep-alive
if self.flags.contains(Flags::ERROR)
|| (!self.flags.contains(Flags::KEEPALIVE)
|| !self.settings.keep_alive_enabled())
&& self.flags.contains(Flags::STARTED)
{
return Ok(Async::Ready(false));
}

// start keep-alive timer
let keep_alive = self.settings.keep_alive();
if self.keepalive_timer.is_none() && keep_alive > 0 {
trace!("Start keep-alive timer");
let mut timer =
Delay::new(Instant::now() + Duration::from_secs(keep_alive));
// register timer
let _ = timer.poll();
self.keepalive_timer = Some(timer);
}
}
Ok(Async::NotReady)
}

Expand All @@ -385,9 +426,12 @@ where
}

pub fn parse(&mut self) {
let mut updated = false;

'outer: loop {
match self.decoder.decode(&mut self.buf, &self.settings) {
Ok(Some(Message::Message { mut msg, payload })) => {
updated = true;
self.flags.insert(Flags::STARTED);

if payload {
Expand All @@ -403,9 +447,6 @@ where
// set remote addr
msg.inner_mut().addr = self.addr;

// stop keepalive timer
self.keepalive_timer.take();

// search handler for request
match self.settings.handler().handle(msg) {
Ok(mut pipe) => {
Expand All @@ -430,7 +471,7 @@ where
}
continue 'outer;
}
Ok(Async::NotReady) => {}
Ok(Async::NotReady) => (),
Err(err) => {
error!("Unhandled error: {}", err);
self.flags.insert(Flags::ERROR);
Expand Down Expand Up @@ -460,6 +501,7 @@ where
self.push_response_entry(StatusCode::NOT_FOUND);
}
Ok(Some(Message::Chunk(chunk))) => {
updated = true;
if let Some(ref mut payload) = self.payload {
payload.feed_data(chunk);
} else {
Expand All @@ -470,6 +512,7 @@ where
}
}
Ok(Some(Message::Eof)) => {
updated = true;
if let Some(mut payload) = self.payload.take() {
payload.feed_eof();
} else {
Expand All @@ -489,6 +532,7 @@ where
break;
}
Err(e) => {
updated = false;
self.flags.insert(Flags::ERROR);
if let Some(mut payload) = self.payload.take() {
let e = match e {
Expand All @@ -504,6 +548,12 @@ where
}
}
}

if self.ka_timer.is_some() && updated {
if let Some(expire) = self.settings.keep_alive_expire() {
self.ka_expire = expire;
}
}
}
}

Expand All @@ -512,7 +562,9 @@ mod tests {
use std::net::Shutdown;
use std::{cmp, io, time};

use actix::System;
use bytes::{Buf, Bytes, BytesMut};
use futures::future;
use http::{Method, Version};
use tokio_io::{AsyncRead, AsyncWrite};

Expand Down Expand Up @@ -647,15 +699,19 @@ mod tests {

#[test]
fn test_req_parse_err() {
let buf = Buffer::new("GET /test HTTP/1\r\n\r\n");
let readbuf = BytesMut::new();
let settings = wrk_settings();
let mut sys = System::new("test");
sys.block_on(future::lazy(|| {
let buf = Buffer::new("GET /test HTTP/1\r\n\r\n");
let readbuf = BytesMut::new();
let settings = wrk_settings();

let mut h1 = Http1::new(settings.clone(), buf, None, readbuf, false, None);
h1.poll_io();
h1.poll_io();
assert!(h1.flags.contains(Flags::ERROR));
assert_eq!(h1.tasks.len(), 1);
let mut h1 = Http1::new(settings.clone(), buf, None, readbuf, false, None);
h1.poll_io();
h1.poll_io();
assert!(h1.flags.contains(Flags::ERROR));
assert_eq!(h1.tasks.len(), 1);
future::ok::<_, ()>(())
}));
}

#[test]
Expand Down
4 changes: 4 additions & 0 deletions src/server/h1writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
self.flags.insert(Flags::DISCONNECTED);
}

pub fn upgrade(&self) -> bool {
self.flags.contains(Flags::UPGRADE)
}

pub fn keepalive(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE)
}
Expand Down
21 changes: 10 additions & 11 deletions src/server/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::VecDeque;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::rc::Rc;
use std::time::{Duration, Instant};
use std::time::Instant;
use std::{cmp, io, mem};

use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -232,16 +232,15 @@ where
// start keep-alive timer
if self.tasks.is_empty() {
if self.settings.keep_alive_enabled() {
let keep_alive = self.settings.keep_alive();
if keep_alive > 0 && self.keepalive_timer.is_none() {
trace!("Start keep-alive timer");
let mut timeout = Delay::new(
Instant::now()
+ Duration::new(keep_alive, 0),
);
// register timeout
let _ = timeout.poll();
self.keepalive_timer = Some(timeout);
if self.keepalive_timer.is_none() {
if let Some(ka) = self.settings.keep_alive() {
trace!("Start keep-alive timer");
let mut timeout =
Delay::new(Instant::now() + ka);
// register timeout
let _ = timeout.poll();
self.keepalive_timer = Some(timeout);
}
}
} else {
// keep-alive disable, drop connection
Expand Down
Loading

0 comments on commit 4aac3d6

Please sign in to comment.