From e540d10ea83ed5fd39c4906d6280537c362c32c0 Mon Sep 17 00:00:00 2001 From: Zhang Jingqiang <127708320+zh-jq-b@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:11:56 +0800 Subject: [PATCH] g3proxy: block websocket upgrade request early (#348) --- g3proxy/src/inspect/http/v1/upgrade/mod.rs | 44 ++++++++++++++++--- .../src/inspect/http/v2/connect/extended.rs | 19 ++++++++ lib/g3-dpi/src/config/mod.rs | 7 +++ 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/g3proxy/src/inspect/http/v1/upgrade/mod.rs b/g3proxy/src/inspect/http/v1/upgrade/mod.rs index 26b6aca4b..fcb63acc2 100644 --- a/g3proxy/src/inspect/http/v1/upgrade/mod.rs +++ b/g3proxy/src/inspect/http/v1/upgrade/mod.rs @@ -149,17 +149,34 @@ where } } - pub(super) async fn forward_icap( + async fn check_blocked(&mut self, clt_w: &mut CW) -> ServerTaskResult<()> + where + CW: AsyncWrite + Unpin, + { + if self.ctx.websocket_inspect_policy().is_block() { + let rsp = HttpProxyClientResponse::forbidden(self.req.version); + self.should_close = true; + if rsp.reply_err_to_request(clt_w).await.is_ok() { + self.http_notes.rsp_status = rsp.status(); + } + Err(ServerTaskError::InternalAdapterError(anyhow!( + "websocket blocked by inspection policy" + ))) + } else { + Ok(()) + } + } + + pub(super) async fn forward_original( &mut self, rsp_io: &mut HttpResponseIo, - reqmod_client: &IcapReqmodClient, ) -> Option<(HttpUpgradeToken, UpstreamAddr)> where CW: AsyncWrite + Unpin, UR: AsyncRead + Unpin, UW: AsyncWrite + Unpin, { - match self.do_forward(rsp_io, reqmod_client).await { + match self.do_forward_original(rsp_io).await { Ok(v) => { intercept_log!(self, &v, "ok"); v @@ -174,16 +191,30 @@ where } } - pub(super) async fn forward_original( + pub(super) async fn do_forward_original( &mut self, rsp_io: &mut HttpResponseIo, + ) -> ServerTaskResult> + where + CW: AsyncWrite + Unpin, + UR: AsyncRead + Unpin, + UW: AsyncWrite + Unpin, + { + self.check_blocked(&mut rsp_io.clt_w).await?; + self.send_request(None, rsp_io).await + } + + pub(super) async fn forward_icap( + &mut self, + rsp_io: &mut HttpResponseIo, + reqmod_client: &IcapReqmodClient, ) -> Option<(HttpUpgradeToken, UpstreamAddr)> where CW: AsyncWrite + Unpin, UR: AsyncRead + Unpin, UW: AsyncWrite + Unpin, { - match self.send_request(None, rsp_io).await { + match self.do_forward_icap(rsp_io, reqmod_client).await { Ok(v) => { intercept_log!(self, &v, "ok"); v @@ -198,7 +229,7 @@ where } } - async fn do_forward( + async fn do_forward_icap( &mut self, rsp_io: &mut HttpResponseIo, reqmod_client: &IcapReqmodClient, @@ -208,6 +239,7 @@ where UR: AsyncRead + Unpin, UW: AsyncWrite + Unpin, { + self.check_blocked(&mut rsp_io.clt_w).await?; match reqmod_client .h1_adapter( self.ctx.server_config.limited_copy_config(), diff --git a/g3proxy/src/inspect/http/v2/connect/extended.rs b/g3proxy/src/inspect/http/v2/connect/extended.rs index fc7b7a9a6..0b904d952 100644 --- a/g3proxy/src/inspect/http/v2/connect/extended.rs +++ b/g3proxy/src/inspect/http/v2/connect/extended.rs @@ -109,6 +109,19 @@ where } } + fn reply_forbidden(&mut self, mut clt_send_rsp: SendResponse) { + if let Ok(rsp) = Response::builder() + .status(StatusCode::FORBIDDEN) + .version(Version::HTTP_2) + .body(()) + { + let rsp_status = rsp.status().as_u16(); + if clt_send_rsp.send_response(rsp, true).is_ok() { + self.http_notes.rsp_status = rsp_status; + } + } + } + pub(crate) async fn into_running( mut self, clt_req: Request, @@ -162,6 +175,12 @@ where } }; + if self.ctx.websocket_inspect_policy().is_block() { + self.reply_forbidden(clt_send_rsp); + intercept_log!(self, "websocket blocked by inspection policy"); + return; + } + let mut ws_notes = WebSocketNotes::new(clt_req.uri().clone()); for (name, value) in clt_req.headers() { ws_notes.append_request_header(name, value); diff --git a/lib/g3-dpi/src/config/mod.rs b/lib/g3-dpi/src/config/mod.rs index 9dca5f856..d87309990 100644 --- a/lib/g3-dpi/src/config/mod.rs +++ b/lib/g3-dpi/src/config/mod.rs @@ -40,6 +40,13 @@ pub enum ProtocolInspectPolicy { Block, } +impl ProtocolInspectPolicy { + #[inline] + pub fn is_block(&self) -> bool { + matches!(self, ProtocolInspectPolicy::Block) + } +} + impl FromStr for ProtocolInspectPolicy { type Err = ();