From f39b9f7d4282cee5df766254b00cfd6a08ad84bc Mon Sep 17 00:00:00 2001 From: Florian Hussonnois Date: Thu, 19 Dec 2024 15:06:14 +0100 Subject: [PATCH] refactor: cleanup AbstractJdbcQueries ensuring no memory leak --- .../plugin/jdbc/AbstractJdbcQueries.java | 148 ++++++++++-------- 1 file changed, 85 insertions(+), 63 deletions(-) diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java index 67d39aa6..ec44daf0 100644 --- a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java @@ -5,7 +5,12 @@ import io.kestra.core.models.property.Property; import io.kestra.core.runners.RunContext; import io.kestra.core.serializers.FileSerde; -import lombok.*; +import io.kestra.core.utils.Rethrow; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; import lombok.experimental.SuperBuilder; import org.slf4j.Logger; @@ -13,8 +18,17 @@ import java.io.File; import java.io.FileWriter; import java.io.IOException; -import java.sql.*; -import java.util.*; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Savepoint; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -30,15 +44,6 @@ public abstract class AbstractJdbcQueries extends AbstractJdbcBaseQuery implemen @Builder.Default protected Property transaction = Property.of(Boolean.TRUE); - @Getter(AccessLevel.NONE) - private Connection conn = null; - - @Getter(AccessLevel.NONE) - private PreparedStatement stmt = null; - - @Getter(AccessLevel.NONE) - private Savepoint savepoint = null; - public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Exception { Logger logger = runContext.logger(); AbstractCellConverter cellConverter = getCellConverter(this.zoneId()); @@ -47,44 +52,61 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex long totalSize = 0L; List outputList = new LinkedList<>(); - try { - //Create connection in not autocommit mode to enable rollback on error - conn = this.connection(runContext); - conn.setAutoCommit(false); - savepoint = initializeSavepoint(conn); + //Create connection in not autocommit mode to enable rollback on error + Connection connection = null; + Savepoint savepoint = null; + try { + connection = this.connection(runContext); + savepoint = initializeSavepoint(connection); + + connection.setAutoCommit(false); String sqlRendered = runContext.render(this.sql, this.additionalVars); String[] queries = sqlRendered.split(";[^']"); - for(String query : queries) { + for (String query : queries) { //Create statement, execute - stmt = createPreparedStatementAndPopulateParameters(runContext, conn, query); - stmt.setFetchSize(this.getFetchSize()); - logger.debug("Starting query: {}", query); - stmt.execute(); - - if(!isTransactional) { - conn.commit(); + try (PreparedStatement stmt = prepareStatement(runContext, connection, query)) { + stmt.setFetchSize(this.getFetchSize()); + logger.debug("Starting query: {}", query); + stmt.execute(); + if (!isTransactional) { + connection.commit(); + } + totalSize = extractResultsFromResultSet(connection, stmt, runContext, cellConverter, totalSize, outputList); } - totalSize = extractResultsFromResultSet(runContext, cellConverter, totalSize, outputList); } - conn.commit(); - + connection.commit(); runContext.metric(Counter.of("fetch.size", totalSize, this.tags())); return MultiQueryOutput.builder().outputs(outputList).build(); } catch (Exception e) { - rollbackIfTransactional(isTransactional); + rollbackIfTransactional(connection, savepoint, isTransactional); throw new RuntimeException(e); } finally { - closeConnectionAndStatement(runContext); + safelyCloseConnection(runContext, connection); } } - private long extractResultsFromResultSet(RunContext runContext, AbstractCellConverter cellConverter, long totalSize, List outputList) throws SQLException, IOException { - try(ResultSet rs = stmt.getResultSet()) { + private static void safelyCloseConnection(final RunContext runContext, final Connection connection) { + try { + if (connection != null) { + connection.close(); + } + } catch (SQLException e) { + runContext.logger().warn("Issue when closing the connection : {}", e.getMessage()); + } + } + + private long extractResultsFromResultSet(final Connection connection, + final PreparedStatement stmt, + final RunContext runContext, + final AbstractCellConverter cellConverter, + long totalSize, + final List outputList) throws SQLException, IOException { + try (ResultSet rs = stmt.getResultSet()) { //When sql is not a select statement skip output creation - if(rs != null) { + if (rs != null) { Output.OutputBuilder output = Output.builder(); //Populate result fro result set long size = 0L; @@ -92,13 +114,13 @@ private long extractResultsFromResultSet(RunContext runContext, AbstractCellConv case FETCH_ONE -> { size = 1L; output - .row(fetchResult(rs, cellConverter, conn)) + .row(fetchResult(rs, cellConverter, connection)) .size(size); } case STORE -> { File tempFile = runContext.workingDir().createTempFile(".ion").toFile(); try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) { - size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn); + size = fetchToFile(stmt, rs, fileWriter, cellConverter, connection); } output .uri(runContext.storage().putFile(tempFile)) @@ -106,13 +128,14 @@ private long extractResultsFromResultSet(RunContext runContext, AbstractCellConv } case FETCH -> { List> maps = new ArrayList<>(); - size = fetchResults(stmt, rs, maps, cellConverter, conn); + size = fetchResults(stmt, rs, maps, cellConverter, connection); output .rows(maps) .size(size); } case NONE -> runContext.logger().info("fetchType is set to NONE, no output will be returned"); - default -> throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE"); + default -> + throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE"); } totalSize += size; outputList.add(output.build()); @@ -121,26 +144,19 @@ private long extractResultsFromResultSet(RunContext runContext, AbstractCellConv return totalSize; } - private void rollbackIfTransactional(boolean isTransactional) throws SQLException { - if(isTransactional && conn != null) { - if(savepoint != null) { - conn.rollback(savepoint); + private static void rollbackIfTransactional(final Connection connection, + final Savepoint savepoint, + final boolean isTransactional) throws SQLException { + if (isTransactional) { + if (savepoint != null) { + connection.rollback(savepoint); return; } - conn.rollback(); + connection.rollback(); } } - private void closeConnectionAndStatement(RunContext runContext) { - try { - if(conn != null && !conn.isClosed()) { conn.close(); } - if(stmt != null && !stmt.isClosed()) { stmt.close(); } - } catch (SQLException e) { - runContext.logger().warn("Issue when closing the connection : {}", e.getMessage()); - } - } - - private Savepoint initializeSavepoint(Connection conn) throws SQLException { + private static Savepoint initializeSavepoint(final Connection conn) { try { return conn.setSavepoint(); } catch (SQLException e) { @@ -168,17 +184,22 @@ public static class MultiQueryOutput implements io.kestra.core.models.tasks.Outp List outputs; } - private PreparedStatement createPreparedStatementAndPopulateParameters(RunContext runContext, Connection conn, String sql) throws SQLException, IllegalVariableEvaluationException { - //Inject named parameters (ex: ':param') - Map namedParamsRendered = this.getParameters() == null ? null : this.getParameters().asMap(runContext, String.class, Object.class); + private PreparedStatement prepareStatement(final RunContext runContext, + final Connection conn, + final String sql) throws SQLException, IllegalVariableEvaluationException { - if(namedParamsRendered == null || namedParamsRendered.isEmpty()) { + // Inject named parameters (ex: ':param') + Optional> namedParamsRendered = Optional + .ofNullable(this.getParameters()) + .map(Rethrow.throwFunction(it -> it.asMap(runContext, String.class, Object.class))); + + if (namedParamsRendered.isEmpty()) { return createPreparedStatement(conn, sql); } //Extract parameters in orders and replace them with '?' String preparedSql = sql; - Pattern pattern = Pattern.compile(":\\w+"); + Pattern pattern = Pattern.compile(":\\w+"); Matcher matcher = pattern.matcher(preparedSql); List params = new LinkedList<>(); @@ -186,19 +207,20 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex while (matcher.find()) { String param = matcher.group(); params.add(param.substring(1)); - preparedSql = matcher.replaceFirst( "?"); + preparedSql = matcher.replaceFirst("?"); matcher = pattern.matcher(preparedSql); } - stmt = createPreparedStatement(conn, preparedSql); - for(int i=0; i