Skip to content

Commit

Permalink
main workflow: split CollectSamples for TrainEQP.
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer2368 committed Apr 10, 2024
1 parent b500c52 commit 1dacb91
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 34 deletions.
5 changes: 4 additions & 1 deletion include/main_workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ SampleGenerator* InitSampleGenerator(MPI_Comm comm);
std::vector<std::string> GetGlobalBasisTagList(const TopologyHandlerMode &topol_mode, const TrainMode &train_mode, bool separate_variable_basis);

void GenerateSamples(MPI_Comm comm);
void CollectSamples(SampleGenerator *sample_generator);
void BuildROM(MPI_Comm comm);
void TrainROM(MPI_Comm comm);
// supremizer-enrichment, hypre-reduction optimization, etc..
// supremizer-enrichment etc..
void AuxiliaryTrainROM(MPI_Comm comm, SampleGenerator *sample_generator);
// EQP training, could include hypre-reduction optimization.
void TrainEQP(MPI_Comm comm);
// Input parsing routine to list out all snapshot files for training a basis.
void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector<std::string> &file_list);
// return relative error if comparing solution.
Expand Down
15 changes: 11 additions & 4 deletions include/sample_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,18 @@ class SampleGenerator

void ReportStatus(const int &sample_idx);

// Perform SVD over snapshot for basis_tag. Calculate the coverage for ref_num_basis (optional).
void FormReducedBasis(const std::string &basis_prefix,
/*
Collect snapshot matrices from the file list to the specified basis tag.
*/
void CollectSnapshots(const std::string &basis_prefix,
const std::string &basis_tag,
const std::vector<std::string> &file_list,
const int &ref_num_basis = -1);
const std::vector<std::string> &file_list);
/*
Perform SVD over snapshot for basis_tag.
Calculate the energy fraction for num_basis.
CollectSnapshots must be executed before this.
*/
void FormReducedBasis(const std::string &basis_prefix);

private:
const int GetDimFromSnapshots(const std::string &filename);
Expand Down
39 changes: 16 additions & 23 deletions src/main_workflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ void GenerateSamples(MPI_Comm comm)
config.dict_ = dict0;
}

void TrainROM(MPI_Comm comm)
void CollectSamples(SampleGenerator *sample_generator)
{
SampleGenerator *sample_generator = InitSampleGenerator(comm);
assert(sample_generator);

TopologyHandlerMode topol_mode = SetTopologyHandlerMode();
TrainMode train_mode = SetTrainMode();
Expand All @@ -239,7 +239,6 @@ void TrainROM(MPI_Comm comm)
YAML::Node basis_list = config.FindNode("basis/tags");

std::string basis_prefix = config.GetOption<std::string>("basis/prefix", "basis");
const int num_basis_default = config.GetOption<int>("basis/number_of_basis", -1);

// loop over the required basis tag list.
for (int p = 0; p < basis_tags.size(); p++)
Expand All @@ -250,29 +249,18 @@ void TrainROM(MPI_Comm comm)
FindSnapshotFilesForBasis(basis_tags[p], default_filename, file_list);
assert(file_list.size() > 0);

int num_basis;
sample_generator->CollectSnapshots(basis_prefix, basis_tags[p], file_list);
} // for (int p = 0; p < basis_tags.size(); p++)
}

// if optional inputs are specified, parse them first.
if (basis_list)
{
// Find if additional inputs are specified for basis_tags[p].
YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tags[p], basis_list);

// If basis_tags[p] has additional inputs, parse them.
// parse tag-specific number of basis.
if (basis_tag_input)
num_basis = config.GetOptionFromDict<int>("number_of_basis", num_basis_default, basis_tag_input);
else
num_basis = num_basis_default;
}
else
// if additional inputs are not specified, use default number of basis.
num_basis = num_basis_default;
void TrainROM(MPI_Comm comm)
{
SampleGenerator *sample_generator = InitSampleGenerator(comm);

assert(num_basis > 0);
std::string basis_prefix = config.GetOption<std::string>("basis/prefix", "basis");
CollectSamples(sample_generator);

sample_generator->FormReducedBasis(basis_prefix, basis_tags[p], file_list, num_basis);
} // for (int p = 0; p < basis_tags.size(); p++)
sample_generator->FormReducedBasis(basis_prefix);

