Skip to content

Commit

Permalink
Clin 3037 (#235)
Browse files Browse the repository at this point in the history
* perf: CLIN-3037 Split Enriched SpliceAI in indel and snv

* refactor: CLIN-3037 Move withSpliceAI from genes to variants
Laura Bégin authored Aug 21, 2024
1 parent fa52e6e commit 94924fd
Showing 12 changed files with 245 additions and 123 deletions.
43 changes: 39 additions & 4 deletions datalake-spark3/src/main/resources/reference_kf.conf
Original file line number Diff line number Diff line change
@@ -1112,13 +1112,13 @@ datalake {
},
{
format=DELTA
id="enriched_spliceai"
id="enriched_spliceai_indel"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched"
path="/public/spliceai/enriched/indel"
readoptions {}
repartition {
column-names=[
@@ -1131,11 +1131,46 @@ datalake {
storageid="public_database"
table {
database=variant
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
view {
database="variant_live"
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
writeoptions {
"created_on_column"="created_on"
"is_current_column"="is_current"
"updated_on_column"="updated_on"
"valid_from_column"="valid_from"
"valid_to_column"="valid_to"
}
},
{
format=DELTA
id="enriched_spliceai_snv"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched/snv"
readoptions {}
repartition {
column-names=[
chromosome,
start
]
kind=RepartitionByRange
sort-columns=[]
}
storageid="public_database"
table {
database=variant
name="spliceai_enriched_snv"
}
view {
database="variant_live"
name="spliceai_enriched_snv"
}
writeoptions {
"created_on_column"="created_on"
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ import bio.ferlab.datalake.spark3.implicits.SparkUtils.firstAs
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}

import java.time.LocalDateTime
@@ -25,9 +26,13 @@ import java.time.LocalDateTime
* @param snvDatasetId the id of the dataset containing the SNV variants
* @param frequencies the frequencies to calculate. See [[FrequencyOperations.freq]]
* @param extraAggregations extra aggregations to be computed when grouping occurrences by locus. Will be added to the root of the data
* @param spliceAi bool indicating whether or not to join variants with SpliceAI. Defaults to true.
* @param rc the etl context
*/
case class Variants(rc: RuntimeETLContext, participantId: Column = col("participant_id"), affectedStatus: Column = col("affected_status"), filterSnv: Option[Column] = Some(col("has_alt")), snvDatasetId: String, splits: Seq[OccurrenceSplit], extraAggregations: Seq[Column] = Nil, checkpoint: Boolean = false) extends SimpleSingleETL(rc) {
case class Variants(rc: RuntimeETLContext, participantId: Column = col("participant_id"),
affectedStatus: Column = col("affected_status"), filterSnv: Option[Column] = Some(col("has_alt")),
snvDatasetId: String, splits: Seq[OccurrenceSplit], extraAggregations: Seq[Column] = Nil,
checkpoint: Boolean = false, spliceAi: Boolean = true) extends SimpleSingleETL(rc) {
override val mainDestination: DatasetConf = conf.getDataset("enriched_variants")
if (checkpoint) {
spark.sparkContext.setCheckpointDir(s"${mainDestination.rootPath}/checkpoints")
@@ -41,6 +46,8 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
protected val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
protected val genes: DatasetConf = conf.getDataset("enriched_genes")
protected val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")
protected val spliceai_indel: DatasetConf = conf.getDataset("enriched_spliceai_indel")
protected val spliceai_snv: DatasetConf = conf.getDataset("enriched_spliceai_snv")

override def extract(lastRunValue: LocalDateTime = minValue,
currentRunValue: LocalDateTime = LocalDateTime.now()): Map[String, DataFrame] = {
@@ -54,6 +61,8 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
clinvar.id -> clinvar.read,
genes.id -> genes.read,
cosmic.id -> cosmic.read,
spliceai_indel.id -> (if (spliceAi) spliceai_indel.read else spark.emptyDataFrame),
spliceai_snv.id -> (if (spliceAi) spliceai_snv.read else spark.emptyDataFrame),
snvDatasetId -> conf.getDataset(snvDatasetId).read
)
}
@@ -84,14 +93,13 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
.withClinvar(data(clinvar.id))
.withGenes(data(genes.id))
.withCosmic(data(cosmic.id))
.withSpliceAi(snv = data(spliceai_snv.id), indel = data(spliceai_indel.id), compute = spliceAi)
.withGeneExternalReference
.withVariantExternalReference
.withColumn("locus", concat_ws("-", locus: _*))
.withColumn("hash", sha1(col("locus")))
.drop("genes_symbol")
}


}

object Variants {
@@ -235,6 +243,53 @@ object Variants {

df.joinAndMerge(cmc, "cmc", "left")
}

def withSpliceAi(snv: DataFrame, indel: DataFrame, compute: Boolean = true)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

def joinAndMergeIntoGenes(variants: DataFrame, spliceai: DataFrame): DataFrame = {
if (!variants.isEmpty) {
variants
.select($"*", explode_outer($"genes") as "gene", $"gene.symbol" as "symbol") // explode_outer since genes can be null
.join(spliceai, locusColumnNames :+ "symbol", "left")
.drop("symbol") // only used for joining
.withColumn("gene", struct($"gene.*", $"spliceai")) // add spliceai struct as nested field of gene struct
.groupByLocus()
.agg(
first(struct(variants.drop("genes")("*"))) as "variant",
collect_list("gene") as "genes" // re-create genes list for each locus, now containing spliceai struct
)
.select("variant.*", "genes")
} else variants
}

if (compute) {
val spliceAiSnvPrepared = snv
.selectLocus($"symbol", $"max_score" as "spliceai")

val spliceAiIndelPrepared = indel
.selectLocus($"symbol", $"max_score" as "spliceai")

val snvVariants = df
.where($"variant_class" === "SNV")

val otherVariants = df
.where($"variant_class" =!= "SNV")

val snvVariantsWithSpliceAi = joinAndMergeIntoGenes(snvVariants, spliceAiSnvPrepared)
val otherVariantsWithSpliceAi = joinAndMergeIntoGenes(otherVariants, spliceAiIndelPrepared)

snvVariantsWithSpliceAi.unionByName(otherVariantsWithSpliceAi, allowMissingColumns = true)
} else {
// Add empty spliceai struct
df
.withColumn("genes", transform($"genes", g => g.withField("spliceai", lit(null).cast(
StructType(Seq(
StructField("ds", DoubleType),
StructField("type", ArrayType(StringType))
))))))
}
}
}
}

Original file line number Diff line number Diff line change
@@ -73,7 +73,10 @@ object ImportPublicTable {
def spliceai_snv(rc: RuntimeETLContext): Unit = SpliceAi.run(rc, "snv")

@main
def spliceai_enriched(rc: RuntimeETLContext): Unit = enriched.SpliceAi.run(rc)
def spliceai_enriched_indel(rc: RuntimeETLContext): Unit = enriched.SpliceAi.run(rc, "indel")

@main
def spliceai_enriched_snv(rc: RuntimeETLContext): Unit = enriched.SpliceAi.run(rc, "snv")

@main
def topmed_bravo(rc: RuntimeETLContext): Unit = TopMed.run(rc)
Original file line number Diff line number Diff line change
@@ -60,10 +60,11 @@ case class PublicDatasets(alias: String, tableDatabase: Option[String], viewData
DatasetConf("normalized_spliceai_snv" , alias, "/public/spliceai/snv" , DELTA, OverWrite , partitionby = List("chromosome"), table = table("spliceai_snv") , view = view("spliceai_snv")),

// enriched
DatasetConf("enriched_genes" , alias, "/public/genes" , DELTA, OverWrite , partitionby = List() , table = table("genes") , view = view("genes")),
DatasetConf("enriched_dbnsfp" , alias, "/public/dbnsfp/scores" , DELTA, OverWrite , partitionby = List("chromosome"), table = table("dbnsfp_original") , view = view("dbnsfp_original")),
DatasetConf("enriched_spliceai" , alias, "/public/spliceai/enriched" , DELTA, OverWrite , partitionby = List("chromosome"), repartition= Some(RepartitionByRange(columnNames = Seq("chromosome", "start"))), table = table("spliceai_enriched") , view = view("spliceai_enriched")),
DatasetConf("enriched_rare_variant" , alias, "/public/rare_variant/enriched", DELTA, OverWrite , partitionby = List("chromosome", "is_rare"), table = table("rare_variant_enriched"), view = view("rare_variant_enriched"))
DatasetConf("enriched_genes" , alias, "/public/genes" , DELTA, OverWrite , partitionby = List() , table = table("genes") , view = view("genes")),
DatasetConf("enriched_dbnsfp" , alias, "/public/dbnsfp/scores" , DELTA, OverWrite , partitionby = List("chromosome"), table = table("dbnsfp_original") , view = view("dbnsfp_original")),
DatasetConf("enriched_spliceai_indel" , alias, "/public/spliceai/enriched/indel", DELTA, OverWrite , partitionby = List("chromosome"), repartition= Some(RepartitionByRange(columnNames = Seq("chromosome", "start"))), table = table("spliceai_enriched_indel"), view = view("spliceai_enriched_indel")),
DatasetConf("enriched_spliceai_snv" , alias, "/public/spliceai/enriched/snv" , DELTA, OverWrite , partitionby = List("chromosome"), repartition= Some(RepartitionByRange(columnNames = Seq("chromosome", "start"))), table = table("spliceai_enriched_snv") , view = view("spliceai_enriched_snv")),
DatasetConf("enriched_rare_variant" , alias, "/public/rare_variant/enriched" , DELTA, OverWrite , partitionby = List("chromosome", "is_rare"), table = table("rare_variant_enriched"), view = view("rare_variant_enriched"))

)

Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

override def extract(lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): Map[String, DataFrame] = {
@@ -32,8 +31,7 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
human_genes.id -> human_genes.read,
ddd_gene_set.id -> ddd_gene_set.read,
cosmic_gene_set.id -> cosmic_gene_set.read,
gnomad_constraint.id -> gnomad_constraint.read,
spliceai.id -> spliceai.read
gnomad_constraint.id -> gnomad_constraint.read
)
}

@@ -58,7 +56,6 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
.withDDD(data(ddd_gene_set.id))
.withCosmic(data(cosmic_gene_set.id))
.withGnomadConstraint(data(gnomad_constraint.id))
.withSpliceAi(data(spliceai.id))
}

override def defaultRepartition: DataFrame => DataFrame = Coalesce()
@@ -138,20 +135,6 @@ object Genes {
.withColumn("hpo_term_label", concat(col("hpo_term_name"), lit(" ("), col("hpo_term_id"), lit(")")))
df.joinAndMergeWith(hpoPrepared, Seq("entrez_gene_id"), "hpo", broadcastOtherDf = true)
}

def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

val spliceAiPrepared = spliceai
.groupBy("symbol")
.agg(first("max_score") as "max_score")
.select($"symbol", $"max_score.*")
.withColumn("type", when($"ds" === 0, null).otherwise($"type"))

df
.joinAndMergeWith(spliceAiPrepared, Seq("symbol"), "spliceai", aggFirst = true, broadcastOtherDf = true)
.withColumn("spliceai", when($"spliceai.ds".isNull and $"spliceai.type".isNull, null).otherwise($"spliceai"))
}
}
}

Original file line number Diff line number Diff line change
@@ -3,35 +3,29 @@ package bio.ferlab.datalake.spark3.publictables.enriched
import bio.ferlab.datalake.commons.config.{DatasetConf, RepartitionByRange, RuntimeETLContext}
import bio.ferlab.datalake.spark3.etl.v4.SimpleSingleETL
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits.DatasetConfOperations
import mainargs.{ParserForMethods, main}
import mainargs.{ParserForMethods, arg, main}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, DataFrame, functions}

