Skip to content

Commit

Permalink
chore(asm): update api sec sampling mechanism (#8459)
Browse files Browse the repository at this point in the history
Implementation of the new sampling mechanism for API Security described
in RFC `Api Security Sampling Algorithm RFC`

- Based on the APM head based sampling mechanism 
- Only compute API Security schema once every 30s for each endpoints
- Do not use api security sampling rate anymore
- add unit tests for the new mechanism
- add the ability to change span priority for testapp used for unit
tests


## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.
- [x] If change touches code that signs or publishes builds or packages,
or handles credentials of any kind, I've requested a review from
`@DataDog/security-design-and-guidance`.

## Reviewer Checklist

- [ ] Title is accurate
- [ ] All changes are related to the pull request's stated goal
- [ ] Description motivates each change
- [ ] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [ ] Testing strategy adequately addresses listed risks
- [ ] Change is maintainable (easy to change, telemetry, documentation)
- [ ] Release note makes sense to a user of the library
- [ ] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [ ] Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
christophe-papazian authored Mar 1, 2024
1 parent 6bc3fd9 commit 8f58003
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 43 deletions.
86 changes: 54 additions & 32 deletions ddtrace/appsec/_api_security/api_manager.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
import base64
import collections
import gzip
import json
import sys
from typing import TYPE_CHECKING # noqa:F401
import time
from typing import Optional

from ddtrace import constants
from ddtrace._trace._limits import MAX_SPAN_META_VALUE_LEN
from ddtrace.appsec import _processor as appsec_processor
from ddtrace.appsec._asm_request_context import add_context_callback
from ddtrace.appsec._asm_request_context import call_waf_callback
from ddtrace.appsec._asm_request_context import remove_context_callback
from ddtrace.appsec._constants import API_SECURITY
from ddtrace.appsec._constants import SPAN_DATA_NAMES
import ddtrace.constants as constants
from ddtrace.internal.logger import get_logger
from ddtrace.internal.metrics import Metrics
from ddtrace.internal.service import Service
from ddtrace.settings.asm import config as asm_config


if TYPE_CHECKING:
from typing import Optional # noqa:F401


log = get_logger(__name__)
metrics = Metrics(namespace="datadog.api_security")
_sentinel = object()

# Delay in seconds to avoid sampling the same route too often
DELAY = 30.0

# Max number of endpoint hashes to keep in the hashtable
MAX_HASHTABLE_SIZE = 4096

M_INFINITY = float("-inf")


class TooLargeSchemaException(Exception):
pass
Expand All @@ -42,13 +47,10 @@ class APIManager(Service):
("RESPONSE_BODY", API_SECURITY.RESPONSE_BODY, lambda f: f()),
]

_instance = None # type: Optional[APIManager]

SAMPLE_START_VALUE = 1.0 - sys.float_info.epsilon
_instance: Optional["APIManager"] = None

@classmethod
def enable(cls):
# type: () -> None
def enable(cls) -> None:
if cls._instance is not None:
log.debug("%s already enabled", cls.__name__)
return
Expand All @@ -60,8 +62,7 @@ def enable(cls):
log.debug("%s enabled", cls.__name__)

@classmethod
def disable(cls):
# type: () -> None
def disable(cls) -> None:
if cls._instance is None:
log.debug("%s not enabled", cls.__name__)
return
Expand All @@ -72,38 +73,49 @@ def disable(cls):
metrics.disable()
log.debug("%s disabled", cls.__name__)

def __init__(self):
# type: () -> None
def __init__(self) -> None:
super(APIManager, self).__init__()

self.current_sampling_value = self.SAMPLE_START_VALUE
self._schema_meter = metrics.get_meter("schema")
log.debug("%s initialized", self.__class__.__name__)
self._hashtable: collections.OrderedDict[int, float] = collections.OrderedDict()

def _stop_service(self):
# type: () -> None
def _stop_service(self) -> None:
remove_context_callback(self._schema_callback, global_callback=True)
self._hashtable.clear()

def _start_service(self):
# type: () -> None
def _start_service(self) -> None:
add_context_callback(self._schema_callback, global_callback=True)

def _should_collect_schema(self, env, priority):
sample_rate = asm_config._api_security_sample_rate
def _should_collect_schema(self, env, priority: int) -> bool:
# Rate limit per route
self.current_sampling_value += sample_rate
if priority <= 0:
return False

