Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716330031
  • Loading branch information
jmr authored and copybara-github committed Feb 17, 2025
1 parent 17fcdb4 commit 964b14e
Show file tree
Hide file tree
Showing 19 changed files with 171 additions and 204 deletions.
11 changes: 4 additions & 7 deletions nisaba/interim/grm2/rewrite/base_rule_cascade.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class BaseRuleCascade {
using MutableTransducer = ::fst::VectorFst<Arc>;

// TODO: Add symbol table support, or YAGNI?
explicit BaseRuleCascade(
::fst::TokenType token_type = ::fst::TokenType::BYTE)
explicit BaseRuleCascade(::fst::TokenType token_type = ::fst::TokenType::BYTE)
: compiler_(token_type), token_type_(token_type) {}

virtual ~BaseRuleCascade() = default;
Expand All @@ -68,9 +67,7 @@ class BaseRuleCascade {

// Returns the symbol table for generated output symbols when available, or
// nullptr otherwise.
virtual const ::fst::SymbolTable *GeneratedSymbols() const {
return nullptr;
}
virtual const ::fst::SymbolTable *GeneratedSymbols() const { return nullptr; }

// TopRewrite() computes one top rewrite, returning false if composition
// fails. Requires a semiring with the path property.
Expand Down Expand Up @@ -251,8 +248,8 @@ bool BaseRuleCascade<Arc>::OneTopRewrite(absl::string_view input,
if (!RewriteLattice(input, output)) return false;
LatticeToDfa(output, /*optimal_only=*/true);
// Make sure there is only one path.
for (::fst::StateIterator<MutableTransducer> siter(*output);
!siter.Done(); siter.Next()) {
for (::fst::StateIterator<MutableTransducer> siter(*output); !siter.Done();
siter.Next()) {
if (output->NumArcs(siter.Value()) > 1) return false;
}
return true;
Expand Down
10 changes: 4 additions & 6 deletions nisaba/interim/grm2/rewrite/parentheses.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ void MakeParenthesesVector(
for (::fst::StateIterator<::fst::Fst<Arc>> siter(parens_transducer);
!siter.Done(); siter.Next()) {
const auto state = siter.Value();
for (::fst::ArcIterator<::fst::Fst<Arc>> aiter(parens_transducer,
state);
for (::fst::ArcIterator<::fst::Fst<Arc>> aiter(parens_transducer, state);
!aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
if (!arc.ilabel && !arc.olabel) {
Expand Down Expand Up @@ -82,11 +81,10 @@ void MakeAssignmentsVector(
std::vector<typename Arc::Label> *assignments) {
using Label = typename Arc::Label;
std::map<Label, Label> assignment_map;
for (::fst::StateIterator<::fst::Fst<Arc>> siter(
assignments_transducer);
for (::fst::StateIterator<::fst::Fst<Arc>> siter(assignments_transducer);
!siter.Done(); siter.Next()) {
for (::fst::ArcIterator<::fst::Fst<Arc>> aiter(
assignments_transducer, siter.Value());
for (::fst::ArcIterator<::fst::Fst<Arc>> aiter(assignments_transducer,
siter.Value());
!aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
if (!arc.ilabel && !arc.olabel) {
Expand Down
92 changes: 40 additions & 52 deletions nisaba/interim/grm2/rewrite/rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,9 @@ inline bool CheckNonEmptyAndCleanup(::fst::MutableFst<Arc> *lattice) {
//
// Callers may wish to arc-sort the input side of the rule ahead of time.
template <class Arc>
bool RewriteLattice(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule,
bool RewriteLattice(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
::fst::MutableFst<Arc> *lattice) {
static const ::fst::ComposeOptions opts(true,
::fst::ALT_SEQUENCE_FILTER);
static const ::fst::ComposeOptions opts(true, ::fst::ALT_SEQUENCE_FILTER);
::fst::Compose(input, rule, lattice, opts);
return internal::CheckNonEmptyAndCleanup(lattice);
}
Expand All @@ -87,8 +85,8 @@ bool RewriteLattice(
::fst::MutableFst<Arc> *lattice,
const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
&pdt_parens) {
static const ::fst::PdtComposeOptions opts(
true, ::fst::PdtComposeFilter::EXPAND);
static const ::fst::PdtComposeOptions opts(true,
::fst::PdtComposeFilter::EXPAND);
::fst::Compose(input, rule, pdt_parens, lattice, opts);
return internal::CheckNonEmptyAndCleanup(lattice);
}
Expand All @@ -101,8 +99,8 @@ bool RewriteLattice(
const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
&pdt_parens,
const std::vector<typename Arc::Label> &mpdt_assignments) {
static const ::fst::MPdtComposeOptions opts(
true, ::fst::PdtComposeFilter::EXPAND);
static const ::fst::MPdtComposeOptions opts(true,
::fst::PdtComposeFilter::EXPAND);
::fst::Compose(input, rule, pdt_parens, mpdt_assignments, lattice, opts);
return internal::CheckNonEmptyAndCleanup(lattice);
}
Expand All @@ -124,8 +122,8 @@ void LatticeToDfa(::fst::MutableFst<Arc> *lattice, bool optimal_only,
using Weight = typename Arc::Weight;
const auto &weight_threshold = optimal_only ? Weight::One() : Weight::Zero();
const StateId state_threshold = 256 + state_multiplier * lattice->NumStates();
const ::fst::DeterminizeOptions<Arc> opts(
::fst::kDelta, weight_threshold, state_threshold);
const ::fst::DeterminizeOptions<Arc> opts(::fst::kDelta, weight_threshold,
state_threshold);
::fst::Determinize(*lattice, lattice, opts);
// Warns if we actually hit the state threshold; if so, we do not have the
// full set of (optimal) rewrites; there may be cycles of unweighted
Expand All @@ -140,8 +138,7 @@ void LatticeToDfa(::fst::MutableFst<Arc> *lattice, bool optimal_only,
// RewriteLattice), extracts n-shortest unique strings. This is only valid in a
// semiring with the path property.
template <class Arc>
void LatticeToShortest(::fst::MutableFst<Arc> *lattice,
int32_t nshortest = 1) {
void LatticeToShortest(::fst::MutableFst<Arc> *lattice, int32_t nshortest = 1) {
::fst::VectorFst<Arc> shortest;
// By requesting unique solutions we request on-the-fly determinization.
::fst::ShortestPath(*lattice, &shortest, nshortest, /*unique=*/true);
Expand Down Expand Up @@ -175,12 +172,11 @@ bool LatticeToTopString(const ::fst::Fst<Arc> &lattice, std::string *output,
// warning and returning false if there's a tie. This is only valid in a
// semiring with the path property.
template <class Arc>
bool LatticeToOneTopString(
const ::fst::Fst<Arc> &lattice, std::string *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
bool LatticeToOneTopString(const ::fst::Fst<Arc> &lattice, std::string *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
::fst::StringPathIterator<Arc> paths(lattice, ttype, syms,
/*check_acyclic=*/false);
/*check_acyclic=*/false);
if (paths.Error() || paths.Done()) return false;
*output = paths.OString();
// Checks for uniqueness.
Expand All @@ -195,12 +191,12 @@ bool LatticeToOneTopString(

// Same as above but overloaded to also compute the path weight as a float.
template <class Arc>
bool LatticeToOneTopString(
const ::fst::Fst<Arc> &lattice, std::string *output, float *weight,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
bool LatticeToOneTopString(const ::fst::Fst<Arc> &lattice, std::string *output,
float *weight,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
::fst::StringPathIterator<Arc> paths(lattice, ttype, syms,
/*check_acyclic=*/false);
/*check_acyclic=*/false);
if (paths.Error() || paths.Done()) return false;
*output = paths.OString();
*weight = paths.Weight().Value();
Expand Down Expand Up @@ -240,7 +236,7 @@ bool LatticeToStrings(const ::fst::Fst<Arc> &lattice,
}
// Input token type and symbol table will be ignored.
::fst::StringPathIterator<Arc> paths(lattice, ttype, syms,
/*check_acyclic=*/false);
/*check_acyclic=*/false);
if (paths.Error()) return false;
for (; !paths.Done(); paths.Next()) {
// Constructs these in-place.
Expand All @@ -266,7 +262,7 @@ bool LatticeToStrings(const ::fst::Fst<Arc> &lattice,
}
// Input token type and symbol table will be ignored.
::fst::StringPathIterator<Arc> paths(lattice, ttype, syms,
/*check_acyclic=*/false);
/*check_acyclic=*/false);
if (paths.Error() || paths.Done()) return false;
for (; !paths.Done(); paths.Next()) output->Add(paths.OString());
return true;
Expand All @@ -288,7 +284,7 @@ bool LatticeToStrings(const ::fst::Fst<Arc> &lattice,
}
// Input token type and symbol table will be ignored.
::fst::StringPathIterator<Arc> paths(lattice, ttype, syms,
/*check_acyclic=*/false);
/*check_acyclic=*/false);
if (paths.Error()) return false;
for (; !paths.Done(); paths.Next()) {
output->emplace_back(paths.OString(), paths.Weight().Value());
Expand All @@ -298,8 +294,8 @@ bool LatticeToStrings(const ::fst::Fst<Arc> &lattice,

// Top rewrite.
template <class Arc>
bool TopRewrite(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, std::string *output,
bool TopRewrite(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
std::string *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
::fst::VectorFst<Arc> lattice;
Expand All @@ -310,9 +306,8 @@ bool TopRewrite(const ::fst::Fst<Arc> &input,
// Same as above but overloaded to also compute the path weight as a float.
// Top rewrite.
template <class Arc>
bool TopRewrite(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, std::string *output,
float *weight,
bool TopRewrite(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
std::string *output, float *weight,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
::fst::VectorFst<Arc> lattice;
Expand All @@ -322,8 +317,8 @@ bool TopRewrite(const ::fst::Fst<Arc> &input,

// Top rewrite, returning false and logging if there's a tie.
template <class Arc>
bool OneTopRewrite(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, std::string *output,
bool OneTopRewrite(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
std::string *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr,
typename Arc::StateId state_multiplier = 4) {
Expand All @@ -335,9 +330,8 @@ bool OneTopRewrite(const ::fst::Fst<Arc> &input,

// Same as above but overloaded to also compute the path weight as a float.
template <class Arc>
bool OneTopRewrite(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, std::string *output,
float *weight,
bool OneTopRewrite(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
std::string *output, float *weight,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr,
typename Arc::StateId state_multiplier = 4) {
Expand Down Expand Up @@ -390,8 +384,7 @@ bool Rewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,

// All optimal rewrites.
template <class Arc>
bool TopRewrites(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule,
bool TopRewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
std::vector<std::string> *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr,
Expand All @@ -405,8 +398,7 @@ bool TopRewrites(const ::fst::Fst<Arc> &input,
#ifndef NO_GOOGLE
// Same as above but overloaded to write into a repeated proto string field.
template <class Arc>
bool TopRewrites(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule,
bool TopRewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
google::protobuf::RepeatedPtrField<std::string> *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr,
Expand All @@ -420,8 +412,7 @@ bool TopRewrites(const ::fst::Fst<Arc> &input,

// Same as above but overloaded to also compute the path weights as floats.
template <class Arc>
bool TopRewrites(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule,
bool TopRewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
std::vector<std::pair<std::string, float>> *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr,
Expand All @@ -434,9 +425,8 @@ bool TopRewrites(const ::fst::Fst<Arc> &input,

// The top n rewrites.
template <class Arc>
bool TopRewrites(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, int32_t nshortest,
std::vector<std::string> *output,
bool TopRewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
int32_t nshortest, std::vector<std::string> *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
::fst::VectorFst<Arc> lattice;
Expand All @@ -448,8 +438,8 @@ bool TopRewrites(const ::fst::Fst<Arc> &input,
#ifndef NO_GOOGLE
// Same as above but overloaded to write into a repeated proto string field.
template <class Arc>
bool TopRewrites(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, int32_t nshortest,
bool TopRewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
int32_t nshortest,
google::protobuf::RepeatedPtrField<std::string> *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
Expand All @@ -462,8 +452,8 @@ bool TopRewrites(const ::fst::Fst<Arc> &input,

// Same as above but overloaded to also compute the path weights as floats.
template <class Arc>
bool TopRewrites(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &rule, int32_t nshortest,
bool TopRewrites(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &rule,
int32_t nshortest,
std::vector<std::pair<std::string, float>> *output,
::fst::TokenType ttype = ::fst::TokenType::BYTE,
const ::fst::SymbolTable *syms = nullptr) {
Expand All @@ -475,15 +465,13 @@ bool TopRewrites(const ::fst::Fst<Arc> &input,

// Determines whether a rule allows an input/output pair.
template <class Arc>
bool Matches(const ::fst::Fst<Arc> &input,
const ::fst::Fst<Arc> &output,
bool Matches(const ::fst::Fst<Arc> &input, const ::fst::Fst<Arc> &output,
const ::fst::Fst<Arc> &rule) {
::fst::VectorFst<Arc> lattice;
if (!RewriteLattice(input, rule, &lattice)) return false;
static const ::fst::OLabelCompare<Arc> ocomp;
::fst::ArcSort(&lattice, ocomp);
static const ::fst::IntersectOptions opts(true,
::fst::SEQUENCE_FILTER);
static const ::fst::IntersectOptions opts(true, ::fst::SEQUENCE_FILTER);
::fst::Intersect(lattice, output, &lattice, opts);
return lattice.Start() != ::fst::kNoStateId;
}
Expand Down
6 changes: 2 additions & 4 deletions nisaba/interim/grm2/rewrite/rewrite_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class RewriteManager {

// Do not use the manager until FSTs are loaded with Load.
// TODO: Add symbol table support, or YAGNI?
explicit RewriteManager(
::fst::TokenType token_type = ::fst::TokenType::BYTE)
explicit RewriteManager(::fst::TokenType token_type = ::fst::TokenType::BYTE)
: compiler_(token_type), token_type_(token_type) {}

// FAR IO.
Expand Down Expand Up @@ -354,8 +353,7 @@ bool RewriteManager<Arc>::Matches(
::fst::ArcSort(&lattice, ocomp);
MutableTransducer output_fst;
compiler_(output, &output_fst);
static const ::fst::IntersectOptions opts(true,
::fst::SEQUENCE_FILTER);
static const ::fst::IntersectOptions opts(true, ::fst::SEQUENCE_FILTER);
::fst::Intersect(lattice, output_fst, &lattice, opts);
return lattice.Start() != ::fst::kNoStateId;
}
Expand Down
3 changes: 1 addition & 2 deletions nisaba/interim/grm2/rewrite/rule_cascade.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class RuleCascade : public BaseRuleCascade<Arc> {

// Do not use the manager until FSTs are loaded (with Load) and rules are
// set (with SetRules).
explicit RuleCascade(
::fst::TokenType token_type = ::fst::TokenType::BYTE)
explicit RuleCascade(::fst::TokenType token_type = ::fst::TokenType::BYTE)
: BaseRuleCascade<Arc>(token_type), manager_(token_type) {}

// Loads rules from a FAR.
Expand Down
Loading

0 comments on commit 964b14e

Please sign in to comment.