From 1dacb912d122c8c2277b9f1c44dc5b76bdcb3993 Mon Sep 17 00:00:00 2001 From: Kevin Chung Date: Wed, 10 Apr 2024 11:55:28 -0700 Subject: [PATCH] main workflow: split CollectSamples for TrainEQP. --- include/main_workflow.hpp | 5 +++- include/sample_generator.hpp | 15 ++++++++--- src/main_workflow.cpp | 39 ++++++++++++----------------- src/sample_generator.cpp | 48 +++++++++++++++++++++++++++++++----- 4 files changed, 73 insertions(+), 34 deletions(-) diff --git a/include/main_workflow.hpp b/include/main_workflow.hpp index 1225f1fe..5541ddc3 100644 --- a/include/main_workflow.hpp +++ b/include/main_workflow.hpp @@ -18,10 +18,13 @@ SampleGenerator* InitSampleGenerator(MPI_Comm comm); std::vector GetGlobalBasisTagList(const TopologyHandlerMode &topol_mode, const TrainMode &train_mode, bool separate_variable_basis); void GenerateSamples(MPI_Comm comm); +void CollectSamples(SampleGenerator *sample_generator); void BuildROM(MPI_Comm comm); void TrainROM(MPI_Comm comm); -// supremizer-enrichment, hypre-reduction optimization, etc.. +// supremizer-enrichment etc.. void AuxiliaryTrainROM(MPI_Comm comm, SampleGenerator *sample_generator); +// EQP training, could include hypre-reduction optimization. +void TrainEQP(MPI_Comm comm); // Input parsing routine to list out all snapshot files for training a basis. void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector &file_list); // return relative error if comparing solution. diff --git a/include/sample_generator.hpp b/include/sample_generator.hpp index 89df2e73..71ede8c5 100644 --- a/include/sample_generator.hpp +++ b/include/sample_generator.hpp @@ -96,11 +96,18 @@ class SampleGenerator void ReportStatus(const int &sample_idx); - // Perform SVD over snapshot for basis_tag. Calculate the coverage for ref_num_basis (optional). - void FormReducedBasis(const std::string &basis_prefix, + /* + Collect snapshot matrices from the file list to the specified basis tag. + */ + void CollectSnapshots(const std::string &basis_prefix, const std::string &basis_tag, - const std::vector &file_list, - const int &ref_num_basis = -1); + const std::vector &file_list); + /* + Perform SVD over snapshot for basis_tag. + Calculate the energy fraction for num_basis. + CollectSnapshots must be executed before this. + */ + void FormReducedBasis(const std::string &basis_prefix); private: const int GetDimFromSnapshots(const std::string &filename); diff --git a/src/main_workflow.cpp b/src/main_workflow.cpp index 08756d63..eb83c4ac 100644 --- a/src/main_workflow.cpp +++ b/src/main_workflow.cpp @@ -224,9 +224,9 @@ void GenerateSamples(MPI_Comm comm) config.dict_ = dict0; } -void TrainROM(MPI_Comm comm) +void CollectSamples(SampleGenerator *sample_generator) { - SampleGenerator *sample_generator = InitSampleGenerator(comm); + assert(sample_generator); TopologyHandlerMode topol_mode = SetTopologyHandlerMode(); TrainMode train_mode = SetTrainMode(); @@ -239,7 +239,6 @@ void TrainROM(MPI_Comm comm) YAML::Node basis_list = config.FindNode("basis/tags"); std::string basis_prefix = config.GetOption("basis/prefix", "basis"); - const int num_basis_default = config.GetOption("basis/number_of_basis", -1); // loop over the required basis tag list. for (int p = 0; p < basis_tags.size(); p++) @@ -250,29 +249,18 @@ void TrainROM(MPI_Comm comm) FindSnapshotFilesForBasis(basis_tags[p], default_filename, file_list); assert(file_list.size() > 0); - int num_basis; + sample_generator->CollectSnapshots(basis_prefix, basis_tags[p], file_list); + } // for (int p = 0; p < basis_tags.size(); p++) +} - // if optional inputs are specified, parse them first. - if (basis_list) - { - // Find if additional inputs are specified for basis_tags[p]. - YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tags[p], basis_list); - - // If basis_tags[p] has additional inputs, parse them. - // parse tag-specific number of basis. - if (basis_tag_input) - num_basis = config.GetOptionFromDict("number_of_basis", num_basis_default, basis_tag_input); - else - num_basis = num_basis_default; - } - else - // if additional inputs are not specified, use default number of basis. - num_basis = num_basis_default; +void TrainROM(MPI_Comm comm) +{ + SampleGenerator *sample_generator = InitSampleGenerator(comm); - assert(num_basis > 0); + std::string basis_prefix = config.GetOption("basis/prefix", "basis"); + CollectSamples(sample_generator); - sample_generator->FormReducedBasis(basis_prefix, basis_tags[p], file_list, num_basis); - } // for (int p = 0; p < basis_tags.size(); p++) + sample_generator->FormReducedBasis(basis_prefix); AuxiliaryTrainROM(comm, sample_generator); @@ -330,6 +318,11 @@ void AuxiliaryTrainROM(MPI_Comm comm, SampleGenerator *sample_generator) } } +void TrainEQP(MPI_Comm comm) +{ + +} + void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector &file_list) { file_list.clear(); diff --git a/src/sample_generator.cpp b/src/sample_generator.cpp index 9e99682c..1b02648f 100644 --- a/src/sample_generator.cpp +++ b/src/sample_generator.cpp @@ -243,10 +243,9 @@ void SampleGenerator::ReportStatus(const int &sample_idx) printf("==============================================\n"); } -void SampleGenerator::FormReducedBasis(const std::string &basis_prefix, +void SampleGenerator::CollectSnapshots(const std::string &basis_prefix, const std::string &basis_tag, - const std::vector &file_list, - const int &ref_num_basis) + const std::vector &file_list) { // Get dimension from the first snapshot file. const int fom_num_vdof = GetDimFromSnapshots(file_list[0]); @@ -265,9 +264,46 @@ void SampleGenerator::FormReducedBasis(const std::string &basis_prefix, for (int s = 0; s < file_list.size(); s++) basis_generator->loadSamples(file_list[s], "snapshot", 1e9, CAROM::Database::formats::HDF5_MPIO); +} + +void SampleGenerator::FormReducedBasis(const std::string &basis_prefix) +{ + assert(snapshot_generators.Size() > 0); + assert(snapshot_generators.Size() == basis_tags.size()); + + const int num_basis_default = config.GetOption("basis/number_of_basis", -1); + int num_basis; + std::string basis_name; - basis_generator->endSamples(); // save the merged basis file - SaveSV(basis_generator, basis_name, ref_num_basis); + // tag-specific optional inputs. + YAML::Node basis_list = config.FindNode("basis/tags"); + for (int k = 0; k < snapshot_generators.Size(); k++) + { + assert(snapshot_generators[k]); + assert(snapshot_generators[k]->getNumSamples() > 0); + snapshot_generators[k]->endSamples(); + + // if optional inputs are specified, parse them first. + if (basis_list) + { + // Find if additional inputs are specified for basis_tags[p]. + YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tags[k], basis_list); + + // If basis_tags[p] has additional inputs, parse them. + // parse tag-specific number of basis. + if (basis_tag_input) + num_basis = config.GetOptionFromDict("number_of_basis", num_basis_default, basis_tag_input); + else + num_basis = num_basis_default; + } + else + // if additional inputs are not specified, use default number of basis. + num_basis = num_basis_default; + + assert(num_basis > 0); + basis_name = GetBaseFilename(basis_prefix, basis_tags[k]); + SaveSV(snapshot_generators[k], basis_name, num_basis); + } } const int SampleGenerator::GetDimFromSnapshots(const std::string &filename) @@ -324,7 +360,7 @@ void SampleGenerator::SaveSV(CAROM::BasisGenerator *basis_generator, const std:: } if (rom_sv->dim() == ref_num_basis) coverage = total; coverage /= total; - printf("Coverage: %.7f%%\n", coverage * 100.0); + printf("Energy fraction with %d basis: %.7f%%\n", ref_num_basis, coverage * 100.0); // TODO: hdf5 format + parallel case. std::string filename = prefix + "_sv.txt";