Skip to content

Commit

Permalink
[SYSTEMDS-3606] Performance shuffle-based spark quaternary operations
Browse files Browse the repository at this point in the history
This patch significantly improves the performance of shuffle-based
spark quaternary operations, where more than one input is an RDD
(too large to broadcast). Instead of replicating the factor blocks, we
now use custom join keys enabling spark to perform more efficient
1:M joins. With appropriate function abstractions, the implementation
also got simpler and thus, easier to maintain.

On the scenario mentioned in the JIRA task, the original implementation
did not finish any task of the first shuffle phase after >9000s, while
with the new implementation the entire script (with two shuffle-based
quaternary operators) finishes in 1276s. Here are the stats:

SystemDS Statistics:
Total elapsed time:             1276.917 sec.
Total compilation time:         2.338 sec.
Total execution time:           1274.578 sec.
Number of compiled Spark inst:  4.
Number of executed Spark inst:  4.
Cache hits (Mem/Li/WB/FS/HDFS): 13/2/0/1/0.
Cache writes (Li/WB/FS/HDFS):   4/6/4/1.
Cache times (ACQr/m, RLS, EXP): 1209.517/0.001/10.926/8.589 sec.
HOP DAGs recompiled (PRED, SB): 0/1.
HOP DAGs recompile time:        0.006 sec.
Functions recompiled:           1.
Functions recompile time:       0.011 sec.
Spark ctx create time (lazy):   19.302 sec.
Spark trans counts (par,bc,col):0/3/1.
Spark trans times (par,bc,col): 0.000/13.671/644.719 secs.
Spark async. count (pf,bc,op):  0/0/0.
Total JIT compile time:         73.677 sec.
Total JVM GC count:             188.
Total JVM GC time:              23.182 sec.
Heavy hitter instructions:
  1  m_pnmf        714.304      1
  2  r'            653.012      5
  3  uak+          560.027      2
  4  sp_redwdivmm   42.446      2
  5  rand            9.414      4
  6  *               3.544      1
  7  /               3.491      1
  8  uack+           3.466      1
  9  uark+           2.146      1
 10  rmvar           0.246     15
  • Loading branch information
mboehm7 committed Aug 10, 2023
1 parent e4f988d commit 33de453
Showing 1 changed file with 59 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
Expand All @@ -44,7 +45,6 @@
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
Expand Down Expand Up @@ -198,11 +198,6 @@ public void processInstruction(ExecutionContext ec) {
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;

DataCharacteristics inMc = sec.getDataCharacteristics( input1.getName() );
long rlen = inMc.getRows();
long clen = inMc.getCols();
int blen = inMc.getBlocksize();

//pre-filter empty blocks (ultra-sparse matrices) for full aggregates
//(map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
if( qop.wtype1 != null || qop.wtype4 != null ) {
Expand Down Expand Up @@ -237,42 +232,25 @@ public void processInstruction(ExecutionContext ec) {
JavaPairRDD<MatrixIndexes,MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ?
sec.getBinaryMatrixBlockRDDHandleForVariable( _input4.getName() ) : null;

//preparation of transposed and replicated U
if( inU != null )
inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, blen, true));

//preparation of transposed and replicated V
if( inV != null )
inV = inV.mapToPair(new TransposeFactorIndexesFunction())
.flatMapToPair(new ReplicateBlockFunction(rlen, blen, false));
//join X and W on original indexes if W existing
JavaPairRDD<MatrixIndexes,MatrixBlock[]> tmp = (inW != null) ?
in.join(inW).mapToPair(new ToArray()) :
in.mapValues(mb -> new MatrixBlock[]{mb, null});

//functions calls w/ two rdd inputs
if( inU != null && inV == null && inW == null )
out = in.join(inU)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if( inU == null && inV != null && inW == null )
out = in.join(inV)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if( inU == null && inV == null && inW != null )
out = in.join(inW)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
//function calls w/ three rdd inputs
else if( inU != null && inV != null && inW == null )
out = in.join(inU).join(inV)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU != null && inV == null && inW != null )
out = in.join(inU).join(inW)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU == null && inV != null && inW != null )
out = in.join(inV).join(inW)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU == null && inV == null && inW == null ) {
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
}
//function call w/ four rdd inputs
else //need keys in case of wdivmm
out = in.join(inU).join(inV).join(inW)
.mapToPair(new RDDQuaternaryFunction4(qop));
//join lhs U on row-block indexes of X/W
tmp = ( inU != null ) ?
tmp.mapToPair(new ExtractIndexWith(true))
.join(inU.mapToPair(new ExtractIndex(true))).mapToPair(new Unpack()) :
tmp.mapValues(mb -> ArrayUtils.add(mb, null));

//join rhs V on column-block indexes X/W (note V transposed input, so rows)
tmp = ( inV != null ) ?
tmp.mapToPair(new ExtractIndexWith(false))
.join(inV.mapToPair(new ExtractIndex(true))).mapToPair(new Unpack()) :
tmp.mapValues(mb -> ArrayUtils.add(mb, null));

//execute quaternary block operations on joined inputs
out = tmp.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));

