Skip to content

Commit

Permalink
feat: support stream in spring data repository
Browse files Browse the repository at this point in the history
  • Loading branch information
shouwn committed Dec 9, 2024
1 parent 3a58ff8 commit 4c4ae8b
Show file tree
Hide file tree
Showing 8 changed files with 809 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import org.springframework.data.domain.PageRequest
import org.springframework.data.domain.Sort
import org.springframework.transaction.annotation.Transactional
import java.time.OffsetDateTime
import java.util.stream.Collectors

@Transactional
@SpringBootTest
Expand Down Expand Up @@ -89,6 +90,30 @@ class SelectExample : WithAssertions {
assertThat(actual).isEqualTo(listOf(Isbn("04"), Isbn("05"), Isbn("06")))
}

@Test
fun `the stream of books`() {
// given
val pageable = PageRequest.of(1, 3, Sort.by(Sort.Direction.ASC, "isbn"))

// when
val actual = bookRepository.findStream(pageable) {
select(
path(Book::isbn),
).from(
entity(Book::class),
)
}

// then
assertThat(actual.collect(Collectors.toList())).isEqualTo(
listOf(
Isbn("04"),
Isbn("05"),
Isbn("06"),
),
)
}

@Test
fun `the page of books`() {
// given
Expand Down Expand Up @@ -302,6 +327,39 @@ class SelectExample : WithAssertions {
)
}

@Test
fun the_number_of_employees_per_department_stream() {
// given
data class Row(
val departmentId: Long,
val count: Long,
)

// when
val actual = employeeRepository.findStream {
selectNew<Row>(
path(EmployeeDepartment::departmentId),
count(Employee::employeeId),
).from(
entity(Employee::class),
join(Employee::departments),
).groupBy(
path(EmployeeDepartment::departmentId),
).orderBy(
path(EmployeeDepartment::departmentId).asc(),
)
}

// then
assertThat(actual.collect(Collectors.toList())).isEqualTo(
listOf(
Row(1, 6),
Row(2, 15),
Row(3, 18),
),
)
}

@Test
fun `the number of employees who belong to more than one department`() {
// when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ class SelectExample : WithAssertions {
assertThat(actual.hasNext()).isTrue
}

@Test
fun `the stream of books`() {
// given
val pageable = PageRequest.of(1, 3, Sort.by(Sort.Direction.ASC, "isbn"))

// when
val actual = bookRepository.findStream(pageable) {
select(
path(Book::isbn),
).from(
entity(Book::class),
)
}

// then
assertThat(actual.toList()).isEqualTo(listOf(Isbn("04"), Isbn("05"), Isbn("06")))
}

@Test
fun `the book with the most authors`() {
// when
Expand Down Expand Up @@ -321,6 +339,39 @@ class SelectExample : WithAssertions {
)
}

@Test
fun the_number_of_employees_per_department_stream() {
// given
data class Row(
val departmentId: Long,
val count: Long,
)

// when
val actual = employeeRepository.findStream {
selectNew<Row>(
path(EmployeeDepartment::departmentId),
count(Employee::employeeId),
).from(
entity(Employee::class),
join(Employee::departments),
).groupBy(
path(EmployeeDepartment::departmentId),
).orderBy(
path(EmployeeDepartment::departmentId).asc(),
)
}

// then
assertThat(actual.toList()).isEqualTo(
listOf(
Row(1, 6),
Row(2, 15),
Row(3, 18),
),
)
}

