diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java index 8e67cddd2..9c17d9661 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java @@ -43,18 +43,13 @@ import java.io.OutputStream; import java.lang.reflect.Type; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.function.Supplier; import java.util.stream.Collectors; -/** - * items: - * - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class. - * - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info - */ - /** Package private internal API. */ final class ConjureBodySerDe implements BodySerDe { @@ -65,7 +60,7 @@ final class ConjureBodySerDe implements BodySerDe { private final Deserializer> optionalBinaryInputStreamDeserializer; private final Deserializer emptyBodyDeserializer; private final LoadingCache> serializers; - private final LoadingCache> deserializers; + private final LoadingCache> deserializers; private final EmptyContainerDeserializer emptyContainerDeserializer; /** @@ -75,7 +70,6 @@ final class ConjureBodySerDe implements BodySerDe { */ ConjureBodySerDe( List rawEncodings, - ErrorDecoder errorDecoder, EmptyContainerDeserializer emptyContainerDeserializer, CaffeineSpec cacheSpec) { List encodings = decorateEncodings(rawEncodings); @@ -83,24 +77,40 @@ final class ConjureBodySerDe implements BodySerDe { Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required"); this.defaultEncoding = encodings.get(0).encoding(); this.emptyContainerDeserializer = emptyContainerDeserializer; - this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( + this.binaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>( ImmutableList.of(BinaryEncoding.INSTANCE), - errorDecoder, emptyContainerDeserializer, - BinaryEncoding.MARKER); - this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( + BinaryEncoding.MARKER, + DeserializerArgs.builder() + .withBaseType(BinaryEncoding.MARKER) + .withExpectedResult(BinaryEncoding.MARKER) + .build()); + this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>( ImmutableList.of(BinaryEncoding.INSTANCE), - errorDecoder, emptyContainerDeserializer, - BinaryEncoding.OPTIONAL_MARKER); - this.emptyBodyDeserializer = new EmptyBodyDeserializer(errorDecoder); + BinaryEncoding.OPTIONAL_MARKER, + DeserializerArgs.>builder() + .withBaseType(BinaryEncoding.OPTIONAL_MARKER) + .withExpectedResult(BinaryEncoding.OPTIONAL_MARKER) + .build()); + this.emptyBodyDeserializer = + new EmptyBodyDeserializer(new EndpointErrorDecoder<>(Collections.emptyMap(), Optional.empty())); // Class unloading: Not supported, Jackson keeps strong references to the types // it sees: https://github.com/FasterXML/jackson-databind/issues/489 this.serializers = Caffeine.from(cacheSpec) .build(type -> new EncodingSerializerRegistry<>(defaultEncoding, TypeMarker.of(type))); - this.deserializers = Caffeine.from(cacheSpec) - .build(type -> new EncodingDeserializerRegistry<>( - encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type))); + this.deserializers = Caffeine.from(cacheSpec).build(type -> buildCacheEntry(TypeMarker.of(type))); + } + + private EncodingDeserializerForEndpointRegistry buildCacheEntry(TypeMarker typeMarker) { + return new EncodingDeserializerForEndpointRegistry<>( + encodingsSortedByWeight, + emptyContainerDeserializer, + typeMarker, + DeserializerArgs.builder() + .withBaseType(typeMarker) + .withExpectedResult(typeMarker) + .build()); } private static List decorateEncodings(List input) { @@ -235,108 +245,7 @@ private static final class EncodingSerializerContainer { } } - private static final class EncodingDeserializerRegistry implements Deserializer { - - private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerRegistry.class); - private final ImmutableList> encodings; - private final ErrorDecoder errorDecoder; - private final Optional acceptValue; - private final Supplier> emptyInstance; - private final TypeMarker token; - - EncodingDeserializerRegistry( - List encodings, - ErrorDecoder errorDecoder, - EmptyContainerDeserializer empty, - TypeMarker token) { - this.encodings = encodings.stream() - .map(encoding -> new EncodingDeserializerContainer<>(encoding, token)) - .collect(ImmutableList.toImmutableList()); - this.errorDecoder = errorDecoder; - this.token = token; - this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token)); - // Encodings are applied to the accept header in the order of preference based on the provided list. - this.acceptValue = - Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", "))); - } - - @Override - public T deserialize(Response response) { - boolean closeResponse = true; - try { - if (errorDecoder.isError(response)) { - throw errorDecoder.decode(response); - } else if (response.code() == 204) { - // TODO(dfox): what if we get a 204 for a non-optional type??? - // TODO(dfox): support http200 & body=null - // TODO(dfox): what if we were expecting an empty list but got {}? - Optional maybeEmptyInstance = emptyInstance.get(); - if (maybeEmptyInstance.isPresent()) { - return maybeEmptyInstance.get(); - } - throw new SafeRuntimeException( - "Unable to deserialize non-optional response type from 204", SafeArg.of("type", token)); - } - - Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); - if (!contentType.isPresent()) { - throw new SafeIllegalArgumentException( - "Response is missing Content-Type header", - SafeArg.of("received", response.headers().keySet())); - } - Encoding.Deserializer deserializer = getResponseDeserializer(contentType.get()); - T deserialized = deserializer.deserialize(response.body()); - // deserializer has taken on responsibility for closing the response body - closeResponse = false; - return deserialized; - } catch (IOException e) { - throw new SafeRuntimeException( - "Failed to deserialize response stream", - e, - SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)), - SafeArg.of("type", token)); - } finally { - if (closeResponse) { - response.close(); - } - } - } - - @Override - public Optional accepts() { - return acceptValue; - } - - /** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */ - @SuppressWarnings("ForLoopReplaceableByForEach") - // performance sensitive code avoids iterator allocation - Encoding.Deserializer getResponseDeserializer(String contentType) { - for (int i = 0; i < encodings.size(); i++) { - EncodingDeserializerContainer container = encodings.get(i); - if (container.encoding.supportsContentType(contentType)) { - return container.deserializer; - } - } - return throwingDeserializer(contentType); - } - - private Encoding.Deserializer throwingDeserializer(String contentType) { - return input -> { - try { - input.close(); - } catch (RuntimeException | IOException e) { - log.warn("Failed to close InputStream", e); - } - throw new SafeRuntimeException( - "Unsupported Content-Type", - SafeArg.of("received", contentType), - SafeArg.of("supportedEncodings", encodings)); - }; - } - } - private static final class EncodingDeserializerForEndpointRegistry implements Deserializer { - private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class); private final ImmutableList> encodings; private final EndpointErrorDecoder endpointErrorDecoder; @@ -353,8 +262,11 @@ private static final class EncodingDeserializerForEndpointRegistry implements .map(encoding -> new EncodingDeserializerContainer<>( encoding, deserializersForEndpoint.expectedResultType())) .collect(ImmutableList.toImmutableList()); - this.endpointErrorDecoder = - new EndpointErrorDecoder<>(deserializersForEndpoint.errorNameToTypeMarker(), encodings); + this.endpointErrorDecoder = new EndpointErrorDecoder<>( + deserializersForEndpoint.errorNameToTypeMarker(), + encodings.stream() + .filter(encoding -> encoding.supportsContentType("application/json")) + .findAny()); this.token = token; this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token)); // Encodings are applied to the accept header in the order of preference based on the provided list. @@ -367,7 +279,6 @@ public T deserialize(Response response) { boolean closeResponse = true; try { if (endpointErrorDecoder.isError(response)) { - // TODO(pm): This needs to return T for the new deserializer API, but throw an exception for the old return endpointErrorDecoder.decode(response); } else if (response.code() == 204) { Optional maybeEmptyInstance = emptyInstance.get(); @@ -439,7 +350,6 @@ public T deserialize(InputStream input) { } /** Effectively just a pair. */ - // TODO(pm): what does saving the deserializer do for us? static final class EncodingDeserializerContainer { private final Encoding encoding; @@ -457,10 +367,10 @@ public String toString() { } private static final class EmptyBodyDeserializer implements Deserializer { - private final ErrorDecoder errorDecoder; + private final EndpointErrorDecoder endpointErrorDecoder; - EmptyBodyDeserializer(ErrorDecoder errorDecoder) { - this.errorDecoder = errorDecoder; + EmptyBodyDeserializer(EndpointErrorDecoder endpointErrorDecoder) { + this.endpointErrorDecoder = endpointErrorDecoder; } @Override @@ -468,8 +378,8 @@ private static final class EmptyBodyDeserializer implements Deserializer { public Void deserialize(Response response) { // We should not fail if a server that previously returned nothing starts returning a response try (Response unused = response) { - if (errorDecoder.isError(response)) { - throw errorDecoder.decode(response); + if (endpointErrorDecoder.isError(response)) { + endpointErrorDecoder.decode(response); } return null; } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java index 3e4766fda..befd4c350 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java @@ -45,7 +45,6 @@ public final class DefaultConjureRuntime implements ConjureRuntime { private DefaultConjureRuntime(Builder builder) { this.bodySerDe = new ConjureBodySerDe( builder.encodings.isEmpty() ? DEFAULT_ENCODINGS : builder.encodings, - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DEFAULT_SERDE_CACHE_SPEC); } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java index f2ecfdfcc..c93776e6d 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java @@ -29,6 +29,7 @@ import com.palantir.conjure.java.api.errors.RemoteException; import com.palantir.conjure.java.api.errors.SerializableError; import com.palantir.conjure.java.api.errors.UnknownRemoteException; +import com.palantir.conjure.java.dialogue.serde.Encoding.Deserializer; import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.Response; import com.palantir.dialogue.TypeMarker; @@ -48,21 +49,33 @@ import java.net.URL; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; -// TODO(pm): public because maybe we need to expose this in the dialogue annotations. What does that do? -// T is the base type of the endpoint response. It's a union of the result type and all of the error types. -public final class EndpointErrorDecoder { +/** + * Extracts the error from a {@link Response}. + *

If the error's name is in the {@link #errorNameToJsonDeserializerMap}, this class attempts to deserialize the + * {@link Response} body as JSON, to the error type. Otherwise, a {@link RemoteException} is thrown. If the + * {@link Response} does not adhere to the expected format, an {@link UnknownRemoteException} is thrown. + * + * @param the base type of the endpoint response. It's a union of the result type and all the error types. + */ +final class EndpointErrorDecoder { private static final SafeLogger log = SafeLoggerFactory.get(EndpointErrorDecoder.class); private static final ObjectMapper MAPPER = ObjectMappers.newClientObjectMapper(); - private final Map> errorNameToTypeMap; - private final List encodings; - - public EndpointErrorDecoder(Map> errorNameToTypeMap, List encodings) { - this.errorNameToTypeMap = errorNameToTypeMap; - this.encodings = encodings; + private final Map> errorNameToJsonDeserializerMap; + + EndpointErrorDecoder( + Map> errorNameToTypeMap, Optional maybeJsonEncoding) { + this.errorNameToJsonDeserializerMap = maybeJsonEncoding + .>>map( + jsonEncoding -> errorNameToTypeMap.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, entry -> jsonEncoding.deserializer(entry.getValue())))) + .orElseGet(Collections::emptyMap); } public boolean isError(Response response) { @@ -76,15 +89,12 @@ public T decode(Response response) { try { return decodeInternal(response); } catch (Exception e) { - // TODO(pm): do we want to add the diagnostic information to the result type as well? e.addSuppressed(diagnostic(response)); throw e; } } - // performance sensitive code avoids iterator allocation - @SuppressWarnings({"checkstyle:CyclomaticComplexity", "ForLoopReplaceableByForEach"}) - private T decodeInternal(Response response) { + Optional checkCode(Response response) { int code = response.code(); switch (code) { case 308: @@ -95,7 +105,7 @@ private T decodeInternal(Response response) { UnknownRemoteException remoteException = new UnknownRemoteException(code, ""); remoteException.initCause( QosException.retryOther(qosReason(response), new URL(locationHeader))); - throw remoteException; + return Optional.of(remoteException); } catch (MalformedURLException e) { log.error( "Failed to parse location header for QosException.RetryOther", @@ -108,15 +118,23 @@ private T decodeInternal(Response response) { } break; case 429: - throw response.getFirstHeader(HttpHeaders.RETRY_AFTER) + return Optional.of(response.getFirstHeader(HttpHeaders.RETRY_AFTER) .map(Longs::tryParse) .map(Duration::ofSeconds) .map(duration -> QosException.throttle(qosReason(response), duration)) - .orElseGet(() -> QosException.throttle(qosReason(response))); + .orElseGet(() -> QosException.throttle(qosReason(response)))); case 503: - throw QosException.unavailable(qosReason(response)); + return Optional.of(QosException.unavailable(qosReason(response))); } + return Optional.empty(); + } + private T decodeInternal(Response response) { + Optional maybeQosException = checkCode(response); + if (maybeQosException.isPresent()) { + throw maybeQosException.get(); + } + int code = response.code(); String body; try { body = toString(response.body()); @@ -127,27 +145,25 @@ private T decodeInternal(Response response) { } Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); - // Use a factory: given contentType, create the deserailizer. - // We need Encoding.Deserializer here. That depends on the encoding. - if (contentType.isPresent() && Encodings.matchesContentType("application/json", contentType.get())) { + String jsonContentType = "application/json"; + if (contentType.isPresent() && Encodings.matchesContentType(jsonContentType, contentType.get())) { try { JsonNode node = MAPPER.readTree(body); - if (node.get("errorName") != null) { - // TODO(pm): Update this to use some struct instead of errorName. - TypeMarker container = Optional.ofNullable( - errorNameToTypeMap.get(node.get("errorName").asText())) - .orElseThrow(); - for (int i = 0; i < encodings.size(); i++) { - Encoding encoding = encodings.get(i); - if (encoding.supportsContentType(contentType.get())) { - return encoding.deserializer(container) - .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); - } - } - } else { - SerializableError serializableError = MAPPER.readValue(body, SerializableError.class); - throw new RemoteException(serializableError, code); + JsonNode errorNameNode = node.get("errorName"); + if (errorNameNode == null) { + throwRemoteException(body, code); } + Optional> maybeDeserializer = + Optional.ofNullable(errorNameToJsonDeserializerMap.get(errorNameNode.asText())); + if (maybeDeserializer.isEmpty()) { + throwRemoteException(body, code); + } + return maybeDeserializer + .get() + .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); + } catch (RemoteException remoteException) { + // rethrow the created remote exception + throw remoteException; } catch (Exception e) { throw new UnknownRemoteException(code, body); } @@ -156,17 +172,22 @@ private T decodeInternal(Response response) { throw new UnknownRemoteException(code, body); } - private static String toString(InputStream body) throws IOException { + private static void throwRemoteException(String body, int code) throws IOException { + SerializableError serializableError = MAPPER.readValue(body, SerializableError.class); + throw new RemoteException(serializableError, code); + } + + static String toString(InputStream body) throws IOException { try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) { return CharStreams.toString(reader); } } - private static ResponseDiagnostic diagnostic(Response response) { + static ResponseDiagnostic diagnostic(Response response) { return new ResponseDiagnostic(diagnosticArgs(response)); } - private static ImmutableList> diagnosticArgs(Response response) { + static ImmutableList> diagnosticArgs(Response response) { ImmutableList.Builder> args = ImmutableList.>builder().add(SafeArg.of("status", response.code())); recordHeader(HttpHeaders.SERVER, response, args); recordHeader(HttpHeaders.CONTENT_TYPE, response, args); diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java index 50642aea0..127f9d0ad 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java @@ -17,35 +17,16 @@ package com.palantir.conjure.java.dialogue.serde; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableList; -import com.google.common.io.CharStreams; import com.google.common.net.HttpHeaders; -import com.google.common.primitives.Longs; -import com.palantir.conjure.java.api.errors.QosException; -import com.palantir.conjure.java.api.errors.QosReason; -import com.palantir.conjure.java.api.errors.QosReasons; -import com.palantir.conjure.java.api.errors.QosReasons.QosResponseDecodingAdapter; import com.palantir.conjure.java.api.errors.RemoteException; import com.palantir.conjure.java.api.errors.SerializableError; import com.palantir.conjure.java.api.errors.UnknownRemoteException; import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.Response; -import com.palantir.logsafe.Arg; -import com.palantir.logsafe.SafeArg; -import com.palantir.logsafe.SafeLoggable; -import com.palantir.logsafe.UnsafeArg; -import com.palantir.logsafe.exceptions.SafeExceptions; import com.palantir.logsafe.logger.SafeLogger; import com.palantir.logsafe.logger.SafeLoggerFactory; import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.Reader; -import java.net.MalformedURLException; -import java.net.URL; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.List; +import java.util.Collections; import java.util.Optional; /** @@ -59,17 +40,19 @@ public enum ErrorDecoder { private static final SafeLogger log = SafeLoggerFactory.get(ErrorDecoder.class); private static final ObjectMapper MAPPER = ObjectMappers.newClientObjectMapper(); + private static final EndpointErrorDecoder ENDPOINT_ERROR_DECODER = + new EndpointErrorDecoder<>(Collections.emptyMap(), Optional.empty()); public boolean isError(Response response) { - return 300 <= response.code() && response.code() <= 599; + return ENDPOINT_ERROR_DECODER.isError(response); } public RuntimeException decode(Response response) { if (log.isDebugEnabled()) { - log.debug("Received an error response", diagnosticArgs(response)); + log.debug("Received an error response", EndpointErrorDecoder.diagnosticArgs(response)); } RuntimeException result = decodeInternal(response); - result.addSuppressed(diagnostic(response)); + result.addSuppressed(EndpointErrorDecoder.diagnostic(response)); return result; } @@ -77,41 +60,15 @@ private RuntimeException decodeInternal(Response response) { // TODO(rfink): What about HTTP/101 switching protocols? // TODO(rfink): What about HEAD requests? - int code = response.code(); - switch (code) { - case 308: - Optional location = response.getFirstHeader(HttpHeaders.LOCATION); - if (location.isPresent()) { - String locationHeader = location.get(); - try { - UnknownRemoteException remoteException = new UnknownRemoteException(code, ""); - remoteException.initCause( - QosException.retryOther(qosReason(response), new URL(locationHeader))); - return remoteException; - } catch (MalformedURLException e) { - log.error( - "Failed to parse location header for QosException.RetryOther", - UnsafeArg.of("locationHeader", locationHeader), - e); - } - } else { - log.error("Retrieved HTTP status code 308 without Location header, cannot perform " - + "redirect. This appears to be a server-side protocol violation."); - } - break; - case 429: - return response.getFirstHeader(HttpHeaders.RETRY_AFTER) - .map(Longs::tryParse) - .map(Duration::ofSeconds) - .map(duration -> QosException.throttle(qosReason(response), duration)) - .orElseGet(() -> QosException.throttle(qosReason(response))); - case 503: - return QosException.unavailable(qosReason(response)); + Optional maybeQosException = ENDPOINT_ERROR_DECODER.checkCode(response); + if (maybeQosException.isPresent()) { + return maybeQosException.get(); } + int code = response.code(); String body; try { - body = toString(response.body()); + body = EndpointErrorDecoder.toString(response.body()); } catch (NullPointerException | IOException e) { UnknownRemoteException exception = new UnknownRemoteException(code, ""); exception.initCause(e); @@ -130,75 +87,4 @@ private RuntimeException decodeInternal(Response response) { return new UnknownRemoteException(code, body); } - - private static String toString(InputStream body) throws IOException { - try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) { - return CharStreams.toString(reader); - } - } - - private static ResponseDiagnostic diagnostic(Response response) { - return new ResponseDiagnostic(diagnosticArgs(response)); - } - - private static ImmutableList> diagnosticArgs(Response response) { - ImmutableList.Builder> args = ImmutableList.>builder().add(SafeArg.of("status", response.code())); - recordHeader(HttpHeaders.SERVER, response, args); - recordHeader(HttpHeaders.CONTENT_TYPE, response, args); - recordHeader(HttpHeaders.CONTENT_LENGTH, response, args); - recordHeader(HttpHeaders.CONNECTION, response, args); - recordHeader(HttpHeaders.DATE, response, args); - recordHeader("x-envoy-response-flags", response, args); - recordHeader("x-envoy-response-code-details", response, args); - recordHeader("Response-Flags", response, args); - recordHeader("Response-Code-Details", response, args); - return args.build(); - } - - private static void recordHeader(String header, Response response, ImmutableList.Builder> args) { - response.getFirstHeader(header).ifPresent(server -> args.add(SafeArg.of(header, server))); - } - - private static final class ResponseDiagnostic extends RuntimeException implements SafeLoggable { - - private static final String SAFE_MESSAGE = "Response Diagnostic Information"; - - private final ImmutableList> args; - - ResponseDiagnostic(ImmutableList> args) { - super(SafeExceptions.renderMessage(SAFE_MESSAGE, args.toArray(new Arg[0]))); - this.args = args; - } - - @Override - public String getLogMessage() { - return SAFE_MESSAGE; - } - - @Override - public List> getArgs() { - return args; - } - - @Override - @SuppressWarnings("UnsynchronizedOverridesSynchronized") // nop - public Throwable fillInStackTrace() { - // no-op: stack trace generation is expensive, this type exists - // to simply associate diagnostic information with a failure. - return this; - } - } - - private static QosReason qosReason(Response response) { - return QosReasons.parseFromResponse(response, DialogueQosResponseDecodingAdapter.INSTANCE); - } - - private enum DialogueQosResponseDecodingAdapter implements QosResponseDecodingAdapter { - INSTANCE; - - @Override - public Optional getFirstHeader(Response response, String headerName) { - return response.getFirstHeader(headerName); - } - } } diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java index 0d13b0b4a..fe8a9bcd2 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java @@ -34,7 +34,6 @@ public void testBinary() throws IOException { TestResponse response = new TestResponse().code(200).contentType("application/octet-stream"); BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(new ConjureBodySerDeTest.StubEncoding("application/json"))), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); InputStream deserialized = serializers.inputStreamDeserializer().deserialize(response); @@ -58,7 +57,6 @@ public void testBinary_optional_present() throws IOException { TestResponse response = new TestResponse().code(200).contentType("application/octet-stream"); BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(new ConjureBodySerDeTest.StubEncoding("application/json"))), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); Optional maybe = diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java index 684c5acdd..03879e807 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java @@ -52,8 +52,6 @@ public class ConjureBodySerDeTest { private static final TypeMarker TYPE = new TypeMarker() {}; private static final TypeMarker> OPTIONAL_TYPE = new TypeMarker>() {}; - private ErrorDecoder errorDecoder = ErrorDecoder.INSTANCE; - @Test public void testRequestContentType() throws IOException { @@ -76,7 +74,6 @@ private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { Arrays.stream(contentTypes) .map(c -> WeightedEncoding.of(new StubEncoding(c))) .collect(ImmutableList.toImmutableList()), - errorDecoder, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); } @@ -115,7 +112,6 @@ public void testAcceptBasedOnWeight() throws IOException { BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(plain, .5), WeightedEncoding.of(json, 1)), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); // first encoding is default @@ -174,7 +170,6 @@ public void if_deserialize_throws_response_is_still_closed() { TestResponse response = new TestResponse().code(200).contentType("application/json"); BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(BrokenEncoding.INSTANCE)), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); assertThatThrownBy(() -> serializers.deserializer(TYPE).deserialize(response)) diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java index a8ac9ab0d..58f1e4631 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java @@ -80,7 +80,6 @@ public final class DefaultClientsTest { private Response response = new TestResponse(); private BodySerDe bodySerde = new ConjureBodySerDe( DefaultConjureRuntime.DEFAULT_ENCODINGS, - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); private final SettableFuture responseFuture = SettableFuture.create(); diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java index 69bcb6949..6f45ffbc4 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java @@ -20,6 +20,7 @@ import com.palantir.conjure.java.api.errors.CheckedServiceException; import com.palantir.dialogue.TypeMarker; import com.palantir.logsafe.Arg; +import com.palantir.logsafe.Safe; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -30,6 +31,26 @@ final class EndpointErrorTestUtils { private EndpointErrorTestUtils() {} + abstract static class EndpointError { + @Safe + String errorCode; + + @Safe + String errorName; + + @Safe + String errorInstanceId; + + T args; + + EndpointError(String errorCode, String errorName, String errorInstanceId, T args) { + this.errorCode = errorCode; + this.errorName = errorName; + this.errorInstanceId = errorInstanceId; + this.args = args; + } + } + record ConjureError( @JsonProperty("errorCode") String errorCode, @JsonProperty("errorName") String errorName, diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java index 296bb193d..a08c84806 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java @@ -17,6 +17,7 @@ package com.palantir.conjure.java.dialogue.serde; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -24,7 +25,10 @@ import com.google.common.collect.ImmutableList; import com.palantir.conjure.java.api.errors.CheckedServiceException; import com.palantir.conjure.java.api.errors.ErrorType; +import com.palantir.conjure.java.api.errors.RemoteException; +import com.palantir.conjure.java.api.errors.SerializableError; import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.ConjureError; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.EndpointError; import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.TypeReturningStubEncoding; import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.BodySerDe; @@ -48,48 +52,31 @@ @ExtendWith(MockitoExtension.class) public class EndpointErrorsConjureBodySerDeTest { private static final ObjectMapper MAPPER = ObjectMappers.newServerObjectMapper(); - private ErrorDecoder errorDecoder = ErrorDecoder.INSTANCE; @Generated("by conjure-java") - private sealed interface EndpointReturnBaseType permits StringReturn, ErrorForEndpoint {} + private sealed interface EndpointReturnBaseType permits ExpectedReturnValue, ErrorReturnValue {} @Generated("by conjure-java") - record StringReturn(String value) implements EndpointReturnBaseType { + record ExpectedReturnValue(String value) implements EndpointReturnBaseType { @JsonCreator - public static StringReturn create(String value) { - return new StringReturn(Preconditions.checkArgumentNotNull(value, "value cannot be null")); + public static ExpectedReturnValue create(String value) { + return new ExpectedReturnValue(Preconditions.checkArgumentNotNull(value, "value cannot be null")); } } - abstract static class EndpointError { - @Safe - String errorCode; - - @Safe - String errorName; - - @Safe - String errorInstanceId; - - T args; - - EndpointError(String errorCode, String errorName, String errorInstanceId, T args) { - this.errorCode = errorCode; - this.errorName = errorName; - this.errorInstanceId = errorInstanceId; - this.args = args; - } - } + @Generated("by conjure-java") + record ComplexArg(int foo, String bar) {} + @Generated("by conjure-java") record ErrorForEndpointArgs( @JsonProperty("arg") @Safe String arg, @JsonProperty("unsafeArg") @Unsafe String unsafeArg, @JsonProperty("complexArg") @Safe ComplexArg complexArg, @JsonProperty("optionalArg") @Safe Optional optionalArg) {} - static final class ErrorForEndpoint extends EndpointError implements EndpointReturnBaseType { + static final class ErrorReturnValue extends EndpointError implements EndpointReturnBaseType { @JsonCreator - ErrorForEndpoint( + ErrorReturnValue( @JsonProperty("errorCode") String errorCode, @JsonProperty("errorName") String errorName, @JsonProperty("errorInstanceId") String errorInstanceId, @@ -98,9 +85,6 @@ static final class ErrorForEndpoint extends EndpointError } } - @Generated("by conjure-java") - record ComplexArg(int foo, String bar) {} - @Generated("by conjure-java") public static final class TestEndpointError extends CheckedServiceException { private TestEndpointError( @@ -120,30 +104,34 @@ private TestEndpointError( } @Test - public void testDeserializeCustomErrors() throws IOException { + public void testDeserializeCustomError() throws IOException { + // Given TestEndpointError errorThrownByEndpoint = new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); - - ErrorForEndpoint expectedErrorForEndpoint = new ErrorForEndpoint( - "FAILED_PRECONDITION", - "Default:FailedPrecondition", - errorThrownByEndpoint.getErrorInstanceId(), - new ErrorForEndpointArgs("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2))); - String responseBody = MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + TestResponse response = TestResponse.withBody(responseBody) .contentType("application/json") .code(500); BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); DeserializerArgs deserializerArgs = DeserializerArgs.builder() .withBaseType(new TypeMarker<>() {}) - .withExpectedResult(new TypeMarker() {}) - .withErrorType("Default:FailedPrecondition", new TypeMarker() {}) + .withExpectedResult(new TypeMarker() {}) + .withErrorType("Default:FailedPrecondition", new TypeMarker() {}) .build(); + + // When EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType value = serializers.deserializer(deserializerArgs).deserialize(response); + // Then + ErrorReturnValue expectedErrorForEndpoint = new ErrorReturnValue( + ErrorType.FAILED_PRECONDITION.code().name(), + ErrorType.FAILED_PRECONDITION.name(), + errorThrownByEndpoint.getErrorInstanceId(), + new ErrorForEndpointArgs("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2))); + assertThat(value).isInstanceOf(ErrorReturnValue.class); assertThat(value) .extracting("errorCode", "errorName", "errorInstanceId", "args") .containsExactly( @@ -153,8 +141,49 @@ public void testDeserializeCustomErrors() throws IOException { expectedErrorForEndpoint.args); } + // When an error is deserialized, but the error type is not registered, the error should be deserialized as a + // SerializableError and a RemoteException should be thrown. + @Test + public void testDeserializingUndefinedErrorFallsbackToSerializableError() throws IOException { + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .withBaseType(new TypeMarker<>() {}) + .withExpectedResult(new TypeMarker() {}) + // Note: no error types are registered. + .build(); + + // Then + assertThatExceptionOfType(RemoteException.class) + .isThrownBy(() -> { + serializers.deserializer(deserializerArgs).deserialize(response); + }) + .satisfies(exception -> { + SerializableError error = exception.getError(); + assertThat(error.errorCode()) + .isEqualTo(ErrorType.FAILED_PRECONDITION.code().name()); + assertThat(error.errorInstanceId()).isEqualTo(errorThrownByEndpoint.getErrorInstanceId()); + assertThat(error.errorName()).isEqualTo(ErrorType.FAILED_PRECONDITION.name()); + assertThat(error.parameters()) + .extracting("arg", "unsafeArg", "complexArg", "optionalArg") + .containsExactly( + "value", + "unsafeValue", + MAPPER.writeValueAsString(new ComplexArg(1, "bar")), + MAPPER.writeValueAsString(Optional.of(2))); + }); + } + @Test public void testDeserializeExpectedValue() { + // Given String expectedString = "expectedString"; TestResponse response = TestResponse.withBody(String.format("\"%s\"", expectedString)) .contentType("application/json") @@ -162,12 +191,14 @@ public void testDeserializeExpectedValue() { BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); DeserializerArgs deserializerArgs = DeserializerArgs.builder() .withBaseType(new TypeMarker<>() {}) - .withExpectedResult(new TypeMarker() {}) - .withErrorType("Default:FailedPrecondition", new TypeMarker() {}) + .withExpectedResult(new TypeMarker() {}) + .withErrorType("Default:FailedPrecondition", new TypeMarker() {}) .build(); + // When EndpointReturnBaseType value = serializers.deserializer(deserializerArgs).deserialize(response); - assertThat(value).isEqualTo(new StringReturn(expectedString)); + // Then + assertThat(value).isEqualTo(new ExpectedReturnValue(expectedString)); } private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { @@ -175,7 +206,6 @@ private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { Arrays.stream(contentTypes) .map(c -> WeightedEncoding.of(new TypeReturningStubEncoding(c))) .collect(ImmutableList.toImmutableList()), - errorDecoder, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); } diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java index 510914481..b809ac3a6 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java @@ -17,6 +17,7 @@ package com.palantir.conjure.java.dialogue.serde; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.fail; import com.fasterxml.jackson.core.JsonProcessingException; @@ -38,8 +39,13 @@ import com.palantir.logsafe.Preconditions; import com.palantir.logsafe.SafeArg; import java.time.Duration; +import java.util.Collections; +import java.util.Optional; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -63,16 +69,17 @@ private static String createServiceException(ServiceException exception) { } private static final ErrorDecoder decoder = ErrorDecoder.INSTANCE; + private static final EndpointErrorDecoder endpointErrorDecoder = + new EndpointErrorDecoder<>(Collections.emptyMap(), Optional.empty()); - @Test - public void extractsRemoteExceptionForAllErrorCodes() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void extractsRemoteExceptionForAllErrorCodes(boolean isLegacyErrorDecoder) { for (int code : ImmutableList.of(300, 400, 404, 500)) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(code).contentType("application/json"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(RemoteException.class, exception -> { + Consumer validationFunction = exception -> { assertThat(exception.getCause()).isNull(); assertThat(exception.getStatus()).isEqualTo(code); assertThat(exception.getError().errorCode()) @@ -91,117 +98,224 @@ public void extractsRemoteExceptionForAllErrorCodes() { + " (" + ErrorType.FAILED_PRECONDITION.name() + ")"); - }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RemoteException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } } - @Test - public void testQos503() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos503(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(503); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Unavailable.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Unavailable.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos503WithMetadata() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos503WithMetadata(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION) .code(503) .withHeader("Qos-Retry-Hint", "do-not-retry") .withHeader("Qos-Due-To", "custom"); - assertThat(decoder.isError(response)).isTrue(); - - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Unavailable.class, exception -> { - assertThat(exception.getReason()) - .isEqualTo(QosReason.builder() - .from(QOS_REASON) - .dueTo(DueTo.CUSTOM) - .retryHint(RetryHint.DO_NOT_RETRY) - .build()); - }); + + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Unavailable.class, qosException -> { + assertThat(qosException.getReason()) + .isEqualTo(QosReason.builder() + .from(QOS_REASON) + .dueTo(DueTo.CUSTOM) + .retryHint(RetryHint.DO_NOT_RETRY) + .build()); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos429() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos429(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(429); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Throttle.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRetryAfter()).isEmpty(); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Throttle.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRetryAfter()).isEmpty(); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos429_retryAfter() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos429_retryAfter(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(429).withHeader(HttpHeaders.RETRY_AFTER, "3"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Throttle.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRetryAfter()).hasValue(Duration.ofSeconds(3)); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Throttle.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRetryAfter()).hasValue(Duration.ofSeconds(3)); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos429_retryAfter_invalid() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos429_retryAfter_invalid(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(429).withHeader(HttpHeaders.RETRY_AFTER, "bad"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Throttle.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRetryAfter()).isEmpty(); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Throttle.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRetryAfter()).isEmpty(); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos308_noLocation() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos308_noLocation(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(308); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result) - .isInstanceOfSatisfying(UnknownRemoteException.class, exception -> assertThat(exception.getStatus()) - .isEqualTo(308)); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(UnknownRemoteException.class, unknownException -> { + assertThat(unknownException.getStatus()).isEqualTo(308); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos308_invalidLocation() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos308_invalidLocation(boolean isLegacyErrorDecoder) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(308).withHeader(HttpHeaders.LOCATION, "invalid"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result) - .isInstanceOfSatisfying(UnknownRemoteException.class, exception -> assertThat(exception.getStatus()) - .isEqualTo(308)); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(UnknownRemoteException.class, unknownException -> { + assertThat(unknownException.getStatus()).isEqualTo(308); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos308() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testQos308(boolean isLegacyErrorDecoder) { String expectedLocation = "https://localhost"; Response response = TestResponse.withBody(SERIALIZED_EXCEPTION) .code(308) .withHeader(HttpHeaders.LOCATION, expectedLocation); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result) - .isInstanceOf(UnknownRemoteException.class) - .getRootCause() - .isInstanceOfSatisfying(QosException.RetryOther.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRedirectTo()).asString().isEqualTo(expectedLocation); - }); + Consumer validationFunction = exception -> { + assertThat(exception) + .isInstanceOf(UnknownRemoteException.class) + .getRootCause() + .isInstanceOfSatisfying(QosException.RetryOther.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRedirectTo()).asString().isEqualTo(expectedLocation); + }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } @Test @@ -217,16 +331,31 @@ public void cannotDecodeNonJsonMediaTypes() { TestResponse.withBody(SERIALIZED_EXCEPTION).code(500).contentType("text/plain"))) .isInstanceOf(UnknownRemoteException.class) .hasMessage("Response status: 500"); + + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode( + TestResponse.withBody(SERIALIZED_EXCEPTION).code(500).contentType("text/plain"))) + .satisfies(exception -> assertThat(exception.getMessage()).isEqualTo("Response status: 500")); } - @Test - public void doesNotHandleUnparseableBody() { - assertThat(decoder.decode(TestResponse.withBody("not json").code(500).contentType("application/json/"))) - .isInstanceOfSatisfying(UnknownRemoteException.class, expected -> { - assertThat(expected.getStatus()).isEqualTo(500); - assertThat(expected.getBody()).isEqualTo("not json"); - assertThat(expected.getMessage()).isEqualTo("Response status: 500"); - }); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void doesNotHandleUnparseableBody(boolean isLegacyErrorDecoder) { + Response response = TestResponse.withBody("not json").code(500).contentType("application/json/"); + + Consumer validationFunction = exception -> { + assertThat(exception.getStatus()).isEqualTo(500); + assertThat(exception.getBody()).isEqualTo("not json"); + }; + + if (isLegacyErrorDecoder) { + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(UnknownRemoteException.class, validationFunction); + } else { + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } @Test @@ -235,32 +364,57 @@ public void doesNotHandleNullBody() { assertThat(decoder.decode(TestResponse.withBody(null).code(500).contentType("application/json"))) .isInstanceOf(UnknownRemoteException.class) .hasMessage("Response status: 500"); + + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode( + TestResponse.withBody(null).code(500).contentType("application/json"))) + .satisfies(exception -> assertThat(exception.getMessage()).isEqualTo("Response status: 500")); } - @Test - public void handlesUnexpectedJson() { - assertThat(decoder.decode(TestResponse.withBody("{\"error\":\"some-unknown-json\"}") - .code(502) - .contentType("application/json"))) - .isInstanceOfSatisfying(UnknownRemoteException.class, expected -> { - assertThat(expected.getStatus()).isEqualTo(502); - assertThat(expected.getBody()).isEqualTo("{\"error\":\"some-unknown-json\"}"); - assertThat(expected.getMessage()).isEqualTo("Response status: 502"); - }); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void handlesUnexpectedJson(boolean isLegacyErrorDecoder) { + Response response = TestResponse.withBody("{\"error\":\"some-unknown-json\"}") + .code(502) + .contentType("application/json"); + + Consumer validationFunction = expected -> { + assertThat(expected.getStatus()).isEqualTo(502); + assertThat(expected.getBody()).isEqualTo("{\"error\":\"some-unknown-json\"}"); + assertThat(expected.getMessage()).isEqualTo("Response status: 502"); + }; + if (isLegacyErrorDecoder) { + assertThat(decoder.decode(response)) + .isInstanceOfSatisfying(UnknownRemoteException.class, validationFunction); + } else { + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void handlesJsonWithEncoding() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void handlesJsonWithEncoding(boolean isLegacyErrorDecoder) { int code = 500; - RuntimeException result = decoder.decode( - TestResponse.withBody(SERIALIZED_EXCEPTION).code(code).contentType("application/json; charset=utf-8")); - assertThat(result).isInstanceOfSatisfying(RemoteException.class, exception -> { + Response response = + TestResponse.withBody(SERIALIZED_EXCEPTION).code(code).contentType("application/json; charset=utf-8"); + + Consumer validationFunction = exception -> { assertThat(exception.getCause()).isNull(); assertThat(exception.getStatus()).isEqualTo(code); assertThat(exception.getError().errorCode()) .isEqualTo(ErrorType.FAILED_PRECONDITION.code().name()); assertThat(exception.getError().errorName()).isEqualTo(ErrorType.FAILED_PRECONDITION.name()); - }); + }; + + if (isLegacyErrorDecoder) { + assertThat(decoder.decode(response)).isInstanceOfSatisfying(RemoteException.class, validationFunction); + } else { + assertThatExceptionOfType(RemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } private static RemoteException encodeAndDecode(Exception exception) {