From edd90c97c3b802a8a7012cc8b0a896a5560c482e Mon Sep 17 00:00:00 2001 From: Tristan Vuong <85768771+tristanvuong2021@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:27:30 -0800 Subject: [PATCH] feat: Add new Bigquery Table for Computation Participant Stages (#1805) feat: Add new Bigquery Table for Computation Participant Stages --- src/main/k8s/dev/kingdom_gke.cue | 2 + .../gcloud/job/OperationalMetricsExport.kt | 257 +++++++++ .../gcloud/job/OperationalMetricsExportJob.kt | 21 + .../operational_metrics_dataset.proto | 31 ++ .../terraform/gcloud/modules/kingdom/main.tf | 102 ++++ .../job/OperationalMetricsExportTest.kt | 521 +++++++++++++++++- 6 files changed, 929 insertions(+), 5 deletions(-) diff --git a/src/main/k8s/dev/kingdom_gke.cue b/src/main/k8s/dev/kingdom_gke.cue index 975ec3f2102..21311c353ac 100644 --- a/src/main/k8s/dev/kingdom_gke.cue +++ b/src/main/k8s/dev/kingdom_gke.cue @@ -119,6 +119,8 @@ kingdom: #Kingdom & { "--latest-measurement-read-table=latest_measurement_read", "--requisitions-table=requisitions", "--latest-requisition-read-table=latest_requisition_read", + "--computation-participant-stages-table=computation_participant_stages", + "--latest-computation-read-table=latest_computation_read", "--tls-cert-file=/var/run/secrets/files/kingdom_tls.pem", "--tls-key-file=/var/run/secrets/files/kingdom_tls.key", "--cert-collection-file=/var/run/secrets/files/kingdom_root.pem", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExport.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExport.kt index 1afdccfb49b..2ad6bc70a91 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExport.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExport.kt @@ -30,6 +30,7 @@ import com.google.protobuf.util.Timestamps import com.google.rpc.Code import io.grpc.StatusException import java.time.Duration +import java.util.concurrent.ExecutionException import java.util.logging.Logger import kotlinx.coroutines.flow.catch import org.jetbrains.annotations.Blocking @@ -38,21 +39,28 @@ import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.common.identity.apiIdToExternalId import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.internal.kingdom.DuchyMeasurementLogEntry import org.wfanet.measurement.internal.kingdom.Measurement import org.wfanet.measurement.internal.kingdom.MeasurementsGrpcKt import org.wfanet.measurement.internal.kingdom.Requisition import org.wfanet.measurement.internal.kingdom.RequisitionsGrpcKt import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt import org.wfanet.measurement.internal.kingdom.StreamRequisitionsRequestKt +import org.wfanet.measurement.internal.kingdom.bigquerytables.ComputationParticipantStagesTableRow +import org.wfanet.measurement.internal.kingdom.bigquerytables.LatestComputationReadTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.LatestMeasurementReadTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.LatestRequisitionReadTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.MeasurementType import org.wfanet.measurement.internal.kingdom.bigquerytables.MeasurementsTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.RequisitionsTableRow +import org.wfanet.measurement.internal.kingdom.bigquerytables.computationParticipantStagesTableRow +import org.wfanet.measurement.internal.kingdom.bigquerytables.copy +import org.wfanet.measurement.internal.kingdom.bigquerytables.latestComputationReadTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.latestMeasurementReadTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.latestRequisitionReadTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.measurementsTableRow import org.wfanet.measurement.internal.kingdom.bigquerytables.requisitionsTableRow +import org.wfanet.measurement.internal.kingdom.computationKey import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.measurementKey import org.wfanet.measurement.internal.kingdom.streamMeasurementsRequest @@ -67,13 +75,16 @@ class OperationalMetricsExport( private val datasetId: String, private val latestMeasurementReadTableId: String, private val latestRequisitionReadTableId: String, + private val latestComputationReadTableId: String, private val measurementsTableId: String, private val requisitionsTableId: String, + private val computationParticipantStagesTableId: String, private val streamWriterFactory: StreamWriterFactory = StreamWriterFactoryImpl(), ) { suspend fun execute() { exportMeasurements() exportRequisitions() + exportComputationParticipants() } private suspend fun exportMeasurements() { @@ -434,6 +445,243 @@ class OperationalMetricsExport( } } + private suspend fun exportComputationParticipants() { + var computationsQueryResponseSize: Int + + val query = + """ + SELECT update_time, external_computation_id + FROM `$datasetId.$latestComputationReadTableId` + ORDER BY update_time DESC, external_computation_id DESC + LIMIT 1 + """ + .trimIndent() + + val queryJobConfiguration: QueryJobConfiguration = + QueryJobConfiguration.newBuilder(query).build() + + val results = bigQuery.query(queryJobConfiguration).iterateAll() + logger.info("Retrieved latest computation read info from BigQuery") + + val latestComputationReadFromPreviousJob: FieldValueList? = results.firstOrNull() + + var streamComputationsRequest = streamMeasurementsRequest { + measurementView = Measurement.View.COMPUTATION_STATS + limit = BATCH_SIZE + filter = + StreamMeasurementsRequestKt.filter { + states += Measurement.State.SUCCEEDED + states += Measurement.State.FAILED + if (latestComputationReadFromPreviousJob != null) { + after = + StreamMeasurementsRequestKt.FilterKt.after { + updateTime = + Timestamps.fromNanos( + latestComputationReadFromPreviousJob.get("update_time").longValue + ) + computation = computationKey { + externalComputationId = + latestComputationReadFromPreviousJob.get("external_computation_id").longValue + } + } + } + } + } + + DataWriter( + projectId = projectId, + datasetId = datasetId, + tableId = computationParticipantStagesTableId, + client = bigQueryWriteClient, + protoSchema = + ProtoSchemaConverter.convert(ComputationParticipantStagesTableRow.getDescriptor()), + streamWriterFactory = streamWriterFactory, + ) + .use { computationParticipantStagesDataWriter -> + DataWriter( + projectId = projectId, + datasetId = datasetId, + tableId = latestComputationReadTableId, + client = bigQueryWriteClient, + protoSchema = + ProtoSchemaConverter.convert(LatestComputationReadTableRow.getDescriptor()), + streamWriterFactory = streamWriterFactory, + ) + .use { latestComputationReadDataWriter -> + do { + computationsQueryResponseSize = 0 + + val computationParticipantStagesProtoRowsBuilder: ProtoRows.Builder = + ProtoRows.newBuilder() + var latestComputation: Measurement = Measurement.getDefaultInstance() + + measurementsClient + .streamMeasurements(streamComputationsRequest) + .catch { e -> + if (e is StatusException) { + logger.warning("Failed to retrieved Computations") + throw e + } + } + .collect { measurement -> + computationsQueryResponseSize++ + latestComputation = measurement + + if (measurement.externalComputationId != 0L) { + val measurementType = + getMeasurementType( + measurement.details.measurementSpec, + measurement.details.apiVersion, + ) + + val measurementConsumerId = + externalIdToApiId(measurement.externalMeasurementConsumerId) + val measurementId = externalIdToApiId(measurement.externalMeasurementId) + val computationId = externalIdToApiId(measurement.externalComputationId) + + val baseComputationParticipantStagesTableRow = + computationParticipantStagesTableRow { + this.measurementConsumerId = measurementConsumerId + this.measurementId = measurementId + this.computationId = computationId + this.measurementType = measurementType + } + + // Map of ExternalDuchyId to log entries. + val logEntriesMap: Map> = + buildMap { + for (logEntry in measurement.logEntriesList) { + val logEntries = getOrPut(logEntry.externalDuchyId) { mutableListOf() } + logEntries.add(logEntry) + } + } + + for (computationParticipant in measurement.computationParticipantsList) { + val sortedStageLogEntries = + logEntriesMap[computationParticipant.externalDuchyId]?.sortedBy { + it.details.stageAttempt.stage + } ?: emptyList() + + if (sortedStageLogEntries.isEmpty()) { + continue + } + + sortedStageLogEntries.zipWithNext { logEntry, nextLogEntry -> + if (logEntry.details.stageAttempt.stageName.isNotBlank()) { + computationParticipantStagesProtoRowsBuilder.addSerializedRows( + baseComputationParticipantStagesTableRow + .copy { + duchyId = computationParticipant.externalDuchyId + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = logEntry.details.stageAttempt.stageName + stageStartTime = logEntry.details.stageAttempt.stageStartTime + completionDurationSeconds = + Duration.between( + logEntry.details.stageAttempt.stageStartTime.toInstant(), + nextLogEntry.details.stageAttempt.stageStartTime.toInstant(), + ) + .seconds + completionDurationSecondsSquared = + completionDurationSeconds * completionDurationSeconds + } + .toByteString() + ) + } + } + + val logEntry = sortedStageLogEntries.last() + if (logEntry.details.stageAttempt.stageName.isBlank()) { + continue + } + + if (measurement.state == Measurement.State.SUCCEEDED) { + computationParticipantStagesProtoRowsBuilder.addSerializedRows( + baseComputationParticipantStagesTableRow + .copy { + duchyId = computationParticipant.externalDuchyId + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = logEntry.details.stageAttempt.stageName + stageStartTime = logEntry.details.stageAttempt.stageStartTime + completionDurationSeconds = + Duration.between( + logEntry.details.stageAttempt.stageStartTime.toInstant(), + measurement.updateTime.toInstant(), + ) + .seconds + completionDurationSecondsSquared = + completionDurationSeconds * completionDurationSeconds + } + .toByteString() + ) + } else if (measurement.state == Measurement.State.FAILED) { + computationParticipantStagesProtoRowsBuilder.addSerializedRows( + baseComputationParticipantStagesTableRow + .copy { + duchyId = computationParticipant.externalDuchyId + result = ComputationParticipantStagesTableRow.Result.FAILED + stageName = logEntry.details.stageAttempt.stageName + stageStartTime = logEntry.details.stageAttempt.stageStartTime + completionDurationSeconds = + Duration.between( + logEntry.details.stageAttempt.stageStartTime.toInstant(), + measurement.updateTime.toInstant(), + ) + .seconds + completionDurationSecondsSquared = + completionDurationSeconds * completionDurationSeconds + } + .toByteString() + ) + } + } + } + } + + logger.info("Computations read from the Kingdom Internal Server") + + if (computationParticipantStagesProtoRowsBuilder.serializedRowsCount > 0) { + computationParticipantStagesDataWriter.appendRows( + computationParticipantStagesProtoRowsBuilder.build() + ) + + logger.info("Computation Participant Stages Metrics written to BigQuery") + // Possible for there to be no stages because all measurements in response are + // direct. + } else if (computationsQueryResponseSize == 0) { + logger.info("No more Computations to process") + break + } + + val latestComputationReadTableRow = latestComputationReadTableRow { + updateTime = Timestamps.toNanos(latestComputation.updateTime) + externalComputationId = latestComputation.externalComputationId + } + + latestComputationReadDataWriter.appendRows( + ProtoRows.newBuilder() + .addSerializedRows(latestComputationReadTableRow.toByteString()) + .build() + ) + + streamComputationsRequest = + streamComputationsRequest.copy { + filter = + filter.copy { + after = + StreamMeasurementsRequestKt.FilterKt.after { + updateTime = latestComputation.updateTime + computation = computationKey { + externalComputationId = + latestComputationReadTableRow.externalComputationId + } + } + } + } + } while (computationsQueryResponseSize == BATCH_SIZE) + } + } + } + companion object { private val logger: Logger = Logger.getLogger(this::class.java.name) private const val BATCH_SIZE = 3000 @@ -515,10 +763,19 @@ class OperationalMetricsExport( break } } catch (e: AppendSerializationError) { + logger.warning("Logging serialization errors") for (value in e.rowIndexToErrorMessage.values) { logger.warning(value) } throw e + } catch (e: ExecutionException) { + if (e.cause is AppendSerializationError) { + logger.warning("Logging serialization errors") + for (value in (e.cause as AppendSerializationError).rowIndexToErrorMessage.values) { + logger.warning(value) + } + } + throw e } } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExportJob.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExportJob.kt index 6086e9c9829..c5f92ba55d7 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExportJob.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/job/OperationalMetricsExportJob.kt @@ -75,8 +75,11 @@ private fun run( val datasetId = operationalMetricsFlags.bigQueryDataSet val measurementsTableId = operationalMetricsFlags.measurementsTable val requisitionsTableId = operationalMetricsFlags.requisitionsTable + val computationParticipantStagesTableId = + operationalMetricsFlags.computationParticipantStagesTable val latestMeasurementReadTableId = operationalMetricsFlags.latestMeasurementReadTable val latestRequisitionReadTableId = operationalMetricsFlags.latestRequisitionReadTable + val latestComputationReadTableId = operationalMetricsFlags.latestComputationReadTable BigQueryWriteClient.create().use { bigQueryWriteClient -> val operationalMetricsExport = @@ -89,8 +92,10 @@ private fun run( datasetId = datasetId, latestMeasurementReadTableId = latestMeasurementReadTableId, latestRequisitionReadTableId = latestRequisitionReadTableId, + latestComputationReadTableId = latestComputationReadTableId, measurementsTableId = measurementsTableId, requisitionsTableId = requisitionsTableId, + computationParticipantStagesTableId = computationParticipantStagesTableId, ) operationalMetricsExport.execute() @@ -133,6 +138,14 @@ class OperationalMetricsFlags { lateinit var requisitionsTable: String private set + @CommandLine.Option( + names = ["--computation-participant-stages-table"], + description = ["Computation Participant Stages table ID"], + required = true, + ) + lateinit var computationParticipantStagesTable: String + private set + @CommandLine.Option( names = ["--latest-measurement-read-table"], description = ["Latest Measurement Read table ID"], @@ -148,4 +161,12 @@ class OperationalMetricsFlags { ) lateinit var latestRequisitionReadTable: String private set + + @CommandLine.Option( + names = ["--latest-computation-read-table"], + description = ["Latest Computation Read table ID"], + required = true, + ) + lateinit var latestComputationReadTable: String + private set } diff --git a/src/main/proto/wfa/measurement/internal/kingdom/bigquerytables/operational_metrics_dataset.proto b/src/main/proto/wfa/measurement/internal/kingdom/bigquerytables/operational_metrics_dataset.proto index 642aab49efc..22381e412e5 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/bigquerytables/operational_metrics_dataset.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/bigquerytables/operational_metrics_dataset.proto @@ -74,6 +74,30 @@ message RequisitionsTableRow { int64 completion_duration_seconds_squared = 11; } +message ComputationParticipantStagesTableRow { + string measurement_consumer_id = 1; + string measurement_id = 2; + string computation_id = 3; + string duchy_id = 4; + + MeasurementType measurement_type = 5; + + enum Result { + RESULT_UNSPECIFIED = 0; + SUCCEEDED = 1; + FAILED = 2; + } + Result result = 6; + + string stage_name = 7; + + google.protobuf.Timestamp stage_start_time = 8; + + // Difference between start of this stage and start of next stage. + int64 completion_duration_seconds = 9; + int64 completion_duration_seconds_squared = 10; +} + message LatestMeasurementReadTableRow { // Since BigQuery timestamp is microsecond precision, this is in nanoseconds // and is stored as an integer. @@ -89,3 +113,10 @@ message LatestRequisitionReadTableRow { int64 external_data_provider_id = 2; int64 external_requisition_id = 3; } + +message LatestComputationReadTableRow { + // Since BigQuery timestamp is microsecond precision, this is in nanoseconds + // and is stored as an integer. + int64 update_time = 1; + int64 external_computation_id = 2; +} diff --git a/src/main/terraform/gcloud/modules/kingdom/main.tf b/src/main/terraform/gcloud/modules/kingdom/main.tf index 830c679872b..1e4f1e3d2d0 100644 --- a/src/main/terraform/gcloud/modules/kingdom/main.tf +++ b/src/main/terraform/gcloud/modules/kingdom/main.tf @@ -202,6 +202,79 @@ EOF } +resource "google_bigquery_table" "computation_participant_stages" { + dataset_id = google_bigquery_dataset.operational_metrics.dataset_id + table_id = "computation_participant_stages" + + deletion_protection = true + + time_partitioning { + field = "stage_start_time" + type = "MONTH" + } + + schema = < measurementsStreamWriterMock REQUISITIONS_TABLE_ID -> requisitionsStreamWriterMock + COMPUTATION_PARTICIPANT_STAGES_TABLE_ID -> computationParticipantStagesStreamWriterMock LATEST_MEASUREMENT_READ_TABLE_ID -> latestMeasurementReadStreamWriterMock LATEST_REQUISITION_READ_TABLE_ID -> latestRequisitionReadStreamWriterMock + LATEST_COMPUTATION_READ_TABLE_ID -> latestComputationReadStreamWriterMock else -> mock {} } } @@ -204,6 +235,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -315,6 +348,93 @@ class OperationalMetricsExportTest { ) } + with(argumentCaptor()) { + verify(computationParticipantStagesStreamWriterMock).append(capture()) + + val protoRows: ProtoRows = allValues.first() + assertThat(protoRows.serializedRowsList).hasSize(4) + + val stageOneTableRow = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[0]) + + assertThat(stageOneTableRow) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "0" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = STAGE_ONE + stageStartTime = timestamp { seconds = 100 } + completionDurationSeconds = 200 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + + val stageTwoTableRow = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[1]) + + assertThat(stageTwoTableRow) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "0" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = STAGE_TWO + stageStartTime = timestamp { seconds = 300 } + completionDurationSeconds = 300 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + + val stageOneTableRow2 = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[2]) + + assertThat(stageOneTableRow2) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "1" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = STAGE_ONE + stageStartTime = timestamp { seconds = 100 } + completionDurationSeconds = 200 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + + val stageTwoTableRow2 = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[3]) + + assertThat(stageTwoTableRow2) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "1" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = STAGE_TWO + stageStartTime = timestamp { seconds = 300 } + completionDurationSeconds = 300 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + } + with(argumentCaptor()) { verify(latestMeasurementReadStreamWriterMock).append(capture()) @@ -350,6 +470,143 @@ class OperationalMetricsExportTest { } ) } + + with(argumentCaptor()) { + verify(latestComputationReadStreamWriterMock).append(capture()) + + val protoRows: ProtoRows = allValues.first() + assertThat(protoRows.serializedRowsList).hasSize(1) + + val latestComputationReadTableRow = + LatestComputationReadTableRow.parseFrom(protoRows.serializedRowsList.first()) + assertThat(latestComputationReadTableRow) + .isEqualTo( + latestComputationReadTableRow { + updateTime = Timestamps.toNanos(COMPUTATION_MEASUREMENT.updateTime) + externalComputationId = COMPUTATION_MEASUREMENT.externalComputationId + } + ) + } + } + + @Test + fun `job successfully creates proto for stages when measurement failed`() = runBlocking { + val tableResultMock: TableResult = mock { tableResult -> + whenever(tableResult.iterateAll()).thenReturn(emptyList()) + } + + val bigQueryMock: BigQuery = mock { bigQuery -> + whenever(bigQuery.query(any())).thenReturn(tableResultMock) + } + + whenever(measurementsMock.streamMeasurements(any())) + .thenReturn(flowOf(COMPUTATION_MEASUREMENT.copy { state = Measurement.State.FAILED })) + + val operationalMetricsExport = + OperationalMetricsExport( + measurementsClient = measurementsClient, + requisitionsClient = requisitionsClient, + bigQuery = bigQueryMock, + bigQueryWriteClient = bigQueryWriteClientMock, + projectId = PROJECT_ID, + datasetId = DATASET_ID, + latestMeasurementReadTableId = LATEST_MEASUREMENT_READ_TABLE_ID, + measurementsTableId = MEASUREMENTS_TABLE_ID, + latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, + requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, + streamWriterFactory = streamWriterFactoryTestImpl, + ) + + operationalMetricsExport.execute() + + with(argumentCaptor()) { + verify(computationParticipantStagesStreamWriterMock).append(capture()) + + val protoRows: ProtoRows = allValues.first() + assertThat(protoRows.serializedRowsList).hasSize(4) + + val stageOneTableRow = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[0]) + + assertThat(stageOneTableRow) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "0" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = STAGE_ONE + stageStartTime = timestamp { seconds = 100 } + completionDurationSeconds = 200 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + + val stageTwoTableRow = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[1]) + + assertThat(stageTwoTableRow) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "0" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.FAILED + stageName = STAGE_TWO + stageStartTime = timestamp { seconds = 300 } + completionDurationSeconds = 300 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + + val stageOneTableRow2 = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[2]) + + assertThat(stageOneTableRow2) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "1" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.SUCCEEDED + stageName = STAGE_ONE + stageStartTime = timestamp { seconds = 100 } + completionDurationSeconds = 200 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + + val stageTwoTableRow2 = + ComputationParticipantStagesTableRow.parseFrom(protoRows.serializedRowsList[3]) + + assertThat(stageTwoTableRow2) + .isEqualTo( + computationParticipantStagesTableRow { + measurementConsumerId = + externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementConsumerId) + measurementId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalMeasurementId) + computationId = externalIdToApiId(COMPUTATION_MEASUREMENT.externalComputationId) + duchyId = "1" + measurementType = MeasurementType.REACH_AND_FREQUENCY + result = ComputationParticipantStagesTableRow.Result.FAILED + stageName = STAGE_TWO + stageStartTime = timestamp { seconds = 300 } + completionDurationSeconds = 300 + completionDurationSecondsSquared = completionDurationSeconds * completionDurationSeconds + } + ) + } } @Test @@ -406,13 +663,15 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) operationalMetricsExport.execute() with(argumentCaptor()) { - verify(measurementsMock).streamMeasurements(capture()) + verify(measurementsMock, times(2)).streamMeasurements(capture()) val streamMeasurementsRequest = allValues.first() assertThat(streamMeasurementsRequest) @@ -470,6 +729,7 @@ class OperationalMetricsExportTest { ) ) ) + .thenReturn(emptyList()) } whenever(requisitionsMock.streamRequisitions(any())).thenReturn(flowOf(REQUISITION_2)) @@ -490,6 +750,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -520,6 +782,143 @@ class OperationalMetricsExportTest { } } + @Test + fun `job can process the next batch of computations without starting at the beginning`() = + runBlocking { + val computationMeasurement = COMPUTATION_MEASUREMENT + + val updateTimeFieldValue: FieldValue = + FieldValue.of( + FieldValue.Attribute.PRIMITIVE, + "${Timestamps.toNanos(computationMeasurement.updateTime)}", + ) + val externalMeasurementConsumerIdFieldValue: FieldValue = + FieldValue.of( + FieldValue.Attribute.PRIMITIVE, + "${computationMeasurement.externalComputationId}", + ) + + val tableResultMock: TableResult = mock { tableResult -> + whenever(tableResult.iterateAll()) + .thenReturn(emptyList()) + .thenReturn(emptyList()) + .thenReturn( + listOf( + FieldValueList.of( + mutableListOf(updateTimeFieldValue, externalMeasurementConsumerIdFieldValue), + LATEST_COMPUTATION_FIELD_LIST, + ) + ) + ) + .thenReturn(emptyList()) + } + + whenever(measurementsMock.streamMeasurements(any())) + .thenReturn(flowOf(COMPUTATION_MEASUREMENT)) + + val bigQueryMock: BigQuery = mock { bigQuery -> + whenever(bigQuery.query(any())).thenReturn(tableResultMock) + } + + val operationalMetricsExport = + OperationalMetricsExport( + measurementsClient = measurementsClient, + requisitionsClient = requisitionsClient, + bigQuery = bigQueryMock, + bigQueryWriteClient = bigQueryWriteClientMock, + projectId = PROJECT_ID, + datasetId = DATASET_ID, + latestMeasurementReadTableId = LATEST_MEASUREMENT_READ_TABLE_ID, + measurementsTableId = MEASUREMENTS_TABLE_ID, + latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, + requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, + streamWriterFactory = streamWriterFactoryTestImpl, + ) + + operationalMetricsExport.execute() + + with(argumentCaptor()) { + verify(measurementsMock, times(2)).streamMeasurements(capture()) + val streamMeasurementsRequest = allValues.last() + + assertThat(streamMeasurementsRequest) + .ignoringRepeatedFieldOrder() + .isEqualTo( + streamMeasurementsRequest { + measurementView = Measurement.View.COMPUTATION_STATS + filter = + StreamMeasurementsRequestKt.filter { + states += Measurement.State.SUCCEEDED + states += Measurement.State.FAILED + after = + StreamMeasurementsRequestKt.FilterKt.after { + updateTime = computationMeasurement.updateTime + computation = computationKey { + externalComputationId = computationMeasurement.externalComputationId + } + } + } + limit = 3000 + } + ) + } + } + + @Test + fun `job skips direct measurements when attempting to export stages`() = runBlocking { + whenever(measurementsMock.streamMeasurements(any())) + .thenReturn(flowOf(DIRECT_MEASUREMENT, COMPUTATION_MEASUREMENT)) + .thenReturn( + buildList { + for (i in 1..3000) { + add(DIRECT_MEASUREMENT) + } + } + .asFlow() + ) + .thenReturn(flowOf(DIRECT_MEASUREMENT, COMPUTATION_MEASUREMENT)) + + val tableResultMock: TableResult = mock { tableResult -> + whenever(tableResult.iterateAll()).thenReturn(emptyList()) + } + + val bigQueryMock: BigQuery = mock { bigQuery -> + whenever(bigQuery.query(any())).thenReturn(tableResultMock) + } + + val operationalMetricsExport = + OperationalMetricsExport( + measurementsClient = measurementsClient, + requisitionsClient = requisitionsClient, + bigQuery = bigQueryMock, + bigQueryWriteClient = bigQueryWriteClientMock, + projectId = PROJECT_ID, + datasetId = DATASET_ID, + latestMeasurementReadTableId = LATEST_MEASUREMENT_READ_TABLE_ID, + measurementsTableId = MEASUREMENTS_TABLE_ID, + latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, + requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, + streamWriterFactory = streamWriterFactoryTestImpl, + ) + + operationalMetricsExport.execute() + + with(argumentCaptor()) { + verify(measurementsMock, times(3)).streamMeasurements(capture()) + } + + with(argumentCaptor()) { + verify(computationParticipantStagesStreamWriterMock).append(capture()) + + val protoRows: ProtoRows = allValues.first() + assertThat(protoRows.serializedRowsList).hasSize(4) + } + } + @Test fun `job recreates streamwriter if it is closed`() = runBlocking { val tableResultMock: TableResult = mock { tableResult -> @@ -545,6 +944,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -583,6 +984,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -615,6 +1018,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -648,6 +1053,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -684,6 +1091,15 @@ class OperationalMetricsExportTest { ) ) + whenever(computationParticipantStagesStreamWriterMock.append(any())) + .thenReturn( + ApiFutures.immediateFuture( + AppendRowsResponse.newBuilder() + .setError(Status.newBuilder().setCode(Code.INTERNAL_VALUE).build()) + .build() + ) + ) + val operationalMetricsExport = OperationalMetricsExport( measurementsClient = measurementsClient, @@ -696,6 +1112,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -735,6 +1153,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -768,6 +1188,8 @@ class OperationalMetricsExportTest { measurementsTableId = MEASUREMENTS_TABLE_ID, latestRequisitionReadTableId = LATEST_REQUISITION_READ_TABLE_ID, requisitionsTableId = REQUISITIONS_TABLE_ID, + latestComputationReadTableId = LATEST_COMPUTATION_READ_TABLE_ID, + computationParticipantStagesTableId = COMPUTATION_PARTICIPANT_STAGES_TABLE_ID, streamWriterFactory = streamWriterFactoryTestImpl, ) @@ -780,8 +1202,10 @@ class OperationalMetricsExportTest { private const val DATASET_ID = "dataset" private const val MEASUREMENTS_TABLE_ID = "measurements" private const val REQUISITIONS_TABLE_ID = "requisitions" + private const val COMPUTATION_PARTICIPANT_STAGES_TABLE_ID = "computation_participant_stages" private const val LATEST_MEASUREMENT_READ_TABLE_ID = "latest_measurement_read" private const val LATEST_REQUISITION_READ_TABLE_ID = "latest_requisition_read" + private const val LATEST_COMPUTATION_READ_TABLE_ID = "latest_computation_read" private val API_VERSION = Version.V2_ALPHA.toString() @@ -817,6 +1241,9 @@ class OperationalMetricsExportTest { } } + private val STAGE_ONE = "stage_one" + private val STAGE_TWO = "stage_two" + private val COMPUTATION_MEASUREMENT = MEASUREMENT.copy { externalMeasurementId = 123 @@ -824,16 +1251,92 @@ class OperationalMetricsExportTest { providedMeasurementId = "computation-participant" state = Measurement.State.SUCCEEDED createTime = timestamp { seconds = 200 } - updateTime = timestamp { - seconds = 300 - nanos = 100 - } + updateTime = timestamp { seconds = 600 } details = details.copy { protocolConfig = protocolConfig { liquidLegionsV2 = ProtocolConfig.LiquidLegionsV2.getDefaultInstance() } } + + requisitions += requisition { + externalDataProviderId = 432 + externalRequisitionId = 433 + state = Requisition.State.FULFILLED + updateTime = timestamp { + seconds = 500 + nanos = 100 + } + } + + computationParticipants += computationParticipant { + externalDuchyId = "0" + state = ComputationParticipant.State.READY + updateTime = timestamp { seconds = 300 } + failureLogEntry = duchyMeasurementLogEntry { + logEntry = measurementLogEntry { + createTime = timestamp { seconds = 350 } + details = measurementLogEntryDetails { error = measurementLogEntryError {} } + } + details = duchyMeasurementLogEntryDetails { + stageAttempt = duchyMeasurementLogEntryStageAttempt { stageName = STAGE_TWO } + } + } + } + computationParticipants += computationParticipant { + externalDuchyId = "1" + state = ComputationParticipant.State.READY + updateTime = timestamp { seconds = 400 } + } + computationParticipants += computationParticipant { + externalDuchyId = "2" + state = ComputationParticipant.State.READY + updateTime = timestamp { seconds = 500 } + } + logEntries += duchyMeasurementLogEntry { + externalDuchyId = "0" + logEntry = measurementLogEntry { createTime = timestamp { seconds = 300 } } + details = duchyMeasurementLogEntryDetails { + stageAttempt = duchyMeasurementLogEntryStageAttempt { + stage = 1 + stageStartTime = timestamp { seconds = 100 } + stageName = STAGE_ONE + } + } + } + logEntries += duchyMeasurementLogEntry { + externalDuchyId = "0" + logEntry = measurementLogEntry { createTime = timestamp { seconds = 300 } } + details = duchyMeasurementLogEntryDetails { + stageAttempt = duchyMeasurementLogEntryStageAttempt { + stage = 2 + stageStartTime = timestamp { seconds = 300 } + stageName = STAGE_TWO + } + } + } + logEntries += duchyMeasurementLogEntry { + externalDuchyId = "1" + logEntry = measurementLogEntry { createTime = timestamp { seconds = 300 } } + details = duchyMeasurementLogEntryDetails { + stageAttempt = duchyMeasurementLogEntryStageAttempt { + stage = 1 + stageStartTime = timestamp { seconds = 100 } + stageName = STAGE_ONE + } + } + } + logEntries += duchyMeasurementLogEntry { + externalDuchyId = "1" + logEntry = measurementLogEntry { createTime = timestamp { seconds = 300 } } + details = duchyMeasurementLogEntryDetails { + stageAttempt = duchyMeasurementLogEntryStageAttempt { + stage = 2 + stageStartTime = timestamp { seconds = 300 } + stageName = STAGE_TWO + } + } + } } private val DIRECT_MEASUREMENT = @@ -914,5 +1417,13 @@ class OperationalMetricsExportTest { Field.of("external_requisition_id", LegacySQLTypeName.INTEGER), ) ) + + private val LATEST_COMPUTATION_FIELD_LIST: FieldList = + FieldList.of( + listOf( + Field.of("update_time", LegacySQLTypeName.INTEGER), + Field.of("external_computation_id", LegacySQLTypeName.INTEGER), + ) + ) } }