Skip to content

Commit

Permalink
feat: Add new filter for checking externalComputationId, and checking…
Browse files Browse the repository at this point in the history
… externalComputationId if view is COMPUTATION or COMPUTATION_STATS (#1753)

feat: Add new filter for checking externalComputationId, and checking
externalComputationId if view is COMPUTATION or COMPUTATION_STATS
  • Loading branch information
tristanvuong2021 authored Nov 7, 2024
1 parent edd90c9 commit a1d5e05
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class StreamMeasurements(
return MeasurementReader(view).apply {
this.orderByClause = orderByClause
fillStatementBuilder {
appendWhereClause(requestFilter)
appendWhereClause(view, requestFilter)
appendClause(orderByClause)
if (limit > 0) {
appendClause("LIMIT @$LIMIT_PARAM")
Expand All @@ -84,9 +84,20 @@ class StreamMeasurements(
}
}

private fun Statement.Builder.appendWhereClause(filter: StreamMeasurementsRequest.Filter) {
private fun Statement.Builder.appendWhereClause(
view: Measurement.View,
filter: StreamMeasurementsRequest.Filter,
) {
val conjuncts = mutableListOf<String>()

if (
filter.hasExternalComputationId ||
view == Measurement.View.COMPUTATION ||
view == Measurement.View.COMPUTATION_STATS
) {
conjuncts.add("ExternalComputationId IS NOT NULL")
}

if (filter.externalMeasurementConsumerId != 0L) {
conjuncts.add("ExternalMeasurementConsumerId = @$EXTERNAL_MEASUREMENT_CONSUMER_ID_PARAM")
bind(EXTERNAL_MEASUREMENT_CONSUMER_ID_PARAM to filter.externalMeasurementConsumerId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,6 @@ private fun MeasurementKt.Dsl.fillComputationView(struct: Struct) {
val requisitionsStructs = struct.getStructList("Requisitions")
val dataProvidersCount = requisitionsStructs.size

if (struct.isNull("ExternalComputationId")) {
for (requisitionStruct in requisitionsStructs) {
requisitions +=
RequisitionReader.buildRequisition(struct, requisitionStruct, mapOf(), dataProvidersCount)
}
return
}

val externalMeasurementId = ExternalId(struct.getLong("ExternalMeasurementId"))
val externalMeasurementConsumerId = ExternalId(struct.getLong("ExternalMeasurementConsumerId"))
val externalComputationId = ExternalId(struct.getLong("ExternalComputationId"))
Expand Down Expand Up @@ -600,10 +592,6 @@ private fun MeasurementKt.Dsl.fillComputationView(struct: Struct) {
private fun MeasurementKt.Dsl.fillComputationStatsView(struct: Struct) {
fillMeasurementCommon(struct)

if (struct.isNull("ExternalComputationId")) {
return
}

val externalMeasurementId = ExternalId(struct.getLong("ExternalMeasurementId"))
val externalMeasurementConsumerId = ExternalId(struct.getLong("ExternalMeasurementConsumerId"))
val externalComputationId = ExternalId(struct.getLong("ExternalComputationId"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,142 @@ abstract class MeasurementsServiceTest<T : MeasurementsCoroutineImplBase> {
.inOrder()
}

@Test
fun `streamMeasurements with hasExternalComputationId filter only gets computations`(): Unit =
runBlocking {
val measurementConsumer =
population.createMeasurementConsumer(measurementConsumersService, accountsService)

val measurement1 =
measurementsService.createMeasurement(
createMeasurementRequest {
measurement =
MEASUREMENT.copy {
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
externalMeasurementConsumerCertificateId =
measurementConsumer.certificate.externalCertificateId
}
}
)
measurementsService.createMeasurement(
createMeasurementRequest {
measurement =
MEASUREMENT.copy {
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
externalMeasurementConsumerCertificateId =
measurementConsumer.certificate.externalCertificateId
details =
details.copy {
protocolConfig = protocolConfig {
direct = ProtocolConfig.Direct.getDefaultInstance()
}
clearDuchyProtocolConfig()
}
}
}
)
val measurement3 =
measurementsService.createMeasurement(
createMeasurementRequest { measurement = measurement1 }
)

val streamMeasurementsRequest = streamMeasurementsRequest {
limit = 2
filter = filter {
hasExternalComputationId = true
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
}
}

val responses: List<Measurement> =
measurementsService.streamMeasurements(streamMeasurementsRequest).toList()

val computationMeasurement1 =
measurementsService.getMeasurement(
getMeasurementRequest {
externalMeasurementConsumerId = measurement1.externalMeasurementConsumerId
externalMeasurementId = measurement1.externalMeasurementId
}
)
val computationMeasurement3 =
measurementsService.getMeasurement(
getMeasurementRequest {
externalMeasurementConsumerId = measurement3.externalMeasurementConsumerId
externalMeasurementId = measurement3.externalMeasurementId
}
)
assertThat(responses)
.containsExactly(computationMeasurement1, computationMeasurement3)
.inOrder()
}

@Test
fun `streamMeasurements with COMPUTATION_STATS view only gets computations`(): Unit =
runBlocking {
val measurementConsumer =
population.createMeasurementConsumer(measurementConsumersService, accountsService)

val measurement1 =
measurementsService.createMeasurement(
createMeasurementRequest {
measurement =
MEASUREMENT.copy {
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
externalMeasurementConsumerCertificateId =
measurementConsumer.certificate.externalCertificateId
}
}
)
measurementsService.createMeasurement(
createMeasurementRequest {
measurement =
MEASUREMENT.copy {
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
externalMeasurementConsumerCertificateId =
measurementConsumer.certificate.externalCertificateId
details =
details.copy {
protocolConfig = protocolConfig {
direct = ProtocolConfig.Direct.getDefaultInstance()
}
clearDuchyProtocolConfig()
}
}
}
)
val measurement3 =
measurementsService.createMeasurement(
createMeasurementRequest { measurement = measurement1 }
)

val streamMeasurementsRequest = streamMeasurementsRequest {
limit = 3
filter = filter {
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
}
measurementView = Measurement.View.COMPUTATION_STATS
}

val responses: List<Measurement> =
measurementsService.streamMeasurements(streamMeasurementsRequest).toList()

val computationMeasurement1 =
measurementsService.getMeasurementByComputationId(
getMeasurementByComputationIdRequest {
externalComputationId = measurement1.externalComputationId
}
)
val computationMeasurement3 =
measurementsService.getMeasurementByComputationId(
getMeasurementByComputationIdRequest {
externalComputationId = measurement3.externalComputationId
}
)
assertThat(responses)
.containsExactly(computationMeasurement1, computationMeasurement3)
.inOrder()
}

@Test
fun `streamMeasurements with computation view only returns failure log`(): Unit = runBlocking {
val measurementConsumer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ service Measurements {
//
// Which key is used for the ordering depends on which view is specified in
// the request.
//
// If the view is `Measurement.View.COMPUTATION` or
// `Measurement.View.COMPUTATION_STATS`, it is guaranteed that only
// Measurements with `external_computation_id` set are in a successful
// response.
rpc StreamMeasurements(StreamMeasurementsRequest)
returns (stream Measurement);

Expand Down Expand Up @@ -88,10 +93,11 @@ message StreamMeasurementsRequest {
int64 external_measurement_consumer_certificate_id = 2;
repeated Measurement.State states = 3;
google.protobuf.Timestamp updated_after = 4;
string externalDuchyId = 7;
string external_duchy_id = 7;
google.protobuf.Timestamp updated_before = 8;
google.protobuf.Timestamp created_before = 9;
google.protobuf.Timestamp created_after = 11;
bool has_external_computation_id = 12;

message After {
google.protobuf.Timestamp update_time = 1;
Expand Down

0 comments on commit a1d5e05

Please sign in to comment.