diff --git a/pom.xml b/pom.xml index 0cb7cfe..cdf21ee 100755 --- a/pom.xml +++ b/pom.xml @@ -77,6 +77,14 @@ 5.7.0 test + + + com.h2database + h2 + 1.4.200 + test + + diff --git a/src/main/java/com/github/collinalpert/java2db/database/DBConnection.java b/src/main/java/com/github/collinalpert/java2db/database/DBConnection.java index f27ec37..c6edbf8 100755 --- a/src/main/java/com/github/collinalpert/java2db/database/DBConnection.java +++ b/src/main/java/com/github/collinalpert/java2db/database/DBConnection.java @@ -1,11 +1,8 @@ package com.github.collinalpert.java2db.database; -import com.github.collinalpert.java2db.exceptions.ConnectionFailedException; import com.github.collinalpert.java2db.mappers.FieldMapper; import com.github.collinalpert.java2db.queries.*; import com.github.collinalpert.java2db.queries.async.*; -import com.mysql.cj.exceptions.CJCommunicationsException; -import com.mysql.cj.jdbc.exceptions.CommunicationsException; import java.io.Closeable; import java.sql.*; @@ -16,7 +13,7 @@ import static com.github.collinalpert.java2db.utilities.Utilities.supplierHandling; /** - * Wrapper around {@link Connection} which eases use of connecting to a database and running queries. + * Wrapper around {@link Connection} which eases use of running queries. * Also supports running functions and stored procedures. * * @author Collin Alpert @@ -28,27 +25,9 @@ public class DBConnection implements Closeable { */ public static boolean LOG_QUERIES = true; - private Connection underlyingConnection; + private final Connection underlyingConnection; private boolean isConnectionValid; - public DBConnection(ConnectionConfiguration configuration) { - try { - var connectionString = String.format("jdbc:mysql://%s:%d/%s?rewriteBatchedStatements=true", configuration.getHost(), configuration.getPort(), configuration.getDatabase()); - Class.forName("com.mysql.cj.jdbc.Driver"); - System.setProperty("user", configuration.getUsername()); - System.setProperty("password", configuration.getPassword()); - DriverManager.setLoginTimeout(configuration.getTimeout()); - underlyingConnection = DriverManager.getConnection(connectionString, System.getProperties()); - isConnectionValid = true; - } catch (CJCommunicationsException | CommunicationsException e) { - isConnectionValid = false; - throw new ConnectionFailedException(); - } catch (ClassNotFoundException | SQLException e) { - e.printStackTrace(); - isConnectionValid = false; - } - } - public DBConnection(Connection underlyingConnection) { this.underlyingConnection = underlyingConnection; this.isConnectionValid = true; diff --git a/src/main/java/com/github/collinalpert/java2db/database/MySQLDriverManagerDataSource.java b/src/main/java/com/github/collinalpert/java2db/database/MySQLDriverManagerDataSource.java new file mode 100644 index 0000000..ff3fb11 --- /dev/null +++ b/src/main/java/com/github/collinalpert/java2db/database/MySQLDriverManagerDataSource.java @@ -0,0 +1,80 @@ +package com.github.collinalpert.java2db.database; + +import javax.sql.DataSource; +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.logging.Logger; + +/** + * {@link DataSource} implementation hard-coded to support only MySQL databases. + * Obtains connections directly from {@link java.sql.DriverManager} + * + * @author Tyler Sharpe + */ +public class MySQLDriverManagerDataSource implements DataSource { + + static { + try { + Class.forName("com.mysql.cj.jdbc.Driver"); + } catch (ClassNotFoundException e) { + throw new ExceptionInInitializerError(e); + } + } + + private final ConnectionConfiguration configuration; + + public MySQLDriverManagerDataSource(ConnectionConfiguration configuration) { + this.configuration = configuration; + } + + @Override + public Connection getConnection() throws SQLException { + return DriverManager.getConnection( + String.format("jdbc:mysql://%s:%d/%s?rewriteBatchedStatements=true", configuration.getHost(), configuration.getPort(), configuration.getDatabase()), + configuration.getUsername(), + configuration.getPassword() + ); + } + + @Override + public Connection getConnection(String username, String password) { + throw new UnsupportedOperationException(); + } + + @Override + public PrintWriter getLogWriter() { + return null; + } + + @Override + public void setLogWriter(PrintWriter out) { + throw new UnsupportedOperationException(); + } + + @Override + public void setLoginTimeout(int seconds) { + throw new UnsupportedOperationException(); + } + + @Override + public int getLoginTimeout() { + return configuration.getTimeout(); + } + + @Override + public Logger getParentLogger() { + return null; + } + + @Override + public T unwrap(Class interfaceType) { + return null; + } + + @Override + public boolean isWrapperFor(Class interfaceType) { + return false; + } +} diff --git a/src/main/java/com/github/collinalpert/java2db/database/TransactionManager.java b/src/main/java/com/github/collinalpert/java2db/database/TransactionManager.java new file mode 100644 index 0000000..43d8c3c --- /dev/null +++ b/src/main/java/com/github/collinalpert/java2db/database/TransactionManager.java @@ -0,0 +1,76 @@ +package com.github.collinalpert.java2db.database; + +import com.github.collinalpert.java2db.utilities.ThrowableConsumer; +import com.github.collinalpert.java2db.utilities.ThrowableFunction; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; + +/** + * Allows to execute code within a database transaction. + * + * This class maintains the notion of a 'current' database connection, which is bound to the currently + * executing thread via a {@link ThreadLocal}. The first call in the stack to execute code within a + * transaction opens a new connection and binds it to this thread local variable. Subsequent calls within + * the same thread which wish to participate within the transaction will then re-use this connection. + * + * @author Tyler Sharpe + */ +public class TransactionManager { + + private static final ThreadLocal CURRENT_THREAD_CONNECTION = new ThreadLocal<>(); + + private final DataSource dataSource; + + public TransactionManager(DataSource dataSource) { + this.dataSource = dataSource; + } + + /** + * Run some code inside of a database transaction, creating one if it does not already exist. + */ + public void transact(ThrowableConsumer action) throws SQLException { + transactAndReturn(connection -> { + action.consume(connection); + return null; + }); + } + + /** + * Run some code inside of a database transaction, creating one if it does not already exist, and then return a value. + * @param action Action to run + * @param Type returned from the action lambda + * @return + * @throws SQLException + */ + public T transactAndReturn(ThrowableFunction action) throws SQLException { + if (CURRENT_THREAD_CONNECTION.get() != null) { + return action.run(CURRENT_THREAD_CONNECTION.get()); + } + + try (Connection rawConnection = dataSource.getConnection()) { + rawConnection.setAutoCommit(false); + DBConnection dbConnection = new DBConnection(rawConnection); + CURRENT_THREAD_CONNECTION.set(dbConnection); + + try { + T result = action.run(dbConnection); + rawConnection.commit(); + return result; + } catch (Exception exception) { + // rollback transaction on error + try { + rawConnection.rollback(); + } catch (Exception rollbackException) { + exception.addSuppressed(rollbackException); + } + + throw new SQLException(exception); + } + } finally { + CURRENT_THREAD_CONNECTION.remove(); + } + } + +} diff --git a/src/main/java/com/github/collinalpert/java2db/queries/EntityProjectionQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/EntityProjectionQuery.java index 0fd3660..843f593 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/EntityProjectionQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/EntityProjectionQuery.java @@ -23,24 +23,26 @@ public class EntityProjectionQuery implements Queryable private final Class returnType; private final IQueryBuilder queryBuilder; private final QueryParameters queryParameters; - private final ConnectionConfiguration connectionConfiguration; + private final TransactionManager transactionManager; - public EntityProjectionQuery(Class returnType, IQueryBuilder queryBuilder, QueryParameters queryParameters, ConnectionConfiguration connectionConfiguration) { + public EntityProjectionQuery(Class returnType, IQueryBuilder queryBuilder, QueryParameters queryParameters, TransactionManager transactionManager) { this.returnType = returnType; this.queryBuilder = queryBuilder; this.queryParameters = queryParameters; - this.connectionConfiguration = connectionConfiguration; + this.transactionManager = transactionManager; } @Override public Optional first() { - try (var connection = new DBConnection(this.connectionConfiguration); - var result = connection.execute(getQuery())) { - if (result.next()) { - return Optional.ofNullable(result.getObject(1, this.returnType)); - } - - return Optional.empty(); + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(getQuery()); + if (result.next()) { + return Optional.ofNullable(result.getObject(1, this.returnType)); + } else { + return Optional.empty(); + } + }); } catch (SQLException e) { e.printStackTrace(); return Optional.empty(); @@ -106,13 +108,15 @@ public String getQuery() { * @return A data structure containing a {@code ResultSet}s data. */ private T resultHandling(D dataStructure, BiConsumer valueConsumer, T defaultValue, Function valueMapping) { - try (var connection = new DBConnection(this.connectionConfiguration); - var result = connection.execute(getQuery())) { - while (result.next()) { - valueConsumer.accept(dataStructure, result.getObject(1, this.returnType)); - } + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(getQuery()); + while (result.next()) { + valueConsumer.accept(dataStructure, result.getObject(1, this.returnType)); + } - return valueMapping.apply(dataStructure); + return valueMapping.apply(dataStructure); + }); } catch (SQLException e) { e.printStackTrace(); return defaultValue; diff --git a/src/main/java/com/github/collinalpert/java2db/queries/EntityQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/EntityQuery.java index f15e293..fce7300 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/EntityQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/EntityQuery.java @@ -26,7 +26,7 @@ public class EntityQuery implements Queryable { private static final TableModule tableModule = TableModule.getInstance(); - protected final ConnectionConfiguration connectionConfiguration; + protected final TransactionManager transactionManager; protected final IQueryBuilder queryBuilder; protected final QueryParameters queryParameters; private final Class type; @@ -39,9 +39,9 @@ public class EntityQuery implements Queryable { * @param type The entity to query. */ - public EntityQuery(Class type, ConnectionConfiguration connectionConfiguration) { + public EntityQuery(Class type, TransactionManager transactionManager) { this.type = type; - this.connectionConfiguration = connectionConfiguration; + this.transactionManager = transactionManager; this.queryParameters = new QueryParameters<>(); this.mapper = IoC.resolveMapper(type, new EntityMapper<>(type)); this.queryBuilder = new EntityQueryBuilder<>(type); @@ -284,7 +284,7 @@ public Queryable project(SqlFunction projection) { @SuppressWarnings("unchecked") var returnType = (Class) LambdaExpression.parse(projection).getBody().getResultType(); var queryBuilder = new ProjectionQueryBuilder<>(projection, this.getTableName(), (QueryBuilder) this.queryBuilder); - return new EntityProjectionQuery<>(returnType, queryBuilder, this.queryParameters, this.connectionConfiguration); + return new EntityProjectionQuery<>(returnType, queryBuilder, this.queryParameters, this.transactionManager); } /** @@ -296,8 +296,10 @@ public Queryable project(SqlFunction projection) { @Override public Optional first() { this.limit(1); - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.map(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.map(connection.execute(getQuery())); + }); } catch (SQLException e) { e.printStackTrace(); return Optional.empty(); @@ -311,8 +313,10 @@ public Optional first() { */ @Override public List toList() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.mapToList(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.mapToList(connection.execute(getQuery())); + }); } catch (SQLException e) { e.printStackTrace(); return Collections.emptyList(); @@ -326,8 +330,10 @@ public List toList() { */ @Override public Stream toStream() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.mapToStream(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.mapToStream(connection.execute(getQuery())); + }); } catch (SQLException e) { e.printStackTrace(); return Stream.empty(); @@ -342,8 +348,10 @@ public Stream toStream() { @Override @SuppressWarnings("unchecked") public E[] toArray() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.mapToArray(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.mapToArray(connection.execute(getQuery())); + }); } catch (SQLException e) { e.printStackTrace(); return (E[]) Array.newInstance(this.type, 0); @@ -359,8 +367,10 @@ public E[] toArray() { */ @Override public Map toMap(Function keyMapping, Function valueMapping) { - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.mapToMap(connection.execute(getQuery()), keyMapping, valueMapping); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.mapToMap(connection.execute(getQuery()), keyMapping, valueMapping); + }); } catch (SQLException e) { e.printStackTrace(); return Collections.emptyMap(); @@ -374,8 +384,10 @@ public Map toMap(Function keyMapping, Function valueMap */ @Override public Set toSet() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.mapToSet(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.mapToSet(connection.execute(getQuery())); + }); } catch (SQLException e) { e.printStackTrace(); diff --git a/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityProjectionQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityProjectionQuery.java index 582273c..e17020f 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityProjectionQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityProjectionQuery.java @@ -22,13 +22,13 @@ public class SingleEntityProjectionQuery implements Que private final Class returnType; private final IQueryBuilder queryBuilder; private final QueryParameters queryParameters; - private final ConnectionConfiguration connectionConfiguration; + private final TransactionManager transactionManager; - public SingleEntityProjectionQuery(Class returnType, ProjectionQueryBuilder queryBuilder, QueryParameters queryParameters, ConnectionConfiguration connectionConfiguration) { + public SingleEntityProjectionQuery(Class returnType, ProjectionQueryBuilder queryBuilder, QueryParameters queryParameters, TransactionManager transactionManager) { this.returnType = returnType; this.queryBuilder = queryBuilder; this.queryParameters = queryParameters; - this.connectionConfiguration = connectionConfiguration; + this.transactionManager = transactionManager; } @Override @@ -106,13 +106,15 @@ public String getQuery() { } private D resultHandling(Function valueConsumer, Supplier defaultValueFactory) { - try (var connection = new DBConnection(this.connectionConfiguration); - var result = connection.execute(getQuery())) { - if (result.next()) { - return valueConsumer.apply(result.getObject(1, this.returnType)); - } + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(getQuery()); + if (result.next()) { + return valueConsumer.apply(result.getObject(1, this.returnType)); + } - return defaultValueFactory.get(); + return defaultValueFactory.get(); + }); } catch (SQLException e) { e.printStackTrace(); return defaultValueFactory.get(); diff --git a/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityQuery.java index 929f381..98e7eb3 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/SingleEntityQuery.java @@ -23,16 +23,16 @@ public class SingleEntityQuery implements Queryable { private static final TableModule tableModule = TableModule.getInstance(); protected final QueryParameters queryParameters; protected final IQueryBuilder queryBuilder; - protected final ConnectionConfiguration connectionConfiguration; + protected final TransactionManager transactionManager; private final Class type; private final Mappable mapper; - public SingleEntityQuery(Class type, ConnectionConfiguration connectionConfiguration) { + public SingleEntityQuery(Class type, TransactionManager transactionManager) { this.type = type; this.queryParameters = new QueryParameters<>(); this.mapper = IoC.resolveMapper(type, new EntityMapper<>(type)); this.queryBuilder = new SingleEntityQueryBuilder<>(type); - this.connectionConfiguration = connectionConfiguration; + this.transactionManager = transactionManager; } //region Configuration @@ -70,7 +70,7 @@ public Queryable project(SqlFunction projection) { @SuppressWarnings("unchecked") var returnType = (Class) LambdaExpression.parse(projection).getBody().getResultType(); var queryBuilder = new ProjectionQueryBuilder<>(projection, this.getTableName(), (QueryBuilder) this.queryBuilder); - return new SingleEntityProjectionQuery<>(returnType, queryBuilder, this.queryParameters, this.connectionConfiguration); + return new SingleEntityProjectionQuery<>(returnType, queryBuilder, this.queryParameters, this.transactionManager); } //endregion @@ -83,8 +83,10 @@ public Queryable project(SqlFunction projection) { */ @Override public Optional first() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - return this.mapper.map(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + return this.mapper.map(connection.execute(getQuery())); + }); } catch (SQLException e) { e.printStackTrace(); return Optional.empty(); @@ -98,9 +100,11 @@ public Optional first() { */ @Override public List toList() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - var mappedValue = this.mapper.map(connection.execute(getQuery())); - return mappedValue.map(Collections::singletonList).orElse(Collections.emptyList()); + try { + return transactionManager.transactAndReturn(connection -> { + var mappedValue = this.mapper.map(connection.execute(getQuery())); + return mappedValue.map(Collections::singletonList).orElse(Collections.emptyList()); + }); } catch (SQLException e) { e.printStackTrace(); return Collections.emptyList(); @@ -114,9 +118,11 @@ public List toList() { */ @Override public Stream toStream() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - var mappedValue = this.mapper.map(connection.execute(getQuery())); - return mappedValue.stream(); + try { + return transactionManager.transactAndReturn(connection -> { + var mappedValue = this.mapper.map(connection.execute(getQuery())); + return mappedValue.stream(); + }); } catch (SQLException e) { e.printStackTrace(); return Stream.empty(); @@ -131,15 +137,17 @@ public Stream toStream() { @Override @SuppressWarnings("unchecked") public E[] toArray() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - var mappedValue = this.mapper.map(connection.execute(getQuery())); + try { + return transactionManager.transactAndReturn(connection -> { + var mappedValue = this.mapper.map(connection.execute(getQuery())); - return mappedValue.map(v -> { - var array = (E[]) Array.newInstance(this.type, 1); - array[0] = v; + return mappedValue.map(v -> { + var array = (E[]) Array.newInstance(this.type, 1); + array[0] = v; - return array; - }).orElse((E[]) Array.newInstance(this.type, 0)); + return array; + }).orElse((E[]) Array.newInstance(this.type, 0)); + }); } catch (SQLException e) { e.printStackTrace(); @@ -156,9 +164,11 @@ public E[] toArray() { */ @Override public Map toMap(Function keyMapping, Function valueMapping) { - try (var connection = new DBConnection(this.connectionConfiguration)) { - var mappedValue = this.mapper.map(connection.execute(getQuery())); - return mappedValue.map(v -> Collections.singletonMap(keyMapping.apply(v), valueMapping.apply(v))).orElse(Collections.emptyMap()); + try { + return transactionManager.transactAndReturn(connection -> { + var mappedValue = this.mapper.map(connection.execute(getQuery())); + return mappedValue.map(v -> Collections.singletonMap(keyMapping.apply(v), valueMapping.apply(v))).orElse(Collections.emptyMap()); + }); } catch (SQLException e) { e.printStackTrace(); return Collections.emptyMap(); @@ -172,9 +182,11 @@ public Map toMap(Function keyMapping, Function valueMap */ @Override public Set toSet() { - try (var connection = new DBConnection(this.connectionConfiguration)) { - var mappedValue = this.mapper.map(connection.execute(getQuery())); - return mappedValue.map(Collections::singleton).orElse(Collections.emptySet()); + try { + return transactionManager.transactAndReturn(connection -> { + var mappedValue = this.mapper.map(connection.execute(getQuery())); + return mappedValue.map(Collections::singleton).orElse(Collections.emptySet()); + }); } catch (SQLException e) { e.printStackTrace(); return Collections.emptySet(); diff --git a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityProjectionQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityProjectionQuery.java index cd3c87a..2a51fcd 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityProjectionQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityProjectionQuery.java @@ -1,6 +1,6 @@ package com.github.collinalpert.java2db.queries.async; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseEntity; import com.github.collinalpert.java2db.queries.*; import com.github.collinalpert.java2db.queries.builder.IQueryBuilder; @@ -10,7 +10,7 @@ */ public class AsyncEntityProjectionQuery extends EntityProjectionQuery implements AsyncQueryable { - public AsyncEntityProjectionQuery(Class returnType, IQueryBuilder queryBuilder, QueryParameters queryParameters, ConnectionConfiguration connectionConfiguration) { - super(returnType, queryBuilder, queryParameters, connectionConfiguration); + public AsyncEntityProjectionQuery(Class returnType, IQueryBuilder queryBuilder, QueryParameters queryParameters, TransactionManager transactionManager) { + super(returnType, queryBuilder, queryParameters, transactionManager); } } diff --git a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityQuery.java index 0b2bc91..8989f93 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncEntityQuery.java @@ -1,7 +1,7 @@ package com.github.collinalpert.java2db.queries.async; import com.github.collinalpert.expressions.expression.LambdaExpression; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseEntity; import com.github.collinalpert.java2db.queries.EntityQuery; import com.github.collinalpert.java2db.queries.builder.*; @@ -21,8 +21,8 @@ public class AsyncEntityQuery extends EntityQuery imple * * @param type The entity to query. */ - public AsyncEntityQuery(Class type, ConnectionConfiguration connectionConfiguration) { - super(type, connectionConfiguration); + public AsyncEntityQuery(Class type, TransactionManager transactionManager) { + super(type, transactionManager); } /** @@ -161,6 +161,6 @@ public AsyncQueryable project(SqlFunction projection) { @SuppressWarnings("unchecked") var returnType = (Class) LambdaExpression.parse(projection).getBody().getResultType(); var queryBuilder = new ProjectionQueryBuilder<>(projection, super.getTableName(), (QueryBuilder) super.queryBuilder); - return new AsyncEntityProjectionQuery<>(returnType, queryBuilder, super.queryParameters, super.connectionConfiguration); + return new AsyncEntityProjectionQuery<>(returnType, queryBuilder, super.queryParameters, super.transactionManager); } } diff --git a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityProjectionQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityProjectionQuery.java index 35058ee..fba0ed5 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityProjectionQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityProjectionQuery.java @@ -1,6 +1,6 @@ package com.github.collinalpert.java2db.queries.async; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseEntity; import com.github.collinalpert.java2db.queries.*; import com.github.collinalpert.java2db.queries.builder.ProjectionQueryBuilder; @@ -10,7 +10,7 @@ */ public class AsyncSingleEntityProjectionQuery extends SingleEntityProjectionQuery implements AsyncQueryable { - public AsyncSingleEntityProjectionQuery(Class returnType, ProjectionQueryBuilder queryBuilder, QueryParameters queryParameters, ConnectionConfiguration connectionConfiguration) { - super(returnType, queryBuilder, queryParameters, connectionConfiguration); + public AsyncSingleEntityProjectionQuery(Class returnType, ProjectionQueryBuilder queryBuilder, QueryParameters queryParameters, TransactionManager transactionManager) { + super(returnType, queryBuilder, queryParameters, transactionManager); } } diff --git a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityQuery.java b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityQuery.java index fd91e7c..595c0d5 100644 --- a/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityQuery.java +++ b/src/main/java/com/github/collinalpert/java2db/queries/async/AsyncSingleEntityQuery.java @@ -1,7 +1,7 @@ package com.github.collinalpert.java2db.queries.async; import com.github.collinalpert.expressions.expression.LambdaExpression; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseEntity; import com.github.collinalpert.java2db.queries.*; import com.github.collinalpert.java2db.queries.builder.*; @@ -12,8 +12,8 @@ */ public class AsyncSingleEntityQuery extends SingleEntityQuery implements AsyncQueryable { - public AsyncSingleEntityQuery(Class type, ConnectionConfiguration connectionConfiguration) { - super(type, connectionConfiguration); + public AsyncSingleEntityQuery(Class type, TransactionManager transactionManager) { + super(type, transactionManager); } /** @@ -52,6 +52,6 @@ public AsyncQueryable project(SqlFunction projection) { @SuppressWarnings("unchecked") var returnType = (Class) LambdaExpression.parse(projection).getBody().getResultType(); var queryBuilder = new ProjectionQueryBuilder<>(projection, super.getTableName(), (QueryBuilder) super.queryBuilder); - return new AsyncSingleEntityProjectionQuery<>(returnType, queryBuilder, super.queryParameters, super.connectionConfiguration); + return new AsyncSingleEntityProjectionQuery<>(returnType, queryBuilder, super.queryParameters, super.transactionManager); } } diff --git a/src/main/java/com/github/collinalpert/java2db/services/AsyncBaseService.java b/src/main/java/com/github/collinalpert/java2db/services/AsyncBaseService.java index 2685971..bf2c199 100644 --- a/src/main/java/com/github/collinalpert/java2db/services/AsyncBaseService.java +++ b/src/main/java/com/github/collinalpert/java2db/services/AsyncBaseService.java @@ -1,6 +1,6 @@ package com.github.collinalpert.java2db.services; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseEntity; import com.github.collinalpert.java2db.queries.async.*; import com.github.collinalpert.java2db.queries.ordering.OrderTypes; @@ -24,8 +24,8 @@ */ public class AsyncBaseService extends BaseService { - protected AsyncBaseService(ConnectionConfiguration connectionConfiguration) { - super(connectionConfiguration); + protected AsyncBaseService(TransactionManager transactionManager) { + super(transactionManager); } //region Create @@ -253,11 +253,11 @@ public CompletableFuture hasDuplicatesAsync(SqlFunction column, Cons //region Query protected AsyncEntityQuery createAsyncQuery() { - return new AsyncEntityQuery<>(super.type, super.connectionConfiguration); + return new AsyncEntityQuery<>(super.type, super.transactionManager); } protected AsyncSingleEntityQuery createAsyncSingleQuery() { - return new AsyncSingleEntityQuery<>(super.type, super.connectionConfiguration); + return new AsyncSingleEntityQuery<>(super.type, super.transactionManager); } /** diff --git a/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionDeletableService.java b/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionDeletableService.java index 17a3c96..429b036 100644 --- a/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionDeletableService.java +++ b/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionDeletableService.java @@ -1,6 +1,6 @@ package com.github.collinalpert.java2db.services; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseCodeAndDescriptionDeletableEntity; import com.github.collinalpert.java2db.queries.EntityQuery; @@ -15,8 +15,8 @@ */ public class BaseCodeAndDescriptionDeletableService extends BaseDeletableService { - protected BaseCodeAndDescriptionDeletableService(ConnectionConfiguration connectionConfiguration) { - super(connectionConfiguration); + protected BaseCodeAndDescriptionDeletableService(TransactionManager transactionManager) { + super(transactionManager); } /** diff --git a/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionService.java b/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionService.java index 5d0afec..ea5b0b5 100755 --- a/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionService.java +++ b/src/main/java/com/github/collinalpert/java2db/services/BaseCodeAndDescriptionService.java @@ -1,6 +1,6 @@ package com.github.collinalpert.java2db.services; -import com.github.collinalpert.java2db.database.ConnectionConfiguration; +import com.github.collinalpert.java2db.database.TransactionManager; import com.github.collinalpert.java2db.entities.BaseCodeAndDescriptionEntity; import com.github.collinalpert.java2db.queries.EntityQuery; @@ -13,8 +13,8 @@ */ public class BaseCodeAndDescriptionService extends BaseService { - protected BaseCodeAndDescriptionService(ConnectionConfiguration connectionConfiguration) { - super(connectionConfiguration); + protected BaseCodeAndDescriptionService(TransactionManager transactionManager) { + super(transactionManager); } /** diff --git a/src/main/java/com/github/collinalpert/java2db/services/BaseDeletableService.java b/src/main/java/com/github/collinalpert/java2db/services/BaseDeletableService.java index 2e4b380..c754ae1 100644 --- a/src/main/java/com/github/collinalpert/java2db/services/BaseDeletableService.java +++ b/src/main/java/com/github/collinalpert/java2db/services/BaseDeletableService.java @@ -19,8 +19,8 @@ public class BaseDeletableService extends BaseSer private final SqlFunction isDeletedFunc = BaseDeletableEntity::isDeleted; - protected BaseDeletableService(ConnectionConfiguration connectionConfiguration) { - super(connectionConfiguration); + protected BaseDeletableService(TransactionManager transactionManager) { + super(transactionManager); } /** @@ -59,9 +59,9 @@ public void delete(List entities) throws SQLException { } var joinedIds = joiner.toString(); - try (var connection = new DBConnection(super.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(String.format("update `%s` set %s = 1 where `%s`.`id` in %s", this.tableName, Lambda2Sql.toSql(this.isDeletedFunc, this.tableName), this.tableName, joinedIds)); - } + }); } /** @@ -101,8 +101,8 @@ public void delete(int... ids) throws SQLException { @Override public void delete(SqlPredicate predicate) throws SQLException { var query = String.format("update %s set %s = 1 where %s", super.tableName, Lambda2Sql.toSql(this.isDeletedFunc, super.tableName), Lambda2Sql.toSql(predicate)); - try (var connection = new DBConnection(super.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(query); - } + }); } } diff --git a/src/main/java/com/github/collinalpert/java2db/services/BaseService.java b/src/main/java/com/github/collinalpert/java2db/services/BaseService.java index 495d101..d864ed3 100755 --- a/src/main/java/com/github/collinalpert/java2db/services/BaseService.java +++ b/src/main/java/com/github/collinalpert/java2db/services/BaseService.java @@ -56,15 +56,15 @@ public class BaseService { private final String idAccess; /** - * The properties this service needs to access the database. + * The transaction manager for the data source this service operates on. */ - protected final ConnectionConfiguration connectionConfiguration; + protected final TransactionManager transactionManager; /** * Constructor for the base class of all services. It is not possible to create instances of it. */ - protected BaseService(ConnectionConfiguration connectionConfiguration) { - this.connectionConfiguration = connectionConfiguration; + protected BaseService(TransactionManager transactionManager) { + this.transactionManager = transactionManager; this.type = getGenericType(); this.tableName = tableModule.getTableName(this.type); @@ -98,11 +98,11 @@ public int create(E instance) throws SQLException { //For using the default database setting for the id. joiner.add("default"); insertQuery.append(joiner.toString()); - try (var connection = new DBConnection(this.connectionConfiguration)) { + return transactionManager.transactAndReturn(connection -> { var id = connection.update(insertQuery.toString()); instance.setId(id); return id; - } + }); } /** @@ -155,9 +155,9 @@ public void create(List instances) throws SQLException { } insertQuery.append(String.join(", ", rows)); - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(insertQuery.toString()); - } + }); } //endregion @@ -183,14 +183,15 @@ public int count() { * @return The number of rows matching the condition. */ public int count(SqlPredicate predicate) { - try (var connection = new DBConnection(this.connectionConfiguration)) { - try (var result = connection.execute(String.format("select count(%s) from `%s` where %s;", this.idAccess, this.tableName, Lambda2Sql.toSql(predicate, this.tableName)))) { + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(String.format("select count(%s) from `%s` where %s;", this.idAccess, this.tableName, Lambda2Sql.toSql(predicate, this.tableName))); if (result.next()) { return result.getInt(String.format("count(%s)", this.idAccess)); } return 0; - } + }); } catch (SQLException e) { throw new IllegalArgumentException(String.format("Could not get amount of rows in table %s for this predicate.", this.tableName), e); } @@ -216,14 +217,15 @@ public boolean any() { * @return {@code True} if the predicate matches one or more records, {@code false} if not. */ public boolean any(SqlPredicate predicate) { - try (var connection = new DBConnection(this.connectionConfiguration)) { - try (var result = connection.execute(String.format("select exists(select %s from `%s` where %s limit 1) as result;", this.idAccess, this.tableName, Lambda2Sql.toSql(predicate, this.tableName)))) { + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(String.format("select exists(select %s from `%s` where %s limit 1) as result;", this.idAccess, this.tableName, Lambda2Sql.toSql(predicate, this.tableName))); if (result.next()) { return result.getInt("result") == 1; } return false; - } + }); } catch (SQLException e) { throw new IllegalArgumentException(String.format("Could not check if a row matches this condition on table %s.", this.tableName), e); } @@ -253,8 +255,9 @@ public T max(SqlFunction column) { * @return The maximum value of the column. */ public T max(SqlFunction column, SqlPredicate predicate) { - try (var connection = new DBConnection(this.connectionConfiguration)) { - try (var result = connection.execute(String.format("select max(%s) from `%s` where %s;", Lambda2Sql.toSql(column, this.tableName), this.tableName, Lambda2Sql.toSql(predicate, this.tableName)))) { + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(String.format("select max(%s) from `%s` where %s;", Lambda2Sql.toSql(column, this.tableName), this.tableName, Lambda2Sql.toSql(predicate, this.tableName))); if (result.next()) { // This is needed to find the generic type at runtime. @@ -262,7 +265,7 @@ public T max(SqlFunction column, SqlPredicate predicate) { } return null; - } + }); } catch (SQLException e) { throw new IllegalArgumentException(String.format("Could not get maximum value of column %s in table %s.", Lambda2Sql.toSql(column), this.tableName), e); } @@ -294,14 +297,15 @@ public T min(SqlFunction column) { * @return The minimum value of the column. */ public T min(SqlFunction column, SqlPredicate predicate) { - try (var connection = new DBConnection(this.connectionConfiguration)) { - try (var result = connection.execute(String.format("select min(%s) from `%s` where %s;", Lambda2Sql.toSql(column, this.tableName), this.tableName, Lambda2Sql.toSql(predicate, this.tableName)))) { + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(String.format("select min(%s) from `%s` where %s;", Lambda2Sql.toSql(column, this.tableName), this.tableName, Lambda2Sql.toSql(predicate, this.tableName))); if (result.next()) { return (T) result.getObject(1); } return null; - } + }); } catch (SQLException e) { throw new IllegalArgumentException(String.format("Could not get minimum value of column %s in table %s.", Lambda2Sql.toSql(column), this.tableName), e); } @@ -319,10 +323,11 @@ public T min(SqlFunction column, SqlPredicate predicate) { */ public boolean hasDuplicates(SqlFunction column) { var sqlColumn = Lambda2Sql.toSql(column, this.tableName); - try (var connection = new DBConnection(this.connectionConfiguration)) { - try (var result = connection.execute(String.format("select %s from `%s` group by %s having count(%s) > 1", sqlColumn, this.tableName, sqlColumn, sqlColumn))) { + try { + return transactionManager.transactAndReturn(connection -> { + var result = connection.execute(String.format("select %s from `%s` group by %s having count(%s) > 1", sqlColumn, this.tableName, sqlColumn, sqlColumn)); return result.next(); - } + }); } catch (SQLException e) { throw new IllegalArgumentException(String.format("Could not check if duplicate values exist in column %s on table %s.", sqlColumn, this.tableName), e); } @@ -338,7 +343,7 @@ public boolean hasDuplicates(SqlFunction column) { * {@link #getSingle(SqlPredicate)}, {@link #getMultiple(SqlPredicate)} or {@link #getAll()} methods. */ protected EntityQuery createQuery() { - return new EntityQuery<>(this.type, this.connectionConfiguration); + return new EntityQuery<>(this.type, this.transactionManager); } /** @@ -347,7 +352,7 @@ protected EntityQuery createQuery() { * {@link #getSingle(SqlPredicate)}, {@link #getMultiple(SqlPredicate)} or {@link #getAll()} methods. */ protected SingleEntityQuery createSingleQuery() { - return new SingleEntityQuery<>(this.type, this.connectionConfiguration); + return new SingleEntityQuery<>(this.type, this.transactionManager); } /** @@ -517,9 +522,9 @@ public CacheablePaginationResult createPagination(SqlPredicate predicate, * i.e. non-existing default value for field or an incorrect data type. */ public void update(E instance) throws SQLException { - try (var connection = new DBConnection(this.connectionConfiguration)) { + this.transactionManager.transact(connection -> { connection.update(updateQuery(instance)); - } + }); } /** @@ -544,11 +549,11 @@ public void update(E... instances) throws SQLException { * i.e. non-existing default value for field or an incorrect data type. */ public void update(List instances) throws SQLException { - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { for (E instance : instances) { connection.update(updateQuery(instance)); } - } + }); } private String updateQuery(E instance) { @@ -577,9 +582,9 @@ public void update(int entityId, SqlFunction column, SqlFunction SqlPredicate predicate = x -> x.getId() == entityId; var query = String.format("update `%s` set %s = %s where %s", this.tableName, Lambda2Sql.toSql(column, this.tableName), Lambda2Sql.toSql(newValueFunction, this.tableName), Lambda2Sql.toSql(predicate, this.tableName)); - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(query); - } + }); } /** @@ -611,9 +616,9 @@ public void update(int entityId, SqlFunction column, R newValue) throw public void update(SqlPredicate condition, SqlFunction column, R newValue) throws SQLException { var query = String.format("update `%s` set %s = %s where %s;", this.tableName, Lambda2Sql.toSql(column, this.tableName), convertToSql(newValue), Lambda2Sql.toSql(condition, this.tableName)); - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(query); - } + }); } //endregion @@ -654,9 +659,9 @@ public void delete(List entities) throws SQLException { } var joinedIds = joiner.toString(); - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(String.format("delete from `%s` where %s in %s", this.tableName, this.idAccess, joinedIds)); - } + }); } /** @@ -693,9 +698,9 @@ public void delete(int... ids) throws SQLException { * @throws SQLException in case the condition cannot be applied or if a foreign key constraint fails. */ public void delete(SqlPredicate predicate) throws SQLException { - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(String.format("delete from `%s` where %s;", this.tableName, Lambda2Sql.toSql(predicate, this.tableName))); - } + }); } /** @@ -706,9 +711,9 @@ public void delete(SqlPredicate predicate) throws SQLException { * command and cannot check if rows are being referenced or not. */ public void truncateTable() throws SQLException { - try (var connection = new DBConnection(this.connectionConfiguration)) { + transactionManager.transact(connection -> { connection.update(String.format("truncate table `%s`;", this.tableName)); - } + }); } //endregion diff --git a/src/test/java/com/github/collinalpert/java2db/sandbox/Order.java b/src/test/java/com/github/collinalpert/java2db/sandbox/Order.java new file mode 100644 index 0000000..91eb444 --- /dev/null +++ b/src/test/java/com/github/collinalpert/java2db/sandbox/Order.java @@ -0,0 +1,29 @@ +package com.github.collinalpert.java2db.sandbox; + +import com.github.collinalpert.java2db.entities.BaseEntity; + +public class Order extends BaseEntity { + String product; + int amount; + + public Order(String product, int amount) { + this.product = product; + this.amount = amount; + } + + public String getProduct() { + return product; + } + + public void setProduct(String product) { + this.product = product; + } + + public int getAmount() { + return amount; + } + + public void setAmount(int amount) { + this.amount = amount; + } +} diff --git a/src/test/java/com/github/collinalpert/java2db/sandbox/OrderService.java b/src/test/java/com/github/collinalpert/java2db/sandbox/OrderService.java new file mode 100644 index 0000000..c219725 --- /dev/null +++ b/src/test/java/com/github/collinalpert/java2db/sandbox/OrderService.java @@ -0,0 +1,10 @@ +package com.github.collinalpert.java2db.sandbox; + +import com.github.collinalpert.java2db.database.TransactionManager; +import com.github.collinalpert.java2db.services.BaseService; + +public class OrderService extends BaseService { + protected OrderService(TransactionManager transactionManager) { + super(transactionManager); + } +} diff --git a/src/test/java/com/github/collinalpert/java2db/sandbox/TestMain.java b/src/test/java/com/github/collinalpert/java2db/sandbox/TestMain.java new file mode 100644 index 0000000..6bae796 --- /dev/null +++ b/src/test/java/com/github/collinalpert/java2db/sandbox/TestMain.java @@ -0,0 +1,113 @@ +package com.github.collinalpert.java2db.sandbox; + +import com.github.collinalpert.java2db.database.TransactionManager; + +import javax.sql.DataSource; +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.logging.Logger; + +public class TestMain { + + public static void main(String[] args) throws Exception { + Class.forName("org.h2.Driver"); + + DataSource h2DataSource = createH2DataSource(); + createSchema(h2DataSource); + + TransactionManager transactionManager = new TransactionManager(h2DataSource); + + UserService userService = new UserService(transactionManager); + OrderService orderService = new OrderService(transactionManager); + + try { + transactionManager.transact(connection -> { + orderService.create(new Order("Lamp", 1)); + orderService.create(new Order("Desk", 1)); + userService.create(new User("John")); + + throw new RuntimeException("this should rollback the transaction"); + }); + } catch (Exception e) { + e.printStackTrace(); // expected + } + + transactionManager.transact(connection -> { + var allUsers = userService.getAll(); + assert allUsers.isEmpty(); + }); + } + + private static void createSchema(DataSource dataSource) throws SQLException { + try (Connection conn = dataSource.getConnection()) { + var statement = conn.createStatement(); + + statement.execute( + "CREATE TABLE `order` (" + + " id INT AUTO_INCREMENT PRIMARY KEY," + + " product VARCHAR(100)," + + " amount SMALLINT" + + ")" + ); + + statement.execute( + "CREATE TABLE user (" + + " id INT AUTO_INCREMENT PRIMARY KEY," + + " name VARCHAR(50)" + + ")" + ); + } + } + + private static DataSource createH2DataSource() { + return new DataSource() { + @Override + public Connection getConnection() throws SQLException { + return DriverManager.getConnection("jdbc:h2:mem:testDb;DB_CLOSE_DELAY=-1"); + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + return null; + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return null; + } + + @Override + public void setLogWriter(PrintWriter out) { + + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + + } + + @Override + public int getLoginTimeout() throws SQLException { + return 0; + } + + @Override + public T unwrap(Class iface) throws SQLException { + return null; + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return false; + } + + @Override + public Logger getParentLogger() throws SQLFeatureNotSupportedException { + return null; + } + }; + } +} diff --git a/src/test/java/com/github/collinalpert/java2db/sandbox/User.java b/src/test/java/com/github/collinalpert/java2db/sandbox/User.java new file mode 100644 index 0000000..eb69bb0 --- /dev/null +++ b/src/test/java/com/github/collinalpert/java2db/sandbox/User.java @@ -0,0 +1,19 @@ +package com.github.collinalpert.java2db.sandbox; + +import com.github.collinalpert.java2db.entities.BaseEntity; + +public class User extends BaseEntity { + private String name; + + public User(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } +} diff --git a/src/test/java/com/github/collinalpert/java2db/sandbox/UserService.java b/src/test/java/com/github/collinalpert/java2db/sandbox/UserService.java new file mode 100644 index 0000000..668bbba --- /dev/null +++ b/src/test/java/com/github/collinalpert/java2db/sandbox/UserService.java @@ -0,0 +1,11 @@ +package com.github.collinalpert.java2db.sandbox; + +import com.github.collinalpert.java2db.database.TransactionManager; +import com.github.collinalpert.java2db.services.BaseService; + +public class UserService extends BaseService { + + protected UserService(TransactionManager transactionManager) { + super(transactionManager); + } +}