Skip to content

Commit

Permalink
refactor: cleanup AbstractJdbcQueries ensuring no memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
fhussonnois committed Dec 19, 2024
1 parent 29c366b commit f39b9f7
Showing 1 changed file with 85 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,30 @@
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.core.serializers.FileSerde;
import lombok.*;
import io.kestra.core.utils.Rethrow;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.slf4j.Logger;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.sql.*;
import java.util.*;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Savepoint;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -30,15 +44,6 @@ public abstract class AbstractJdbcQueries extends AbstractJdbcBaseQuery implemen
@Builder.Default
protected Property<Boolean> transaction = Property.of(Boolean.TRUE);

@Getter(AccessLevel.NONE)
private Connection conn = null;

@Getter(AccessLevel.NONE)
private PreparedStatement stmt = null;

@Getter(AccessLevel.NONE)
private Savepoint savepoint = null;

public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Exception {
Logger logger = runContext.logger();
AbstractCellConverter cellConverter = getCellConverter(this.zoneId());
Expand All @@ -47,72 +52,90 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex
long totalSize = 0L;
List<AbstractJdbcQuery.Output> outputList = new LinkedList<>();

try {
//Create connection in not autocommit mode to enable rollback on error
conn = this.connection(runContext);
conn.setAutoCommit(false);
savepoint = initializeSavepoint(conn);
//Create connection in not autocommit mode to enable rollback on error
Connection connection = null;
Savepoint savepoint = null;
try {
connection = this.connection(runContext);
savepoint = initializeSavepoint(connection);

connection.setAutoCommit(false);

String sqlRendered = runContext.render(this.sql, this.additionalVars);
String[] queries = sqlRendered.split(";[^']");

for(String query : queries) {
for (String query : queries) {
//Create statement, execute
stmt = createPreparedStatementAndPopulateParameters(runContext, conn, query);
stmt.setFetchSize(this.getFetchSize());
logger.debug("Starting query: {}", query);
stmt.execute();

if(!isTransactional) {
conn.commit();
try (PreparedStatement stmt = prepareStatement(runContext, connection, query)) {
stmt.setFetchSize(this.getFetchSize());
logger.debug("Starting query: {}", query);
stmt.execute();
if (!isTransactional) {
connection.commit();
}
totalSize = extractResultsFromResultSet(connection, stmt, runContext, cellConverter, totalSize, outputList);
}
totalSize = extractResultsFromResultSet(runContext, cellConverter, totalSize, outputList);
}
conn.commit();

connection.commit();
runContext.metric(Counter.of("fetch.size", totalSize, this.tags()));

return MultiQueryOutput.builder().outputs(outputList).build();
} catch (Exception e) {
rollbackIfTransactional(isTransactional);
rollbackIfTransactional(connection, savepoint, isTransactional);
throw new RuntimeException(e);
} finally {
closeConnectionAndStatement(runContext);
safelyCloseConnection(runContext, connection);
}
}

private long extractResultsFromResultSet(RunContext runContext, AbstractCellConverter cellConverter, long totalSize, List<Output> outputList) throws SQLException, IOException {
try(ResultSet rs = stmt.getResultSet()) {
private static void safelyCloseConnection(final RunContext runContext, final Connection connection) {
try {
if (connection != null) {
connection.close();
}
} catch (SQLException e) {
runContext.logger().warn("Issue when closing the connection : {}", e.getMessage());
}
}

private long extractResultsFromResultSet(final Connection connection,
final PreparedStatement stmt,
final RunContext runContext,
final AbstractCellConverter cellConverter,
long totalSize,
final List<Output> outputList) throws SQLException, IOException {
try (ResultSet rs = stmt.getResultSet()) {
//When sql is not a select statement skip output creation
if(rs != null) {
if (rs != null) {
Output.OutputBuilder<?, ?> output = Output.builder();
//Populate result fro result set
long size = 0L;
switch (this.getFetchType()) {
case FETCH_ONE -> {
size = 1L;
output
.row(fetchResult(rs, cellConverter, conn))
.row(fetchResult(rs, cellConverter, connection))
.size(size);
}
case STORE -> {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) {
size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn);
size = fetchToFile(stmt, rs, fileWriter, cellConverter, connection);
}
output
.uri(runContext.storage().putFile(tempFile))
.size(size);
}
case FETCH -> {
List<Map<String, Object>> maps = new ArrayList<>();
size = fetchResults(stmt, rs, maps, cellConverter, conn);
size = fetchResults(stmt, rs, maps, cellConverter, connection);
output
.rows(maps)
.size(size);
}
case NONE -> runContext.logger().info("fetchType is set to NONE, no output will be returned");
default -> throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE");
default ->
throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE");
}
totalSize += size;
outputList.add(output.build());
Expand All @@ -121,26 +144,19 @@ private long extractResultsFromResultSet(RunContext runContext, AbstractCellConv
return totalSize;
}

private void rollbackIfTransactional(boolean isTransactional) throws SQLException {
if(isTransactional && conn != null) {
if(savepoint != null) {
conn.rollback(savepoint);
private static void rollbackIfTransactional(final Connection connection,
final Savepoint savepoint,
final boolean isTransactional) throws SQLException {
if (isTransactional) {
if (savepoint != null) {
connection.rollback(savepoint);
return;
}
conn.rollback();
connection.rollback();
}
}

private void closeConnectionAndStatement(RunContext runContext) {
try {
if(conn != null && !conn.isClosed()) { conn.close(); }
if(stmt != null && !stmt.isClosed()) { stmt.close(); }
} catch (SQLException e) {
runContext.logger().warn("Issue when closing the connection : {}", e.getMessage());
}
}

private Savepoint initializeSavepoint(Connection conn) throws SQLException {
private static Savepoint initializeSavepoint(final Connection conn) {
try {
return conn.setSavepoint();
} catch (SQLException e) {
Expand Down Expand Up @@ -168,37 +184,43 @@ public static class MultiQueryOutput implements io.kestra.core.models.tasks.Outp
List<AbstractJdbcQuery.Output> outputs;
}

private PreparedStatement createPreparedStatementAndPopulateParameters(RunContext runContext, Connection conn, String sql) throws SQLException, IllegalVariableEvaluationException {
//Inject named parameters (ex: ':param')
Map<String, Object> namedParamsRendered = this.getParameters() == null ? null : this.getParameters().asMap(runContext, String.class, Object.class);
private PreparedStatement prepareStatement(final RunContext runContext,
final Connection conn,
final String sql) throws SQLException, IllegalVariableEvaluationException {

if(namedParamsRendered == null || namedParamsRendered.isEmpty()) {
// Inject named parameters (ex: ':param')
Optional<Map<String, Object>> namedParamsRendered = Optional
.ofNullable(this.getParameters())
.map(Rethrow.throwFunction(it -> it.asMap(runContext, String.class, Object.class)));

if (namedParamsRendered.isEmpty()) {
return createPreparedStatement(conn, sql);
}

//Extract parameters in orders and replace them with '?'
String preparedSql = sql;
Pattern pattern = Pattern.compile(":\\w+");
Pattern pattern = Pattern.compile(":\\w+");
Matcher matcher = pattern.matcher(preparedSql);

List<String> params = new LinkedList<>();

while (matcher.find()) {
String param = matcher.group();
params.add(param.substring(1));
preparedSql = matcher.replaceFirst( "?");
preparedSql = matcher.replaceFirst("?");
matcher = pattern.matcher(preparedSql);
}
stmt = createPreparedStatement(conn, preparedSql);

for(int i=0; i<params.size(); i++) {
stmt.setObject(i+1, namedParamsRendered.get(params.get(i)));
PreparedStatement stmt = createPreparedStatement(conn, preparedSql);

for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, namedParamsRendered.get().get(params.get(i)));
}

return stmt;
}

protected PreparedStatement createPreparedStatement(Connection conn, String preparedSql) throws SQLException {
return conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
protected PreparedStatement createPreparedStatement(final Connection conn, final String sql) throws SQLException {
return conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
}
}

0 comments on commit f39b9f7

Please sign in to comment.