Skip to content

Commit

Permalink
FindSnapshotFilesForBasis routine for parsing snapshot file names.
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer2368 committed Feb 8, 2024
1 parent 7d1ca41 commit 36746ef
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
2 changes: 2 additions & 0 deletions include/main_workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ void BuildROM(MPI_Comm comm);
void TrainROM(MPI_Comm comm);
// supremizer-enrichment, hypre-reduction optimization, etc..
void AuxiliaryTrainROM(MPI_Comm comm);
// Input parsing routine to list out all snapshot files for training a basis.
void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector<std::string> &file_list);
// return relative error if comparing solution.
double SingleRun(MPI_Comm comm, const std::string output_file = "");

Expand Down
82 changes: 60 additions & 22 deletions src/main_workflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ void TrainROM(MPI_Comm comm)
for (int p = 0; p < basis_tags.size(); p++)
{
std::vector<std::string> file_list(0);
std::string default_filename = sample_generator->GetBaseFilename(sample_generator->GetSamplePrefix(), basis_tags[p]);
default_filename += "_snapshot";
FindSnapshotFilesForBasis(basis_tags[p], default_filename, file_list);
assert(file_list.size() > 0);

int num_basis;

// if optional inputs are specified, parse them first.
Expand All @@ -249,21 +254,9 @@ void TrainROM(MPI_Comm comm)
YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tags[p], basis_list);

// If basis_tags[p] has additional inputs, parse them.
// parse tag-specific number of basis.
if (basis_tag_input)
{
// parse tag-specific number of basis.
num_basis = config.GetOptionFromDict<int>("number_of_basis", num_basis_default, basis_tag_input);

// parse the sample snapshot file list.
file_list = config.GetOptionFromDict<std::vector<std::string>>(
"snapshot_files", std::vector<std::string>(0), basis_tag_input);
YAML::Node snapshot_format = config.FindNodeFromDict("snapshot_format", basis_tag_input);
if (snapshot_format)
{
FilenameParam snapshot_param("", snapshot_format);
snapshot_param.ParseFilenames(file_list);
}
} // if (basis_tag_input)
else
num_basis = num_basis_default;
}
Expand All @@ -273,14 +266,6 @@ void TrainROM(MPI_Comm comm)

assert(num_basis > 0);

// if additional inputs are not specified for snapshot files, set default snapshot file name.
if (file_list.size() == 0)
{
std::string filename = sample_generator->GetBaseFilename(sample_generator->GetSamplePrefix(), basis_tags[p]);
filename += "_snapshot";
file_list.push_back(filename);
}

sample_generator->FormReducedBasis(basis_prefix, basis_tags[p], file_list, num_basis);
} // for (int p = 0; p < basis_tags.size(); p++)

Expand Down Expand Up @@ -314,7 +299,60 @@ void AuxiliaryTrainROM(MPI_Comm comm)
delete solver;
}

// TODO: EQP weight optimization procedure.
/* EQP NNLS procedure */
std::string eqp_str = config.GetOption<std::string>("model_reduction/nonlinear_handling", "none");
if (eqp_str == "eqp")
{
MultiBlockSolver *test = NULL;
test = InitSolver();
test->InitVariables();

if (!test->IsNonlinear())
{
delete test;
return;
}

if (!test->UseRom()) mfem_error("ROM must be enabled for EQP training!\n");

test->LoadReducedBasis();

delete test;
}
}

void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector<std::string> &file_list)
{
file_list.clear();

// tag-specific optional inputs.
YAML::Node basis_list = config.FindNode("basis/tags");

// if optional inputs are specified, parse them first.
if (basis_list)
{
// Find if additional inputs are specified for basis_tag.
YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tag, basis_list);

// If basis_tag has additional inputs, parse them.
if (basis_tag_input)
{
// parse the sample snapshot file list.
file_list = config.GetOptionFromDict<std::vector<std::string>>(
"snapshot_files", std::vector<std::string>(0), basis_tag_input);
YAML::Node snapshot_format = config.FindNodeFromDict("snapshot_format", basis_tag_input);
// if file list is specified with a format, parse through the format.
if (snapshot_format)
{
FilenameParam snapshot_param("", snapshot_format);
snapshot_param.ParseFilenames(file_list);
}
} // if (basis_tag_input)
}

// if additional inputs are not specified for snapshot files, set default snapshot file name.
if (file_list.size() == 0)
file_list.push_back(default_filename);
}

void BuildROM(MPI_Comm comm)
Expand Down

0 comments on commit 36746ef

Please sign in to comment.