diff --git a/include/main_workflow.hpp b/include/main_workflow.hpp index 19f6faa5..2ddc3c19 100644 --- a/include/main_workflow.hpp +++ b/include/main_workflow.hpp @@ -22,6 +22,8 @@ void BuildROM(MPI_Comm comm); void TrainROM(MPI_Comm comm); // supremizer-enrichment, hypre-reduction optimization, etc.. void AuxiliaryTrainROM(MPI_Comm comm); +// Input parsing routine to list out all snapshot files for training a basis. +void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector &file_list); // return relative error if comparing solution. double SingleRun(MPI_Comm comm, const std::string output_file = ""); diff --git a/src/main_workflow.cpp b/src/main_workflow.cpp index 22aa122c..d1d56555 100644 --- a/src/main_workflow.cpp +++ b/src/main_workflow.cpp @@ -240,6 +240,11 @@ void TrainROM(MPI_Comm comm) for (int p = 0; p < basis_tags.size(); p++) { std::vector file_list(0); + std::string default_filename = sample_generator->GetBaseFilename(sample_generator->GetSamplePrefix(), basis_tags[p]); + default_filename += "_snapshot"; + FindSnapshotFilesForBasis(basis_tags[p], default_filename, file_list); + assert(file_list.size() > 0); + int num_basis; // if optional inputs are specified, parse them first. @@ -249,21 +254,9 @@ void TrainROM(MPI_Comm comm) YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tags[p], basis_list); // If basis_tags[p] has additional inputs, parse them. + // parse tag-specific number of basis. if (basis_tag_input) - { - // parse tag-specific number of basis. num_basis = config.GetOptionFromDict("number_of_basis", num_basis_default, basis_tag_input); - - // parse the sample snapshot file list. - file_list = config.GetOptionFromDict>( - "snapshot_files", std::vector(0), basis_tag_input); - YAML::Node snapshot_format = config.FindNodeFromDict("snapshot_format", basis_tag_input); - if (snapshot_format) - { - FilenameParam snapshot_param("", snapshot_format); - snapshot_param.ParseFilenames(file_list); - } - } // if (basis_tag_input) else num_basis = num_basis_default; } @@ -273,14 +266,6 @@ void TrainROM(MPI_Comm comm) assert(num_basis > 0); - // if additional inputs are not specified for snapshot files, set default snapshot file name. - if (file_list.size() == 0) - { - std::string filename = sample_generator->GetBaseFilename(sample_generator->GetSamplePrefix(), basis_tags[p]); - filename += "_snapshot"; - file_list.push_back(filename); - } - sample_generator->FormReducedBasis(basis_prefix, basis_tags[p], file_list, num_basis); } // for (int p = 0; p < basis_tags.size(); p++) @@ -314,7 +299,60 @@ void AuxiliaryTrainROM(MPI_Comm comm) delete solver; } - // TODO: EQP weight optimization procedure. + /* EQP NNLS procedure */ + std::string eqp_str = config.GetOption("model_reduction/nonlinear_handling", "none"); + if (eqp_str == "eqp") + { + MultiBlockSolver *test = NULL; + test = InitSolver(); + test->InitVariables(); + + if (!test->IsNonlinear()) + { + delete test; + return; + } + + if (!test->UseRom()) mfem_error("ROM must be enabled for EQP training!\n"); + + test->LoadReducedBasis(); + + delete test; + } +} + +void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector &file_list) +{ + file_list.clear(); + + // tag-specific optional inputs. + YAML::Node basis_list = config.FindNode("basis/tags"); + + // if optional inputs are specified, parse them first. + if (basis_list) + { + // Find if additional inputs are specified for basis_tag. + YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tag, basis_list); + + // If basis_tag has additional inputs, parse them. + if (basis_tag_input) + { + // parse the sample snapshot file list. + file_list = config.GetOptionFromDict>( + "snapshot_files", std::vector(0), basis_tag_input); + YAML::Node snapshot_format = config.FindNodeFromDict("snapshot_format", basis_tag_input); + // if file list is specified with a format, parse through the format. + if (snapshot_format) + { + FilenameParam snapshot_param("", snapshot_format); + snapshot_param.ParseFilenames(file_list); + } + } // if (basis_tag_input) + } + + // if additional inputs are not specified for snapshot files, set default snapshot file name. + if (file_list.size() == 0) + file_list.push_back(default_filename); } void BuildROM(MPI_Comm comm)