//keep variable names for lineage maintenance
if( inU == null ) bcVars.add(input2.getName()); else rddVars.add(input2.getName());
Expand Down Expand Up @@ -374,12 +352,11 @@ public RDDQuaternaryPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock
protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) {
MatrixIndexes ixIn = arg._1();
MatrixBlock blkIn = arg._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = _pmU.getBlock((int)ixIn.getRowIndex(), 1);
MatrixBlock mbV = _pmV.getBlock((int)ixIn.getColumnIndex(), 1);

//execute core operation
blkIn.quaternaryOperations(_qop, mbU, mbV, null, blkOut);
MatrixBlock blkOut = blkIn.quaternaryOperations(_qop, mbU, mbV, null, new MatrixBlock());

//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
Expand All @@ -389,7 +366,7 @@ protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, M
}

private static class RDDQuaternaryFunction2 extends RDDQuaternaryBaseFunction //two rdd input
implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock>
implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 7493974462943080693L;

Expand All @@ -398,100 +375,67 @@ public RDDQuaternaryFunction2( QuaternaryOperator qop, PartitionedBroadcast<Matr
}

@Override
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) {
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg0) {
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1();
MatrixBlock blkIn2 = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blkIn2;
MatrixBlock mbW = (_qop.hasFourInputs()) ? blkIn2 : null;
MatrixBlock[] blks = arg0._2();
MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blks[2];
MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blks[3];
MatrixBlock mbW = (_qop.hasFourInputs()) ? blks[1] : null;

//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
MatrixBlock blkOut = blks[0].quaternaryOperations(_qop, mbU, mbV, mbW, new MatrixBlock());

//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2<>(ixOut, blkOut);
}
}

private static class RDDQuaternaryFunction3 extends RDDQuaternaryBaseFunction //three rdd input
implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -2294086455843773095L;

public RDDQuaternaryFunction3( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) {
super(qop, bcU, bcV);
private static class ExtractIndex implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, Long, MatrixBlock> {
private static final long serialVersionUID = -6542246824481788376L;
private final boolean _row;
public ExtractIndex(boolean row) {
_row = row;
}

@Override
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg0) {
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1()._1();
MatrixBlock blkIn2 = arg0._2()._1()._2();
MatrixBlock blkIn3 = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) :
(_pmU!=null)? blkIn2 : blkIn3;
MatrixBlock mbW = (_qop.hasFourInputs())? blkIn3 : null;

//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);

//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2<>(ixOut, blkOut);
public Tuple2<Long, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
return new Tuple2<>(_row?arg._1().getRowIndex():arg._1().getColumnIndex(), arg._2());
}
}

/**
* Note: never called for wsigmoid/wdivmm (only wsloss)
*/
private static class RDDQuaternaryFunction4 extends RDDQuaternaryBaseFunction //four rdd input
implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>,MatrixBlock>>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 7328911771600289250L;

public RDDQuaternaryFunction4( QuaternaryOperator qop ) {
super(qop, null, null);
private static class ExtractIndexWith implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock[]>, Long, Tuple2<MatrixIndexes,MatrixBlock[]>> {
private static final long serialVersionUID = -966212318512764461L;
private final boolean _row;
public ExtractIndexWith(boolean row) {
_row = row;
}
@Override
public Tuple2<Long, Tuple2<MatrixIndexes, MatrixBlock[]>> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg)
throws Exception
{
return new Tuple2<>(_row?arg._1().getRowIndex():arg._1().getColumnIndex(), arg);
}
}

private static class ToArray implements PairFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock[]> {
private static final long serialVersionUID = -4856316007590144978L;

@Override
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock>> arg0)
public Tuple2<MatrixIndexes, MatrixBlock[]> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
throws Exception
{
MatrixIndexes ixIn1 = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1()._1()._1();
MatrixBlock mbU = arg0._2()._1()._1()._2();
MatrixBlock mbV = arg0._2()._1()._2();
MatrixBlock mbW = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();

//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);

//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn1);
return new Tuple2<>(ixOut, blkOut);
return new Tuple2<>(arg._1(), new MatrixBlock[]{arg._2()._1(),arg._2()._2()});
}
}

private static class TransposeFactorIndexesFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -2571724736131823708L;

private static class Unpack implements PairFunction<Tuple2<Long, Tuple2<Tuple2<MatrixIndexes,MatrixBlock[]>,MatrixBlock>>, MatrixIndexes, MatrixBlock[]> {
private static final long serialVersionUID = 3812660351709830714L;
@Override
public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) {
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();

//swap the matrix indexes
MatrixIndexes ixOut = new MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex());
MatrixBlock blkOut = new MatrixBlock(blkIn);

//output new tuple
return new Tuple2<>(ixOut,blkOut);
public Tuple2<MatrixIndexes, MatrixBlock[]> call(
Tuple2<Long, Tuple2<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixBlock>> arg) throws Exception
{
return new Tuple2<>(arg._2()._1()._1(), //matrix indexes
ArrayUtils.addAll(arg._2()._1()._2(), arg._2()._2())); //array of matrix blocks
}
}
}

0 comments on commit 33de453

Please sign in to comment.