Skip to content

Commit

Permalink
Avoid Copying Fit Models (#219)
Browse files Browse the repository at this point in the history
* Modify Prediction to avoid copying the model and fit when used
in a chained call such as:

  model.fit(dataset).predict(features);

* use _t helpers
  • Loading branch information
akleeman authored Mar 28, 2020
1 parent 8605e60 commit 81b8704
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
17 changes: 17 additions & 0 deletions include/albatross/src/core/declarations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ using mapbox::util::variant;

namespace albatross {

/*
* We frequently inspect for definitions of functions which
* must be defined for const references to objects
* (so that repeated evaluations return the same thing
* and so the computations are not repeatedly copying.)
* This type conversion utility will turn a type `T` into `const T&`
*/
template <class T> struct const_ref {
typedef std::add_lvalue_reference_t<std::add_const_t<T>> type;
};

template <typename T> using const_ref_t = typename const_ref<T>::type;

/*
* Model
*/
Expand All @@ -35,6 +48,10 @@ template <typename T> struct PredictTypeIdentity;
template <typename ModelType, typename FeatureType, typename FitType>
class Prediction;

template <typename ModelType, typename FeatureType, typename FitType>
using PredictionReference =
Prediction<const_ref_t<ModelType>, FeatureType, const_ref_t<FitType>>;

template <typename ModelType, typename FitType> class FitModel;

template <typename Derived> class Fit {};
Expand Down
23 changes: 16 additions & 7 deletions include/albatross/src/core/fit_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,32 @@ template <typename ModelType, typename Fit> class FitModel {
FitModel(const ModelType &model, Fit &&fit)
: model_(model), fit_(std::move(fit)) {}

// When FitModel is an lvalue we store a reference to the fit
// inside the resulting Prediction class.
template <typename PredictFeatureType>
const PredictionReference<ModelType, PredictFeatureType, Fit>
predict(const std::vector<PredictFeatureType> &features) const & {
return PredictionReference<ModelType, PredictFeatureType, Fit>(model_, fit_,
features);
}

// When FitModel is an rvalue the Fit will be a temporary so
// we move it into the Prediction class to be stored there.
template <typename PredictFeatureType>
Prediction<ModelType, PredictFeatureType, Fit>
predict(const std::vector<PredictFeatureType> &features) const {
return Prediction<ModelType, PredictFeatureType, Fit>(model_, fit_,
features);
predict(const std::vector<PredictFeatureType> &features) && {
return Prediction<ModelType, PredictFeatureType, Fit>(
std::move(model_), std::move(fit_), features);
}

template <typename PredictFeatureType>
Prediction<ModelType, Measurement<PredictFeatureType>, Fit>
predict_with_measurement_noise(
auto predict_with_measurement_noise(
const std::vector<PredictFeatureType> &features) const {
std::vector<Measurement<PredictFeatureType>> measurements;
for (const auto &f : features) {
measurements.emplace_back(Measurement<PredictFeatureType>(f));
}
return Prediction<ModelType, Measurement<PredictFeatureType>, Fit>(
model_, fit_, measurements);
return predict(measurements);
}

Fit get_fit() const { return fit_; }
Expand Down
11 changes: 7 additions & 4 deletions include/albatross/src/core/prediction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,17 @@ class JointPredictor {
template <typename ModelType, typename FeatureType, typename FitType>
class Prediction {

using PlainModelType = typename std::decay<ModelType>::type;
using PlainFitType = typename std::decay<FitType>::type;

public:
Prediction(const ModelType &model, const FitType &fit,
Prediction(const PlainModelType &model, const PlainFitType &fit,
const std::vector<FeatureType> &features)
: model_(model), fit_(fit), features_(features) {}

Prediction(const ModelType &model, const FitType &fit,
std::vector<FeatureType> &&features)
: model_(model), fit_(fit), features_(std::move(features)) {}
Prediction(PlainModelType &&model, PlainFitType &&fit,
const std::vector<FeatureType> &features)
: model_(std::move(model)), fit_(std::move(fit)), features_(features) {}

// Mean
template <typename DummyType = FeatureType,
Expand Down
13 changes: 0 additions & 13 deletions include/albatross/src/details/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,6 @@

namespace albatross {

/*
* We frequently inspect for definitions of functions which
* must be defined for const references to objects
* (so that repeated evaluations return the same thing
* and so the computations are not repeatedly copying.)
* This type conversion utility will turn a type `T` into `const T&`
*/
template <class T> struct const_ref {
typedef
typename std::add_lvalue_reference<typename std::add_const<T>::type>::type
type;
};

/*
* This little trick was borrowed from cereal, you can think of it as
* a function that will always return false ... but that doesn't
Expand Down

0 comments on commit 81b8704

Please sign in to comment.