Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Invalid Tokens from cache when found #1333

Merged
merged 7 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
/**
* This class is used to check the validity of a http request. It has methods that extract the
* bearer token, check if the token is empty or null, and if the token is valid. For example,
* expired tokens, empty tokens, or tokens not signed by our private key, will be invalid.
* expired tokens, empty tokens, or tokens not signed by our private key, will be invalid. Tokens
* are cached on first use, and removed if invalid.
*/
public class AuthRequestValidator {

private static final AuthRequestValidator INSTANCE = new AuthRequestValidator();

@Inject private AuthEngine jwtEngine;
@Inject private Cache keyCache;
@Inject Cache keyCache;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this is the only inject we actually use in this class. Would a small doc be good?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? I see jwtEngine, secrets, and logger being used below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some extra comments on the class. I think all the injects are used unless I'm misunderstanding what you mean.

@Inject private Secrets secrets;
@Inject private Logger logger;

String ourPublicKey = "trusted-intermediary-public-key-" + ApplicationContext.getEnvironment();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this public key name hardcoded here? And if so do we need a #pragma allow secret

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It previously existed, but I moved it to a higher level so it's accessible in tests. I can add the #pragma allow secrets if we want but I also don't think it got flagged before as a secret.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool


private AuthRequestValidator() {}

public static AuthRequestValidator getInstance() {
Expand All @@ -49,12 +52,12 @@ public boolean isValidAuthenticatedRequest(DomainRequest request)
return true;
} catch (InvalidTokenException e) {
logger.logError("Invalid bearer token!", e);
this.keyCache.remove(ourPublicKey);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a log statement before removing the key from the cache to aid in debugging and operational monitoring. [important]

return false;
}
}

protected String retrievePublicKey() throws SecretRetrievalException {
var ourPublicKey = "trusted-intermediary-public-key-" + ApplicationContext.getEnvironment();
String key = this.keyCache.get(ourPublicKey);
if (key != null) {
return key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import gov.hhs.cdc.trustedintermediary.external.inmemory.KeyCache
import gov.hhs.cdc.trustedintermediary.external.jjwt.JjwtEngine
import gov.hhs.cdc.trustedintermediary.wrappers.AuthEngine
import gov.hhs.cdc.trustedintermediary.wrappers.Cache
import gov.hhs.cdc.trustedintermediary.wrappers.InvalidTokenException
import gov.hhs.cdc.trustedintermediary.wrappers.Secrets
import spock.lang.Specification

Expand Down Expand Up @@ -227,21 +226,17 @@ class AuthRequestValidatorTest extends Specification{
def validator = AuthRequestValidator.getInstance()
def token = "fake-token-here"
def header = Map.of("Authorization", "Bearer " + token)
def mockEngine = Mock(JjwtEngine)
def mockCache = Mock(KeyCache)
def request = new DomainRequest()
def expected = false
TestApplicationContext.register(Cache, mockCache)
TestApplicationContext.register(AuthEngine, mockEngine)
TestApplicationContext.register(Cache, KeyCache.getInstance())
TestApplicationContext.injectRegisteredImplementations()

when:
request.setHeaders(header)
mockCache.get(_ as String) >> {"my-fake-private-key"}
mockEngine.validateToken(_ as String, _ as String) >> { throw new InvalidTokenException(new Throwable("fake exception"))}
def actual = validator.isValidAuthenticatedRequest(request)

then:
actual == expected
validator.keyCache.get(validator.ourPublicKey) == null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,9 @@ public void put(String key, String value) {
public String get(String key) {
return keys.get(key);
}

@Override
public void remove(String key) {
keys.remove(key);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ public interface Cache {
void put(String key, String value);

String get(String key);

void remove(String key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,18 @@ class KeyCacheTest extends Specification {
keys.values().toSet().size() == 1 // all entries have same value, threads had to wait on the lock

}

def "keyCache removal works"() {
given:
def cache = KeyCache.getInstance()
def value = "fake_key"
def key = "report_stream"
def expected = null
when:
cache.put(key, value)
cache.remove(key)
def actual = cache.get(key)
then:
actual == expected
}
}