method = env.waf_addresses.get(SPAN_DATA_NAMES.REQUEST_METHOD)
route = env.waf_addresses.get(SPAN_DATA_NAMES.REQUEST_ROUTE)
status = env.waf_addresses.get(SPAN_DATA_NAMES.RESPONSE_STATUS)
# Framework is not fully supported
if not method or not route:
log.debug("unsupported groupkey for api security [method %s] [route %s]", bool(method), bool(route))
if method is None or route is None or status is None:
log.debug(
"unsupported groupkey for api security [method %s] [route %s] [status %s]",
bool(method),
bool(route),
bool(status),
)
return False
# Keep most of manual keep spans and auto keep spans. Other spans are not considered.
if self.current_sampling_value >= 1.0 and (priority == constants.USER_KEEP or priority == constants.AUTO_KEEP):
self.current_sampling_value -= 1.0
return True
return False
end_point_hash = hash((route, method, status))
current_time = time.monotonic()
previous_time = self._hashtable.get(end_point_hash, M_INFINITY)
if previous_time >= current_time - DELAY:
return False
if previous_time is M_INFINITY:
if len(self._hashtable) >= MAX_HASHTABLE_SIZE:
self._hashtable.popitem(last=False)
else:
self._hashtable.move_to_end(end_point_hash)
self._hashtable[end_point_hash] = current_time
return True

def _schema_callback(self, env):
from ddtrace.appsec._utils import _appsec_apisec_features_is_active
Expand All @@ -115,10 +127,20 @@ def _schema_callback(self, env):
return

try:
if not self._should_collect_schema(env, root.context.sampling_priority):
# check both current span and root span for sampling priority
# if any of them is set to USER_KEEP or USER_REJECT, we should respect it
priorities = (root.context.sampling_priority or 0, env.span.context.sampling_priority or 0)
if constants.USER_KEEP in priorities:
priority = constants.USER_KEEP
elif constants.USER_REJECT in priorities:
priority = constants.USER_REJECT
else:
priority = max(priorities)
if not self._should_collect_schema(env, priority):
return
except Exception:
log.warning("Failed to sample request for schema generation", exc_info=True)
return

# we need the request content type on the span
try:
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/appsec/_remoteconfiguration.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def _add_rules_to_list(features: Mapping[str, Any], feature: str, message: str,
def _appsec_callback(features: Mapping[str, Any], test_tracer: Optional[Tracer] = None) -> None:
config = features.get("config", {})
_appsec_1click_activation(config, test_tracer)
_appsec_api_security_settings(config, test_tracer)
_appsec_rules_data(config, test_tracer)


Expand Down Expand Up @@ -235,6 +234,7 @@ def _appsec_1click_activation(features: Mapping[str, Any], test_tracer: Optional

def _appsec_api_security_settings(features: Mapping[str, Any], test_tracer: Optional[Tracer] = None) -> None:
"""
Deprecated
Update API Security settings from remote config
Actually: Update sample rate
"""
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/appsec/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _appsec_rc_features_is_enabled() -> bool:


def _appsec_apisec_features_is_active() -> bool:
return asm_config._asm_enabled and asm_config._api_security_enabled and asm_config._api_security_sample_rate > 0.0
return asm_config._asm_enabled and asm_config._api_security_enabled


def _safe_userid(user_id):
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/settings/asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ASMConfig(Env):
_user_model_email_field = Env.var(str, APPSEC.USER_MODEL_EMAIL_FIELD, default="")
_user_model_name_field = Env.var(str, APPSEC.USER_MODEL_NAME_FIELD, default="")
_api_security_enabled = Env.var(bool, API_SECURITY.ENV_VAR_ENABLED, default=True)
_api_security_sample_rate = Env.var(float, API_SECURITY.SAMPLE_RATE, validator=_validate_sample_rate, default=0.1)
_api_security_sample_rate = 0.0
_api_security_parse_response_body = Env.var(bool, API_SECURITY.PARSE_RESPONSE_BODY, default=True)
_asm_libddwaf = build_libddwaf_filename()
_asm_libddwaf_available = os.path.exists(_asm_libddwaf)
Expand Down
6 changes: 6 additions & 0 deletions tests/appsec/contrib_appsec/django_app/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.views.decorators.csrf import csrf_exempt

from ddtrace import tracer
import ddtrace.constants


# django.conf.urls.url was deprecated in django 3 and removed in django 4
Expand Down Expand Up @@ -40,6 +41,11 @@ def multi_view(request, param_int=0, param_str=""):
}
status = int(query_params.get("status", "200"))
headers_query = query_params.get("headers", "").split(",")
priority = query_params.get("priority", None)
if priority in ("keep", "drop"):
tracer.current_span().set_tag(
ddtrace.constants.MANUAL_KEEP_KEY if priority == "keep" else ddtrace.constants.MANUAL_DROP_KEY
)
response_headers = {}
for header in headers_query:
vk = header.split("=")
Expand Down
21 changes: 20 additions & 1 deletion tests/appsec/contrib_appsec/fastapi_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from ddtrace import tracer
import ddtrace.constants


fake_secret_token = "DataDog"

Expand Down Expand Up @@ -51,6 +54,11 @@ async def multi_view(param_int: int, param_str: str, request: Request): # noqa:
}
status = int(query_params.get("status", "200"))
headers_query = query_params.get("headers", "").split(",")
priority = query_params.get("priority", None)
if priority in ("keep", "drop"):
tracer.current_span().set_tag(
ddtrace.constants.MANUAL_KEEP_KEY if priority == "keep" else ddtrace.constants.MANUAL_DROP_KEY
)
response_headers = {}
for header in headers_query:
vk = header.split("=")
Expand All @@ -71,7 +79,18 @@ async def multi_view_no_param(request: Request): # noqa: B008
"method": request.method,
}
status = int(query_params.get("status", "200"))
return JSONResponse(body, status_code=status)
headers_query = query_params.get("headers", "").split(",")
priority = query_params.get("priority", None)
if priority in ("keep", "drop"):
tracer.current_span().set_tag(
ddtrace.constants.MANUAL_KEEP_KEY if priority == "keep" else ddtrace.constants.MANUAL_DROP_KEY
)
response_headers = {}
for header in headers_query:
vk = header.split("=")
if len(vk) == 2:
response_headers[vk[0]] = vk[1]
return JSONResponse(body, status_code=status, headers=response_headers)

