Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to not encode sentencepiece during training/decoding al… #1003

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [Unreleased]

- Added `--no-spm-encode` option, allowing the model to use vocabulary IDs directly to train/decode.
## [1.12.0] - 2023-02-20

### Added
Expand Down
6 changes: 6 additions & 0 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Maximum lines to train SentencePiece vocabulary, selected with sampling from all data. "
"When set to 0 all lines are going to be used.",
2000000);
cli.add<bool>("--no-spm-encode",
"Assume the input has already had sentencepiece applied before decoding. "
"Expects spm pieces, like the ones produced by spm_encode's default format.");
#endif
// scheduling options

Expand Down Expand Up @@ -698,6 +701,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
#ifdef USE_SENTENCEPIECE
cli.add<bool>("--no-spm-decode",
"Keep the output segmented into SentencePiece subwords");
cli.add<bool>("--no-spm-encode",
"Assume the input has already had sentencepiece applied before decoding. "
"Expects spm pieces, like the ones produced by spm_encode's default format.");
#endif

addSuboptionsInputLength(cli);
Expand Down
31 changes: 22 additions & 9 deletions src/data/sentencepiece_vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class SentencePieceVocab : public IVocab {
// Keeps sentences segmented into subword units
bool keepEncoded_{false};

// Assume sentencepiece has already been applied and we are expecting spm pieces as input
bool noEncode_{false};

// Contains control characters added to vocab due to byte-fallback
std::vector<Word> controlChars_;

Expand Down Expand Up @@ -127,7 +130,8 @@ class SentencePieceVocab : public IVocab {
: options_(options),
batchIndex_(batchIndex),
generator_((uint32_t)Config::seed),
keepEncoded_(options->get<bool>("no-spm-decode", false)) {
keepEncoded_(options->get<bool>("no-spm-decode", false)),
noEncode_(options->get<bool>("no-spm-encode", false)) {
if(options_->has("sentencepiece-alphas")) {
auto alphas = options_->get<std::vector<float>>("sentencepiece-alphas");
if(alphas.size() <= batchIndex)
Expand Down Expand Up @@ -221,15 +225,24 @@ class SentencePieceVocab : public IVocab {
}

Words encode(const std::string& line, bool addEOS, bool inference) const override {
std::vector<int> spmIds;
if(inference || alpha_ == 0)
spm_->Encode(line, &spmIds);
else
spm_->SampleEncode(line, -1, alpha_, &spmIds);
Words words;
if (noEncode_) {
auto lineTokens = utils::split(line, " ");
words.reserve(lineTokens.size() + addEOS);
for (auto&& token : lineTokens) {
words.push_back((*this)[token]);
}
} else {
std::vector<int> spmIds;
if(inference || alpha_ == 0)
spm_->Encode(line, &spmIds);
else
spm_->SampleEncode(line, -1, alpha_, &spmIds);

Words words; words.reserve(spmIds.size() + addEOS);
for (auto&& spmId : spmIds)
words.push_back(Word::fromWordIndex(spmId));
words.reserve(spmIds.size() + addEOS);
for (auto&& spmId : spmIds)
words.push_back(Word::fromWordIndex(spmId));
}

if(addEOS)
words.push_back(getEosId());
Expand Down
Loading