-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Caffe2::Operator
Caffe2 has a concept operator, which corresponds to TensorFlow's Op.
Different from Op, an operator usualy accompany with a gradient operator (GradientOp).
Let us take ReluOp
and ReluGradientOp
as an example.
All operators are classes derived from Operaotr<Context>
.
template <typename T, class Context>
class ReluOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReluOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class ReluGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReluGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
Operator<Context>
has a data member Context context_
, which records the current device (or GPU). The constructor of Operator
initializes context_
by passing in its constructor a proto message OperatorDef
. Then Operator::Operator
calls context_.SwitchToDevice(0)
.
Operator<Context
has three virtual functions:
-
RunOnDevice() = 0
is what you want to override, -
Run(stream_id)
callscontext_.SwitchToDevice(stream_id)
,RunOnDevice
, andcontext_.FinishDeviceComputation
, and -
RunAsync
callscontext_.SwitchToDevice(stream_id)
andRunOnDevice
.
[TODO: Check what Context::FinishDeviceComputation
does.]
Operator<Context>
also allows user overriden RunOnDevice
to access inputs and outputs through:
Operator<Context>
derives from class OperatorBase
.