import java.time.LocalDateTime

case class SpliceAi(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
case class SpliceAi(rc: RuntimeETLContext, variantType: String) extends SimpleSingleETL(rc) {

override val mainDestination: DatasetConf = conf.getDataset("enriched_spliceai")
val spliceai_indel: DatasetConf = conf.getDataset("normalized_spliceai_indel")
val spliceai_snv: DatasetConf = conf.getDataset("normalized_spliceai_snv")
override val mainDestination: DatasetConf = conf.getDataset(s"enriched_spliceai_$variantType")
val normalized_spliceai: DatasetConf = conf.getDataset(s"normalized_spliceai_$variantType")

override def extract(lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): Map[String, DataFrame] = {
Map(
spliceai_indel.id -> spliceai_indel.read,
spliceai_snv.id -> spliceai_snv.read
)
Map(normalized_spliceai.id -> normalized_spliceai.read)
}

override def transformSingle(data: Map[String, DataFrame],
lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): DataFrame = {
import spark.implicits._

val spliceai_snvDf = data(spliceai_snv.id)
val spliceai_indelDf = data(spliceai_indel.id)

val originalColumns = spliceai_snvDf.columns.map(col)
val df = data(normalized_spliceai.id)
val originalColumns = df.columns.map(col)

val getDs: Column => Column = _.getItem(0).getField("ds") // Get delta score
val scoreColumnNames = Array("AG", "AL", "DG", "DL")
@@ -43,8 +37,7 @@ case class SpliceAi(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
.otherwise(c2)
}

spliceai_snvDf
.union(spliceai_indelDf)
df
.select(
originalColumns :+
$"ds_ag".as("AG") :+ // acceptor gain
@@ -57,17 +50,17 @@ case class SpliceAi(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
getDs($"max_score_temp") as "ds",
functions.transform($"max_score_temp", c => c.getField("type")) as "type")
)
.withColumn("max_score", $"max_score".withField("type", when($"max_score.ds" === 0, null).otherwise($"max_score.type")))
.select(originalColumns :+ $"max_score": _*)
}

override def defaultRepartition: DataFrame => DataFrame = RepartitionByRange(columnNames = Seq("chromosome", "start"), n = Some(500))

}

object SpliceAi {
@main
def run(rc: RuntimeETLContext): Unit = {
SpliceAi(rc).run()
def run(rc: RuntimeETLContext, @arg(name = "variant_type", short = 'v', doc = "Variant Type") variantType: String): Unit = {
SpliceAi(rc, variantType).run()
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrThrow(args)
43 changes: 39 additions & 4 deletions datalake-spark3/src/test/resources/config/reference_kf.conf
Original file line number Diff line number Diff line change
@@ -1112,13 +1112,13 @@ datalake {
},
{
format=DELTA
id="enriched_spliceai"
id="enriched_spliceai_indel"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched"
path="/public/spliceai/enriched/indel"
readoptions {}
repartition {
column-names=[
@@ -1131,11 +1131,46 @@ datalake {
storageid="public_database"
table {
database=variant
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
view {
database="variant_live"
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
writeoptions {
"created_on_column"="created_on"
"is_current_column"="is_current"
"updated_on_column"="updated_on"
"valid_from_column"="valid_from"
"valid_to_column"="valid_to"
}
},
{
format=DELTA
id="enriched_spliceai_snv"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched/snv"
readoptions {}
repartition {
column-names=[
chromosome,
start
]
kind=RepartitionByRange
sort-columns=[]
}
storageid="public_database"
table {
database=variant
name="spliceai_enriched_snv"
}
view {
database="variant_live"
name="spliceai_enriched_snv"
}
writeoptions {
"created_on_column"="created_on"
Original file line number Diff line number Diff line change
@@ -3,13 +3,14 @@ package bio.ferlab.datalake.spark3.genomics.enriched
import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.genomics.enriched.Variants.DataFrameOps
import bio.ferlab.datalake.spark3.genomics.{FrequencySplit, SimpleAggregation}
import bio.ferlab.datalake.spark3.implicits.GenomicImplicits._
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.enriched.EnrichedVariant.CMC
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, EnrichedVariant}
import bio.ferlab.datalake.testutils.models.enriched.EnrichedVariant.{CMC, GENES, SPLICEAI}
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, EnrichedSpliceAi, EnrichedVariant, MAX_SCORE}
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.testutils.{SparkSpec, TestETLContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, collect_set, max}
import org.apache.spark.sql.functions.{col, collect_set, max, transform}

class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {

@@ -26,6 +27,8 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
val genes: DatasetConf = conf.getDataset("enriched_genes")
val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")
val spliceai_snv: DatasetConf = conf.getDataset("enriched_spliceai_snv")
val spliceai_indel: DatasetConf = conf.getDataset("enriched_spliceai_indel")

val occurrencesDf: DataFrame = Seq(
NormalizedSNV(`participant_id` = "PA0001", study_id = "S1"),
@@ -41,6 +44,8 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val clinvarDf: DataFrame = Seq(NormalizedClinvar(chromosome = "1", start = 69897, reference = "T", alternate = "C")).toDF
val genesDf: DataFrame = Seq(EnrichedGenes()).toDF()
val cosmicDf: DataFrame = Seq(NormalizedCosmicMutationSet(chromosome = "1", start = 69897, reference = "T", alternate = "C")).toDF()
val spliceAiSnvDf: DataFrame = Seq(EnrichedSpliceAi(chromosome = "1", start = 69897, reference = "T", alternate = "C", symbol = "OR4F5", ds_ag = 0.01, `max_score` = MAX_SCORE(ds = 0.1, `type` = Some(Seq("AG", "AL", "DG", "DL"))))).toDF()
val spliceAiIndelDf: DataFrame = Seq(EnrichedSpliceAi(chromosome = "1", start = 69897, reference = "TTG", alternate = "C", symbol = "OR4F5", ds_ag = 0.01, `max_score` = MAX_SCORE(ds = 0.2, `type` = Some(Seq("AG"))))).toDF()

val etl = Variants(TestETLContext(), snvDatasetId = snvKeyId, splits = Seq(FrequencySplit("frequency", extraAggregations = Seq(SimpleAggregation(name = "zygosities", c = col("zygosity"))))))

@@ -54,9 +59,24 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
dbsnp.id -> dbsnpDf,
clinvar.id -> clinvarDf,
genes.id -> genesDf,
cosmic.id -> cosmicDf
cosmic.id -> cosmicDf,
spliceai_snv.id -> spliceAiSnvDf,
spliceai_indel.id -> spliceAiIndelDf
)

it should "not join with SpliceAI if it is set to false" in {
val noSpliceAiETL = etl.copy(spliceAi = false)

val result = noSpliceAiETL.transformSingle(data).cache()

result
.as[EnrichedVariant]
.collect() should contain theSameElementsAs Seq(
EnrichedVariant(genes = List(GENES(spliceai = None)),
gene_external_reference = List("HPO", "Orphanet", "OMIM", "DDD", "Cosmic", "gnomAD"))
)
}

"transformSingle" should "return expected result" in {
val df = etl.transformSingle(data)

@@ -96,4 +116,44 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val result = df.select("participant_ids", "latest_study").as[(Set[String], String)].collect()
result.head shouldBe(Set("PA0001", "PA0002"), "S2")
}

"withSpliceAi" should "enrich variants with SpliceAi scores" in {
val variants = Seq(
EnrichedVariant(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C" , variant_class = "SNV", `genes` = List(GENES(`symbol` = Some("gene1")), GENES(`symbol` = Some("gene2")))),
EnrichedVariant(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT" , variant_class = "Insertion"),
EnrichedVariant(`chromosome` = "2", `start` = 1, `end` = 2, `reference` = "A", `alternate` = "T" , variant_class = "SNV"),
EnrichedVariant(`chromosome` = "3", `start` = 1, `end` = 2, `reference` = "C", `alternate` = "A" , variant_class = "SNV" , genes = List(null)),
).toDF()

// Remove spliceai nested field from variants df
val variantsWithoutSpliceAi = variants.withColumn("genes", transform($"genes", g => g.dropFields("spliceai")))

val spliceAiSnv = Seq(
// snv
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `max_score` = MAX_SCORE(`ds` = 2.0, `type` = Some(Seq("AL")))),
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene2", `max_score` = MAX_SCORE(`ds` = 0.0, `type` = None)),
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene3", `max_score` = MAX_SCORE(`ds` = 0.0, `type` = None)),
).toDF()

val spliceAiIndel = Seq(
// indel
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT", `symbol` = "OR4F5", `max_score` = MAX_SCORE(`ds` = 1.0, `type` = Some(Seq("AG", "AL"))))
).toDF()

val result = variantsWithoutSpliceAi.withSpliceAi(spliceAiSnv, spliceAiIndel)

val expected = Seq(
EnrichedVariant(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", variant_class = "SNV", `genes` = List(
GENES(`symbol` = Some("gene1"), `spliceai` = Some(SPLICEAI(`ds` = 2.0, `type` = Some(Seq("AL"))))),
GENES(`symbol` = Some("gene2"), `spliceai` = Some(SPLICEAI(`ds` = 0.0, `type` = None))),
)),
EnrichedVariant(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT", variant_class = "Insertion", `genes` = List(GENES(`spliceai` = Some(SPLICEAI(`ds` = 1.0, `type` = Some(Seq("AG", "AL"))))))),
EnrichedVariant(`chromosome` = "2", `start` = 1, `end` = 2, `reference` = "A", `alternate` = "T", variant_class = "SNV" , `genes` = List(GENES(`spliceai` = None))),
EnrichedVariant(`chromosome` = "3", `start` = 1, `end` = 2, `reference` = "C", `alternate` = "A", variant_class = "SNV" , `genes` = List(null))
).toDF().selectLocus($"genes.spliceai").collect()

result
.selectLocus($"genes.spliceai")
.collect() should contain theSameElementsAs expected
}
}
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@ package bio.ferlab.datalake.spark3.publictables.enriched

import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._
import bio.ferlab.datalake.spark3.publictables.enriched.Genes._
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.enriched._
import bio.ferlab.datalake.testutils.models.normalized._
@@ -22,7 +21,6 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

private val inputData = Map(
omim_gene_set.id -> Seq(
@@ -36,8 +34,7 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
gnomad_constraint.id -> Seq(
NormalizedGnomadConstraint(chromosome = "1", start = 69897, symbol = "OR4F5", `pLI` = 1.0f, oe_lof_upper = 0.01f),
NormalizedGnomadConstraint(chromosome = "1", start = 69900, symbol = "OR4F5", `pLI` = 0.9f, oe_lof_upper = 0.054f)
).toDF(),
spliceai.id -> Seq(EnrichedSpliceAi(`symbol` = "OR4F5")).toDF()
).toDF()
)

val job = new Genes(TestETLContext())
@@ -62,7 +59,6 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
functions.size(col("orphanet")),
functions.size(col("ddd")),
functions.size(col("cosmic"))).as[(Long, Long, Long)].collect().head shouldBe(0, 0, 0)

}

it should "write data into genes table" in {
@@ -80,36 +76,5 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
resultDF.where("symbol='OR4F5'").as[EnrichedGenes].collect().head shouldBe
EnrichedGenes(`orphanet` = expectedOrphanet, `omim` = expectedOmim, `cosmic` = expectedCosmic)
}

"withSpliceAi" should "enrich genes with SpliceAi scores" in {
val genes = Seq(
EnrichedGenes(`chromosome` = "1", `symbol` = "gene1"),
EnrichedGenes(`chromosome` = "1", `symbol` = "gene2"),
EnrichedGenes(`chromosome` = "2", `symbol` = "gene3"),
EnrichedGenes(`chromosome` = "3", `symbol` = "gene4"),
).toDF().drop("spliceai")

val spliceai = Seq(
// snv
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `max_score` = MAX_SCORE(`ds` = 2.0, `type` = Seq("AL"))),
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene2", `max_score` = MAX_SCORE(`ds` = 0.0, `type` = Seq("AG", "AL", "DG", "DL"))),

// indel
EnrichedSpliceAi(`chromosome` = "2", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT", `symbol` = "gene3", `max_score` = MAX_SCORE(`ds` = 1.0, `type` = Seq("AG", "AL")))
).toDF()

val result = genes.withSpliceAi(spliceai)

val expected = Seq(
EnrichedGenes(`chromosome` = "1", `symbol` = "gene1", `spliceai` = Some(SPLICEAI(`ds` = 2.0, `type` = Some(List("AL"))))),
EnrichedGenes(`chromosome` = "1", `symbol` = "gene2", `spliceai` = Some(SPLICEAI(`ds` = 0.0, `type` = None))),
EnrichedGenes(`chromosome` = "2", `symbol` = "gene3", `spliceai` = Some(SPLICEAI(`ds` = 1.0, `type` = Some(List("AG", "AL"))))),
EnrichedGenes(`chromosome` = "3", `symbol` = "gene4", `spliceai` = None),
)

result
.as[EnrichedGenes]
.collect() should contain theSameElementsAs expected
}
}

Original file line number Diff line number Diff line change
@@ -1,62 +1,58 @@
package bio.ferlab.datalake.spark3.publictables.enriched

import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedSpliceAi, MAX_SCORE}
import bio.ferlab.datalake.testutils.models.normalized.NormalizedSpliceAi
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.{SparkSpec, TestETLContext}

class SpliceAiSpec extends SparkSpec with WithTestConfig {

import spark.implicits._

val job = new SpliceAi(TestETLContext())
val source: DatasetConf = conf.getDataset("normalized_spliceai_snv")
val destination: DatasetConf = conf.getDataset("enriched_spliceai_snv")

val spliceai_indel: DatasetConf = job.spliceai_indel
val spliceai_snv: DatasetConf = job.spliceai_snv
val destination: DatasetConf = job.mainDestination
val job = SpliceAi(TestETLContext(), variantType = "snv")

"transformSingle" should "transform NormalizedSpliceAi to EnrichedSpliceAi" in {
val inputData = Map(
spliceai_snv.id -> Seq(NormalizedSpliceAi("1")).toDF(),
spliceai_indel.id -> Seq(NormalizedSpliceAi("2")).toDF(),
)
val inputData = Map(source.id -> Seq(NormalizedSpliceAi("1"), NormalizedSpliceAi("2")).toDF())

val resultDF = job.transformSingle(inputData)

// ClassGenerator
// .writeCLassFile(
// "bio.ferlab.datalake.testutils.models.enriched",
// "EnrichedSpliceAi",
// resultDF,
// "datalake-spark3/src/test/scala/")
// ClassGenerator
// .writeCLassFile(
// "bio.ferlab.datalake.testutils.models.enriched",
// "EnrichedSpliceAi",
// resultDF,
// "datalake-spark3/src/test/scala/")

val expected = Seq(EnrichedSpliceAi("1"), EnrichedSpliceAi("2"))
resultDF.as[EnrichedSpliceAi].collect() shouldBe expected
}

"transformSingle" should "compute max score for each variant-gene" in {
"transformSingle" should "compute max score for each variant-gene" in {
val inputData = Map(
spliceai_snv.id -> Seq(
source.id -> Seq(
NormalizedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `ds_ag` = 1.0, `ds_al` = 2.00, `ds_dg` = 0.0, `ds_dl` = 0.0),
NormalizedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene2", `ds_ag` = 0.0, `ds_al` = 0.00, `ds_dg` = 0.0, `ds_dl` = 0.0),
).toDF(),
spliceai_indel.id -> Seq(
NormalizedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT", `symbol` = "gene1", `ds_ag` = 1.0, `ds_al` = 1.00, `ds_dg` = 0.0, `ds_dl` = 0.0),
).toDF(),
NormalizedSpliceAi(`chromosome` = "2", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `ds_ag` = 1.0, `ds_al` = 1.00, `ds_dg` = 0.0, `ds_dl` = 0.0),
NormalizedSpliceAi(`chromosome` = "3", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `ds_ag` = 1.0, `ds_al` = 1.00, `ds_dg` = 1.0, `ds_dl` = 1.0),
).toDF()
)

val resultDF = job.transformSingle(inputData)
resultDF.show(false)

val expected = Seq(
MAX_SCORE(`ds` = 2.00, `type` = Seq("AL")),
MAX_SCORE(`ds` = 0.00, `type` = Seq("AG", "AL", "DG", "DL")),
MAX_SCORE(`ds` = 1.00, `type` = Seq("AG", "AL")),
MAX_SCORE(`ds` = 2.00, `type` = Some(Seq("AL"))),
MAX_SCORE(`ds` = 0.00, `type` = None),
MAX_SCORE(`ds` = 1.00, `type` = Some(Seq("AG", "AL"))),
MAX_SCORE(`ds` = 1.00, `type` = Some(Seq("AG", "AL", "DG", "DL"))),
)

resultDF
.select("max_score.*")
.as[MAX_SCORE].collect() shouldBe expected
}

}
Original file line number Diff line number Diff line change
@@ -19,8 +19,7 @@ case class EnrichedGenes(`symbol`: String = "OR4F5",
`chromosome`: String = "1",
`ddd`: List[DDD] = List(DDD()),
`cosmic`: List[COSMIC] = List(COSMIC()),
gnomad: Option[GNOMAD] = Some(GNOMAD()),
spliceai: Option[SPLICEAI] = Some(SPLICEAI()))
gnomad: Option[GNOMAD] = Some(GNOMAD()))

case class ORPHANET(`disorder_id`: Long = 17827,
`panel`: String = "Immunodeficiency due to a classical component pathway complement deficiency",
@@ -41,6 +40,3 @@ case class COSMIC(`tumour_types_germline`: List[String] = List("breast", "colon"

case class GNOMAD(pli: Float = 1.0f,
loeuf: Float = 0.054f)

case class SPLICEAI(ds: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))
Original file line number Diff line number Diff line change
@@ -22,4 +22,4 @@ case class EnrichedSpliceAi(`chromosome`: String = "1",
`max_score`: MAX_SCORE = MAX_SCORE())

case class MAX_SCORE(`ds`: Double = 0.1,
`type`: Seq[String] = Seq("AG", "AL", "DG", "DL"))
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))

0 comments on commit 94924fd

Please sign in to comment.