@app.get("/new_service/{service_name:str}/")
@app.post("/new_service/{service_name:str}/")
Expand Down
6 changes: 6 additions & 0 deletions tests/appsec/contrib_appsec/flask_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flask import request

from ddtrace import tracer
import ddtrace.constants
from tests.webclient import PingFilter


Expand Down Expand Up @@ -36,6 +37,11 @@ def multi_view(param_int=0, param_str=""):
}
status = int(query_params.get("status", "200"))
headers_query = query_params.get("headers", "").split(",")
priority = query_params.get("priority", None)
if priority in ("keep", "drop"):
tracer.current_span().set_tag(
ddtrace.constants.MANUAL_KEEP_KEY if priority == "keep" else ddtrace.constants.MANUAL_DROP_KEY
)
response_headers = {}
for header in headers_query:
vk = header.split("=")
Expand Down
45 changes: 38 additions & 7 deletions tests/appsec/contrib_appsec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,6 @@ def test_nested_appsec_events(
({"User-Agent": "AllOK"}, False, False),
],
)
@pytest.mark.parametrize("sample_rate", [0.0, 1.0])
def test_api_security_schemas(
self,
interface: Interface,
Expand All @@ -897,16 +896,13 @@ def test_api_security_schemas(
headers,
event,
blocked,
sample_rate,
):
import base64
import gzip

from ddtrace.ext import http

with override_global_config(
dict(_asm_enabled=True, _api_security_enabled=apisec_enabled, _api_security_sample_rate=sample_rate)
):
with override_global_config(dict(_asm_enabled=True, _api_security_enabled=apisec_enabled)):
self.update_tracer(interface)
response = interface.client.post(
"/asm/324/huj/?x=1&y=2",
Expand All @@ -916,7 +912,6 @@ def test_api_security_schemas(
content_type="application/json",
)
assert asm_config._api_security_enabled == apisec_enabled
assert asm_config._api_security_sample_rate == sample_rate

assert self.status(response) == 403 if blocked else 200
assert get_tag(http.STATUS_CODE) == "403" if blocked else "200"
Expand All @@ -925,7 +920,7 @@ def test_api_security_schemas(
else:
assert get_triggers(root_span()) is None
value = get_tag(name)
if apisec_enabled and sample_rate:
if apisec_enabled:
assert value, name
api = json.loads(gzip.decompress(base64.b64decode(value)).decode())
assert api, name
Expand Down Expand Up @@ -976,6 +971,42 @@ def test_api_security_scanners(self, interface: Interface, get_tag, apisec_enabl
else:
assert value is None

@pytest.mark.parametrize("apisec_enabled", [True, False])
@pytest.mark.parametrize("priority", ["keep", "drop"])
def test_api_security_sampling(self, interface: Interface, get_tag, apisec_enabled, priority):
from ddtrace.ext import http

payload = {"mastercard": "5123456789123456"}
with override_global_config(dict(_asm_enabled=True, _api_security_enabled=apisec_enabled)):
self.update_tracer(interface)
response = interface.client.post(
f"/asm/?priority={priority}",
data=json.dumps(payload),
content_type="application/json",
)
assert self.status(response) == 200
assert get_tag(http.STATUS_CODE) == "200"
assert asm_config._api_security_enabled == apisec_enabled

value = get_tag("_dd.appsec.s.req.body")
if apisec_enabled and priority == "keep":
assert value
else:
assert value is None
# second request must be ignored
self.update_tracer(interface)
response = interface.client.post(
f"/asm/?priority={priority}",
data=json.dumps(payload),
content_type="application/json",
)
assert self.status(response) == 200
assert get_tag(http.STATUS_CODE) == "200"
assert asm_config._api_security_enabled == apisec_enabled

value = get_tag("_dd.appsec.s.req.body")
assert value is None

def test_request_invalid_rule_file(self, interface):
"""
When the rule file is invalid, the tracer should not crash or prevent normal behavior
Expand Down

0 comments on commit 8f58003

Please sign in to comment.