Skip to content

Commit

Permalink
NIFI-13799 Improved Replicated Cluster Response Handling (apache#9312)
Browse files Browse the repository at this point in the history
- Return the remote Response Stream in the Replicated Response for unknown or large content length values
- Buffered smaller responses
- Addressed code warnings in ThreadPoolRequestReplicator
  • Loading branch information
exceptionfactory authored Sep 25, 2024
1 parent 1fb8498 commit 2e7a39d
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
private final EventReporter eventReporter;
private final RequestCompletionCallback callback;
private final ClusterCoordinator clusterCoordinator;
private final NiFiProperties nifiProperties;

private final ThreadPoolExecutor executorService;
private final ScheduledExecutorService maintenanceExecutor;
Expand Down Expand Up @@ -145,7 +144,6 @@ public ThreadPoolRequestReplicator(final int maxPoolSize, final int maxConcurren
this.responseMapper = new StandardHttpResponseMapper(nifiProperties);
this.eventReporter = eventReporter;
this.callback = callback;
this.nifiProperties = nifiProperties;
this.httpClient = client;

final AtomicInteger threadId = new AtomicInteger(0);
Expand Down Expand Up @@ -468,7 +466,7 @@ AsyncClusterResponse replicate(final Set<NodeIdentifier> nodeIds, final String m
final Function<NodeIdentifier, NodeHttpRequest> requestFactory =
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), nodeCompletionCallback, finalResponse);

submitAsyncRequest(nodeIds, uri.getScheme(), uri.getPath(), requestFactory, updatedHeaders);
submitAsyncRequest(nodeIds, requestFactory);

return response;
} catch (final Throwable t) {
Expand Down Expand Up @@ -541,17 +539,14 @@ public void onCompletion(final NodeResponse nodeResponse) {
try {
final Map<String, String> cancelLockHeaders = new HashMap<>(headers);
cancelLockHeaders.put(RequestReplicationHeader.CANCEL_TRANSACTION.getHeader(), "true");
final Thread cancelLockThread = new Thread(new Runnable() {
@Override
public void run() {
logger.debug("Found {} dissenting nodes for {} {}; canceling claim request", dissentingCount, method, uri.getPath());
final Thread cancelLockThread = new Thread(() -> {
logger.debug("Found {} dissenting nodes for {} {}; canceling claim request", dissentingCount, method, uri.getPath());

final PreparedRequest request = httpClient.prepareRequest(method, cancelLockHeaders, entity);
final Function<NodeIdentifier, NodeHttpRequest> requestFactory =
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), null, clusterResponse);
final PreparedRequest request = httpClient.prepareRequest(method, cancelLockHeaders, entity);
final Function<NodeIdentifier, NodeHttpRequest> requestFactory =
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), null, clusterResponse);

submitAsyncRequest(nodeIds, uri.getScheme(), uri.getPath(), requestFactory, cancelLockHeaders);
}
submitAsyncRequest(nodeIds, requestFactory);
});
cancelLockThread.setName("Cancel Flow Locks");
cancelLockThread.start();
Expand Down Expand Up @@ -627,30 +622,23 @@ public void run() {
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), completionCallback, clusterResponse);

// replicate the 'verification request' to all nodes
submitAsyncRequest(nodeIds, uri.getScheme(), uri.getPath(), requestFactory, validationHeaders);
submitAsyncRequest(nodeIds, requestFactory);
}


@Override
public AsyncClusterResponse getClusterResponse(final String identifier) {
final AsyncClusterResponse response = responseMap.get(identifier);
if (response == null) {
return null;
}

return response;
return responseMap.get(identifier);
}

