Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] auth support for paasta APIs #1005

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/api/auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest

from tron.api.auth import AuthorizationFilter
from tron.api.auth import AuthorizationOutcome


@pytest.fixture
def mock_auth_filter():
with patch("tron.api.auth.requests"):
yield AuthorizationFilter("http://localhost:31337/whatever", True)


def mock_request(path: str, token: str, method: str):
res = MagicMock(path=path.encode(), method=method.encode())
res.getHeader.return_value = token
return res


def test_is_request_authorized(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {
"result": {"allowed": True, "reason": "User allowed"}
}
assert mock_auth_filter.is_request_authorized(
mock_request("/allowed", "aaa.bbb.ccc", "get")
) == AuthorizationOutcome(True, "User allowed")
mock_auth_filter.session.post.assert_called_once_with(
url="http://localhost:31337/whatever",
json={
"input": {
"path": "/allowed",
"backend": "tron",
"token": "aaa.bbb.ccc",
"method": "get",
}
},
timeout=2,
)


def test_is_request_authorized_fail(mock_auth_filter):
mock_auth_filter.session.post.side_effect = Exception
assert mock_auth_filter.is_request_authorized(
mock_request("/allowed", "eee.ddd.fff", "get")
) == AuthorizationOutcome(False, "Auth backend error")


def test_is_request_authorized_malformed(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {"foo": "bar"}
assert mock_auth_filter.is_request_authorized(
mock_request("/allowed", "eee.ddd.fff", "post")
) == AuthorizationOutcome(False, "Malformed auth response")


def test_is_request_authorized_no_enforce(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {
"result": {"allowed": False, "reason": "Missing token"}
}
with patch.object(mock_auth_filter, "enforce", False):
assert mock_auth_filter.is_request_authorized(mock_request("/foobar", "", "post")) == AuthorizationOutcome(
True, "Auth dry-run"
)


def test_is_request_authorized_disabled(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {
"result": {"allowed": False, "reason": "Missing token"}
}
with patch.object(mock_auth_filter, "endpoint", None):
assert mock_auth_filter.is_request_authorized(mock_request("/buzz", "", "post")) == AuthorizationOutcome(
True, "Auth not enabled"
)
94 changes: 94 additions & 0 deletions tron/api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
import os
from functools import lru_cache
from typing import NamedTuple

import cachetools.func
import requests
from twisted.web.server import Request


logger = logging.getLogger(__name__)
AUTH_CACHE_SIZE = 50000
AUTH_CACHE_TTL = 30 * 60


class AuthorizationOutcome(NamedTuple):
authorized: bool
reason: str


class AuthorizationFilter:
"""API request authorization via external system"""

def __init__(self, endpoint: str, enforce: bool):
"""Constructor

:param str endpoint: HTTP endpoint of external authorization system
:param bool enforce: whether to enforce authorization decisions
"""
self.endpoint = endpoint
self.enforce = enforce
self.session = requests.Session()

@classmethod
@lru_cache(maxsize=1)
def get_from_env(cls) -> "AuthorizationFilter":
return cls(
endpoint=os.getenv("API_AUTH_ENDPOINT", ""),
enforce=bool(os.getenv("API_AUTH_ENFORCE", "")),
)

def is_request_authorized(self, request: Request) -> AuthorizationOutcome:
"""Check if API request is authorized

:param Request request: API request object
:return: auth outcome
"""
if not self.endpoint:
return AuthorizationOutcome(True, "Auth not enabled")
token = (request.getHeader("Authorization") or "").strip()
token = token.split()[-1] if token else "" # removes "Bearer" prefix
auth_outcome = self._is_request_authorized_impl(
# path and method are byte arrays in twisted
path=request.path.decode(),
token=token,
method=request.method.decode(),
)
return auth_outcome if self.enforce else AuthorizationOutcome(True, "Auth dry-run")

@cachetools.func.ttl_cache(maxsize=AUTH_CACHE_SIZE, ttl=AUTH_CACHE_TTL)
def _is_request_authorized_impl(self, path: str, token: str, method: str) -> AuthorizationOutcome:
"""Check if API request is authorized

:param str path: API path
:param str token: authentication token
:param str method: http method
:return: auth outcome
"""
try:
response = self.session.post(
url=self.endpoint,
json={
"input": {
"path": path,
"backend": "tron",
"token": token,
"method": method.lower(),
},
},
timeout=2,
).json()
except Exception as e:
logger.exception(f"Issue communicating with auth endpoint: {e}")
return AuthorizationOutcome(False, "Auth backend error")

if "result" not in response or "allowed" not in response["result"]:
return AuthorizationOutcome(False, "Malformed auth response")

if not response["result"]["allowed"]:
reason = response["result"].get("reason", "Denied")
return AuthorizationOutcome(False, reason)

reason = response["result"].get("reason", "Ok")
return AuthorizationOutcome(True, reason)
13 changes: 13 additions & 0 deletions tron/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tron.api import adapter, controller
from tron.api import requestargs
from tron.api.async_resource import AsyncResource
from tron.api.auth import AuthorizationFilter
from tron.metrics import view_all_metrics
from tron.metrics import meter
from tron.utils import maybe_decode
Expand Down Expand Up @@ -514,6 +515,18 @@ def render_GET(self, request):
}
return respond(request=request, response=response)

def render(self, request):
"""Overriding base `render` method to support auth"""
auth_outcome = AuthorizationFilter.get_from_env().is_request_authorized(request)
if not auth_outcome.authorized:
return respond(
request=request,
response={"reason": auth_outcome.reason},
code=http.FORBIDDEN,
headers={"X-Auth-Failure-Reason": auth_outcome.reason},
)
return super().render(request)


class RootResource(resource.Resource):
def __init__(self, mcp, web_path):
Expand Down
13 changes: 13 additions & 0 deletions tron/commands/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,22 @@ class RequestError(ValueError):
}


def get_sso_auth_token() -> str:
"""Generate an authentication token for the calling user from the Single Sign On provider, if configured"""
from okta_auth import get_and_cache_jwt_default # type: ignore
from tron.commands.cmd_utils import get_client_config

client_id = get_client_config().get("auth_sso_oidc_client_id")
return get_and_cache_jwt_default(client_id) if client_id else "" # type: ignore


def build_url_request(uri, data, headers=None, method=None):
headers = headers or default_headers
enc_data = urllib.parse.urlencode(data).encode() if data else None
if os.getenv("TRONCTL_API_AUTH") and (data or method.upper() == "POST"):
token = get_sso_auth_token()
if token:
headers["Authorization"] = f"Bearer {token}"
return urllib.request.Request(uri, enc_data, headers=headers, method=method)


Expand Down
3 changes: 3 additions & 0 deletions yelp_package/extra_requirements_yelp.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ clusterman-metrics==2.2.1 # used by tron for pre-scaling for Spark runs
dateglob==1.1.1 # required by yelp-logging
geogrid==2.1.0 # required by yelp-logging
monk==3.0.4 # required by yelp-clog
okta-auth==1.0.1 # used for API auth
ply==3.11 # required by thriftpy2
pyjwt==2.9.0 # required by okta-auth
saml-helper==2.3.3 # required by okta-auth
scribereader==1.1.1 # used by tron to get tronjob logs
simplejson==3.19.2 # required by yelp-logging
srv-configs==1.3.4 # required by monk
Expand Down
Loading