Skip to content

Commit

Permalink
Pass ACI to captcha checker
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya-signal committed Oct 31, 2024
1 parent ce0ccf4 commit 190f2a7
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import javax.ws.rs.BadRequestException;
import org.slf4j.Logger;
Expand Down Expand Up @@ -42,6 +44,7 @@ public CaptchaChecker(
/**
* Check if a solved captcha should be accepted
*
* @param maybeAci optional account UUID of the user solving the captcha
* @param expectedAction the {@link Action} for which this captcha solution is intended
* @param input expected to contain a prefix indicating the captcha scheme, sitekey, token, and action. The
* expected format is {@code version-prefix.sitekey.action.token}
Expand All @@ -53,6 +56,7 @@ public CaptchaChecker(
* @throws BadRequestException if input is not in the expected format
*/
public AssessmentResult verify(
final Optional<UUID> maybeAci,
final Action expectedAction,
final String input,
final String ip,
Expand Down Expand Up @@ -100,7 +104,7 @@ public AssessmentResult verify(
throw new BadRequestException("invalid captcha site-key");
}

final AssessmentResult result = client.verify(siteKey, parsedAction, token, ip, userAgent);
final AssessmentResult result = client.verify(maybeAci, siteKey, parsedAction, token, ip, userAgent);
Metrics.counter(ASSESSMENTS_COUNTER_NAME,
"action", action,
"score", result.getScoreString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
package org.whispersystems.textsecuregcm.captcha;

import java.io.IOException;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

public interface CaptchaClient {

Expand All @@ -27,6 +27,7 @@ public interface CaptchaClient {
/**
* Verify a provided captcha solution
*
* @param maybeAci optional account service identifier of the user
* @param siteKey identifying string for the captcha service
* @param action an action indicating the purpose of the captcha
* @param token the captcha solution that will be verified
Expand All @@ -36,6 +37,7 @@ public interface CaptchaClient {
* @throws IOException if the underlying captcha provider returns an error
*/
AssessmentResult verify(
final Optional<UUID> maybeAci,
final String siteKey,
final Action action,
final String token,
Expand All @@ -55,7 +57,7 @@ public Set<String> validSiteKeys(final Action action) {
}

@Override
public AssessmentResult verify(final String siteKey, final Action action, final String token, final String ip,
public AssessmentResult verify(final Optional<UUID> maybeAci, final String siteKey, final Action action, final String token, final String ip,
final String userAgent) throws IOException {
return AssessmentResult.alwaysValid();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.io.IOException;
import java.util.Optional;
import java.util.UUID;

public class RegistrationCaptchaManager {

Expand All @@ -17,10 +18,10 @@ public RegistrationCaptchaManager(final CaptchaChecker captchaChecker) {
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Optional<AssessmentResult> assessCaptcha(final Optional<String> captcha, final String sourceHost, final String userAgent)
public Optional<AssessmentResult> assessCaptcha(final Optional<UUID> aci, final Optional<String> captcha, final String sourceHost, final String userAgent)
throws IOException {
return captcha.isPresent()
? Optional.of(captchaChecker.verify(Action.REGISTRATION, captcha.get(), sourceHost, userAgent))
? Optional.of(captchaChecker.verify(aci, Action.REGISTRATION, captcha.get(), sourceHost, userAgent))
: Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ private VerificationSession handleCaptcha(
try {

assessmentResult = registrationCaptchaManager.assessCaptcha(
Optional.empty(),
Optional.of(updateVerificationSessionRequest.captcha()), sourceHost, userAgent)
.orElseThrow(() -> new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static com.codahale.metrics.MetricRegistry.name;

import com.google.i18n.phonenumbers.NumberParseException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
Expand All @@ -16,12 +19,15 @@
import org.whispersystems.textsecuregcm.captcha.Action;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.ChallengeType;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util;
import javax.ws.rs.ServerErrorException;
import javax.ws.rs.core.Response;

public class RateLimitChallengeManager {

Expand Down Expand Up @@ -67,7 +73,7 @@ public boolean answerCaptchaChallenge(final Account account, final String captch

rateLimiters.getCaptchaChallengeAttemptLimiter().validate(account.getUuid());

final boolean challengeSuccess = captchaChecker.verify(Action.CHALLENGE, captcha, mostRecentProxyIp, userAgent).isValid(scoreThreshold);
final boolean challengeSuccess = captchaChecker.verify(Optional.of(account.getUuid()), Action.CHALLENGE, captcha, mostRecentProxyIp, userAgent).isValid(scoreThreshold);

final Tags tags = Tags.of(
Tag.of(SOURCE_COUNTRY_TAG_NAME, Util.getCountryCode(account.getNumber())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import javax.ws.rs.BadRequestException;

import com.google.i18n.phonenumbers.NumberParseException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
Expand All @@ -36,6 +41,7 @@ public class CaptchaCheckerTest {
private static final String PREFIX_A = "prefix-a";
private static final String PREFIX_B = "prefix-b";
private static final String USER_AGENT = "user-agent";
private static final UUID ACI = UUID.randomUUID();

static Stream<Arguments> parseInputToken() {
return Stream.of(
Expand Down Expand Up @@ -67,7 +73,7 @@ private static CaptchaClient mockClient(final String prefix) throws IOException
when(captchaClient.scheme()).thenReturn(prefix);
when(captchaClient.validSiteKeys(eq(Action.CHALLENGE))).thenReturn(Collections.singleton(CHALLENGE_SITE_KEY));
when(captchaClient.validSiteKeys(eq(Action.REGISTRATION))).thenReturn(Collections.singleton(REG_SITE_KEY));
when(captchaClient.verify(any(), any(), any(), any(), any())).thenReturn(AssessmentResult.invalid());
when(captchaClient.verify(any(), any(), any(), any(), any(), any())).thenReturn(AssessmentResult.invalid());
return captchaClient;
}

Expand All @@ -80,8 +86,8 @@ void parseInputToken(
final String siteKey,
final Action expectedAction) throws IOException {
final CaptchaClient captchaClient = mockClient(PREFIX);
new CaptchaChecker(null, PREFIX -> captchaClient).verify(expectedAction, input, null, USER_AGENT);
verify(captchaClient, times(1)).verify(eq(siteKey), eq(expectedAction), eq(expectedToken), any(), eq(USER_AGENT));
new CaptchaChecker(null, PREFIX -> captchaClient).verify(Optional.empty(), expectedAction, input, null, USER_AGENT);
verify(captchaClient, times(1)).verify(any(), eq(siteKey), eq(expectedAction), eq(expectedToken), any(), eq(USER_AGENT));
}

@ParameterizedTest
Expand Down Expand Up @@ -109,11 +115,11 @@ public void choose() throws IOException {
final CaptchaClient b = mockClient(PREFIX_B);
final Map<String, CaptchaClient> captchaClientMap = Map.of(PREFIX_A, a, PREFIX_B, b);

new CaptchaChecker(null, captchaClientMap::get).verify(Action.CHALLENGE, ainput, null, USER_AGENT);
verify(a, times(1)).verify(any(), any(), any(), any(), any());
new CaptchaChecker(null, captchaClientMap::get).verify(Optional.of(ACI), Action.CHALLENGE, ainput, null, USER_AGENT);
verify(a, times(1)).verify(any(), any(), any(), any(), any(), any());

new CaptchaChecker(null, captchaClientMap::get).verify(Action.CHALLENGE, binput, null, USER_AGENT);
verify(b, times(1)).verify(any(), any(), any(), any(), any());
new CaptchaChecker(null, captchaClientMap::get).verify(Optional.of(ACI), Action.CHALLENGE, binput, null, USER_AGENT);
verify(b, times(1)).verify(any(), any(), any(), any(), any(), any());
}

static Stream<Arguments> badArgs() {
Expand All @@ -134,7 +140,7 @@ static Stream<Arguments> badArgs() {
public void badArgs(final String input) throws IOException {
final CaptchaClient cc = mockClient(PREFIX);
assertThrows(BadRequestException.class,
() -> new CaptchaChecker(null, prefix -> PREFIX.equals(prefix) ? cc : null).verify(Action.CHALLENGE, input, null, USER_AGENT));
() -> new CaptchaChecker(null, prefix -> PREFIX.equals(prefix) ? cc : null).verify(Optional.of(ACI), Action.CHALLENGE, input, null, USER_AGENT));

}

Expand All @@ -144,8 +150,8 @@ public void testShortened() throws IOException {
final ShortCodeExpander retriever = mock(ShortCodeExpander.class);
when(retriever.retrieve("abc")).thenReturn(Optional.of(TOKEN));
final String input = String.join(SEPARATOR, PREFIX + "-short", REG_SITE_KEY, "registration", "abc");
new CaptchaChecker(retriever, ignored -> captchaClient).verify(Action.REGISTRATION, input, null, USER_AGENT);
verify(captchaClient, times(1)).verify(eq(REG_SITE_KEY), eq(Action.REGISTRATION), eq(TOKEN), any(), any());
new CaptchaChecker(retriever, ignored -> captchaClient).verify(Optional.of(ACI), Action.REGISTRATION, input, null, USER_AGENT);
verify(captchaClient, times(1)).verify(any(), eq(REG_SITE_KEY), eq(Action.REGISTRATION), eq(TOKEN), any(), any());

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ void patchSessionCaptchaInvalid() throws Exception {
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));

when(registrationCaptchaManager.assessCaptcha(any(), any(), any()))
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenReturn(Optional.of(AssessmentResult.invalid()));

when(verificationSessionManager.update(any(), any()))
Expand Down Expand Up @@ -637,7 +637,7 @@ void patchSessionCaptchaSuccess() throws Exception {
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));

when(registrationCaptchaManager.assessCaptcha(any(), any(), any()))
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenReturn(Optional.of(AssessmentResult.alwaysValid()));

when(verificationSessionManager.update(any(), any()))
Expand Down Expand Up @@ -685,7 +685,7 @@ void patchSessionPushAndCaptchaSuccess() throws Exception {
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));

when(registrationCaptchaManager.assessCaptcha(any(), any(), any()))
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenReturn(Optional.of(AssessmentResult.alwaysValid()));

when(verificationSessionManager.update(any(), any()))
Expand Down Expand Up @@ -732,7 +732,7 @@ void patchSessionTokenUpdatedCaptchaError() throws Exception {
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));

when(registrationCaptchaManager.assessCaptcha(any(), any(), any()))
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenThrow(new IOException("expected service error"));

when(verificationSessionManager.update(any(), any()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void answerCaptchaChallenge(Optional<Float> scoreThreshold, float actualScore, b
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());

when(captchaChecker.verify(eq(Action.CHALLENGE), any(), any(), any()))
when(captchaChecker.verify(any(), eq(Action.CHALLENGE), any(), any(), any()))
.thenReturn(AssessmentResult.fromScore(actualScore, DEFAULT_SCORE_THRESHOLD));

when(rateLimiters.getCaptchaChallengeAttemptLimiter()).thenReturn(mock(RateLimiter.class));
Expand Down

0 comments on commit 190f2a7

Please sign in to comment.