From a1d5e05247932bfa7ca99c4570585ca16a831416 Mon Sep 17 00:00:00 2001 From: Tristan Vuong <85768771+tristanvuong2021@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:49:08 -0800 Subject: [PATCH] feat: Add new filter for checking externalComputationId, and checking externalComputationId if view is COMPUTATION or COMPUTATION_STATS (#1753) feat: Add new filter for checking externalComputationId, and checking externalComputationId if view is COMPUTATION or COMPUTATION_STATS --- .../spanner/queries/StreamMeasurements.kt | 15 +- .../spanner/readers/MeasurementReader.kt | 12 -- .../testing/MeasurementsServiceTest.kt | 136 ++++++++++++++++++ .../kingdom/measurements_service.proto | 8 +- 4 files changed, 156 insertions(+), 15 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt index 21ef2aa11c..5054b89a91 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt @@ -62,7 +62,7 @@ class StreamMeasurements( return MeasurementReader(view).apply { this.orderByClause = orderByClause fillStatementBuilder { - appendWhereClause(requestFilter) + appendWhereClause(view, requestFilter) appendClause(orderByClause) if (limit > 0) { appendClause("LIMIT @$LIMIT_PARAM") @@ -84,9 +84,20 @@ class StreamMeasurements( } } - private fun Statement.Builder.appendWhereClause(filter: StreamMeasurementsRequest.Filter) { + private fun Statement.Builder.appendWhereClause( + view: Measurement.View, + filter: StreamMeasurementsRequest.Filter, + ) { val conjuncts = mutableListOf() + if ( + filter.hasExternalComputationId || + view == Measurement.View.COMPUTATION || + view == Measurement.View.COMPUTATION_STATS + ) { + conjuncts.add("ExternalComputationId IS NOT NULL") + } + if (filter.externalMeasurementConsumerId != 0L) { conjuncts.add("ExternalMeasurementConsumerId = @$EXTERNAL_MEASUREMENT_CONSUMER_ID_PARAM") bind(EXTERNAL_MEASUREMENT_CONSUMER_ID_PARAM to filter.externalMeasurementConsumerId) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/MeasurementReader.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/MeasurementReader.kt index 871c0cd5d2..c094bff7df 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/MeasurementReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/MeasurementReader.kt @@ -555,14 +555,6 @@ private fun MeasurementKt.Dsl.fillComputationView(struct: Struct) { val requisitionsStructs = struct.getStructList("Requisitions") val dataProvidersCount = requisitionsStructs.size - if (struct.isNull("ExternalComputationId")) { - for (requisitionStruct in requisitionsStructs) { - requisitions += - RequisitionReader.buildRequisition(struct, requisitionStruct, mapOf(), dataProvidersCount) - } - return - } - val externalMeasurementId = ExternalId(struct.getLong("ExternalMeasurementId")) val externalMeasurementConsumerId = ExternalId(struct.getLong("ExternalMeasurementConsumerId")) val externalComputationId = ExternalId(struct.getLong("ExternalComputationId")) @@ -600,10 +592,6 @@ private fun MeasurementKt.Dsl.fillComputationView(struct: Struct) { private fun MeasurementKt.Dsl.fillComputationStatsView(struct: Struct) { fillMeasurementCommon(struct) - if (struct.isNull("ExternalComputationId")) { - return - } - val externalMeasurementId = ExternalId(struct.getLong("ExternalMeasurementId")) val externalMeasurementConsumerId = ExternalId(struct.getLong("ExternalMeasurementConsumerId")) val externalComputationId = ExternalId(struct.getLong("ExternalComputationId")) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt index fbeb8b791f..7d5022261b 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt @@ -1499,6 +1499,142 @@ abstract class MeasurementsServiceTest { .inOrder() } + @Test + fun `streamMeasurements with hasExternalComputationId filter only gets computations`(): Unit = + runBlocking { + val measurementConsumer = + population.createMeasurementConsumer(measurementConsumersService, accountsService) + + val measurement1 = + measurementsService.createMeasurement( + createMeasurementRequest { + measurement = + MEASUREMENT.copy { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + externalMeasurementConsumerCertificateId = + measurementConsumer.certificate.externalCertificateId + } + } + ) + measurementsService.createMeasurement( + createMeasurementRequest { + measurement = + MEASUREMENT.copy { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + externalMeasurementConsumerCertificateId = + measurementConsumer.certificate.externalCertificateId + details = + details.copy { + protocolConfig = protocolConfig { + direct = ProtocolConfig.Direct.getDefaultInstance() + } + clearDuchyProtocolConfig() + } + } + } + ) + val measurement3 = + measurementsService.createMeasurement( + createMeasurementRequest { measurement = measurement1 } + ) + + val streamMeasurementsRequest = streamMeasurementsRequest { + limit = 2 + filter = filter { + hasExternalComputationId = true + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + } + } + + val responses: List = + measurementsService.streamMeasurements(streamMeasurementsRequest).toList() + + val computationMeasurement1 = + measurementsService.getMeasurement( + getMeasurementRequest { + externalMeasurementConsumerId = measurement1.externalMeasurementConsumerId + externalMeasurementId = measurement1.externalMeasurementId + } + ) + val computationMeasurement3 = + measurementsService.getMeasurement( + getMeasurementRequest { + externalMeasurementConsumerId = measurement3.externalMeasurementConsumerId + externalMeasurementId = measurement3.externalMeasurementId + } + ) + assertThat(responses) + .containsExactly(computationMeasurement1, computationMeasurement3) + .inOrder() + } + + @Test + fun `streamMeasurements with COMPUTATION_STATS view only gets computations`(): Unit = + runBlocking { + val measurementConsumer = + population.createMeasurementConsumer(measurementConsumersService, accountsService) + + val measurement1 = + measurementsService.createMeasurement( + createMeasurementRequest { + measurement = + MEASUREMENT.copy { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + externalMeasurementConsumerCertificateId = + measurementConsumer.certificate.externalCertificateId + } + } + ) + measurementsService.createMeasurement( + createMeasurementRequest { + measurement = + MEASUREMENT.copy { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + externalMeasurementConsumerCertificateId = + measurementConsumer.certificate.externalCertificateId + details = + details.copy { + protocolConfig = protocolConfig { + direct = ProtocolConfig.Direct.getDefaultInstance() + } + clearDuchyProtocolConfig() + } + } + } + ) + val measurement3 = + measurementsService.createMeasurement( + createMeasurementRequest { measurement = measurement1 } + ) + + val streamMeasurementsRequest = streamMeasurementsRequest { + limit = 3 + filter = filter { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + } + measurementView = Measurement.View.COMPUTATION_STATS + } + + val responses: List = + measurementsService.streamMeasurements(streamMeasurementsRequest).toList() + + val computationMeasurement1 = + measurementsService.getMeasurementByComputationId( + getMeasurementByComputationIdRequest { + externalComputationId = measurement1.externalComputationId + } + ) + val computationMeasurement3 = + measurementsService.getMeasurementByComputationId( + getMeasurementByComputationIdRequest { + externalComputationId = measurement3.externalComputationId + } + ) + assertThat(responses) + .containsExactly(computationMeasurement1, computationMeasurement3) + .inOrder() + } + @Test fun `streamMeasurements with computation view only returns failure log`(): Unit = runBlocking { val measurementConsumer = diff --git a/src/main/proto/wfa/measurement/internal/kingdom/measurements_service.proto b/src/main/proto/wfa/measurement/internal/kingdom/measurements_service.proto index c63f603739..dc21d78fc9 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/measurements_service.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/measurements_service.proto @@ -39,6 +39,11 @@ service Measurements { // // Which key is used for the ordering depends on which view is specified in // the request. + // + // If the view is `Measurement.View.COMPUTATION` or + // `Measurement.View.COMPUTATION_STATS`, it is guaranteed that only + // Measurements with `external_computation_id` set are in a successful + // response. rpc StreamMeasurements(StreamMeasurementsRequest) returns (stream Measurement); @@ -88,10 +93,11 @@ message StreamMeasurementsRequest { int64 external_measurement_consumer_certificate_id = 2; repeated Measurement.State states = 3; google.protobuf.Timestamp updated_after = 4; - string externalDuchyId = 7; + string external_duchy_id = 7; google.protobuf.Timestamp updated_before = 8; google.protobuf.Timestamp created_before = 9; google.protobuf.Timestamp created_after = 11; + bool has_external_computation_id = 12; message After { google.protobuf.Timestamp update_time = 1;