Skip to content

Commit

Permalink
chore(iast): taint parameter name and header name in fastapi (#12009)
Browse files Browse the repository at this point in the history
Code security: Taint FastAPI parameter name and header name in each
request.

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
avara1986 authored Jan 22, 2025
1 parent 922c71b commit 6a10743
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 9 deletions.
38 changes: 37 additions & 1 deletion ddtrace/appsec/_iast/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def if_iast_taint_yield_tuple_for(origins, wrapped, instance, args, kwargs):

def if_iast_taint_returned_object_for(origin, wrapped, instance, args, kwargs):
value = wrapped(*args, **kwargs)

if _is_iast_enabled() and is_iast_request_enabled():
try:
if not is_pyobject_tainted(value):
Expand All @@ -310,6 +309,29 @@ def if_iast_taint_returned_object_for(origin, wrapped, instance, args, kwargs):
return value


def if_iast_taint_starlette_datastructures(origin, wrapped, instance, args, kwargs):
value = wrapped(*args, **kwargs)
if _is_iast_enabled() and is_iast_request_enabled():
try:
res = []
for element in value:
if not is_pyobject_tainted(element):
res.append(
taint_pyobject(
pyobject=element,
source_name=origin_to_str(origin),
source_value=element,
source_origin=origin,
)
)
else:
res.append(element)
return res
except Exception:
log.debug("Unexpected exception while tainting pyobject", exc_info=True)
return value


def _on_iast_fastapi_patch():
# Cookies sources
try_wrap_function_wrapper(
Expand All @@ -333,6 +355,13 @@ def _on_iast_fastapi_patch():
)
_set_metric_iast_instrumented_source(OriginType.PARAMETER)

try_wrap_function_wrapper(
"starlette.datastructures",
"QueryParams.keys",
functools.partial(if_iast_taint_starlette_datastructures, OriginType.PARAMETER_NAME),
)
_set_metric_iast_instrumented_source(OriginType.PARAMETER_NAME)

# Header sources
try_wrap_function_wrapper(
"starlette.datastructures",
Expand All @@ -346,6 +375,13 @@ def _on_iast_fastapi_patch():
)
_set_metric_iast_instrumented_source(OriginType.HEADER)

try_wrap_function_wrapper(
"starlette.datastructures",
"Headers.keys",
functools.partial(if_iast_taint_starlette_datastructures, OriginType.HEADER_NAME),
)
_set_metric_iast_instrumented_source(OriginType.HEADER_NAME)

# Path source
try_wrap_function_wrapper("starlette.datastructures", "URL.__init__", _iast_instrument_starlette_url)
_set_metric_iast_instrumented_source(OriginType.PATH)
Expand Down
74 changes: 66 additions & 8 deletions tests/contrib/fastapi/test_fastapi_appsec_iast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ddtrace.appsec._iast import oce
from ddtrace.appsec._iast._handlers import _on_iast_fastapi_patch
from ddtrace.appsec._iast._patch_modules import patch_iast
from ddtrace.appsec._iast._taint_tracking import origin_to_str
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
from ddtrace.appsec._iast.constants import VULN_HEADER_INJECTION
from ddtrace.appsec._iast.constants import VULN_INSECURE_COOKIE
from ddtrace.appsec._iast.constants import VULN_NO_HTTPONLY_COOKIE
Expand All @@ -34,8 +36,6 @@
TEST_FILE_PATH = "tests/contrib/fastapi/test_fastapi_appsec_iast.py"

fastapi_version = tuple([int(v) for v in _fastapi_version.split(".")])
if sys.version_info > (3, 12):
pytest.skip(reason="IAST only supports Py3.12 and older", allow_module_level=True)


def _aux_appsec_prepare_tracer(tracer):
Expand Down Expand Up @@ -78,9 +78,6 @@ def check_native_code_exception_in_each_fastapi_test(request, caplog, telemetry_
def test_query_param_source(fastapi_application, client, tracer, test_spans):
@fastapi_application.get("/index.html")
async def test_route(request: Request):
from ddtrace.appsec._iast._taint_tracking import origin_to_str
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges

query_params = request.query_params.get("iast_queryparam")
ranges_result = get_tainted_ranges(query_params)

Expand Down Expand Up @@ -110,12 +107,41 @@ async def test_route(request: Request):
assert result["ranges_origin"] == "http.request.parameter"


def test_header_value_source(fastapi_application, client, tracer, test_spans):
def test_query_param_name_source(fastapi_application, client, tracer, test_spans):
@fastapi_application.get("/index.html")
async def test_route(request: Request):
from ddtrace.appsec._iast._taint_tracking import origin_to_str
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
query_params = [k for k in request.query_params.keys() if k == "iast_queryparam"][0]
ranges_result = get_tainted_ranges(query_params)

return JSONResponse(
{
"result": query_params,
"is_tainted": len(ranges_result),
"ranges_start": ranges_result[0].start,
"ranges_length": ranges_result[0].length,
"ranges_origin": origin_to_str(ranges_result[0].source.origin),
}
)

with override_global_config(dict(_iast_enabled=True, _iast_request_sampling=100.0)):
# disable callback
_aux_appsec_prepare_tracer(tracer)
resp = client.get(
"/index.html?iast_queryparam=test1234",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200
result = json.loads(get_response_body(resp))
assert result["result"] == "iast_queryparam"
assert result["is_tainted"] == 1
assert result["ranges_start"] == 0
assert result["ranges_length"] == 15
assert result["ranges_origin"] == "http.request.parameter.name"


def test_header_value_source(fastapi_application, client, tracer, test_spans):
@fastapi_application.get("/index.html")
async def test_route(request: Request):
query_params = request.headers.get("iast_header")
ranges_result = get_tainted_ranges(query_params)

Expand Down Expand Up @@ -145,6 +171,38 @@ async def test_route(request: Request):
assert result["ranges_origin"] == "http.request.header"


def test_header_name_source(fastapi_application, client, tracer, test_spans):
@fastapi_application.get("/index.html")
async def test_route(request: Request):
query_params = [k for k in request.headers.keys() if k == "iast_header"][0]
ranges_result = get_tainted_ranges(query_params)

return JSONResponse(
{
"result": query_params,
"is_tainted": len(ranges_result),
"ranges_start": ranges_result[0].start,
"ranges_length": ranges_result[0].length,
"ranges_origin": origin_to_str(ranges_result[0].source.origin),
}
)

with override_global_config(dict(_iast_enabled=True, _iast_request_sampling=100.0)):
# disable callback
_aux_appsec_prepare_tracer(tracer)
resp = client.get(
"/index.html",
headers={"iast_header": "test1234"},
)
assert resp.status_code == 200
result = json.loads(get_response_body(resp))
assert result["result"] == "iast_header"
assert result["is_tainted"] == 1
assert result["ranges_start"] == 0
assert result["ranges_length"] == 11
assert result["ranges_origin"] == "http.request.header.name"


@pytest.mark.skipif(sys.version_info < (3, 9), reason="typing.Annotated was introduced on 3.9")
@pytest.mark.skipif(fastapi_version < (0, 95, 0), reason="Header annotation doesn't work on fastapi 94 or lower")
def test_header_value_source_typing_param(fastapi_application, client, tracer, test_spans):
Expand Down

0 comments on commit 6a10743

Please sign in to comment.