@Test
fun `the number of employees who belong to more than one department`() {
// given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Slice
import org.springframework.data.repository.NoRepositoryBean
import java.util.stream.Stream

@NoRepositoryBean
@SinceJdsl("3.0.0")
Expand Down Expand Up @@ -134,6 +135,67 @@ interface KotlinJdslJpqlExecutor {
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Slice<T?>

/**
* Returns all results of the select query.
*/
@SinceJdsl("3.5.4")
fun <T : Any> findStream(
offset: Int? = null,
limit: Int? = null,
init: Jpql.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?>

/**
* Returns all results of the select query.
*/
@SinceJdsl("3.5.4")
fun <T : Any, DSL : JpqlDsl> findStream(
dsl: JpqlDsl.Constructor<DSL>,
offset: Int? = null,
limit: Int? = null,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?>

/**
* Returns all results of the select query.
*/
@SinceJdsl("3.5.4")
fun <T : Any, DSL : JpqlDsl> findStream(
dsl: DSL,
offset: Int? = null,
limit: Int? = null,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?>

/**
* Returns all results of the select query.
*/
@SinceJdsl("3.5.4")
fun <T : Any> findStream(
pageable: Pageable,
init: Jpql.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?>

/**
* Returns all results of the select query.
*/
@SinceJdsl("3.5.4")
fun <T : Any, DSL : JpqlDsl> findStream(
dsl: JpqlDsl.Constructor<DSL>,
pageable: Pageable,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?>

/**
* Returns all results of the select query.
*/
@SinceJdsl("3.5.4")
fun <T : Any, DSL : JpqlDsl> findStream(
dsl: DSL,
pageable: Pageable,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?>

/**
* Execute the update query.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.springframework.data.jpa.repository.support.QueryHints
import org.springframework.data.repository.NoRepositoryBean
import org.springframework.data.support.PageableExecutionUtilsAdaptor
import org.springframework.transaction.annotation.Transactional
import java.util.stream.Stream
import javax.persistence.EntityManager
import javax.persistence.LockModeType
import javax.persistence.Query
Expand Down Expand Up @@ -60,10 +61,7 @@ open class KotlinJdslJpqlExecutorImpl(
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): List<T?> {
val query: SelectQuery<T> = jpql(dsl, init)
val jpaQuery = createJpaQuery(query, query.returnType).apply {
offset?.let { setFirstResult(it) }
limit?.let { setMaxResults(it) }
}
val jpaQuery = createJpaQuery(query, query.returnType, offset, limit)

return jpaQuery.resultList
}
Expand All @@ -90,7 +88,7 @@ open class KotlinJdslJpqlExecutorImpl(
): List<T?> {
val query: SelectQuery<T> = jpql(dsl, init)

return createList(query, query.returnType, pageable)
return createSortedQuery(query, query.returnType, pageable).resultList
}

override fun <T : Any> findPage(
Expand Down Expand Up @@ -143,6 +141,60 @@ open class KotlinJdslJpqlExecutorImpl(
return createSlice(query, query.returnType, pageable)
}

override fun <T : Any> findStream(
offset: Int?,
limit: Int?,
init: Jpql.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?> {
return findStream(Jpql, offset = offset, limit = limit, init)
}

override fun <T : Any, DSL : JpqlDsl> findStream(
dsl: JpqlDsl.Constructor<DSL>,
offset: Int?,
limit: Int?,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?> {
return findStream(dsl.newInstance(), offset = offset, limit = limit, init)
}

override fun <T : Any, DSL : JpqlDsl> findStream(
dsl: DSL,
offset: Int?,
limit: Int?,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?> {
val query: SelectQuery<T> = jpql(dsl, init)
val jpaQuery = createJpaQuery(query, query.returnType, offset, limit)

return jpaQuery.resultStream
}

override fun <T : Any> findStream(
pageable: Pageable,
init: Jpql.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?> {
return findStream(Jpql, pageable, init)
}

override fun <T : Any, DSL : JpqlDsl> findStream(
dsl: JpqlDsl.Constructor<DSL>,
pageable: Pageable,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?> {
return findStream(dsl.newInstance(), pageable, init)
}

override fun <T : Any, DSL : JpqlDsl> findStream(
dsl: DSL,
pageable: Pageable,
init: DSL.() -> JpqlQueryable<SelectQuery<T>>,
): Stream<T?> {
val query: SelectQuery<T> = jpql(dsl, init)

return this.createSortedQuery(query, query.returnType, pageable).resultStream
}

@Transactional
override fun <T : Any> update(
init: Jpql.() -> JpqlQueryable<UpdateQuery<T>>,
Expand Down Expand Up @@ -198,9 +250,13 @@ open class KotlinJdslJpqlExecutorImpl(
private fun <T : Any> createJpaQuery(
query: JpqlQuery<*>,
returnType: KClass<T>,
offset: Int?,
limit: Int?,
): TypedQuery<T> {
return JpqlEntityManagerUtils.createQuery(entityManager, query, returnType, renderContext).apply {
setMetadata(this, metadata)
offset?.let { setFirstResult(it) }
limit?.let { setMaxResults(it) }
}
}

Expand Down Expand Up @@ -258,11 +314,11 @@ open class KotlinJdslJpqlExecutorImpl(
}
}

private fun <T : Any> createList(
private fun <T : Any> createSortedQuery(
query: JpqlQuery<*>,
returnType: KClass<T>,
pageable: Pageable,
): List<T?> {
): TypedQuery<T> {
val enhancedQuery = createJpaEnhancedQuery(query, returnType, pageable.sort)

val sortedQuery = enhancedQuery.sortedQuery
Expand All @@ -272,7 +328,7 @@ open class KotlinJdslJpqlExecutorImpl(
sortedQuery.maxResults = pageable.pageSize
}

return sortedQuery.resultList
return sortedQuery
}

private fun <T : Any> createPage(
Expand Down
Loading

0 comments on commit 4c4ae8b

Please sign in to comment.