Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pritham Marupaka committed Jan 8, 2025
1 parent d74dea3 commit 4f922a8
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.palantir.dialogue.Response;

// TODO(pm): use the new EndpointErrorDecoder
public final class ConjureErrorDecoder implements ErrorDecoder {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.io.OutputStream;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
Expand All @@ -52,7 +53,6 @@
/**
* items:
* - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class.
* - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info
*/

/** Package private internal API. */
Expand All @@ -65,7 +65,7 @@ final class ConjureBodySerDe implements BodySerDe {
private final Deserializer<Optional<InputStream>> optionalBinaryInputStreamDeserializer;
private final Deserializer<Void> emptyBodyDeserializer;
private final LoadingCache<Type, Serializer<?>> serializers;
private final LoadingCache<Type, EncodingDeserializerRegistry<?>> deserializers;
private final LoadingCache<Type, EncodingDeserializerForEndpointRegistry<?>> deserializers;
private final EmptyContainerDeserializer emptyContainerDeserializer;

/**
Expand All @@ -75,32 +75,49 @@ final class ConjureBodySerDe implements BodySerDe {
*/
ConjureBodySerDe(
List<WeightedEncoding> rawEncodings,
ErrorDecoder errorDecoder,
ErrorDecoder _errorDecoder,
EmptyContainerDeserializer emptyContainerDeserializer,
CaffeineSpec cacheSpec) {
List<WeightedEncoding> encodings = decorateEncodings(rawEncodings);
this.encodingsSortedByWeight = sortByWeight(encodings);
Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required");
// note(pm): why do the weighted encoding thing? can we just pass in the default encoding?
this.defaultEncoding = encodings.get(0).encoding();
this.emptyContainerDeserializer = emptyContainerDeserializer;
this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
this.binaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>(
ImmutableList.of(BinaryEncoding.INSTANCE),
errorDecoder,
emptyContainerDeserializer,
BinaryEncoding.MARKER);
this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
BinaryEncoding.MARKER,
DeserializerArgs.<InputStream>builder()
.withBaseType(BinaryEncoding.MARKER)
.withExpectedResult(BinaryEncoding.MARKER)
.build());
this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>(
ImmutableList.of(BinaryEncoding.INSTANCE),
errorDecoder,
emptyContainerDeserializer,
BinaryEncoding.OPTIONAL_MARKER);
this.emptyBodyDeserializer = new EmptyBodyDeserializer(errorDecoder);
BinaryEncoding.OPTIONAL_MARKER,
DeserializerArgs.<Optional<InputStream>>builder()
.withBaseType(BinaryEncoding.OPTIONAL_MARKER)
.withExpectedResult(BinaryEncoding.OPTIONAL_MARKER)
.build());
this.emptyBodyDeserializer =
new EmptyBodyDeserializer(new EndpointErrorDecoder<>(Collections.emptyMap(), Optional.empty()));
// Class unloading: Not supported, Jackson keeps strong references to the types
// it sees: https://github.com/FasterXML/jackson-databind/issues/489
this.serializers = Caffeine.from(cacheSpec)
.build(type -> new EncodingSerializerRegistry<>(defaultEncoding, TypeMarker.of(type)));
this.deserializers = Caffeine.from(cacheSpec)
.build(type -> new EncodingDeserializerRegistry<>(
encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type)));
this.deserializers = Caffeine.from(cacheSpec).build(type -> buildCacheEntry(TypeMarker.of(type)));
}

private <T> EncodingDeserializerForEndpointRegistry<?> buildCacheEntry(TypeMarker<T> typeMarker) {
return new EncodingDeserializerForEndpointRegistry<>(
encodingsSortedByWeight,
emptyContainerDeserializer,
typeMarker,
DeserializerArgs.<T>builder()
.withBaseType(typeMarker)
.withExpectedResult(typeMarker)
.build());
}

private static List<WeightedEncoding> decorateEncodings(List<WeightedEncoding> input) {
Expand Down Expand Up @@ -235,108 +252,7 @@ private static final class EncodingSerializerContainer<T> {
}
}

private static final class EncodingDeserializerRegistry<T> implements Deserializer<T> {

private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerRegistry.class);
private final ImmutableList<EncodingDeserializerContainer<T>> encodings;
private final ErrorDecoder errorDecoder;
private final Optional<String> acceptValue;
private final Supplier<Optional<T>> emptyInstance;
private final TypeMarker<T> token;

EncodingDeserializerRegistry(
List<Encoding> encodings,
ErrorDecoder errorDecoder,
EmptyContainerDeserializer empty,
TypeMarker<T> token) {
this.encodings = encodings.stream()
.map(encoding -> new EncodingDeserializerContainer<>(encoding, token))
.collect(ImmutableList.toImmutableList());
this.errorDecoder = errorDecoder;
this.token = token;
this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token));
// Encodings are applied to the accept header in the order of preference based on the provided list.
this.acceptValue =
Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", ")));
}

