Skip to content

Commit

Permalink
apacheGH-41262: [Java][FlightSQL] Implement stateless prepared statem…
Browse files Browse the repository at this point in the history
…ents (apache#41237)

### Rationale for this change

Expand the number of implemented languages for stateless prepared statements to include Java.

### What changes are included in this PR?

Update FlightSqlClient and include a stateless server implementation example with tests.

### Are these changes tested?

Yes, tests are added to cover a stateless server implementation.

### Are there any user-facing changes?

There is a modified FlightSqlClient that is required to enable use of stateless prepared statements.

* GitHub Issue: apache#41262

Lead-authored-by: Steve Lord <[email protected]>
Co-authored-by: Mateusz Rzeszutek <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
stevelorddremio and mateuszrzeszutek authored Jun 3, 2024
1 parent 813fe25 commit 1598782
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.flight.SyncPutListener;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery;
import org.apache.arrow.flight.sql.util.TableRef;
Expand Down Expand Up @@ -1048,15 +1049,35 @@ private Schema deserializeSchema(final ByteString bytes) {
public FlightInfo execute(final CallOption... options) {
checkOpen();

final FlightDescriptor descriptor = FlightDescriptor
FlightDescriptor descriptor = FlightDescriptor
.command(Any.pack(CommandPreparedStatementQuery.newBuilder()
.setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle())
.build())
.toByteArray());

if (parameterBindingRoot != null && parameterBindingRoot.getRowCount() > 0) {
try (final SyncPutListener listener = putParameters(descriptor, options)) {
listener.getResult();
try (final SyncPutListener putListener = putParameters(descriptor, options)) {
if (getParameterSchema().getFields().size() > 0 &&
parameterBindingRoot != null &&
parameterBindingRoot.getRowCount() > 0) {
final PutResult read = putListener.read();
if (read != null) {
try (final ArrowBuf metadata = read.getApplicationMetadata()) {
final FlightSql.DoPutPreparedStatementResult doPutPreparedStatementResult =
FlightSql.DoPutPreparedStatementResult.parseFrom(metadata.nioBuffer());
descriptor = FlightDescriptor
.command(Any.pack(CommandPreparedStatementQuery.newBuilder()
.setPreparedStatementHandle(
doPutPreparedStatementResult.getPreparedStatementHandle())
.build())
.toByteArray());
}
}
}
} catch (final InterruptedException | ExecutionException e) {
throw CallStatus.CANCELLED.withCause(e).toRuntimeException();
} catch (final InvalidProtocolBufferException e) {
throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight.sql.example;

import java.io.Serializable;

public class DoPutPreparedStatementResultPOJO implements Serializable {
private String query;
private byte[] parameters;

public DoPutPreparedStatementResultPOJO(String query, byte[] parameters) {
this.query = query;
this.parameters = parameters.clone();
}

public String getQuery() {
return query;
}

public byte[] getParameters() {
return parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,22 @@
* supports all current features of Flight SQL.
*/
public class FlightSqlExample implements FlightSqlProducer, AutoCloseable {
private static final String DATABASE_URI = "jdbc:derby:target/derbyDB";
private static final Logger LOGGER = getLogger(FlightSqlExample.class);
private static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar();
protected static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar();
public static final String DB_NAME = "derbyDB";
private final String databaseUri;
// ARROW-15315: Use ExecutorService to simulate an async scenario
private final ExecutorService executorService = Executors.newFixedThreadPool(10);
private final Location location;
private final PoolingDataSource<PoolableConnection> dataSource;
private final BufferAllocator rootAllocator = new RootAllocator();
protected final PoolingDataSource<PoolableConnection> dataSource;
protected final BufferAllocator rootAllocator = new RootAllocator();
private final Cache<ByteString, StatementContext<PreparedStatement>> preparedStatementLoadingCache;
private final Cache<ByteString, StatementContext<Statement>> statementLoadingCache;
private final SqlInfoBuilder sqlInfoBuilder;

public static void main(String[] args) throws Exception {
Location location = Location.forGrpcInsecure("localhost", 55555);
final FlightSqlExample example = new FlightSqlExample(location);
final FlightSqlExample example = new FlightSqlExample(location, DB_NAME);
Location listenLocation = Location.forGrpcInsecure("0.0.0.0", 55555);
try (final BufferAllocator allocator = new RootAllocator();
final FlightServer server = FlightServer.builder(allocator, listenLocation, example).build()) {
Expand All @@ -179,13 +180,14 @@ public static void main(String[] args) throws Exception {
}
}

public FlightSqlExample(final Location location) {
public FlightSqlExample(final Location location, final String dbName) {
// TODO Constructor should not be doing work.
checkState(
removeDerbyDatabaseIfExists() && populateDerbyDatabase(),
removeDerbyDatabaseIfExists(dbName) && populateDerbyDatabase(dbName),
"Failed to reset Derby database!");
databaseUri = "jdbc:derby:target/" + dbName;
final ConnectionFactory connectionFactory =
new DriverManagerConnectionFactory(DATABASE_URI, new Properties());
new DriverManagerConnectionFactory(databaseUri, new Properties());
final PoolableConnectionFactory poolableConnectionFactory =
new PoolableConnectionFactory(connectionFactory, null);
final ObjectPool<PoolableConnection> connectionPool = new GenericObjectPool<>(poolableConnectionFactory);
Expand Down Expand Up @@ -248,9 +250,9 @@ public FlightSqlExample(final Location location) {

}

private static boolean removeDerbyDatabaseIfExists() {
public static boolean removeDerbyDatabaseIfExists(final String dbName) {
boolean wasSuccess;
final Path path = Paths.get("target" + File.separator + "derbyDB");
final Path path = Paths.get("target" + File.separator + dbName);

try (final Stream<Path> walk = Files.walk(path)) {
/*
Expand All @@ -262,7 +264,7 @@ private static boolean removeDerbyDatabaseIfExists() {
* this not expected.
*/
wasSuccess = walk.sorted(Comparator.reverseOrder()).map(Path::toFile).map(File::delete)
.reduce(Boolean::logicalAnd).orElseThrow(IOException::new);
.reduce(Boolean::logicalAnd).orElseThrow(IOException::new);
} catch (IOException e) {
/*
* The only acceptable scenario for an `IOException` to be thrown here is if
Expand All @@ -277,9 +279,12 @@ private static boolean removeDerbyDatabaseIfExists() {
return wasSuccess;
}

private static boolean populateDerbyDatabase() {
try (final Connection connection = DriverManager.getConnection("jdbc:derby:target/derbyDB;create=true");
private static boolean populateDerbyDatabase(final String dbName) {
try (final Connection connection = DriverManager.getConnection("jdbc:derby:target/" + dbName + ";create=true");
Statement statement = connection.createStatement()) {

dropTable(statement, "intTable");
dropTable(statement, "foreignTable");
statement.execute("CREATE TABLE foreignTable (" +
"id INT not null primary key GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " +
"foreignName varchar(100), " +
Expand All @@ -302,6 +307,18 @@ private static boolean populateDerbyDatabase() {
return true;
}

private static void dropTable(final Statement statement, final String tableName) throws SQLException {
try {
statement.execute("DROP TABLE " + tableName);
} catch (SQLException e) {
// sql error code for "object does not exist"; which is fine, we're trying to delete the table
// see https://db.apache.org/derby/docs/10.17/ref/rrefexcept71493.html
if (!"42Y55".equals(e.getSQLState())) {
throw e;
}
}
}

private static ArrowType getArrowTypeFromJdbcType(final int jdbcDataType, final int precision, final int scale) {
try {
return JdbcToArrowUtils.getArrowTypeFromJdbcType(new JdbcFieldInfo(jdbcDataType, precision, scale),
Expand Down Expand Up @@ -778,7 +795,7 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r
// Running on another thread
Future<?> unused = executorService.submit(() -> {
try {
final ByteString preparedStatementHandle = copyFrom(randomUUID().toString().getBytes(StandardCharsets.UTF_8));
final ByteString preparedStatementHandle = copyFrom(request.getQuery().getBytes(StandardCharsets.UTF_8));
// Ownership of the connection will be passed to the context. Do NOT close!
final Connection connection = dataSource.getConnection();
final PreparedStatement preparedStatement = connection.prepareStatement(request.getQuery(),
Expand Down Expand Up @@ -882,7 +899,7 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
while (binder.next()) {
preparedStatement.addBatch();
}
int[] recordCounts = preparedStatement.executeBatch();
final int[] recordCounts = preparedStatement.executeBatch();
recordCount = Arrays.stream(recordCounts).sum();
}

Expand Down Expand Up @@ -928,6 +945,7 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co
.toRuntimeException());
return;
}

ackStream.onCompleted();
};
}
Expand Down Expand Up @@ -1035,7 +1053,7 @@ public void getStreamTables(final CommandGetTables command, final CallContext co
final String[] tableTypes =
protocolSize == 0 ? null : protocolStringList.toArray(new String[protocolSize]);

try (final Connection connection = DriverManager.getConnection(DATABASE_URI);
try (final Connection connection = DriverManager.getConnection(databaseUri);
final VectorSchemaRoot vectorSchemaRoot = getTablesRoot(
connection.getMetaData(),
rootAllocator,
Expand Down Expand Up @@ -1086,7 +1104,7 @@ public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final Call
final String schema = command.hasDbSchema() ? command.getDbSchema() : null;
final String table = command.getTable();

try (Connection connection = DriverManager.getConnection(DATABASE_URI)) {
try (Connection connection = DriverManager.getConnection(databaseUri)) {
final ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(catalog, schema, table);

final VarCharVector catalogNameVector = new VarCharVector("catalog_name", rootAllocator);
Expand Down Expand Up @@ -1140,7 +1158,7 @@ public void getStreamExportedKeys(final CommandGetExportedKeys command, final Ca
String schema = command.hasDbSchema() ? command.getDbSchema() : null;
String table = command.getTable();

try (Connection connection = DriverManager.getConnection(DATABASE_URI);
try (Connection connection = DriverManager.getConnection(databaseUri);
ResultSet keys = connection.getMetaData().getExportedKeys(catalog, schema, table);
VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) {
listener.start(vectorSchemaRoot);
Expand All @@ -1165,7 +1183,7 @@ public void getStreamImportedKeys(final CommandGetImportedKeys command, final Ca
String schema = command.hasDbSchema() ? command.getDbSchema() : null;
String table = command.getTable();

try (Connection connection = DriverManager.getConnection(DATABASE_URI);
try (Connection connection = DriverManager.getConnection(databaseUri);
ResultSet keys = connection.getMetaData().getImportedKeys(catalog, schema, table);
VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) {
listener.start(vectorSchemaRoot);
Expand Down Expand Up @@ -1193,7 +1211,7 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex
final String pkTable = command.getPkTable();
final String fkTable = command.getFkTable();

try (Connection connection = DriverManager.getConnection(DATABASE_URI);
try (Connection connection = DriverManager.getConnection(databaseUri);
ResultSet keys = connection.getMetaData()
.getCrossReference(pkCatalog, pkSchema, pkTable, fkCatalog, fkSchema, fkTable);
VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) {
Expand Down Expand Up @@ -1280,7 +1298,7 @@ public void getStreamStatement(final TicketStatementQuery ticketStatementQuery,
}
}

private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
protected <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
final Schema schema) {
final Ticket ticket = new Ticket(pack(request).toByteArray());
// TODO Support multiple endpoints.
Expand Down
Loading

0 comments on commit 1598782

Please sign in to comment.