Skip to content

Commit

Permalink
improve inspect policy code based upon zh-jq's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Oct 10, 2024
1 parent eb0b21b commit 78b6975
Show file tree
Hide file tree
Showing 25 changed files with 146 additions and 151 deletions.
8 changes: 4 additions & 4 deletions g3proxy/src/auth/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use governor::{clock::DefaultClock, state::InMemoryState, state::NotKeyed, RateL
use tokio::time::Instant;

use g3_io_ext::{GlobalDatagramLimiter, GlobalLimitGroup, GlobalStreamLimiter};
use g3_types::acl::{AclAction, AclNetworkRule, ActionContract};
use g3_types::acl::{AclAction, AclNetworkRule};
use g3_types::acl_set::AclDstHostRuleSet;
use g3_types::auth::UserAuthError;
use g3_types::limit::{GaugeSemaphore, GaugeSemaphorePermit};
Expand Down Expand Up @@ -605,7 +605,7 @@ impl User {
forbid_stats.add_dest_denied();
return action;
};
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}

if let Some(filter) = &self.dst_host_filter {
Expand All @@ -614,7 +614,7 @@ impl User {
forbid_stats.add_dest_denied();
return action;
}
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}

if default_action.forbid_early() {
Expand All @@ -636,7 +636,7 @@ impl User {
forbid_stats.add_ua_blocked();
return Some(action);
}
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}
}
Some(default_action)
Expand Down
7 changes: 2 additions & 5 deletions g3proxy/src/inspect/http/v1/upgrade/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,8 @@ where
CW: AsyncWrite + Unpin,
{
let policy_action = match self.req.host.as_ref() {
Some(upstream) => {
let (_, policy_action) = self.ctx.websocket_inspect_policy().check(upstream.host());
policy_action
}
None => self.ctx.websocket_inspect_policy().missing_action(),
Some(upstream) => self.ctx.websocket_inspect_action(upstream.host()),
None => self.ctx.websocket_inspect_missing_action(),
};
let block_websocket = policy_action == ProtocolInspectAction::Block;

Expand Down
7 changes: 2 additions & 5 deletions g3proxy/src/inspect/http/v2/connect/extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,8 @@ where
};