// Visible for testing - overriding this method makes it easy to verify behavior without actually making any web requests
protected NodeResponse replicateRequest(final PreparedRequest request, final NodeIdentifier nodeId, final URI uri, final String requestId,
final StandardAsyncClusterResponse clusterResponse) throws IOException {

final Response response;
final long startNanos = System.nanoTime();
logger.debug("Replicating request to {} {}, request ID = {}, headers = {}", request.getMethod(), uri, requestId, request.getHeaders());

// invoke the request
response = httpClient.replicate(request, uri);
final Response response = httpClient.replicate(request, uri);

final long nanos = System.nanoTime() - startNanos;
clusterResponse.addTiming("Perform HTTP Request", nodeId.toString(), nanos);
Expand All @@ -669,14 +657,10 @@ protected NodeResponse replicateRequest(final PreparedRequest request, final Nod
}

private boolean isMutableRequest(final String method) {
switch (method.toUpperCase()) {
case HttpMethod.GET:
case HttpMethod.HEAD:
case HttpMethod.OPTIONS:
return false;
default:
return true;
}
return switch (method.toUpperCase()) {
case HttpMethod.GET, HttpMethod.HEAD, HttpMethod.OPTIONS -> false;
default -> true;
};
}

private boolean isDeleteComponent(final String method, final String uriPath) {
Expand All @@ -689,7 +673,7 @@ private boolean isDeleteComponent(final String method, final String uriPath) {
// This is because we do need to allow deletion of asynchronous requests, such as updating parameters, querying provenance, etc.
// which create a request, poll until the request completes, and then deletes it. Additionally, we want to allow terminating
// Processors, which is done by issuing a request to DELETE /processors/<id>/threads
final boolean componentUri = ConnectionEndpointMerger.CONNECTION_URI_PATTERN.matcher(uriPath).matches()
return ConnectionEndpointMerger.CONNECTION_URI_PATTERN.matcher(uriPath).matches()
|| ProcessorEndpointMerger.PROCESSOR_URI_PATTERN.matcher(uriPath).matches()
|| FunnelEndpointMerger.FUNNEL_URI_PATTERN.matcher(uriPath).matches()
|| PortEndpointMerger.INPUT_PORT_URI_PATTERN.matcher(uriPath).matches()
Expand All @@ -704,8 +688,6 @@ private boolean isDeleteComponent(final String method, final String uriPath) {
|| ParameterProviderEndpointMerger.PARAMETER_PROVIDER_URI_PATTERN.matcher(uriPath).matches()
|| FlowRegistryClientEndpointMerger.CONTROLLER_REGISTRY_URI_PATTERN.matcher(uriPath).matches()
|| SNIPPET_URI_PATTERN.matcher(uriPath).matches();

return componentUri;
}

/**
Expand Down Expand Up @@ -754,18 +736,20 @@ private void onResponseConsumed(final String requestId) {
*/
private void onCompletedResponse(final String requestId) {
final AsyncClusterResponse response = responseMap.get(requestId);
if (response == null) {
logger.info("Replicated Request [{}] not found", requestId);
return;
}

if (response != null && callback != null) {
if (callback != null) {
try {
callback.afterRequest(response.getURIPath(), response.getMethod(), response.getCompletedNodeResponses());
} catch (final Exception e) {
logger.warn("Completed request {} {} but failed to properly handle the Request Completion Callback due to {}",
response.getMethod(), response.getURIPath(), e.toString());
logger.warn("", e);
logger.warn("Completed request {} {} but failed to properly handle the Request Completion Callback", response.getMethod(), response.getURIPath(), e);
}
}

if (response != null && logger.isDebugEnabled()) {
if (logger.isDebugEnabled()) {
logTimingInfo(response);
}

Expand Down Expand Up @@ -811,8 +795,7 @@ private void logTimingInfo(final AsyncClusterResponse response) {
}


private void submitAsyncRequest(final Set<NodeIdentifier> nodeIds, final String scheme, final String path,
final Function<NodeIdentifier, NodeHttpRequest> callableFactory, final Map<String, String> headers) {
private void submitAsyncRequest(final Set<NodeIdentifier> nodeIds, final Function<NodeIdentifier, NodeHttpRequest> callableFactory) {

if (nodeIds.isEmpty()) {
return; // return quickly for trivial case
Expand Down Expand Up @@ -887,18 +870,18 @@ public void run() {
}
}

private static interface NodeRequestCompletionCallback {
private interface NodeRequestCompletionCallback {
void onCompletion(NodeResponse nodeResponse);
}

private synchronized int purgeExpiredRequests() {
final Set<String> expiredRequestIds = ThreadPoolRequestReplicator.this.responseMap.entrySet().stream()
.filter(entry -> entry.getValue().isOlderThan(30, TimeUnit.SECONDS)) // older than 30 seconds
.filter(entry -> entry.getValue().isComplete()) // is complete
.map(entry -> entry.getKey()) // get the request id
.map(Map.Entry::getKey) // get the request id
.collect(Collectors.toSet());

expiredRequestIds.forEach(id -> onResponseConsumed(id));
expiredRequestIds.forEach(this::onResponseConsumed);
return responseMap.size();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.apache.nifi.cluster.coordination.http.replication.HttpReplicationClient;
import org.apache.nifi.cluster.coordination.http.replication.PreparedRequest;
import org.apache.nifi.cluster.coordination.http.replication.io.EntitySerializer;
import org.apache.nifi.cluster.coordination.http.replication.io.JacksonResponse;
import org.apache.nifi.cluster.coordination.http.replication.io.ReplicatedResponse;
import org.apache.nifi.cluster.coordination.http.replication.io.JsonEntitySerializer;
import org.apache.nifi.cluster.coordination.http.replication.io.XmlEntitySerializer;
import org.apache.nifi.web.client.api.HttpEntityHeaders;
Expand Down Expand Up @@ -62,6 +62,8 @@ public class StandardHttpReplicationClient implements HttpReplicationClient {

private static final Set<String> DISALLOWED_HEADERS = Set.of("connection", "content-length", "expect", "host", "upgrade");

private static final int CONTENT_LENGTH_NOT_FOUND = -1;

private static final char PSEUDO_HEADER_PREFIX = ':';

private static final String GZIP_ENCODING = "gzip";
Expand Down Expand Up @@ -199,17 +201,25 @@ private Response replicate(final StandardPreparedRequest preparedRequest, final
private Response replicate(final HttpRequestBodySpec httpRequestBodySpec, final String method, final URI location) throws IOException {
final long started = System.currentTimeMillis();

try (HttpResponseEntity responseEntity = httpRequestBodySpec.retrieve()) {
final int statusCode = responseEntity.statusCode();
final HttpEntityHeaders headers = responseEntity.headers();
final MultivaluedMap<String, String> responseHeaders = getResponseHeaders(headers);
final byte[] responseBody = getResponseBody(responseEntity.body(), headers);
final HttpResponseEntity responseEntity = httpRequestBodySpec.retrieve();
final int statusCode = responseEntity.statusCode();
final HttpEntityHeaders headers = responseEntity.headers();
final MultivaluedMap<String, String> responseHeaders = getResponseHeaders(headers);
final int contentLength = getContentLength(headers);

final InputStream responseBody = getResponseBody(responseEntity.body(), headers);
final Runnable closeCallback = () -> {
try {
responseEntity.close();
} catch (final IOException e) {
logger.warn("Close failed for Replicated {} {} HTTP {}", method, location, statusCode, e);
}
};

final long elapsed = System.currentTimeMillis() - started;
logger.debug("Replicated {} {} HTTP {} in {} ms", method, location, statusCode, elapsed);
final long elapsed = System.currentTimeMillis() - started;
logger.debug("Replicated {} {} HTTP {} in {} ms", method, location, statusCode, elapsed);

return new JacksonResponse(objectMapper, responseBody, responseHeaders, location, statusCode, null);
}
return new ReplicatedResponse(objectMapper, responseBody, responseHeaders, location, statusCode, contentLength, closeCallback);
}

private URI getRequestUri(final StandardPreparedRequest preparedRequest, final URI location) {
Expand Down Expand Up @@ -288,14 +298,32 @@ private MultivaluedMap<String, String> getResponseHeaders(final HttpEntityHeader
return headers;
}

private byte[] getResponseBody(final InputStream inputStream, final HttpEntityHeaders responseHeaders) throws IOException {
private InputStream getResponseBody(final InputStream inputStream, final HttpEntityHeaders responseHeaders) throws IOException {
final boolean gzipEncoded = isGzipEncoded(responseHeaders);
return gzipEncoded ? new GZIPInputStream(inputStream) : inputStream;
}

final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
try (InputStream responseBodyStream = gzipEncoded ? new GZIPInputStream(inputStream) : inputStream) {
responseBodyStream.transferTo(outputStream);
private int getContentLength(final HttpEntityHeaders headers) {
final Optional<String> contentLengthFound = headers.getHeaderNames()
.stream()
.filter(PreparedRequestHeader.CONTENT_LENGTH.getHeader()::equalsIgnoreCase)
.findFirst()
.flatMap(headers::getFirstHeader);

int contentLength;
if (contentLengthFound.isPresent()) {
final String contentLengthHeader = contentLengthFound.get();
try {
contentLength = Integer.parseInt(contentLengthHeader);
} catch (final NumberFormatException e) {
logger.warn("Replicated Header Content-Length [{}] parsing failed", contentLengthHeader, e);
contentLength = CONTENT_LENGTH_NOT_FOUND;
}
} else {
contentLength = CONTENT_LENGTH_NOT_FOUND;
}
return outputStream.toByteArray();

return contentLength;
}

private byte[] getRequestBody(final Object requestEntity, final Map<String, String> headers) {
Expand Down
Loading

0 comments on commit 2e7a39d

Please sign in to comment.