Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement session logout in OAuth2Authenticator #18753

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.server.security.oauth2.OAuth2ServerConfigProvider.OAuth2ServerConfig;
import jakarta.ws.rs.core.UriBuilder;

import java.net.MalformedURLException;
import java.net.URI;
Expand Down Expand Up @@ -101,6 +102,7 @@ public class NimbusOAuth2Client
private URI authUrl;
private URI tokenUrl;
private Optional<URI> userinfoUrl;
private Optional<URI> endSessionUrl;
private JWSKeySelector<SecurityContext> jwsKeySelector;
private JWTProcessor<SecurityContext> accessTokenProcessor;
private AuthorizationCodeFlow flow;
Expand Down Expand Up @@ -128,13 +130,14 @@ public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider s
public void load()
{
OAuth2ServerConfig config = serverConfigurationProvider.get();
this.authUrl = config.getAuthUrl();
this.tokenUrl = config.getTokenUrl();
this.userinfoUrl = config.getUserinfoUrl();
this.authUrl = config.authUrl();
this.tokenUrl = config.tokenUrl();
this.userinfoUrl = config.userinfoUrl();
this.endSessionUrl = config.endSessionUrl();
try {
jwsKeySelector = new JWSVerificationKeySelector<>(
Stream.concat(JWSAlgorithm.Family.RSA.stream(), JWSAlgorithm.Family.EC.stream()).collect(toImmutableSet()),
JWKSourceBuilder.create(config.getJwksUrl().toURL(), httpClient).build());
JWKSourceBuilder.create(config.jwksUrl().toURL(), httpClient).build());
}
catch (MalformedURLException e) {
throw new RuntimeException(e);
Expand All @@ -148,7 +151,7 @@ public void load()
DefaultJWTClaimsVerifier<SecurityContext> accessTokenVerifier = new DefaultJWTClaimsVerifier<>(
accessTokenAudiences,
new JWTClaimsSet.Builder()
.issuer(config.getAccessTokenIssuer().orElse(issuer.getValue()))
.issuer(config.accessTokenIssuer().orElse(issuer.getValue()))
.build(),
ImmutableSet.of(principalField),
ImmutableSet.of());
Expand Down Expand Up @@ -189,6 +192,18 @@ public Response refreshTokens(String refreshToken)
return flow.refreshTokens(refreshToken);
}

@Override
public Optional<URI> getLogoutEndpoint(Optional<String> idToken, URI callbackUrl)
{
if (endSessionUrl.isPresent()) {
UriBuilder builder = UriBuilder.fromUri(endSessionUrl.get());
idToken.ifPresent(token -> builder.queryParam("id_token_hint", token));
builder.queryParam("post_logout_redirect_uri", callbackUrl);
return Optional.of(builder.build());
}
return Optional.empty();
}