@Override
public T deserialize(Response response) {
boolean closeResponse = true;
try {
if (errorDecoder.isError(response)) {
throw errorDecoder.decode(response);
} else if (response.code() == 204) {
// TODO(dfox): what if we get a 204 for a non-optional type???
// TODO(dfox): support http200 & body=null
// TODO(dfox): what if we were expecting an empty list but got {}?
Optional<T> maybeEmptyInstance = emptyInstance.get();
if (maybeEmptyInstance.isPresent()) {
return maybeEmptyInstance.get();
}
throw new SafeRuntimeException(
"Unable to deserialize non-optional response type from 204", SafeArg.of("type", token));
}

Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
if (!contentType.isPresent()) {
throw new SafeIllegalArgumentException(
"Response is missing Content-Type header",
SafeArg.of("received", response.headers().keySet()));
}
Encoding.Deserializer<T> deserializer = getResponseDeserializer(contentType.get());
T deserialized = deserializer.deserialize(response.body());
// deserializer has taken on responsibility for closing the response body
closeResponse = false;
return deserialized;
} catch (IOException e) {
throw new SafeRuntimeException(
"Failed to deserialize response stream",
e,
SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)),
SafeArg.of("type", token));
} finally {
if (closeResponse) {
response.close();
}
}
}

@Override
public Optional<String> accepts() {
return acceptValue;
}

/** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */
@SuppressWarnings("ForLoopReplaceableByForEach")
// performance sensitive code avoids iterator allocation
Encoding.Deserializer<T> getResponseDeserializer(String contentType) {
for (int i = 0; i < encodings.size(); i++) {
EncodingDeserializerContainer<T> container = encodings.get(i);
if (container.encoding.supportsContentType(contentType)) {
return container.deserializer;
}
}
return throwingDeserializer(contentType);
}

private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
return input -> {
try {
input.close();
} catch (RuntimeException | IOException e) {
log.warn("Failed to close InputStream", e);
}
throw new SafeRuntimeException(
"Unsupported Content-Type",
SafeArg.of("received", contentType),
SafeArg.of("supportedEncodings", encodings));
};
}
}

private static final class EncodingDeserializerForEndpointRegistry<T> implements Deserializer<T> {

private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class);
private final ImmutableList<EncodingDeserializerContainer<? extends T>> encodings;
private final EndpointErrorDecoder<T> endpointErrorDecoder;
Expand All @@ -353,8 +269,11 @@ private static final class EncodingDeserializerForEndpointRegistry<T> implements
.map(encoding -> new EncodingDeserializerContainer<>(
encoding, deserializersForEndpoint.expectedResultType()))
.collect(ImmutableList.toImmutableList());
this.endpointErrorDecoder =
new EndpointErrorDecoder<>(deserializersForEndpoint.errorNameToTypeMarker(), encodings);
this.endpointErrorDecoder = new EndpointErrorDecoder<>(
deserializersForEndpoint.errorNameToTypeMarker(),
encodings.stream()
.filter(encoding -> encoding.supportsContentType("application/json"))
.findAny());
this.token = token;
this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token));
// Encodings are applied to the accept header in the order of preference based on the provided list.
Expand All @@ -367,7 +286,6 @@ public T deserialize(Response response) {
boolean closeResponse = true;
try {
if (endpointErrorDecoder.isError(response)) {
// TODO(pm): This needs to return T for the new deserializer API, but throw an exception for the old
return endpointErrorDecoder.decode(response);
} else if (response.code() == 204) {
Optional<T> maybeEmptyInstance = emptyInstance.get();
Expand Down Expand Up @@ -457,19 +375,19 @@ public String toString() {
}

private static final class EmptyBodyDeserializer implements Deserializer<Void> {
private final ErrorDecoder errorDecoder;
private final EndpointErrorDecoder<?> endpointErrorDecoder;

EmptyBodyDeserializer(ErrorDecoder errorDecoder) {
this.errorDecoder = errorDecoder;
EmptyBodyDeserializer(EndpointErrorDecoder<?> endpointErrorDecoder) {
this.endpointErrorDecoder = endpointErrorDecoder;
}

@Override
@SuppressWarnings("NullAway") // empty body is a special case
public Void deserialize(Response response) {
// We should not fail if a server that previously returned nothing starts returning a response
try (Response unused = response) {
if (errorDecoder.isError(response)) {
throw errorDecoder.decode(response);
if (endpointErrorDecoder.isError(response)) {
endpointErrorDecoder.decode(response);
}
return null;
}
Expand Down
Loading

0 comments on commit 4f922a8

Please sign in to comment.