Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pritham Marupaka committed Dec 4, 2024
1 parent 6d14c96 commit 6f894f2
Show file tree
Hide file tree
Showing 8 changed files with 539 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.palantir.dialogue.BinaryRequestBody;
import com.palantir.dialogue.BodySerDe;
import com.palantir.dialogue.Deserializer;
import com.palantir.dialogue.DeserializerArgs;
import com.palantir.dialogue.RequestBody;
import com.palantir.dialogue.Response;
import com.palantir.dialogue.Serializer;
Expand All @@ -43,11 +44,19 @@
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* items:
* - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class.
* - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info
*/

/** Package private internal API. */
final class ConjureBodySerDe implements BodySerDe {

Expand All @@ -58,7 +67,10 @@ final class ConjureBodySerDe implements BodySerDe {
private final Deserializer<Optional<InputStream>> optionalBinaryInputStreamDeserializer;
private final Deserializer<Void> emptyBodyDeserializer;
private final LoadingCache<Type, Serializer<?>> serializers;
private final LoadingCache<Type, Deserializer<?>> deserializers;
private final LoadingCache<Type, EncodingDeserializerRegistry<?>> deserializers;
private final Map<Type, DeserializersForEndpoint<?>> deserializersForEndpointBaseType;
private final LoadingCache<Type, EncodingDeserializerForEndpointRegistry<?>> endpointWithErrorsDeserializers;
private final EmptyContainerDeserializer emptyContainerDeserializer;

/**
* Selects the first (based on input order) of the provided encodings that
Expand All @@ -74,6 +86,7 @@ final class ConjureBodySerDe implements BodySerDe {
this.encodingsSortedByWeight = sortByWeight(encodings);
Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required");
this.defaultEncoding = encodings.get(0).encoding();
this.emptyContainerDeserializer = emptyContainerDeserializer;
this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
ImmutableList.of(BinaryEncoding.INSTANCE),
errorDecoder,
Expand All @@ -92,6 +105,9 @@ final class ConjureBodySerDe implements BodySerDe {
this.deserializers = Caffeine.from(cacheSpec)
.build(type -> new EncodingDeserializerRegistry<>(
encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type)));
// TODO(pm): revisit storing this in a map.
this.deserializersForEndpointBaseType = new HashMap<>();
this.endpointWithErrorsDeserializers = Caffeine.from(cacheSpec).build(this::buildCacheEntry);
}

private static List<WeightedEncoding> decorateEncodings(List<WeightedEncoding> input) {
Expand Down Expand Up @@ -122,6 +138,33 @@ public <T> Deserializer<T> deserializer(TypeMarker<T> token) {
return (Deserializer<T>) deserializers.get(token.getType());
}

@Override
@SuppressWarnings("unchecked")
public <T> Deserializer<T> deserializer(DeserializerArgs<T> deserializerArgs) {
Map<String, Deserializer<? extends T>> deserializersForErrors =
deserializerArgs.errorNameToTypeMarker().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> (Deserializer<? extends T>)
deserializers.get(entry.getValue().getType())));
Deserializer<? extends T> resType = (Deserializer<? extends T>)
deserializers.get(deserializerArgs.expectedResultType().getType());
DeserializersForEndpoint<T> deserializersForEndpoint = new DeserializersForEndpoint<>(
resType, deserializersForErrors, deserializerArgs.errorNameToTypeMarker());

Type baseType = deserializerArgs.baseType().getType();
this.deserializersForEndpointBaseType.put(baseType, deserializersForEndpoint);
return (Deserializer<T>) endpointWithErrorsDeserializers.get(baseType);
}

@SuppressWarnings("unchecked")
private <T> EncodingDeserializerForEndpointRegistry<T> buildCacheEntry(Type baseType) {
return new EncodingDeserializerForEndpointRegistry<>(
encodingsSortedByWeight,
emptyContainerDeserializer,
(TypeMarker<T>) TypeMarker.of(baseType),
(DeserializersForEndpoint<T>) Optional.ofNullable(deserializersForEndpointBaseType.get(baseType))
.orElseThrow());
}

@Override
public Deserializer<Void> emptyBodyDeserializer() {
return emptyBodyDeserializer;
Expand Down Expand Up @@ -301,6 +344,106 @@ Encoding.Deserializer<T> getResponseDeserializer(String contentType) {
return throwingDeserializer(contentType);
}

private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
return input -> {
try {
input.close();
} catch (RuntimeException | IOException e) {
log.warn("Failed to close InputStream", e);
}
throw new SafeRuntimeException(
"Unsupported Content-Type",
SafeArg.of("received", contentType),
SafeArg.of("supportedEncodings", encodings));
};
}
}

private static final class EncodingDeserializerForEndpointRegistry<T> implements Deserializer<T> {

private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class);
private final ImmutableList<EncodingDeserializerContainer<T>> encodings;
private final EndpointErrorDecoder<T> endpointErrorDecoder;
private final Optional<String> acceptValue;
private final Supplier<Optional<T>> emptyInstance;
private final TypeMarker<T> token;

EncodingDeserializerForEndpointRegistry(
List<Encoding> encodings,
EmptyContainerDeserializer empty,
TypeMarker<T> token,
DeserializersForEndpoint<T> deserializersForEndpoint) {
this.encodings = encodings.stream()
.map(encoding -> new EncodingDeserializerContainer<>(encoding, token))
.collect(ImmutableList.toImmutableList());
this.endpointErrorDecoder = new EndpointErrorDecoder<>(deserializersForEndpoint, encodings);
this.token = token;
this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token));
// Encodings are applied to the accept header in the order of preference based on the provided list.
this.acceptValue =
Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", ")));
}

@Override
public T deserialize(Response response) {
boolean closeResponse = true;
try {
if (endpointErrorDecoder.isError(response)) {
return endpointErrorDecoder.decode(response);
} else if (response.code() == 204) {
// TODO(dfox): what if we get a 204 for a non-optional type???
// TODO(dfox): support http200 & body=null
// TODO(dfox): what if we were expecting an empty list but got {}?
Optional<T> maybeEmptyInstance = emptyInstance.get();
if (maybeEmptyInstance.isPresent()) {
return maybeEmptyInstance.get();
}
throw new SafeRuntimeException(
"Unable to deserialize non-optional response type from 204", SafeArg.of("type", token));
}

Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
if (!contentType.isPresent()) {
throw new SafeIllegalArgumentException(
"Response is missing Content-Type header",
SafeArg.of("received", response.headers().keySet()));
}
Encoding.Deserializer<T> deserializer = getResponseDeserializer(contentType.get());
T deserialized = deserializer.deserialize(response.body());
// deserializer has taken on responsibility for closing the response body
closeResponse = false;
return deserialized;
} catch (IOException e) {
throw new SafeRuntimeException(
"Failed to deserialize response stream",
e,
SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)),
SafeArg.of("type", token));
} finally {
if (closeResponse) {
response.close();
}
}
}

@Override
public Optional<String> accepts() {
return acceptValue;
}

/** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */
@SuppressWarnings("ForLoopReplaceableByForEach")
// performance sensitive code avoids iterator allocation
Encoding.Deserializer<T> getResponseDeserializer(String contentType) {
for (int i = 0; i < encodings.size(); i++) {
EncodingDeserializerContainer<T> container = encodings.get(i);
if (container.encoding.supportsContentType(contentType)) {
return container.deserializer;
}
}
return throwingDeserializer(contentType);
}

private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
return new Encoding.Deserializer<T>() {
@Override
Expand All @@ -320,7 +463,8 @@ public T deserialize(InputStream input) {
}

/** Effectively just a pair. */
private static final class EncodingDeserializerContainer<T> {
// TODO(pm): saving the deserializer actually isn't doing much for us.
static final class EncodingDeserializerContainer<T> {

private final Encoding encoding;
private final Encoding.Deserializer<T> deserializer;
Expand All @@ -330,6 +474,10 @@ private static final class EncodingDeserializerContainer<T> {
this.deserializer = encoding.deserializer(token);
}

public Encoding.Deserializer<T> getDeserializer() {
return deserializer;
}

@Override
public String toString() {
return "EncodingDeserializerContainer{encoding=" + encoding + ", deserializer=" + deserializer + '}';
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* (c) Copyright 2024 Palantir Technologies Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.palantir.conjure.java.dialogue.serde;

import com.palantir.dialogue.Deserializer;
import com.palantir.dialogue.TypeMarker;
import java.util.Map;

public record DeserializersForEndpoint<T>(
Deserializer<? extends T> expectedResultType,
Map<String, Deserializer<? extends T>> errorNameToDeserializer,
Map<String, TypeMarker<? extends T>> errorNameToType) {}
Loading

0 comments on commit 6f894f2

Please sign in to comment.