From fa52e6ea39f5ac992f40b12ed014b66a7d41be07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20B=C3=A9gin?= Date: Thu, 1 Aug 2024 11:46:02 -0400 Subject: [PATCH] perf(enriched_genes): UNIC-3037 Broadcast and agg smaller gene set dfs (#234) --- .../spark3/publictables/enriched/Genes.scala | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/publictables/enriched/Genes.scala b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/publictables/enriched/Genes.scala index fe134aae..097f402e 100644 --- a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/publictables/enriched/Genes.scala +++ b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/publictables/enriched/Genes.scala @@ -3,8 +3,6 @@ package bio.ferlab.datalake.spark3.publictables.enriched import bio.ferlab.datalake.commons.config.{Coalesce, DatasetConf, RuntimeETLContext} import bio.ferlab.datalake.spark3.etl.v4.SimpleSingleETL import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._ -import bio.ferlab.datalake.spark3.implicits.GenomicImplicits._ -import bio.ferlab.datalake.spark3.implicits.GenomicImplicits.columns.locusColumnNames import bio.ferlab.datalake.spark3.implicits.SparkUtils.removeEmptyObjectsIn import bio.ferlab.datalake.spark3.publictables.enriched.Genes._ import mainargs.{ParserForMethods, main} @@ -76,17 +74,18 @@ object Genes { implicit class DataFrameOps(df: DataFrame) { - def joinAndMergeWith(gene_set: DataFrame, - joinOn: Seq[String], - asColumnName: String, - aggFirst: Boolean = false): DataFrame = { + def joinAndMergeWith(other: DataFrame, + joinOn: Seq[String], + asColumnName: String, + aggFirst: Boolean = false, + broadcastOtherDf: Boolean = false): DataFrame = { val aggFn: Column => Column = c => if (aggFirst) first(c) else collect_list(c) val aggDF = df - .join(gene_set, joinOn, "left") + .join(if (broadcastOtherDf) broadcast(other) else other, joinOn, "left") .groupBy("symbol") .agg( first(struct(df("*"))) as "hg", - aggFn(struct(gene_set.drop(joinOn: _*)("*"))) as asColumnName, + aggFn(struct(other.drop(joinOn: _*)("*"))) as asColumnName, ) .select(col("hg.*"), col(asColumnName)) if (aggFirst) @@ -102,14 +101,14 @@ object Genes { max("pLI") as "pli", max("oe_lof_upper") as "loeuf" ) - df.joinAndMergeWith(gnomadConstraint, Seq("chromosome", "symbol"), "gnomad", aggFirst = true) + df.joinAndMergeWith(gnomadConstraint, Seq("chromosome", "symbol"), "gnomad", aggFirst = true, broadcastOtherDf = true) } def withOrphanet(orphanet: DataFrame): DataFrame = { val orphanetPrepared = orphanet .select(col("gene_symbol") as "symbol", col("disorder_id"), col("name") as "panel", col("type_of_inheritance") as "inheritance") - df.joinAndMergeWith(orphanetPrepared, Seq("symbol"), "orphanet") + df.joinAndMergeWith(orphanetPrepared, Seq("symbol"), "orphanet", broadcastOtherDf = true) } def withOmim(omim: DataFrame): DataFrame = { @@ -120,35 +119,37 @@ object Genes { col("phenotype.omim_id") as "omim_id", col("phenotype.inheritance") as "inheritance", col("phenotype.inheritance_code") as "inheritance_code") - df.joinAndMergeWith(omimPrepared, Seq("omim_gene_id"), "omim") + df.joinAndMergeWith(omimPrepared, Seq("omim_gene_id"), "omim", broadcastOtherDf = true) } def withDDD(ddd: DataFrame): DataFrame = { val dddPrepared = ddd.select("disease_name", "symbol") - df.joinAndMergeWith(dddPrepared, Seq("symbol"), "ddd") + df.joinAndMergeWith(dddPrepared, Seq("symbol"), "ddd", broadcastOtherDf = true) } def withCosmic(cosmic: DataFrame): DataFrame = { val cosmicPrepared = cosmic.select("symbol", "tumour_types_germline") - df.joinAndMergeWith(cosmicPrepared, Seq("symbol"), "cosmic") + df.joinAndMergeWith(cosmicPrepared, Seq("symbol"), "cosmic", broadcastOtherDf = true) } def withHPO(hpo: DataFrame): DataFrame = { val hpoPrepared = hpo.select(col("entrez_gene_id"), col("hpo_term_id"), col("hpo_term_name")) .distinct() .withColumn("hpo_term_label", concat(col("hpo_term_name"), lit(" ("), col("hpo_term_id"), lit(")"))) - df.joinAndMergeWith(hpoPrepared, Seq("entrez_gene_id"), "hpo") + df.joinAndMergeWith(hpoPrepared, Seq("entrez_gene_id"), "hpo", broadcastOtherDf = true) } def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = { import spark.implicits._ val spliceAiPrepared = spliceai + .groupBy("symbol") + .agg(first("max_score") as "max_score") .select($"symbol", $"max_score.*") .withColumn("type", when($"ds" === 0, null).otherwise($"type")) df - .joinAndMergeWith(spliceAiPrepared, Seq("symbol"), "spliceai", aggFirst = true) + .joinAndMergeWith(spliceAiPrepared, Seq("symbol"), "spliceai", aggFirst = true, broadcastOtherDf = true) .withColumn("spliceai", when($"spliceai.ds".isNull and $"spliceai.type".isNull, null).otherwise($"spliceai")) } }