private interface AuthorizationCodeFlow
{
Request createAuthorizationRequest(String state, URI callbackUri);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Response getOAuth2Response(String code, URI callbackUri, Optional<String> nonce)
Response refreshTokens(String refreshToken)
throws ChallengeFailedException;

Optional<URI> getLogoutEndpoint(Optional<String> idToken, URI callbackUrl);

class Request
{
private final URI authorizationUri;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,16 @@ public interface OAuth2ServerConfigProvider
{
OAuth2ServerConfig get();

class OAuth2ServerConfig
record OAuth2ServerConfig(Optional<String> accessTokenIssuer, URI authUrl, URI tokenUrl, URI jwksUrl, Optional<URI> userinfoUrl, Optional<URI> endSessionUrl)
{
private final Optional<String> accessTokenIssuer;
private final URI authUrl;
private final URI tokenUrl;
private final URI jwksUrl;
private final Optional<URI> userinfoUrl;

public OAuth2ServerConfig(Optional<String> accessTokenIssuer, URI authUrl, URI tokenUrl, URI jwksUrl, Optional<URI> userinfoUrl)
{
this.accessTokenIssuer = requireNonNull(accessTokenIssuer, "accessTokenIssuer is null");
this.authUrl = requireNonNull(authUrl, "authUrl is null");
this.tokenUrl = requireNonNull(tokenUrl, "tokenUrl is null");
this.jwksUrl = requireNonNull(jwksUrl, "jwksUrl is null");
this.userinfoUrl = requireNonNull(userinfoUrl, "userinfoUrl is null");
}

public Optional<String> getAccessTokenIssuer()
{
return accessTokenIssuer;
}

public URI getAuthUrl()
{
return authUrl;
}

public URI getTokenUrl()
{
return tokenUrl;
}

public URI getJwksUrl()
{
return jwksUrl;
}

public Optional<URI> getUserinfoUrl()
public OAuth2ServerConfig
{
return userinfoUrl;
requireNonNull(accessTokenIssuer, "accessTokenIssuer is null");
requireNonNull(authUrl, "authUrl is null");
requireNonNull(tokenUrl, "tokenUrl is null");
requireNonNull(jwksUrl, "jwksUrl is null");
requireNonNull(userinfoUrl, "userinfoUrl is null");
requireNonNull(endSessionUrl, "endSessionUrl is null");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.trino.server.ui.OAuth2WebUiInstalled;
import io.trino.server.ui.OAuthIdTokenCookie;
import io.trino.server.ui.OAuthWebUiCookie;
import jakarta.ws.rs.core.Response;

Expand Down Expand Up @@ -167,29 +168,31 @@ public Response finishOAuth2Challenge(String state, String code, URI callbackUri
// fetch access token
OAuth2Client.Response oauth2Response = client.getOAuth2Response(code, callbackUri, nonce);

Instant cookieExpirationTime = tokenExpiration
.map(expiration -> Instant.now().plus(expiration))
.orElse(oauth2Response.getExpiration());
if (handlerState.isEmpty()) {
return Response
Response.ResponseBuilder builder = Response
.seeOther(URI.create(UI_LOCATION))
.cookie(
OAuthWebUiCookie.create(
tokenPairSerializer.serialize(
fromOAuth2Response(oauth2Response)),
tokenExpiration
.map(expiration -> Instant.now().plus(expiration))
.orElse(oauth2Response.getExpiration())),
NonceCookie.delete())
.build();
OAuthWebUiCookie.create(tokenPairSerializer.serialize(fromOAuth2Response(oauth2Response)), cookieExpirationTime),
NonceCookie.delete());
if (oauth2Response.getIdToken().isPresent()) {
builder.cookie(OAuthIdTokenCookie.create(oauth2Response.getIdToken().get(), cookieExpirationTime));
}
return builder.build();
}

tokenHandler.setAccessToken(handlerState.get(), tokenPairSerializer.serialize(fromOAuth2Response(oauth2Response)));

Response.ResponseBuilder builder = Response.ok(getSuccessHtml());
if (webUiOAuthEnabled) {
builder.cookie(
OAuthWebUiCookie.create(
tokenPairSerializer.serialize(fromOAuth2Response(oauth2Response)),
tokenExpiration.map(expiration -> Instant.now().plus(expiration))
.orElse(oauth2Response.getExpiration())));
OAuthWebUiCookie.create(tokenPairSerializer.serialize(fromOAuth2Response(oauth2Response)), cookieExpirationTime));

if (oauth2Response.getIdToken().isPresent()) {
builder.cookie(OAuthIdTokenCookie.create(oauth2Response.getIdToken().get(), cookieExpirationTime));
}
}
return builder.cookie(NonceCookie.delete()).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static io.airlift.http.client.HttpStatus.TOO_MANY_REQUESTS;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.ACCESS_TOKEN_ISSUER;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.AUTH_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.END_SESSION_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.JWKS_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.TOKEN_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.USERINFO_URL;
Expand Down Expand Up @@ -114,6 +115,7 @@ private OAuth2ServerConfig readConfiguration(String body)
else {
userinfoEndpoint = Optional.empty();
}
Optional<URI> endSessionEndpoint = Optional.of(getRequiredField("end_session_endpoint", metadata.getEndSessionEndpointURI(), END_SESSION_URL, Optional.empty()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should provide an override value for END_SESSION_URL which comes from the OidcDiscoveryConfig. Similar to the other URL values. This allows non-standard endpoint values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point but OidcDiscoveryConfig actually allows providing urls explicitly (effectively overriding those obtained from metadata discovery) for backward compatibility, allowing for non-standard endpoint values wasn't the objective. We intend to drop these properties #18101
That said, is it something you think we should support? Is this something you use these properties for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This END_SESSION_URL, Optional.empty()) becomes this END_SESSION_URL, logoutUrl), where logoutUrl is set in the constructor

logoutUrl = requireNonNull(oidcConfig.getEndSessionUrl(), "logoutUrl is null");

I had to make these changes for a customer, who is using Ping. The "standard" endpoint found with end_session_endpoint is not being found because the well-known endpoint is returning ping_end_session_endpoint. So the customer had to be able to specify the logout URL explicitly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we disable the OIDC discovery then we could inject them via StaticConfigurationProvider - we could provide custom end-points right ? cc : @lukasz-walkiewicz Are we removing that endpoint also ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, StaticConfigurationProvider is here to stay. I think we ultimately would like to remove this overriding some day and users should provide configuration explicitly if the IdP does not conform OIDC spec.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thorbjornsen In this case we could still use StaticOAuth2ServerConfiguration but we need to specify the configuration explicitly.

return new OAuth2ServerConfig(
// AD FS server can include "access_token_issuer" field in OpenID Provider Metadata.
// It's not a part of the OIDC standard thus have to be handled separately.
Expand All @@ -122,7 +124,8 @@ private OAuth2ServerConfig readConfiguration(String body)
getRequiredField("authorization_endpoint", metadata.getAuthorizationEndpointURI(), AUTH_URL, authUrl),
getRequiredField("token_endpoint", metadata.getTokenEndpointURI(), TOKEN_URL, tokenUrl),
getRequiredField("jwks_uri", metadata.getJWKSetURI(), JWKS_URL, jwksUrl),
userinfoEndpoint.map(URI::create));
userinfoEndpoint.map(URI::create),
endSessionEndpoint);
}
catch (JsonProcessingException e) {
throw new ParseException("Invalid JSON value", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public class StaticConfigurationProvider
URI.create(config.getAuthUrl()),
URI.create(config.getTokenUrl()),
URI.create(config.getJwksUrl()),
config.getUserinfoUrl().map(URI::create));
config.getUserinfoUrl().map(URI::create),
config.getEndSessionUrl().map(URI::create));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ public class StaticOAuth2ServerConfiguration
public static final String TOKEN_URL = "http-server.authentication.oauth2.token-url";
public static final String JWKS_URL = "http-server.authentication.oauth2.jwks-url";
public static final String USERINFO_URL = "http-server.authentication.oauth2.userinfo-url";
public static final String END_SESSION_URL = "http-server.authentication.oauth2.end-session-url";

private Optional<String> accessTokenIssuer = Optional.empty();
private String authUrl;
private String tokenUrl;
private String jwksUrl;
private Optional<String> userinfoUrl = Optional.empty();
private Optional<String> endSessionUrl = Optional.empty();

@NotNull
public Optional<String> getAccessTokenIssuer()
Expand Down Expand Up @@ -101,4 +103,17 @@ public StaticOAuth2ServerConfiguration setUserinfoUrl(String userinfoUrl)
this.userinfoUrl = Optional.ofNullable(userinfoUrl);
return this;
}

public Optional<String> getEndSessionUrl()
{
return endSessionUrl;
}

@Config(END_SESSION_URL)
@ConfigDescription("URL of the end session endpoint")
public StaticOAuth2ServerConfiguration setEndSessionUrl(String endSessionUrl)
{
this.endSessionUrl = Optional.ofNullable(endSessionUrl);
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static io.trino.server.ui.FormWebUiAuthenticationFilter.DISABLED_LOCATION;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.DISABLED_LOCATION_URI;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.TRINO_FORM_LOGIN;
import static io.trino.server.ui.OAuthIdTokenCookie.ID_TOKEN_COOKIE;
import static io.trino.server.ui.OAuthWebUiCookie.OAUTH2_COOKIE;
import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -163,9 +164,14 @@ private void redirectForNewToken(ContainerRequestContext request, String refresh
{
OAuth2Client.Response response = client.refreshTokens(refreshToken);
String serializedToken = tokenPairSerializer.serialize(TokenPair.fromOAuth2Response(response));
request.abortWith(Response.temporaryRedirect(request.getUriInfo().getRequestUri())
.cookie(OAuthWebUiCookie.create(serializedToken, tokenExpiration.map(expiration -> Instant.now().plus(expiration)).orElse(response.getExpiration())))
.build());
Instant newExpirationTime = tokenExpiration.map(expiration -> Instant.now().plus(expiration)).orElse(response.getExpiration());
Response.ResponseBuilder builder = Response.temporaryRedirect(request.getUriInfo().getRequestUri())
.cookie(OAuthWebUiCookie.create(serializedToken, newExpirationTime));

OAuthIdTokenCookie.read(request.getCookies().get(ID_TOKEN_COOKIE))
.ifPresent(idToken -> builder.cookie(OAuthIdTokenCookie.create(idToken, newExpirationTime)));

request.abortWith(builder.build());
}

private void handleAuthenticationFailure(ContainerRequestContext request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,63 @@
package io.trino.server.ui;

import com.google.common.io.Resources;
import com.google.inject.Inject;
import io.trino.server.security.ResourceSecurity;
import io.trino.server.security.oauth2.OAuth2Client;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.SecurityContext;
import jakarta.ws.rs.core.UriBuilder;
import jakarta.ws.rs.core.UriInfo;

import java.io.IOException;
import java.net.URI;
import java.util.Optional;

import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
import static io.trino.server.security.ResourceSecurity.AccessType.WEB_UI;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGOUT;
import static io.trino.server.ui.OAuthIdTokenCookie.ID_TOKEN_COOKIE;
import static io.trino.server.ui.OAuthWebUiCookie.delete;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

@Path(UI_LOGOUT)
public class OAuth2WebUiLogoutResource
{
private final OAuth2Client auth2Client;

@Inject
public OAuth2WebUiLogoutResource(OAuth2Client auth2Client)
{
this.auth2Client = requireNonNull(auth2Client, "auth2Client is null");
}

@ResourceSecurity(WEB_UI)
@GET
public Response logout(@Context HttpHeaders httpHeaders, @Context UriInfo uriInfo, @Context SecurityContext securityContext)
throws IOException
{
Optional<String> idToken = OAuthIdTokenCookie.read(httpHeaders.getCookies().get(ID_TOKEN_COOKIE));
URI callBackUri = UriBuilder.fromUri(uriInfo.getAbsolutePath())
.path("logout.html")
.build();

return Response.seeOther(auth2Client.getLogoutEndpoint(idToken, callBackUri).orElse(callBackUri))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to call sso logout ourselves behind the scenes instead of depending on user browser's doing that?
Otherwise we could end up in situation where the user doesn't follow that redirect for any reason and the value is still valid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec here mentions that

An RP requests that the OP log out the End-User by redirecting the End-User's User Agent to the OP's Logout Endpoint. This URL is normally obtained via the end_session_endpoint element of the OP's Discovery response or may be learned via other mechanisms.

So I guess we should redirect on the User browser instead of handling it. Additionally each IdP can have its own way to redirecting the user right ?

.cookie(delete(), OAuthIdTokenCookie.delete())
.build();
}

@ResourceSecurity(PUBLIC)
@GET
@Path("/logout.html")
public Response logoutPage(@Context HttpHeaders httpHeaders, @Context UriInfo uriInfo, @Context SecurityContext securityContext)
throws IOException
{
return Response.ok(Resources.toString(Resources.getResource(getClass(), "/oauth2/logout.html"), UTF_8))
.cookie(delete())
.build();
}
}
Loading