let policy_action = match self.upstream.as_ref() {
Some(upstream) => {
let (_, policy_action) = self.ctx.websocket_inspect_policy().check(upstream.host());
policy_action
}
None => self.ctx.websocket_inspect_policy().missing_action(),
Some(upstream) => self.ctx.websocket_inspect_action(upstream.host()),
None => self.ctx.websocket_inspect_missing_action(),
};
if policy_action == ProtocolInspectAction::Block {
self.reply_forbidden(clt_send_rsp);
Expand Down
3 changes: 1 addition & 2 deletions g3proxy/src/inspect/http/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ where
SC: ServerConfig + Send + Sync + 'static,
{
pub(crate) async fn intercept(mut self) -> ServerTaskResult<()> {
let (_, inspect_action) = self.ctx.h2_inspect_policy().check(self.upstream.host());
let r = match inspect_action {
let r = match self.ctx.h2_inspect_action(self.upstream.host()) {
ProtocolInspectAction::Intercept => self
.do_intercept()
.await
Expand Down
3 changes: 1 addition & 2 deletions g3proxy/src/inspect/imap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ where
}

pub(crate) async fn intercept(mut self) -> ServerTaskResult<Option<StreamInspection<SC>>> {
let (_, inspect_action) = self.ctx.imap_inspect_policy().check(self.upstream.host());
let r = match inspect_action {
let r = match self.ctx.imap_inspect_action(self.upstream.host()) {
ProtocolInspectAction::Intercept => self.do_intercept().await,
#[cfg(feature = "quic")]
ProtocolInspectAction::Detour => self.do_detour().await.map(|_| None),
Expand Down
57 changes: 47 additions & 10 deletions g3proxy/src/inspect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ use uuid::Uuid;
use g3_daemon::server::ServerQuitPolicy;
use g3_dpi::{
H1InterceptionConfig, H2InterceptionConfig, ImapInterceptionConfig, MaybeProtocol,
ProtocolInspectPolicy, ProtocolInspector, SmtpInterceptionConfig,
ProtocolInspectAction, ProtocolInspector, SmtpInterceptionConfig,
};
use g3_types::net::OpensslClientConfig;
use g3_types::net::{Host, OpensslClientConfig};

use crate::audit::AuditHandle;
use crate::auth::{User, UserForbiddenStats, UserSite};
Expand Down Expand Up @@ -263,8 +263,17 @@ impl<SC: ServerConfig> StreamInspectContext<SC> {
}

#[inline]
fn h2_inspect_policy(&self) -> &ProtocolInspectPolicy {
self.audit_handle.h2_inspect_policy()
fn h2_inspect_action(&self, host: &Host) -> ProtocolInspectAction {
match self.audit_handle.h2_inspect_policy().check(host) {
(true, policy_action) => policy_action,
(false, missing_policy_action) => missing_policy_action,
}
}

#[inline]
#[allow(dead_code)]
fn h2_inspect_missing_action(&self) -> ProtocolInspectAction {
self.audit_handle.h2_inspect_policy().missing_action()
}

#[inline]
Expand All @@ -281,13 +290,32 @@ impl<SC: ServerConfig> StreamInspectContext<SC> {
}

#[inline]
fn websocket_inspect_policy(&self) -> &ProtocolInspectPolicy {
self.audit_handle.websocket_inspect_policy()
fn websocket_inspect_action(&self, host: &Host) -> ProtocolInspectAction {
match self.audit_handle.websocket_inspect_policy().check(host) {
(true, policy_action) => policy_action,
(false, missing_policy_action) => missing_policy_action,
}
}

#[inline]
fn websocket_inspect_missing_action(&self) -> ProtocolInspectAction {
self.audit_handle
.websocket_inspect_policy()
.missing_action()
}

#[inline]
fn smtp_inspect_policy(&self) -> &ProtocolInspectPolicy {
self.audit_handle.smtp_inspect_policy()
fn smtp_inspect_action(&self, host: &Host) -> ProtocolInspectAction {
match self.audit_handle.smtp_inspect_policy().check(host) {
(true, policy_action) => policy_action,
(false, missing_policy_action) => missing_policy_action,
}
}

#[inline]
#[allow(dead_code)]
fn smtp_inspect_missing_action(&self) -> ProtocolInspectAction {
self.audit_handle.smtp_inspect_policy().missing_action()
}

#[inline]
Expand All @@ -296,8 +324,17 @@ impl<SC: ServerConfig> StreamInspectContext<SC> {
}

#[inline]
fn imap_inspect_policy(&self) -> &ProtocolInspectPolicy {
self.audit_handle.imap_inspect_policy()
fn imap_inspect_action(&self, host: &Host) -> ProtocolInspectAction {
match self.audit_handle.imap_inspect_policy().check(host) {
(true, policy_action) => policy_action,
(false, missing_policy_action) => missing_policy_action,
}
}

#[inline]
#[allow(dead_code)]
fn imap_inspect_missing_action(&self) -> ProtocolInspectAction {
self.audit_handle.imap_inspect_policy().missing_action()
}

#[inline]
Expand Down
3 changes: 1 addition & 2 deletions g3proxy/src/inspect/smtp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ where
}

pub(crate) async fn intercept(mut self) -> ServerTaskResult<Option<StreamInspection<SC>>> {
let (_, inspect_action) = self.ctx.smtp_inspect_policy().check(self.upstream.host());
let r = match inspect_action {
let r = match self.ctx.smtp_inspect_action(self.upstream.host()) {
ProtocolInspectAction::Intercept => self.do_intercept().await,
#[cfg(feature = "quic")]
ProtocolInspectAction::Detour => self.do_detour().await.map(|_| None),
Expand Down
12 changes: 6 additions & 6 deletions g3proxy/src/inspect/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ impl<SC: ServerConfig> TlsInterceptObject<SC> {

fn retain_alpn_protocol(&self, p: &[u8]) -> bool {
if p == AlpnProtocol::Http2.identification_sequence() {
let (_, inspect_policy) = self.ctx.h2_inspect_policy().check(self.upstream.host());
return inspect_policy != ProtocolInspectAction::Block;
return ProtocolInspectAction::Block
!= self.ctx.h2_inspect_action(self.upstream.host());
} else if p == AlpnProtocol::Smtp.identification_sequence() {
let (_, inspect_policy) = self.ctx.smtp_inspect_policy().check(self.upstream.host());
return inspect_policy != ProtocolInspectAction::Block;
return ProtocolInspectAction::Block
!= self.ctx.smtp_inspect_action(self.upstream.host());
} else if p == AlpnProtocol::Imap.identification_sequence() {
let (_, inspect_policy) = self.ctx.imap_inspect_policy().check(self.upstream.host());
return inspect_policy != ProtocolInspectAction::Block;
return ProtocolInspectAction::Block
!= self.ctx.imap_inspect_action(self.upstream.host());
}
true
}
Expand Down
6 changes: 1 addition & 5 deletions g3proxy/src/inspect/websocket/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ impl<SC: ServerConfig> H1WebsocketInterceptObject<SC> {
}

pub(crate) async fn intercept(mut self) -> ServerTaskResult<()> {
let (_, inspect_action) = self
.ctx
.websocket_inspect_policy()
.check(self.upstream.host());
let r = match inspect_action {
let r = match self.ctx.websocket_inspect_action(self.upstream.host()) {
ProtocolInspectAction::Intercept => self.do_intercept().await,
#[cfg(feature = "quic")]
ProtocolInspectAction::Detour => self.do_detour().await,
Expand Down
6 changes: 1 addition & 5 deletions g3proxy/src/inspect/websocket/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ impl<SC: ServerConfig> H2WebsocketInterceptObject<SC> {
ups_r: RecvStream,
ups_w: SendStream<Bytes>,
) {
let (_, inspect_action) = self
.ctx
.websocket_inspect_policy()
.check(self.upstream.host());
let r = match inspect_action {
let r = match self.ctx.websocket_inspect_action(self.upstream.host()) {
ProtocolInspectAction::Intercept => self.do_intercept(clt_r, clt_w, ups_r, ups_w).await,
#[cfg(feature = "quic")]
ProtocolInspectAction::Detour => self.do_detour(clt_r, clt_w, ups_r, ups_w).await,
Expand Down
6 changes: 3 additions & 3 deletions g3proxy/src/serve/http_proxy/task/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use slog::Logger;

use g3_daemon::server::ClientConnectionInfo;
use g3_icap_client::reqmod::h1::HttpAdapterErrorResponse;
use g3_types::acl::{AclAction, ActionContract};
use g3_types::acl::AclAction;
use g3_types::acl_set::AclDstHostRuleSet;
use g3_types::net::{OpensslClientConfig, UpstreamAddr};

Expand Down Expand Up @@ -73,15 +73,15 @@ impl CommonTaskContext {
if found && action.forbid_early() {
return action;
};
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}

if let Some(filter) = &self.dst_host_filter {
let (found, action) = filter.check(upstream.host());
if found && action.forbid_early() {
return action;
}
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}

default_action
Expand Down
6 changes: 3 additions & 3 deletions g3proxy/src/serve/socks_proxy/task/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use slog::Logger;
use tokio::net::UdpSocket;

use g3_daemon::server::ClientConnectionInfo;
use g3_types::acl::{AclAction, AclNetworkRule, ActionContract};
use g3_types::acl::{AclAction, AclNetworkRule};
use g3_types::acl_set::AclDstHostRuleSet;
use g3_types::net::UpstreamAddr;

Expand Down Expand Up @@ -70,15 +70,15 @@ impl CommonTaskContext {
if found && action.forbid_early() {
return action;
};
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}

if let Some(filter) = &self.dst_host_filter {
let (found, action) = filter.check(upstream.host());
if found && action.forbid_early() {
return action;
}
default_action = default_action.restrict(&action);
default_action = default_action.restrict(action);
}

default_action
Expand Down
34 changes: 8 additions & 26 deletions lib/g3-dpi/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
* limitations under the License.
*/

use std::fmt;
use std::str::FromStr;
use std::time::Duration;
use std::{fmt, str::FromStr};

use g3_types::acl::ActionContract;
use g3_types::acl_set::AclDstHostRuleSet;

mod size_limit;

use g3_types::acl::ActionContract;
pub use size_limit::ProtocolInspectionSizeLimit;

mod http;
Expand All @@ -31,9 +34,9 @@ pub use smtp::SmtpInterceptionConfig;
mod imap;
pub use imap::ImapInterceptionConfig;

pub type ProtocolInspectPolicy = g3_types::acl_set::AclDstHostRuleSet<ProtocolInspectAction>;
pub type ProtocolInspectPolicy = AclDstHostRuleSet<ProtocolInspectAction>;

#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Hash)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub enum ProtocolInspectAction {
Intercept,
#[cfg(feature = "quic")]
Expand Down Expand Up @@ -64,7 +67,7 @@ impl FromStr for ProtocolInspectAction {
}
}

impl g3_types::acl::ActionContract for ProtocolInspectAction {
impl ActionContract for ProtocolInspectAction {
fn default_forbid() -> Self {
Self::Block
}
Expand All @@ -73,27 +76,6 @@ impl g3_types::acl::ActionContract for ProtocolInspectAction {
Self::Intercept
}

fn restrict(&self, other: &ProtocolInspectAction) -> ProtocolInspectAction {
if other > self {
*other
} else {
*self
}
}

fn strict_than(&self, other: &ProtocolInspectAction) -> bool {
self.gt(other)
}

fn forbid_early(&self) -> bool {
match self {
Self::Block => true,
Self::Intercept | Self::Bypass => false,
#[cfg(feature = "quic")]
Self::Detour => false,
}
}

fn serialize(&self) -> &'static str {
match self {
Self::Intercept => "intercept",
Expand Down
4 changes: 2 additions & 2 deletions lib/g3-types/src/acl/a_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ where
Q: Hash + Eq + ?Sized,
{
if let Some(action) = self.inner.get(node) {
(true, action.clone())
(true, *action)
} else {
(false, self.missed_action.clone())
(false, self.missed_action)
}
}
}
10 changes: 5 additions & 5 deletions lib/g3-types/src/acl/exact_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ impl<Action: ActionContract> AclExactHostRule<Action> {
#[inline]
pub fn new(missed_action: Action) -> Self {
AclExactHostRule {
missed_action: missed_action.clone(),
domain: AclAHashRule::new(missed_action.clone()),
missed_action,
domain: AclAHashRule::new(missed_action),
ip: AclAHashRule::new(missed_action),
}
}
Expand All @@ -56,14 +56,14 @@ impl<Action: ActionContract> AclExactHostRule<Action> {

#[inline]
pub fn set_missed_action(&mut self, action: Action) {
self.missed_action = action.clone();
self.domain.set_missed_action(action.clone());
self.missed_action = action;
self.domain.set_missed_action(action);
self.ip.set_missed_action(action);
}

#[inline]
pub fn missed_action(&self) -> Action {
self.missed_action.clone()
self.missed_action
}

#[inline]
Expand Down
Loading

0 comments on commit 78b6975

Please sign in to comment.