Skip to content

Commit

Permalink
Reverted to per-class sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
pierotofy committed Mar 18, 2023
1 parent 65dcf86 commit 13f781b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pctrain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int main(int argc, char **argv){
("s,scales", "Number of scales to compute", cxxopts::value<int>()->default_value(MKSTR(NUM_SCALES)))
("t,trees", "Number of trees in the forest", cxxopts::value<int>()->default_value(MKSTR(N_TREES)))
("d,depth", "Maximum depth of trees", cxxopts::value<int>()->default_value(MKSTR(MAX_DEPTH)))
("m,max-samples", "Maximum number of samples for each input point cloud", cxxopts::value<int>()->default_value("1000000"))
("m,max-samples", "Approximate maximum number of samples for each input point cloud", cxxopts::value<int>()->default_value("100000"))
("radius", "Radius size to use for neighbor search (meters)", cxxopts::value<double>()->default_value(MKSTR(RADIUS)))
("e,eval", "Labeled point cloud to use for model accuracy evaluation", cxxopts::value<std::string>()->default_value(""))
("h,help", "Print usage")
Expand Down
22 changes: 7 additions & 15 deletions randomforest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,28 +58,20 @@ RandomForest *train(const std::vector<std::string> filenames,
}
}

size_t totalSamples = 0;
std::vector<size_t> samplesPerLabel(count.size(), 0);
for (size_t i = 0; i < count.size(); i++) totalSamples += count[i];
if (totalSamples == 0) return rtrees;

// Allocate samples based on their relative distribution
for (size_t i = 0; i < samplesPerLabel.size(); i++){
float perc = static_cast<float>(count[i]) / static_cast<float>(totalSamples);
samplesPerLabel[i] = std::min<size_t>(
std::ceil<size_t>(
perc * static_cast<float>(std::min<size_t>(maxSamples, totalSamples))
),
count[i]);
size_t samplesPerLabel = std::numeric_limits<size_t>::max();
for (std::size_t i = 0; i < labels.size(); i++){
if (count[i] > 0) samplesPerLabel = std::min(count[i], samplesPerLabel);
}

samplesPerLabel = std::min<size_t>(samplesPerLabel, maxSamples);
std::vector<std::size_t> added (labels.size(), 0);

std::cout << "Samples per label: " << samplesPerLabel << std::endl;
std::random_shuffle ( idxes.begin(), idxes.end() );

for (const auto &p : idxes){
size_t idx = p.first;
int g = p.second;
if (added[std::size_t(g)] < samplesPerLabel[std::size_t(g)]){
if (added[std::size_t(g)] < samplesPerLabel){
for (std::size_t f = 0; f < features.size(); f++){
ft.push_back(features[f]->getValue(idx));
}
Expand Down

0 comments on commit 13f781b

Please sign in to comment.