AuxiliaryTrainROM(comm, sample_generator);

Expand Down Expand Up @@ -330,6 +318,11 @@ void AuxiliaryTrainROM(MPI_Comm comm, SampleGenerator *sample_generator)
}
}

void TrainEQP(MPI_Comm comm)
{

}

void FindSnapshotFilesForBasis(const std::string &basis_tag, const std::string &default_filename, std::vector<std::string> &file_list)
{
file_list.clear();
Expand Down
48 changes: 42 additions & 6 deletions src/sample_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,9 @@ void SampleGenerator::ReportStatus(const int &sample_idx)
printf("==============================================\n");
}

void SampleGenerator::FormReducedBasis(const std::string &basis_prefix,
void SampleGenerator::CollectSnapshots(const std::string &basis_prefix,
const std::string &basis_tag,
const std::vector<std::string> &file_list,
const int &ref_num_basis)
const std::vector<std::string> &file_list)
{
// Get dimension from the first snapshot file.
const int fom_num_vdof = GetDimFromSnapshots(file_list[0]);
Expand All @@ -265,9 +264,46 @@ void SampleGenerator::FormReducedBasis(const std::string &basis_prefix,

for (int s = 0; s < file_list.size(); s++)
basis_generator->loadSamples(file_list[s], "snapshot", 1e9, CAROM::Database::formats::HDF5_MPIO);
}

void SampleGenerator::FormReducedBasis(const std::string &basis_prefix)
{
assert(snapshot_generators.Size() > 0);
assert(snapshot_generators.Size() == basis_tags.size());

const int num_basis_default = config.GetOption<int>("basis/number_of_basis", -1);
int num_basis;
std::string basis_name;

basis_generator->endSamples(); // save the merged basis file
SaveSV(basis_generator, basis_name, ref_num_basis);
// tag-specific optional inputs.
YAML::Node basis_list = config.FindNode("basis/tags");
for (int k = 0; k < snapshot_generators.Size(); k++)
{
assert(snapshot_generators[k]);
assert(snapshot_generators[k]->getNumSamples() > 0);
snapshot_generators[k]->endSamples();

// if optional inputs are specified, parse them first.
if (basis_list)
{
// Find if additional inputs are specified for basis_tags[p].
YAML::Node basis_tag_input = config.LookUpFromDict("name", basis_tags[k], basis_list);

// If basis_tags[p] has additional inputs, parse them.
// parse tag-specific number of basis.
if (basis_tag_input)
num_basis = config.GetOptionFromDict<int>("number_of_basis", num_basis_default, basis_tag_input);
else
num_basis = num_basis_default;
}
else
// if additional inputs are not specified, use default number of basis.
num_basis = num_basis_default;

assert(num_basis > 0);
basis_name = GetBaseFilename(basis_prefix, basis_tags[k]);
SaveSV(snapshot_generators[k], basis_name, num_basis);
}
}

const int SampleGenerator::GetDimFromSnapshots(const std::string &filename)
Expand Down Expand Up @@ -324,7 +360,7 @@ void SampleGenerator::SaveSV(CAROM::BasisGenerator *basis_generator, const std::
}
if (rom_sv->dim() == ref_num_basis) coverage = total;
coverage /= total;
printf("Coverage: %.7f%%\n", coverage * 100.0);
printf("Energy fraction with %d basis: %.7f%%\n", ref_num_basis, coverage * 100.0);

// TODO: hdf5 format + parallel case.
std::string filename = prefix + "_sv.txt";
Expand Down

0 comments on commit 1dacb91

Please sign in to comment.