diff --git a/CHANGELOG.md b/CHANGELOG.md index f70f73ab2..a50945fb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Added `--no-spm-encode` option, allowing the model to use vocabulary IDs directly to train/decode. - Added --custom-fallbacks option that allows to specify a list of option sets that get traversed for subsequent fallbacks upon divergence - Added --overwrite-checkpoint option that (when set to false) can be used to dump checkpoints with iteration numbers. - Implementations of COMET-20 (reference-based) and BLEURT-20 for inference with conversion scripts. diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index bad9904f9..a6e38792f 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -411,6 +411,9 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "Maximum lines to train SentencePiece vocabulary, selected with sampling from all data. " "When set to 0 all lines are going to be used.", 2000000); + cli.add("--no-spm-encode", + "Assume the input has already had sentencepiece applied before decoding. " + "Expects spm pieces, like the ones produced by spm_encode's default format."); #endif // scheduling options @@ -752,6 +755,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { #ifdef USE_SENTENCEPIECE cli.add("--no-spm-decode", "Keep the output segmented into SentencePiece subwords"); + cli.add("--no-spm-encode", + "Assume the input has already had sentencepiece applied before decoding. " + "Expects spm pieces, like the ones produced by spm_encode's default format."); #endif addSuboptionsInputLength(cli); diff --git a/src/data/sentencepiece_vocab.cpp b/src/data/sentencepiece_vocab.cpp index dc06cc17b..548b95a46 100644 --- a/src/data/sentencepiece_vocab.cpp +++ b/src/data/sentencepiece_vocab.cpp @@ -39,6 +39,9 @@ class SentencePieceVocab : public IVocab { // Keeps sentences segmented into subword units bool keepEncoded_{false}; + // Assume sentencepiece has already been applied and we are expecting spm pieces as input + bool noEncode_{false}; + // Contains control characters added to vocab due to byte-fallback std::vector controlChars_; @@ -127,7 +130,8 @@ class SentencePieceVocab : public IVocab { : options_(options), batchIndex_(batchIndex), generator_((uint32_t)Config::seed), - keepEncoded_(options->get("no-spm-decode", false)) { + keepEncoded_(options->get("no-spm-decode", false)), + noEncode_(options->get("no-spm-encode", false)) { if(options_->has("sentencepiece-alphas")) { auto alphas = options_->get>("sentencepiece-alphas"); if(alphas.size() <= batchIndex) @@ -221,16 +225,24 @@ class SentencePieceVocab : public IVocab { } Words encode(const std::string& line, bool addEOS, bool inference) const override { - std::vector spmIds; - if(inference || alpha_ == 0) - spm_->Encode(line, &spmIds); - else - spm_->SampleEncode(line, -1, alpha_, &spmIds); - Words words; - words.reserve(spmIds.size() + addEOS); - for (auto&& spmId : spmIds) - words.push_back(Word::fromWordIndex(spmId)); + if (noEncode_) { + auto lineTokens = utils::split(line, " "); + words.reserve(lineTokens.size() + addEOS); + for (auto&& token : lineTokens) { + words.push_back((*this)[token]); + } + } else { + std::vector spmIds; + if(inference || alpha_ == 0) + spm_->Encode(line, &spmIds); + else + spm_->SampleEncode(line, -1, alpha_, &spmIds); + + words.reserve(spmIds.size() + addEOS); + for (auto&& spmId : spmIds) + words.push_back(Word::fromWordIndex(spmId)); + } if(addEOS) words.push_back(getEosId());