diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java index b738b8529..853bd0fad 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java @@ -26,6 +26,7 @@ import com.palantir.dialogue.BinaryRequestBody; import com.palantir.dialogue.BodySerDe; import com.palantir.dialogue.Deserializer; +import com.palantir.dialogue.DeserializerArgs; import com.palantir.dialogue.RequestBody; import com.palantir.dialogue.Response; import com.palantir.dialogue.Serializer; @@ -43,11 +44,19 @@ import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.Supplier; import java.util.stream.Collectors; +/** + * items: + * - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class. + * - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info + */ + /** Package private internal API. */ final class ConjureBodySerDe implements BodySerDe { @@ -58,7 +67,10 @@ final class ConjureBodySerDe implements BodySerDe { private final Deserializer> optionalBinaryInputStreamDeserializer; private final Deserializer emptyBodyDeserializer; private final LoadingCache> serializers; - private final LoadingCache> deserializers; + private final LoadingCache> deserializers; + private final Map> baseTypeToDeserializerArgs; + private final LoadingCache> endpointWithErrorsDeserializers; + private final EmptyContainerDeserializer emptyContainerDeserializer; /** * Selects the first (based on input order) of the provided encodings that @@ -74,6 +86,7 @@ final class ConjureBodySerDe implements BodySerDe { this.encodingsSortedByWeight = sortByWeight(encodings); Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required"); this.defaultEncoding = encodings.get(0).encoding(); + this.emptyContainerDeserializer = emptyContainerDeserializer; this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( ImmutableList.of(BinaryEncoding.INSTANCE), errorDecoder, @@ -92,6 +105,8 @@ final class ConjureBodySerDe implements BodySerDe { this.deserializers = Caffeine.from(cacheSpec) .build(type -> new EncodingDeserializerRegistry<>( encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type))); + this.baseTypeToDeserializerArgs = new HashMap<>(); + this.endpointWithErrorsDeserializers = Caffeine.from(cacheSpec).build(this::buildCacheEntry); } private static List decorateEncodings(List input) { @@ -122,6 +137,24 @@ public Deserializer deserializer(TypeMarker token) { return (Deserializer) deserializers.get(token.getType()); } + @Override + @SuppressWarnings("unchecked") + public Deserializer deserializer(DeserializerArgs deserializerArgs) { + Type baseType = deserializerArgs.baseType().getType(); + this.baseTypeToDeserializerArgs.put(baseType, deserializerArgs); + return (Deserializer) endpointWithErrorsDeserializers.get(baseType); + } + + @SuppressWarnings("unchecked") + private EncodingDeserializerForEndpointRegistry buildCacheEntry(Type baseType) { + return new EncodingDeserializerForEndpointRegistry<>( + encodingsSortedByWeight, + emptyContainerDeserializer, + (TypeMarker) TypeMarker.of(baseType), + (DeserializerArgs) Optional.ofNullable(baseTypeToDeserializerArgs.get(baseType)) + .orElseThrow()); + } + @Override public Deserializer emptyBodyDeserializer() { return emptyBodyDeserializer; @@ -301,6 +334,105 @@ Encoding.Deserializer getResponseDeserializer(String contentType) { return throwingDeserializer(contentType); } + private Encoding.Deserializer throwingDeserializer(String contentType) { + return input -> { + try { + input.close(); + } catch (RuntimeException | IOException e) { + log.warn("Failed to close InputStream", e); + } + throw new SafeRuntimeException( + "Unsupported Content-Type", + SafeArg.of("received", contentType), + SafeArg.of("supportedEncodings", encodings)); + }; + } + } + + private static final class EncodingDeserializerForEndpointRegistry implements Deserializer { + + private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class); + private final ImmutableList> encodings; + private final EndpointErrorDecoder endpointErrorDecoder; + private final Optional acceptValue; + private final Supplier> emptyInstance; + private final TypeMarker token; + + EncodingDeserializerForEndpointRegistry( + List encodings, + EmptyContainerDeserializer empty, + TypeMarker token, + DeserializerArgs deserializersForEndpoint) { + this.encodings = encodings.stream() + .map(encoding -> new EncodingDeserializerContainer<>( + encoding, deserializersForEndpoint.expectedResultType())) + .collect(ImmutableList.toImmutableList()); + this.endpointErrorDecoder = + new EndpointErrorDecoder<>(deserializersForEndpoint.errorNameToTypeMarker(), encodings); + this.token = token; + this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token)); + // Encodings are applied to the accept header in the order of preference based on the provided list. + this.acceptValue = + Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", "))); + } + + @Override + public T deserialize(Response response) { + boolean closeResponse = true; + try { + if (endpointErrorDecoder.isError(response)) { + return endpointErrorDecoder.decode(response); + } else if (response.code() == 204) { + Optional maybeEmptyInstance = emptyInstance.get(); + if (maybeEmptyInstance.isPresent()) { + return maybeEmptyInstance.get(); + } + throw new SafeRuntimeException( + "Unable to deserialize non-optional response type from 204", SafeArg.of("type", token)); + } + + Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); + if (!contentType.isPresent()) { + throw new SafeIllegalArgumentException( + "Response is missing Content-Type header", + SafeArg.of("received", response.headers().keySet())); + } + Encoding.Deserializer deserializer = getResponseDeserializer(contentType.get()); + T deserialized = deserializer.deserialize(response.body()); + // deserializer has taken on responsibility for closing the response body + closeResponse = false; + return deserialized; + } catch (IOException e) { + throw new SafeRuntimeException( + "Failed to deserialize response stream", + e, + SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)), + SafeArg.of("type", token)); + } finally { + if (closeResponse) { + response.close(); + } + } + } + + @Override + public Optional accepts() { + return acceptValue; + } + + /** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */ + @SuppressWarnings("ForLoopReplaceableByForEach") + // performance sensitive code avoids iterator allocation + Encoding.Deserializer getResponseDeserializer(String contentType) { + for (int i = 0; i < encodings.size(); i++) { + EncodingDeserializerContainer container = encodings.get(i); + if (container.encoding.supportsContentType(contentType)) { + return container.deserializer; + } + } + return throwingDeserializer(contentType); + } + private Encoding.Deserializer throwingDeserializer(String contentType) { return new Encoding.Deserializer() { @Override @@ -320,7 +452,8 @@ public T deserialize(InputStream input) { } /** Effectively just a pair. */ - private static final class EncodingDeserializerContainer { + // TODO(pm): saving the deserializer actually isn't doing much for us. + static final class EncodingDeserializerContainer { private final Encoding encoding; private final Encoding.Deserializer deserializer; @@ -330,6 +463,10 @@ private static final class EncodingDeserializerContainer { this.deserializer = encoding.deserializer(token); } + public Encoding.Deserializer getDeserializer() { + return deserializer; + } + @Override public String toString() { return "EncodingDeserializerContainer{encoding=" + encoding + ", deserializer=" + deserializer + '}'; diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java new file mode 100644 index 000000000..34b1261cf --- /dev/null +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java @@ -0,0 +1,229 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.io.CharStreams; +import com.google.common.net.HttpHeaders; +import com.google.common.primitives.Longs; +import com.palantir.conjure.java.api.errors.QosException; +import com.palantir.conjure.java.api.errors.QosReason; +import com.palantir.conjure.java.api.errors.QosReasons; +import com.palantir.conjure.java.api.errors.QosReasons.QosResponseDecodingAdapter; +import com.palantir.conjure.java.api.errors.RemoteException; +import com.palantir.conjure.java.api.errors.SerializableError; +import com.palantir.conjure.java.api.errors.UnknownRemoteException; +import com.palantir.conjure.java.serialization.ObjectMappers; +import com.palantir.dialogue.Response; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Arg; +import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.SafeLoggable; +import com.palantir.logsafe.UnsafeArg; +import com.palantir.logsafe.exceptions.SafeExceptions; +import com.palantir.logsafe.logger.SafeLogger; +import com.palantir.logsafe.logger.SafeLoggerFactory; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +// TODO(pm): public because maybe we need to expose this in the dialogue annotations. What does that do? +// T is the base type of the endpoint response. It's a union of the result type and all of the error types. +public final class EndpointErrorDecoder { + private static final SafeLogger log = SafeLoggerFactory.get(EndpointErrorDecoder.class); + private static final ObjectMapper MAPPER = ObjectMappers.newClientObjectMapper(); + private final Map> errorNameToTypeMap; + private final List encodings; + + public EndpointErrorDecoder(Map> errorNameToTypeMap, List encodings) { + this.errorNameToTypeMap = errorNameToTypeMap; + this.encodings = encodings; + } + + public boolean isError(Response response) { + return 300 <= response.code() && response.code() <= 599; + } + + public T decode(Response response) { + if (log.isDebugEnabled()) { + log.debug("Received an error response", diagnosticArgs(response)); + } + try { + return decodeInternal(response); + } catch (Exception e) { + // TODO(pm): do we want to add the diagnostic information to the result type as well? + e.addSuppressed(diagnostic(response)); + throw e; + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private T decodeInternal(Response response) { + int code = response.code(); + switch (code) { + case 308: + Optional location = response.getFirstHeader(HttpHeaders.LOCATION); + if (location.isPresent()) { + String locationHeader = location.get(); + try { + UnknownRemoteException remoteException = new UnknownRemoteException(code, ""); + remoteException.initCause( + QosException.retryOther(qosReason(response), new URL(locationHeader))); + throw remoteException; + } catch (MalformedURLException e) { + log.error( + "Failed to parse location header for QosException.RetryOther", + UnsafeArg.of("locationHeader", locationHeader), + e); + } + } else { + log.error("Retrieved HTTP status code 308 without Location header, cannot perform " + + "redirect. This appears to be a server-side protocol violation."); + } + break; + case 429: + throw response.getFirstHeader(HttpHeaders.RETRY_AFTER) + .map(Longs::tryParse) + .map(Duration::ofSeconds) + .map(duration -> QosException.throttle(qosReason(response), duration)) + .orElseGet(() -> QosException.throttle(qosReason(response))); + case 503: + throw QosException.unavailable(qosReason(response)); + } + + String body; + try { + body = toString(response.body()); + } catch (NullPointerException | IOException e) { + UnknownRemoteException exception = new UnknownRemoteException(code, ""); + exception.initCause(e); + throw exception; + } + + Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); + // Use a factory: given contentType, create the deserailizer. + // We need Encoding.Deserializer here. That depends on the encoding. + if (contentType.isPresent() && Encodings.matchesContentType("application/json", contentType.get())) { + try { + // TODO(pm): figure out if we can avoid double parsing. + JsonNode node = MAPPER.readTree(body); + if (node.get("errorName") != null) { + // TODO(pm): Update this to use some struct instead of errorName. + TypeMarker container = Optional.ofNullable( + errorNameToTypeMap.get(node.get("errorName").asText())) + .orElseThrow(); + // make this a normal for + for (Encoding encoding : encodings) { + if (encoding.supportsContentType(contentType.get())) { + return encoding.deserializer(container) + .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); + } + } + } else { + SerializableError serializableError = MAPPER.readValue(body, SerializableError.class); + throw new RemoteException(serializableError, code); + } + } catch (Exception e) { + throw new UnknownRemoteException(code, body); + } + } + + throw new UnknownRemoteException(code, body); + } + + private static String toString(InputStream body) throws IOException { + try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) { + return CharStreams.toString(reader); + } + } + + private static ResponseDiagnostic diagnostic(Response response) { + return new ResponseDiagnostic(diagnosticArgs(response)); + } + + private static ImmutableList> diagnosticArgs(Response response) { + ImmutableList.Builder> args = ImmutableList.>builder().add(SafeArg.of("status", response.code())); + recordHeader(HttpHeaders.SERVER, response, args); + recordHeader(HttpHeaders.CONTENT_TYPE, response, args); + recordHeader(HttpHeaders.CONTENT_LENGTH, response, args); + recordHeader(HttpHeaders.CONNECTION, response, args); + recordHeader(HttpHeaders.DATE, response, args); + recordHeader("x-envoy-response-flags", response, args); + recordHeader("x-envoy-response-code-details", response, args); + recordHeader("Response-Flags", response, args); + recordHeader("Response-Code-Details", response, args); + return args.build(); + } + + private static void recordHeader(String header, Response response, ImmutableList.Builder> args) { + response.getFirstHeader(header).ifPresent(server -> args.add(SafeArg.of(header, server))); + } + + private static final class ResponseDiagnostic extends RuntimeException implements SafeLoggable { + + private static final String SAFE_MESSAGE = "Response Diagnostic Information"; + + private final ImmutableList> args; + + ResponseDiagnostic(ImmutableList> args) { + super(SafeExceptions.renderMessage(SAFE_MESSAGE, args.toArray(new Arg[0]))); + this.args = args; + } + + @Override + public String getLogMessage() { + return SAFE_MESSAGE; + } + + @Override + public List> getArgs() { + return args; + } + + @Override + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // nop + public Throwable fillInStackTrace() { + // no-op: stack trace generation is expensive, this type exists + // to simply associate diagnostic information with a failure. + return this; + } + } + + private static QosReason qosReason(Response response) { + return QosReasons.parseFromResponse(response, DialogueQosResponseDecodingAdapter.INSTANCE); + } + + private enum DialogueQosResponseDecodingAdapter implements QosResponseDecodingAdapter { + INSTANCE; + + @Override + public Optional getFirstHeader(Response response, String headerName) { + return response.getFirstHeader(headerName); + } + } +} diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java index 50642aea0..cc2dba5e2 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java @@ -80,6 +80,7 @@ private RuntimeException decodeInternal(Response response) { int code = response.code(); switch (code) { case 308: + // Permanent redirect Optional location = response.getFirstHeader(HttpHeaders.LOCATION); if (location.isPresent()) { String locationHeader = location.get(); @@ -100,6 +101,7 @@ private RuntimeException decodeInternal(Response response) { } break; case 429: + // Too many requests return response.getFirstHeader(HttpHeaders.RETRY_AFTER) .map(Longs::tryParse) .map(Duration::ofSeconds) diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java index da7ea260c..cb877f0c3 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java @@ -328,4 +328,6 @@ public String toString() { return "StubEncoding{" + contentType + '}'; } } + + } diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java new file mode 100644 index 000000000..69bcb6949 --- /dev/null +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java @@ -0,0 +1,100 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.palantir.conjure.java.api.errors.CheckedServiceException; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Arg; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; +import java.util.OptionalLong; + +final class EndpointErrorTestUtils { + private EndpointErrorTestUtils() {} + + record ConjureError( + @JsonProperty("errorCode") String errorCode, + @JsonProperty("errorName") String errorName, + @JsonProperty("errorInstanceId") String errorInstanceId, + @JsonProperty("parameters") Map parameters) { + static ConjureError fromCheckedServiceException(CheckedServiceException exception) { + Map parameters = new HashMap<>(); + for (Arg arg : exception.getArgs()) { + if (shouldIncludeArgInParameters(arg)) { + parameters.put(arg.getName(), arg.getValue()); + } + } + return new ConjureError( + exception.getErrorType().code().name(), + exception.getErrorType().name(), + exception.getErrorInstanceId(), + parameters); + } + + private static boolean shouldIncludeArgInParameters(Arg arg) { + Object obj = arg.getValue(); + return obj != null + && (!(obj instanceof Optional) || ((Optional) obj).isPresent()) + && (!(obj instanceof OptionalInt) || ((OptionalInt) obj).isPresent()) + && (!(obj instanceof OptionalLong) || ((OptionalLong) obj).isPresent()) + && (!(obj instanceof OptionalDouble) || ((OptionalDouble) obj).isPresent()); + } + } + + /** Deserializes requests as the type. */ + public static final class TypeReturningStubEncoding implements Encoding { + + private final String contentType; + + TypeReturningStubEncoding(String contentType) { + this.contentType = contentType; + } + + @Override + public Encoding.Serializer serializer(TypeMarker _type) { + return (_value, _output) -> { + // nop + }; + } + + @Override + public Encoding.Deserializer deserializer(TypeMarker type) { + return input -> { + return (T) Encodings.json().deserializer(type).deserialize(input); + }; + } + + @Override + public String getContentType() { + return contentType; + } + + @Override + public boolean supportsContentType(String input) { + return contentType.equals(input); + } + + @Override + public String toString() { + return "TypeReturningStubEncoding{" + contentType + '}'; + } + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java new file mode 100644 index 000000000..c04f0ad7d --- /dev/null +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java @@ -0,0 +1,185 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.common.collect.ImmutableList; +import com.palantir.conjure.java.api.errors.CheckedServiceException; +import com.palantir.conjure.java.api.errors.ErrorType; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.ConjureError; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.TypeReturningStubEncoding; +import com.palantir.conjure.java.serialization.ObjectMappers; +import com.palantir.dialogue.BodySerDe; +import com.palantir.dialogue.DeserializerArgs; +import com.palantir.dialogue.TestResponse; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Preconditions; +import com.palantir.logsafe.Safe; +import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.Unsafe; +import com.palantir.logsafe.UnsafeArg; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nullable; +import javax.annotation.processing.Generated; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class EndpointErrorsConjureBodySerDeTest { + private static final ObjectMapper MAPPER = ObjectMappers.newServerObjectMapper(); + private ErrorDecoder errorDecoder = ErrorDecoder.INSTANCE; + + @Generated("by conjure-java") + private sealed interface EndpointReturnBaseType permits StringReturn, ErrorForEndpoint {} + + @Generated("by conjure-java") + record StringReturn(String value) implements EndpointReturnBaseType { + @JsonCreator + public static StringReturn create(String value) { + return new StringReturn(Preconditions.checkArgumentNotNull(value, "value cannot be null")); + } + } + + @Generated("by conjure-java") + @JsonDeserialize(using = ErrorForEndpoint.Deserializer.class) + record ErrorForEndpoint( + @Safe String errorCode, + @Safe String errorName, + @Safe String errorInstanceId, + @Safe String arg, + @Unsafe String unsafeArg, + @Safe ComplexArg complexArg, + @Safe Optional optionalArg) + implements EndpointReturnBaseType { + private static final ObjectMapper CLIENT_OBJECT_MAPPER = ObjectMappers.newClientObjectMapper(); + + static final class Deserializer extends JsonDeserializer { + @Override + public ErrorForEndpoint deserialize(JsonParser parser, DeserializationContext _ctx) throws IOException { + JsonNode node = parser.getCodec().readTree(parser); + JsonNode params = node.get("parameters"); + ComplexArg complex = CLIENT_OBJECT_MAPPER.treeToValue(params.get("complexArg"), ComplexArg.class); + Optional optionalArg = Optional.ofNullable(params.get("optionalArg")); + Optional optArg = optionalArg.isPresent() + ? Optional.of(CLIENT_OBJECT_MAPPER.treeToValue(optionalArg.get(), Integer.class)) + : Optional.empty(); + + return new ErrorForEndpoint( + Preconditions.checkArgumentNotNull(node.get("errorCode")) + .asText(), + Preconditions.checkArgumentNotNull(node.get("errorName")) + .asText(), + Preconditions.checkArgumentNotNull(node.get("errorInstanceId")) + .asText(), + CLIENT_OBJECT_MAPPER.treeToValue( + Preconditions.checkArgumentNotNull(params.get("arg")), String.class), + CLIENT_OBJECT_MAPPER.treeToValue( + Preconditions.checkArgumentNotNull(params.get("unsafeArg")), String.class), + complex, + optArg); + } + } + } + + @Generated("by conjure-java") + record ComplexArg(int foo, String bar) {} + + @Generated("by conjure-java") + public static final class TestEndpointError extends CheckedServiceException { + private TestEndpointError( + @Safe String arg, + @Unsafe String unsafeArg, + @Safe ComplexArg complexArg, + @Safe Optional optionalArg, + @Nullable Throwable cause) { + super( + ErrorType.FAILED_PRECONDITION, + cause, + SafeArg.of("arg", arg), + UnsafeArg.of("unsafeArg", unsafeArg), + SafeArg.of("complexArg", complexArg), + SafeArg.of("optionalArg", optionalArg)); + } + } + + @Test + public void testDeserializeCustomErrors() throws IOException { + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + + ErrorForEndpoint expectedErrorForEndpoint = new ErrorForEndpoint( + "FAILED_PRECONDITION", + "Default:FailedPrecondition", + errorThrownByEndpoint.getErrorInstanceId(), + "value", + "unsafeValue", + new ComplexArg(1, "bar"), + Optional.of(2)); + + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + EndpointReturnBaseType value = serializers + .deserializer(new DeserializerArgs<>( + /* baseType */ new TypeMarker() {}, + /* expectedResultType */ new TypeMarker() {}, + Map.of("Default:FailedPrecondition", new TypeMarker() {}))) + .deserialize(response); + assertThat(value).isEqualTo(expectedErrorForEndpoint); + } + + @Test + public void testDeserializeExpectedValue() { + String expectedString = "expectedString"; + TestResponse response = TestResponse.withBody(String.format("\"%s\"", expectedString)) + .contentType("application/json") + .code(200); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + EndpointReturnBaseType value = serializers + .deserializer(new DeserializerArgs<>( + /* baseType */ new TypeMarker() {}, + /* expectedResultType */ new TypeMarker() {}, + Map.of("Default:FailedPrecondition", new TypeMarker() {}))) + .deserialize(response); + assertThat(value).isEqualTo(new StringReturn(expectedString)); + } + + private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { + return new ConjureBodySerDe( + Arrays.stream(contentTypes) + .map(c -> WeightedEncoding.of(new TypeReturningStubEncoding(c))) + .collect(ImmutableList.toImmutableList()), + errorDecoder, + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java index 9f26a7aa6..510914481 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java @@ -54,7 +54,8 @@ public final class ErrorDecoderTest { private static String createServiceException(ServiceException exception) { try { - return SERVER_MAPPER.writeValueAsString(SerializableError.forException(exception)); + String ret = SERVER_MAPPER.writeValueAsString(SerializableError.forException(exception)); + return ret; } catch (JsonProcessingException e) { fail("failed to serialize"); return ""; diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java b/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java index 8801f0c44..f7c69d264 100644 --- a/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java +++ b/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java @@ -28,6 +28,8 @@ public interface BodySerDe { /** Creates a {@link Deserializer} for the requested type. Deserializer instances should be reused. */ Deserializer deserializer(TypeMarker type); + Deserializer deserializer(DeserializerArgs deserializerArgs); + /** * Returns a {@link Deserializer} that fails if a non-empty reponse body is presented and returns null otherwise. */ diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java b/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java new file mode 100644 index 000000000..30b77a4d7 --- /dev/null +++ b/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java @@ -0,0 +1,47 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.dialogue; + +import java.util.Map; + +// TODO(pm): add builder +public final class DeserializerArgs { + private final TypeMarker baseType; + private final TypeMarker expectedResultType; + private final Map> errorNameToTypeMarker; + + public DeserializerArgs( + TypeMarker baseType, + TypeMarker expectedResultType, + Map> errorNameToTypeMarker) { + this.baseType = baseType; + this.expectedResultType = expectedResultType; + this.errorNameToTypeMarker = errorNameToTypeMarker; + } + + public TypeMarker baseType() { + return baseType; + } + + public TypeMarker expectedResultType() { + return expectedResultType; + } + + public Map> errorNameToTypeMarker() { + return errorNameToTypeMarker; + } +}