diff --git a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/genomics/enriched/Variants.scala b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/genomics/enriched/Variants.scala index 1a94fec2..f193f173 100644 --- a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/genomics/enriched/Variants.scala +++ b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/genomics/enriched/Variants.scala @@ -9,6 +9,7 @@ import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._ import bio.ferlab.datalake.spark3.implicits.GenomicImplicits._ import bio.ferlab.datalake.spark3.implicits.GenomicImplicits.columns.{locus, locusColumnNames} import bio.ferlab.datalake.spark3.implicits.SparkUtils.firstAs +import org.apache.spark.SparkContext import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} @@ -26,9 +27,11 @@ import java.time.LocalDateTime * @param extraAggregations extra aggregations to be computed when grouping occurrences by locus. Will be added to the root of the data * @param rc the etl context */ -case class Variants(rc: RuntimeETLContext, participantId: Column = col("participant_id"), affectedStatus: Column = col("affected_status"), filterSnv: Option[Column] = Some(col("has_alt")), snvDatasetId: String, splits: Seq[OccurrenceSplit], extraAggregations: Seq[Column] = Nil) extends SimpleSingleETL(rc) { - +case class Variants(rc: RuntimeETLContext, participantId: Column = col("participant_id"), affectedStatus: Column = col("affected_status"), filterSnv: Option[Column] = Some(col("has_alt")), snvDatasetId: String, splits: Seq[OccurrenceSplit], extraAggregations: Seq[Column] = Nil, checkpoint: Boolean = false) extends SimpleSingleETL(rc) { override val mainDestination: DatasetConf = conf.getDataset("enriched_variants") + if (checkpoint) { + spark.sparkContext.setCheckpointDir(s"${mainDestination.rootPath}/checkpoints") + } protected val thousand_genomes: DatasetConf = conf.getDataset("normalized_1000_genomes") protected val topmed_bravo: DatasetConf = conf.getDataset("normalized_topmed_bravo") protected val gnomad_genomes_v2: DatasetConf = conf.getDataset("normalized_gnomad_genomes_v2_1_1") @@ -74,7 +77,9 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip .withColumn("dna_change", concat_ws(">", col("reference"), col("alternate"))) .withColumn("assembly_version", lit("GRCh38")) - variants + val variantsCheckpoint = if (checkpoint) variants.checkpoint() else variants + + variantsCheckpoint .withFrequencies(participantId, affectedStatus, snv, splits) .withPopulations(data(thousand_genomes.id), data(topmed_bravo.id), data(gnomad_genomes_v2.id), data(gnomad_exomes_v2.id), data(gnomad_genomes_v3.id)) .withDbSNP(data(dbsnp.id)) @@ -235,14 +240,14 @@ object Variants { val w = Window.partitionBy(locus: _*).orderBy($"sample_mutated".desc) val cmc = cosmic.selectLocus( - $"mutation_url", - $"shared_aa", - $"genomic_mutation_id" as "cosmic_id", - $"cosmic_sample_mutated" as "sample_mutated", - $"cosmic_sample_tested" as "sample_tested", - $"mutation_significance_tier" as "tier", - $"cosmic_sample_mutated".divide($"cosmic_sample_tested") as "sample_ratio" - ) + $"mutation_url", + $"shared_aa", + $"genomic_mutation_id" as "cosmic_id", + $"cosmic_sample_mutated" as "sample_mutated", + $"cosmic_sample_tested" as "sample_tested", + $"mutation_significance_tier" as "tier", + $"cosmic_sample_mutated".divide($"cosmic_sample_tested") as "sample_ratio" + ) // Deduplicate .withColumn("rn", row_number().over(w)) .filter($"rn" === 1)