Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/issue 587 expose hessians #590

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
30 changes: 15 additions & 15 deletions rstan/rstan/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export(
stan_model,
stanc,
stanc_builder,
stan_version,
stan_version,
stan,
stan_rdump,
read_rdump,
Expand All @@ -35,12 +35,12 @@ export(
rstan_options,
As.mcmc.list,
set_cppo,
stan_plot,
stan_trace,
stan_hist,
stan_dens,
stan_plot,
stan_trace,
stan_hist,
stan_dens,
stan_scat,
stan_ac,
stan_ac,
stan_diag,
stan_par,
stan_rhat,
Expand All @@ -54,8 +54,8 @@ export(
cpp_object_initializer,
# get_rstan.options
check_hmc_diagnostics,
check_divergences,
check_treedepth,
check_divergences,
check_treedepth,
check_energy,
get_divergent_iterations,
get_max_treedepth_iterations,
Expand All @@ -70,21 +70,21 @@ export(
)

exportClasses(
stanmodel, stanfit
)
stanmodel, stanfit
)
exportMethods(
# print, plot,
# extract,
# print, plot,
# extract,
optimizing, vb,
get_cppcode, get_cxxflags, # for stanmodel
get_cppcode, get_cxxflags, # for stanmodel
show, sampling, summary, extract,
traceplot, plot, get_stancode, get_inits, get_seed, get_cppo_mode,
log_prob, grad_log_prob,
log_prob, grad_log_prob, hessian_log_prob, hessian_times_vector_log_prob,
unconstrain_pars, constrain_pars, get_num_upars,
get_seeds,
get_adaptation_info,
get_sampler_params,
get_logposterior,
get_logposterior,
get_posterior_mean,
get_elapsed_time,
get_stanmodel,
Expand Down
24 changes: 24 additions & 0 deletions rstan/rstan/R/stanfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,30 @@ setMethod("grad_log_prob", signature = "stanfit",
[email protected]$stan_fit_instance$grad_log_prob(upars, adjust_transform)
})

if (!isGeneric("hessian_log_prob")) {
setGeneric(name = "hessian_log_prob",
def = function(object, ...) { standardGeneric("hessian_log_prob") })
}

setMethod("hessian_log_prob", signature = "stanfit",
function(object, upars, adjust_transform = TRUE) {
if (!is_sfinstance_valid(object))
stop("the model object is not created or not valid")
[email protected]$stan_fit_instance$hessian_log_prob(upars, adjust_transform)
})

if (!isGeneric("hessian_times_vector_log_prob")) {
setGeneric(name = "hessian_times_vector_log_prob",
def = function(object, ...) { standardGeneric("hessian_times_vector_log_prob") })
}

setMethod("hessian_times_vector_log_prob", signature = "stanfit",
function(object, upars, v, adjust_transform = TRUE) {
if (!is_sfinstance_valid(object))
stop("the model object is not created or not valid")
[email protected]$stan_fit_instance$hessian_times_vector_log_prob(upars, v, adjust_transform)
})

setMethod("traceplot", signature = "stanfit",
function(object, pars, include = TRUE, unconstrain = FALSE,
inc_warmup = FALSE, window = NULL, nrow = NULL, ncol = NULL,
Expand Down
4 changes: 4 additions & 0 deletions rstan/rstan/inst/include/rstan/rcpp_module_def_for_rstan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ RCPP_MODULE(stan_fit4%model_name%_mod){
&rstan::stan_fit<%model_name%_namespace::%model_name%, boost::random::ecuyer1988>::param_oi_tidx)
.method("grad_log_prob",
&rstan::stan_fit<%model_name%_namespace::%model_name%, boost::random::ecuyer1988>::grad_log_prob)
.method("hessian_times_vector_log_prob",
&rstan::stan_fit<%model_name%_namespace::%model_name%, boost::random::ecuyer1988>::hessian_times_vector_log_prob)
.method("hessian_log_prob",
&rstan::stan_fit<%model_name%_namespace::%model_name%, boost::random::ecuyer1988>::hessian_log_prob)
.method("log_prob",
&rstan::stan_fit<%model_name%_namespace::%model_name%, boost::random::ecuyer1988>::log_prob)
.method("unconstrain_pars",
Expand Down
114 changes: 113 additions & 1 deletion rstan/rstan/inst/include/rstan/stan_fit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
// void R_CheckUserInterrupt(void);


// TODO: where are these includes actually supposed to go?
#include <stan/model/hessian.hpp>
#include <stan/model/hessian_times_vector.hpp>
#include <RcppEigen.h>

// REF: cmdstan: src/cmdstan/command.hpp
#include <stan/callbacks/interrupt.hpp>
#include <stan/callbacks/stream_logger.hpp>
Expand Down Expand Up @@ -1001,7 +1006,7 @@ class stan_fit {
get_all_flatnames(names_oi_, dims_oi_, fnames_oi_, true);
// get_all_indices_col2row(dims_, midx_for_col2row);
}

stan_fit(SEXP data, SEXP seed, SEXP cxxf) :
data_(data),
model_(data_, Rcpp::as<boost::uint32_t>(seed), &rstan::io::rcout),
Expand Down Expand Up @@ -1177,6 +1182,113 @@ class stan_fit {
END_RCPP
}


/**
* Expose the hessian_log_prob of the model to stan_fit so R user
* can call this function.
*
* @param upar The real parameters on the unconstrained
* space.
* @param jacobian_adjust_transform TRUE/FALSE, whether
* we add the term due to the transform from constrained
* space to unconstrained space implicitly done in Stan.
*/
SEXP hessian_log_prob(SEXP upar, SEXP jacobian_adjust_transform) {
BEGIN_RCPP
std::vector<double> std_par_r = Rcpp::as<std::vector<double> >(upar);
Eigen::Matrix<double, Eigen::Dynamic, 1> par_r =
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1>>(
std_par_r.data(), std_par_r.size());

if (par_r.size() != model_.num_params_r()) {
std::stringstream msg;
msg << "Number of unconstrained parameters does not match "
"that of the model ("
<< par_r.size() << " vs "
<< model_.num_params_r()
<< ").";
throw std::domain_error(msg.str());
}
//std::vector<int> par_i(model_.num_params_i(), 0);
Eigen::Matrix<double, Eigen::Dynamic, 1> grad_f;
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> hess_f;

double lp;
if (Rcpp::as<bool>(jacobian_adjust_transform))
stan::model::log_prob_hessian<true,true>(
model_, par_r, lp, grad_f, hess_f, &rstan::io::rcout);
else
stan::model::log_prob_hessian<true,false>(
model_, par_r, lp, grad_f, hess_f, &rstan::io::rcout);

Rcpp::NumericVector grad = Rcpp::wrap(grad_f);
Rcpp::NumericMatrix hess = Rcpp::wrap(hess_f);

hess.attr("log_prob") = lp;
hess.attr("gradient") = grad;
SEXP __sexp_result;
PROTECT(__sexp_result = Rcpp::wrap(hess));
UNPROTECT(1);
return __sexp_result;
END_RCPP
}

/**
* Expose the hessian_log_prob of the model to stan_fit so R user
* can call this function.
*
* @param upar The real parameters on the unconstrained
* space.
* @param jacobian_adjust_transform TRUE/FALSE, whether
* we add the term due to the transform from constrained
* space to unconstrained space implicitly done in Stan.
*/
SEXP hessian_times_vector_log_prob(
SEXP upar, SEXP v, SEXP jacobian_adjust_transform) {

BEGIN_RCPP

// Is there a way to convert directly from SEXP to an Eigen matrix?
std::vector<double> std_par_r = Rcpp::as<std::vector<double> >(upar);
Eigen::Matrix<double, Eigen::Dynamic, 1> par_r =
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1>>(
std_par_r.data(), std_par_r.size());

std::vector<double> v_r = Rcpp::as<std::vector<double> >(v);
Eigen::Matrix<double, Eigen::Dynamic, 1> v_vec =
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1>>(
v_r.data(), v_r.size());

if (par_r.size() != model_.num_params_r()) {
std::stringstream msg;
msg << "Number of unconstrained parameters does not match "
"that of the model ("
<< par_r.size() << " vs "
<< model_.num_params_r()
<< ").";
throw std::domain_error(msg.str());
}
//std::vector<int> par_i(model_.num_params_i(), 0);
Eigen::Matrix<double, Eigen::Dynamic, 1> hess_f_dot_v;

double lp;
if (Rcpp::as<bool>(jacobian_adjust_transform))
stan::model::log_prob_hessian_times_vector<true,true>(
model_, par_r, v_vec, lp, hess_f_dot_v, &rstan::io::rcout);
else
stan::model::log_prob_hessian_times_vector<true,false>(
model_, par_r, v_vec, lp, hess_f_dot_v, &rstan::io::rcout);
Rcpp::NumericVector hess_f_dot_v_r = Rcpp::wrap(hess_f_dot_v);

// TODO: do this differently
hess_f_dot_v_r.attr("log_prob") = lp;
SEXP __sexp_result;
PROTECT(__sexp_result = Rcpp::wrap(hess_f_dot_v_r));
UNPROTECT(1);
return __sexp_result;
END_RCPP
}

/**
* Return the number of unconstrained parameters
*/
Expand Down