Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

printErrors method for HybridFactorGraph #1670

Merged
merged 13 commits into from
Jan 7, 2024
30 changes: 30 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->errorTree(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
Expand Down
17 changes: 17 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Error method using HybridValues which returns specific error for
* assignment.
*/
using Base::error;

/**
* @brief Compute log probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which
* to compute the log probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

Expand Down
79 changes: 79 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,85 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}

/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
const KeyFormatter &keyFormatter,
const std::function<bool(const Factor * /*factor*/,
double /*whitenedError*/, size_t /*index*/)>
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);

if (hc->isContinuous()) {
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
std::cout << "error = ";
hc->asDiscrete()->errorTree().print("", keyFormatter);
std::cout << "\n";
} else {
// Is hybrid
std::cout << "error = ";
hc->asMixture()->errorTree(values.continuous()).print();
std::cout << "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->errorTree().print("", keyFormatter);
}

} else {
continue;
}

std::cout << "\n";
}
std::cout.flush();
}

/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
Expand Down
16 changes: 13 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @{

// TODO(dellaert): customize print and equals.
// void print(const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
// override;
// void print(
// const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

void printErrors(
const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;

// bool equals(const This& fg, double tol = 1e-9) const override;

/// @}
Expand Down
92 changes: 92 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,98 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
}
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::printErrors(
const HybridValues& values, const std::string& str,
const KeyFormatter& keyFormatter,
const std::function<bool(const Factor* /*factor*/, double /*whitenedError*/,
size_t /*index*/)>& printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto&& factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto mf = std::dynamic_pointer_cast<MixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
mf->errorTree(values.nonlinear()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto gmf =
std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gm->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
const double errorValue = (factor != nullptr ? nf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->errorTree().print("", keyFormatter);
std::cout << std::endl;
}

} else {
continue;
}

std::cout << "\n";
}
std::cout.flush();
}

/* ************************************************************************* */
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const {
Expand Down
12 changes: 11 additions & 1 deletion gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
protected:
public:
using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class
using This = HybridNonlinearFactorGraph; ///< this class
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This

using Values = gtsam::Values; ///< backwards compatibility
Expand Down Expand Up @@ -63,6 +63,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
const std::string& s = "HybridNonlinearFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

/** print errors along with factors*/
void printErrors(
const HybridValues& values,
const std::string& str = "HybridNonlinearFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;

/// @}
/// @name Standard Interface
/// @{
Expand Down
39 changes: 39 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) {
*gbn.at(3)));
}

/* ****************************************************************************/
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, Error) {
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);

const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0));

const auto conditional0 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(5), I_1x1, model0),
conditional1 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(2), I_1x1, model1);

auto gm =
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1});
// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.push_back(continuousConditional);
bayesNet.emplace_back(gm);
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));

// Create values at which to evaluate.
HybridValues values;
values.insert(asiaKey, 0);
values.insert(X(0), Vector1(-6));
values.insert(X(1), Vector1(1));

AlgebraicDecisionTree<Key> actual_errors =
bayesNet.errorTree(values.continuous());

// Regression.
// Manually added all the error values from the 3 conditional types.
AlgebraicDecisionTree<Key> expected_errors(
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});

EXPECT(assert_equal(expected_errors, actual_errors));
}

/* ****************************************************************************/
// Test Bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) {
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
}

/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Check that the bayes net unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridBayesNet &posterior, size_t num_samples = 100) {
Expand Down
Loading