diff --git a/crates/dekaf/src/main.rs b/crates/dekaf/src/main.rs index e204baaec8..20fac234ce 100644 --- a/crates/dekaf/src/main.rs +++ b/crates/dekaf/src/main.rs @@ -17,7 +17,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; -use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use url::Url; @@ -261,41 +261,43 @@ where S: AsyncRead + AsyncWrite + Unpin, { tracing::info!("accepted client connection"); + + let (r, mut w) = tokio::io::split(socket); + let mut r = tokio_util::codec::FramedRead::new( + r, + tokio_util::codec::LengthDelimitedCodec::builder() + .big_endian() + .length_field_length(4) + .max_frame_length(1 << 27) // 128 MiB + .new_codec(), + ); + let mut out = bytes::BytesMut::new(); + let mut raw_sasl_auth = false; + metrics::gauge!("dekaf_total_connections").increment(1); + let result = async { - let (r, mut w) = split(socket); - - let mut r = tokio_util::codec::FramedRead::new( - r, - tokio_util::codec::LengthDelimitedCodec::builder() - .big_endian() - .length_field_length(4) - .max_frame_length(1 << 27) // 128 MiB - .new_codec(), - ); - - let mut out = bytes::BytesMut::new(); - let mut raw_sasl_auth = false; - let mut res = Ok(()); - while let Some(frame) = r.try_next().await? { - if let err @ Err(_) = - dekaf::dispatch_request_frame(&mut session, &mut raw_sasl_auth, frame, &mut out) - .await - { - // Close the connection on error - w.shutdown().await?; - res = err; - } + loop { + let Some(frame) = tokio::time::timeout(SESSION_TIMEOUT, r.try_next()) + .await + .context("timeout waiting for next session request")? + .context("failed to read next session request")? + else { + return Ok(()); + }; + + dekaf::dispatch_request_frame(&mut session, &mut raw_sasl_auth, frame, &mut out) + .await?; + () = w.write_all(&mut out).await?; out.clear(); } - - res } .await; metrics::gauge!("dekaf_total_connections").decrement(1); + w.shutdown().await?; result } @@ -334,3 +336,5 @@ fn validate_certificate_name( } return Ok(false); } + +const SESSION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);