diff --git a/.gitignore b/.gitignore index 5070d68c7..09be72c15 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ doc/ data/ceeaus data/breast-cancer data/housing +data/cranfield biicode.conf bii/ bin/ diff --git a/.travis.yml b/.travis.yml index 054b00c3a..7dbec261b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -65,36 +65,50 @@ matrix: - gcc-6 - g++-6 - # Linux/Clang 3.6 + # Linux/Clang 3.8 - os: linux - env: COMPILER=clang CLANG_VERSION=3.6 + env: COMPILER=clang CLANG_VERSION=3.8 addons: apt: sources: - ubuntu-toolchain-r-test - - llvm-toolchain-precise-3.6 + - llvm-toolchain-precise-3.8 packages: - *default-packages - - clang-3.6 - - llvm-3.6-dev + - clang-3.8 - # OS X 10.9 + Xcode 6.1 - - os: osx - env: COMPILER=clang + # Linux/Clang 3.8 + libc++-3.9 + # (I want this to be 3.9 across the board, but the apt source is not + # yet whitelisted for llvm 3.9) + - os: linux + env: + - COMPILER=clang + - CLANG_VERSION=3.8 + - LLVM_TAG=RELEASE_390 + - LIBCXX_EXTRA_CMAKE_FLAGS=-DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=On + - CMAKE_VERSION=3.4.3 + addons: + apt: + sources: + - ubuntu-toolchain-r-test + - llvm-toolchain-precise-3.8 + packages: + - *default-packages + - clang-3.8 - # OS X 10.10 + Xcode 6.4 + # OS X 10.10 + Xcode 7.1.1 - os: osx - osx_image: xcode6.4 + osx_image: xcode7.1 env: COMPILER=clang - # OS X 10.10 + Xcode 7.1.1 + # OS X 10.11 + Xcode 7.3 - os: osx - osx_image: xcode7.1 + osx_image: xcode7.3 env: COMPILER=clang - # OS X 10.11 + Xcode 7.2 + # OS X 10.11 + Xcode 8 - os: osx - osx_image: xcode7.2 + osx_image: xcode8 env: COMPILER=clang # OS X/GCC 6 diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a5c4cf69..15a5eec81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,119 @@ +# [v3.0.0][3.0.0] +## New features +- Add an `embedding_analyzer` that represents documents with their averaged word + vectors. +- Add a `parallel::reduction` algorithm designed for parallelizing complex + accumulation operations (like an E step in an EM algorithm) +- Parallelize feature counting in feature selector using the new + `parallel::reduction` +- Add a `parallel::for_each_block` algorithm to run functions on + (relatively) equal sub-ranges of an iterator range in parallel +- Add a parallel merge sort as `parallel::sort` +- Add a `util/traits.h` header for general useful traits +- Add a Markov model implementation in `sequence::markov_model` +- Add a generic unsupervised HMM implementation. This implementation + supports HMMs with discrete observations (what is used most often) and + sequence observations (useful for log mining applications). The + forward-backward algorithm is implemented using both the scaling method + and the log-space method. The scaling method is used by default, but the + log-space method is useful for HMMs with sequence observations to avoid + underflow issues when the output probabilities themselves are very small. +- Add the KL-divergence retrieval function using pseudo-relevance feedback + with the two-component mixture-model approach of Zhai and Lafferty, + called `kl_divergence_prf`. This ranker internally can use any + `language_model_ranker` subclass like `dirichlet_prior` or + `jelinek_mercer` to perform the ranking of the feedback set and the + result documents with respect to the modified query. + + The EM algorithm used for the two-component mixture model is provided as + the `index::feedback::unigram_mixture` free function and returns the + feedback model. +- Add the Rocchio algorithm (`rocchio`) for pseudo-relevance feedback in + the vector space model. +- **Breaking Change.** To facilitate the above to changes, we have also + broken the `ranker` hierarchy into one more level. At the top we have + `ranker`, which has a pure virtual function `rank()` that can be + overridden to provide entirely custom ranking behavior, This is the class + the KL-divergence and Rocchio methods derive from, as we need to + re-define what it means to rank documents (first retrieving a feedback + set, then ranking documents with respect to an updated query). + + Most of the time, however, you will want to derive from the second level + `ranking_function`, which is what was called `ranker` before. This class + provides a definition of `rank()` to perform document-at-a-time ranking, + and expects deriving classes to instead provide `initial_score()` and + `score_one()` implementations to define the scoring function used for + each document. **Existing code that derived from `ranker` prior to this + version of MeTA likely needs to be changed to instead derive from + `ranking_function`.** +- Add the `util::transform_iterator` class and `util::make_transform_iterator` + function for providing iterators that transform their output according to + a unary function. +- **Breaking Change.** `whitespace_tokenizer` now emits *only* word tokens + by default, suppressing all whitespace tokens. The old default was to + emit tokens containing whitespace in addition to actual word tokens. The + old behavior can be obtained by passing `false` to its constructor, or + setting `suppress-whitespace = false` in its configuration group in + `config.toml.` (Note that whitespace tokens are still needed if using a + `sentence_boundary` filter but, in nearly all circumstances, + `icu_tokenizer` should be preferred.) +- **Breaking Change.** Co-occurrence counting for embeddings now uses + history that crosses sentence boundaries by default. The old behavior + (clearing the history when starting a new sentence) can be obtained by + ensuring that a tokenizer is being used that emits sentence boundary tags + and by setting `break-on-tags = true` in the `[embeddings]` table of + `config.toml`. +- **Breaking Change.** All references in the embeddings library to "coocur" + are have changed to "cooccur". This means that some files and binaries + have been renamed. Much of the co-occurrence counting part of the + embeddings library has also been moved to the public API. +- Co-occurrence counting now is performed in parallel. Behavior of its + merge strategy can be configured with the new `[embeddings]` config + parameter `merge-fanout = n`, which specifies the maximum number of + on-disk chunks to allow before kicking off a multi-way merge (default 8). + +## Enhancements +- Add additional `packed_write` and `packed_read` overloads: for + `std::pair`, `stats::dirichlet`, `stats::multinomial`, + `util::dense_matrix`, and `util::sparse_vector` +- Additional functions have been added to `ranker_factory` to allow + construction/loading of language_model_ranker subclasses (useful for the + `kl_divergence_prf` implementation) +- Add a `util::make_fixed_heap` helper function to simplify the declaration + of `util::fixed_heap` classes with lambda function comparators. +- Add regression tests for rankers MAP and NDCG scores. This adds a new + dataset `cranfield` that contains non-binary relevance judgments to + facilitate these new tests. +- Bump bundled version of ICU to 58.2. + +## Bug Fixes +- Fix bug in NDCG calculation (ideal-DCG was computed using the wrong + sorting order for non-binary judgments) +- Fix bug where the final chunks to be merged in index creation were not + being deleted when merging completed +- Fix bug where GloVe training would allocate the embedding matrix before + starting the shuffling process, causing it to exceed the "max-ram" + config parameter. +- Fix bug with consuming MeTA from a build directory with `cmake` when + building a static ICU library. `meta-utf` is now forced to be a shared + library, which (1) should save on binary sizes and (2) ensures that the + statically build ICU is linked into the `libmeta-utf.so` library to avoid + undefined references to ICU functions. +- Fix bug with consuming Release-mode MeTA libraries from another project + being built in Debug mode. Before, `identifiers.h` would change behavior + based on the `NDEBUG` macro's setting. This behavior has been removed, + and opaque identifiers are always on. + +## Deprecation +- `disk_index::doc_name` and `disk_index::doc_path` have been deprecated in + favor of the more general (and less confusing) `metadata()`. They will be + removed in a future major release. +- Support for 32-bit architectures is provided on a best-effort basis. MeTA + makes heavy use of memory mapping, which is best paired with a 64-bit + address space. Please move to a 64-bit platform for using MeTA if at all + possible (most consumer machines should support 64-bit if they were made + in the last 5 years or so). + # [v2.4.2][2.4.2] ## Bug Fixes - Properly shuffle documents when doing an even-split classification test @@ -493,7 +609,8 @@ # [v1.0][1.0] - Initial release. -[unreleased]: https://github.com/meta-toolkit/meta/compare/v2.4.2...develop +[unreleased]: https://github.com/meta-toolkit/meta/compare/v3.0.0...develop +[3.0.0]: https://github.com/meta-toolkit/meta/compare/v2.4.2...v3.0.0 [2.4.2]: https://github.com/meta-toolkit/meta/compare/v2.4.1...v2.4.2 [2.4.1]: https://github.com/meta-toolkit/meta/compare/v2.4.0...v2.4.1 [2.4.0]: https://github.com/meta-toolkit/meta/compare/v2.3.0...v2.4.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index b314bebdb..e4d6421f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,11 +47,10 @@ endif() list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/deps/findicu) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/deps/meta-cmake/) -# We require Unicode 8 for the unit tests, which was added in ICU 56.1 FindOrBuildICU( - VERSION 57.1 - URL http://download.icu-project.org/files/icu4c/57.1/icu4c-57_1-src.tgz - URL_HASH MD5=976734806026a4ef8bdd17937c8898b9 + VERSION 58.2 + URL http://download.icu-project.org/files/icu4c/58.2/icu4c-58_2-src.tgz + URL_HASH MD5=fac212b32b7ec7ab007a12dff1f3aea1 ) add_library(meta-definitions INTERFACE) @@ -143,6 +142,14 @@ add_subdirectory(src) add_subdirectory(tests) add_subdirectory(deps/cpptoml EXCLUDE_FROM_ALL) +# Warn users that are using a 32-bit system +if (CMAKE_SIZEOF_VOID_P LESS 8) + message(WARNING "You appear to be running on a 32-bit system. Support \ + for 32-bit systems is provided on a best-effort basis; if at all \ + possible, we strongly recommend that you use MeTA on a 64-bit \ + platform.") +endif() + # install our targets defined in this file install(TARGETS meta-definitions EXPORT meta-exports diff --git a/README.md b/README.md index 8d2912dae..fcdb27116 100644 --- a/README.md +++ b/README.md @@ -622,9 +622,12 @@ you should run the following commands to download dependencies and related software needed for building: ```bash -pacman -Syu git make mingw-w64-x86_64-{gcc,cmake,icu,jemalloc,zlib} +pacman -Syu git make mingw-w64-x86_64-{gcc,cmake,icu,jemalloc,zlib} --force ``` +(the `--force` is needed to work around a bug with the latest MSYS2 +installer as of the time of writing.) + Then, exit the shell and launch the "MinGW-w64 Win64" shell. You can obtain the toolkit and get started with: diff --git a/config.toml b/config.toml index 28b723508..9819acc54 100644 --- a/config.toml +++ b/config.toml @@ -96,7 +96,7 @@ test-sections = [23, 23] [embeddings] prefix = "word-embeddings" -filter = [{type = "icu-tokenizer"}, {type = "lowercase"}] +filter = [{type = "icu-tokenizer", suppress-tags = true}, {type = "lowercase"}] vector-size = 50 [embeddings.vocab] min-count = 10 diff --git a/deps/cpptoml b/deps/cpptoml index 4fd49e3f5..6b780c98c 160000 --- a/deps/cpptoml +++ b/deps/cpptoml @@ -1 +1 @@ -Subproject commit 4fd49e3f5c4fa00467ad478b12ad2189d881a27a +Subproject commit 6b780c98c767cf1a9f36b06070db8cf07243354f diff --git a/deps/meta-cmake b/deps/meta-cmake index f21b5fc5b..06539eca8 160000 --- a/deps/meta-cmake +++ b/deps/meta-cmake @@ -1 +1 @@ -Subproject commit f21b5fc5b9ad678bcf1fc7af8244e03147ed8a68 +Subproject commit 06539eca8c1cd8abd4f6ce4c570ffc23f7ff7bc7 diff --git a/include/meta/analyzers/tokenizers/whitespace_tokenizer.h b/include/meta/analyzers/tokenizers/whitespace_tokenizer.h index d17201c22..d11b22fb5 100644 --- a/include/meta/analyzers/tokenizers/whitespace_tokenizer.h +++ b/include/meta/analyzers/tokenizers/whitespace_tokenizer.h @@ -9,6 +9,7 @@ #ifndef META_WHITESPACE_TOKENIZER_H_ #define META_WHITESPACE_TOKENIZER_H_ +#include "meta/analyzers/filter_factory.h" #include "meta/analyzers/token_stream.h" #include "meta/util/clonable.h" #include "meta/util/string_view.h" @@ -39,8 +40,10 @@ class whitespace_tokenizer : public util::clonable +std::unique_ptr + make_tokenizer(const cpptoml::table& config); } } } diff --git a/include/meta/classify/classifier/classifier.h b/include/meta/classify/classifier/classifier.h index 17021c4d7..e1008f971 100644 --- a/include/meta/classify/classifier/classifier.h +++ b/include/meta/classify/classifier/classifier.h @@ -109,6 +109,7 @@ confusion_matrix cross_validate(Creator&& creator, docs, docs.begin(), docs.begin() + static_cast(step_size)}; auto m = cls->test(test_view); + matrix.add_fold_accuracy(m.accuracy()); matrix += m; docs.rotate(step_size); } diff --git a/include/meta/classify/confusion_matrix.h b/include/meta/classify/confusion_matrix.h index aa27156f4..133843198 100644 --- a/include/meta/classify/confusion_matrix.h +++ b/include/meta/classify/confusion_matrix.h @@ -41,6 +41,16 @@ class confusion_matrix void add(const predicted_label& predicted, const class_label& actual, size_t times = 1); + /** + * @param Accuracy to add + */ + void add_fold_accuracy(double acc); + + /** + * @return the list of added accuracies + */ + std::vector fold_accuracy() const; + /** * Prints this matrix's statistics to out. * @@ -160,6 +170,9 @@ class confusion_matrix /// Total number of classification attempts size_t total_; + + /// Keeps track of accuracies between folds + std::vector fold_acc_; }; } } diff --git a/include/meta/classify/models/linear_model.tcc b/include/meta/classify/models/linear_model.tcc index e8fc25a85..b0e02d757 100644 --- a/include/meta/classify/models/linear_model.tcc +++ b/include/meta/classify/models/linear_model.tcc @@ -122,17 +122,15 @@ template auto linear_model::best_class( FeatureVector&& features) const -> class_id { - return best_class(std::forward(features), [](const class_id&) - { - return true; - }); + return best_class(std::forward(features), + [](const class_id&) { return true; }); } template template auto linear_model::best_classes( - FeatureVector&& features, uint64_t num, - Filter&& filter) const -> scored_classes + FeatureVector&& features, uint64_t num, Filter&& filter) const + -> scored_classes { weight_vector class_scores; for (const auto& feat : features) @@ -153,12 +151,10 @@ auto linear_model::best_classes( } } - auto comp = [](const scored_class& lhs, const scored_class& rhs) - { - return lhs.second > rhs.second; - }; - - util::fixed_heap heap{num, comp}; + auto heap = util::make_fixed_heap( + num, [](const scored_class& lhs, const scored_class& rhs) { + return lhs.second > rhs.second; + }); for (const auto& score : class_scores) { auto cid = score.first; @@ -175,10 +171,7 @@ auto linear_model::best_classes( FeatureVector&& features, uint64_t num) const -> scored_classes { return best_classes(std::forward(features), num, - [](const class_id&) - { - return true; - }); + [](const class_id&) { return true; }); } template @@ -230,8 +223,8 @@ void linear_model::condense(bool log) } template -auto linear_model::weights() const -> const - weight_vectors & +auto linear_model::weights() const + -> const weight_vectors& { return weights_; } diff --git a/include/meta/classify/multiclass_dataset.h b/include/meta/classify/multiclass_dataset.h index f3958def0..ee9e6ce33 100644 --- a/include/meta/classify/multiclass_dataset.h +++ b/include/meta/classify/multiclass_dataset.h @@ -122,8 +122,9 @@ class multiclass_dataset : public learn::labeled_dataset * feature_vector and a conversion operator to a class_label. */ template - multiclass_dataset(ForwardIterator begin, ForwardIterator end) - : labeled_dataset{begin, end} + multiclass_dataset(ForwardIterator begin, ForwardIterator end, + size_type total_features) + : labeled_dataset{begin, end, total_features} { // build label_id_mapping for (; begin != end; ++begin) diff --git a/include/meta/classify/multiclass_dataset_view.h b/include/meta/classify/multiclass_dataset_view.h index 56c0dfbd6..8533ccb06 100644 --- a/include/meta/classify/multiclass_dataset_view.h +++ b/include/meta/classify/multiclass_dataset_view.h @@ -38,13 +38,31 @@ class multiclass_dataset_view : public learn::dataset_view // nothing } - multiclass_dataset_view(const multiclass_dataset_view& mdv, iterator begin, - iterator end) + multiclass_dataset_view(const multiclass_dataset_view& mdv, + const_iterator begin, const_iterator end) : dataset_view{mdv, begin, end} { // nothing } + multiclass_dataset_view(const multiclass_dataset& dset, + multiclass_dataset::const_iterator begin, + multiclass_dataset::const_iterator end) + : dataset_view{dset, begin, end} + { + // nothing + } + + template + multiclass_dataset_view(const multiclass_dataset& dset, + multiclass_dataset::const_iterator begin, + multiclass_dataset::const_iterator end, + RandomEngine&& rng) + : dataset_view{dset, begin, end, std::forward(rng)} + { + // nothing + } + multiclass_dataset_view(const multiclass_dataset_view& mdv, std::vector&& indices) : dataset_view{mdv, std::move(indices)} diff --git a/include/meta/config.h.in b/include/meta/config.h.in index 63ed61b16..5ad5e3bab 100644 --- a/include/meta/config.h.in +++ b/include/meta/config.h.in @@ -1,6 +1,20 @@ #ifndef META_CONFIG_H_ #define META_CONFIG_H_ +#if __cplusplus > 201103L +#define META_DEPRECATED(reason) [[deprecated(reason)]] +#elif defined(__clang__) +#define META_DEPRECATED(reason) __attribute__((deprecated(reason))) +#elif defined(__GNUG__) +#define META_DEPRECATED(reason) __attribute__((deprecated)) +#elif defined(_MSC_VER) +#if _MSC_VER < 1910 +#define META_DEPRECATED(reason) __declspec(deprecated) +#else +#define META_DEPRECATED(reason) [[deprecated(reason)]] +#endif +#endif + #include "meta/kludges.h" // OS X diff --git a/include/meta/corpus/corpus.h b/include/meta/corpus/corpus.h index 554187732..f036c8c64 100644 --- a/include/meta/corpus/corpus.h +++ b/include/meta/corpus/corpus.h @@ -11,6 +11,7 @@ #define META_CORPUS_H_ #include +#include #include #include "cpptoml.h" @@ -18,7 +19,9 @@ #include "meta/corpus/document.h" #include "meta/corpus/metadata_parser.h" #include "meta/meta.h" +#include "meta/parallel/thread_pool.h" #include "meta/util/optional.h" +#include "meta/util/progress.h" namespace meta { @@ -131,7 +134,46 @@ class corpus_exception : public std::runtime_error public: using std::runtime_error::runtime_error; }; + +/** + * Consumes each document in a corpus using a pool of threads. + * @param docs The corpus to consume + * @param pool The thread pool to use + * @param ls_fn A function to create thread-specific storage + * @param consume_fn A function to consume a document + */ +template +void parallel_consume(corpus& docs, parallel::thread_pool& pool, + LocalStorage&& ls_fn, ConsumeFunction&& consume_fn) +{ + std::mutex mutex; + auto task = [&]() { + auto local_storage = ls_fn(); + while (true) + { + util::optional doc; + { + std::lock_guard lock{mutex}; + + if (!docs.has_next()) + return; + + doc = docs.next(); + } + + consume_fn(local_storage, *doc); + } + }; + + std::vector> futures; + futures.reserve(pool.size()); + for (std::size_t i = 0; i < pool.size(); ++i) + { + futures.emplace_back(pool.submit_task(task)); + } + for (auto& fut : futures) + fut.get(); +} } } - #endif diff --git a/include/meta/embeddings/analyzers/embedding_analyzer.h b/include/meta/embeddings/analyzers/embedding_analyzer.h new file mode 100644 index 000000000..5617c1389 --- /dev/null +++ b/include/meta/embeddings/analyzers/embedding_analyzer.h @@ -0,0 +1,87 @@ +/** + * @file embedding_analyzer.h + * @author Sean Massung + * + * All files in META are released under the MIT license. For more details, + * consult the file LICENSE in the root of the project. + */ + +#ifndef META_EMBEDDINGS_EMBEDDING_ANALYZER_H_ +#define META_EMBEDDINGS_EMBEDDING_ANALYZER_H_ + +#include "meta/analyzers/analyzer.h" +#include "meta/analyzers/analyzer_factory.h" +#include "meta/embeddings/word_embeddings.h" +#include "meta/util/clonable.h" +#include + +namespace meta +{ +namespace analyzers +{ + +/** + * Analyzes documents by averaging word embeddings for each token. This analyzer + * should only be used with forward_index since it stores double feature values. + * + * Required config parameters: + * ~~~toml + * [[analyzers]] + * method = "embedding" # this analyzer + * filter = # use same filter type that embeddings were learned with + * prefix = "path/to/embedding/model/" + * ~~~ + */ +class embedding_analyzer : public util::clonable +{ + public: + /** + * Constructor. + * @param stream The stream to read tokens from. + */ + embedding_analyzer(const cpptoml::table& config, + std::unique_ptr stream); + + /** + * Copy constructor. + * @param other The other embedding_analyzer to copy from + */ + embedding_analyzer(const embedding_analyzer& other); + + /// Identifier for this analyzer. + const static util::string_view id; + + private: + virtual void tokenize(const corpus::document& doc, + featurizer& counts) override; + + /// The token stream to be used for extracting tokens + std::unique_ptr stream_; + + /// Learned word embeddings + std::shared_ptr embeddings_; + + /// Path to the embedding model files + std::string prefix_; + + /// Storage for the aggregated word embeddings per document + std::vector features_; +}; + +/** + * Specialization of the factory method for creating embedding_analyzers. + */ +template <> +std::unique_ptr +make_analyzer(const cpptoml::table&, const cpptoml::table&); +} + +namespace embeddings +{ +/** + * Registers analyzers provided by the meta-embeddings library. + */ +void register_analyzers(); +} +} +#endif diff --git a/include/meta/embeddings/coocur_iterator.h b/include/meta/embeddings/cooccur_iterator.h similarity index 51% rename from include/meta/embeddings/coocur_iterator.h rename to include/meta/embeddings/cooccur_iterator.h index 85f7840a1..f6a23a183 100644 --- a/include/meta/embeddings/coocur_iterator.h +++ b/include/meta/embeddings/cooccur_iterator.h @@ -1,5 +1,5 @@ /** - * @file coocur_iterator.h + * @file cooccur_iterator.h * @author Chase Geigle * * All files in META are dual-licensed under the MIT and NCSA licenses. For more @@ -7,11 +7,11 @@ * project. */ -#ifndef META_EMBEDDINGS_COOCUR_ITERATOR_H_ -#define META_EMBEDDINGS_COOCUR_ITERATOR_H_ +#ifndef META_EMBEDDINGS_COOCCUR_ITERATOR_H_ +#define META_EMBEDDINGS_COOCCUR_ITERATOR_H_ #include "meta/config.h" -#include "meta/embeddings/coocur_record.h" +#include "meta/embeddings/cooccur_record.h" #include "meta/util/multiway_merge.h" namespace meta @@ -19,10 +19,12 @@ namespace meta namespace embeddings { /** - * An iterator over coocur_record's that live in a packed file on disk. + * An iterator over cooccur_records that live in a packed file on disk. * Satisfies the ChunkIterator concept for multiway_merge support. */ -using coocur_iterator = util::chunk_iterator; +using cooccur_iterator = util::chunk_iterator; +using destructive_cooccur_iterator + = util::destructive_chunk_iterator; } } #endif diff --git a/include/meta/embeddings/coocur_record.h b/include/meta/embeddings/cooccur_record.h similarity index 62% rename from include/meta/embeddings/coocur_record.h rename to include/meta/embeddings/cooccur_record.h index bced6b276..fde298bb6 100644 --- a/include/meta/embeddings/coocur_record.h +++ b/include/meta/embeddings/cooccur_record.h @@ -1,5 +1,5 @@ /** - * @file coocur_record.h + * @file cooccur_record.h * @author Chase Geigle * * All files in META are dual-licensed under the MIT and NCSA licenses. For more @@ -7,8 +7,8 @@ * project. */ -#ifndef META_EMBEDDINGS_COOCUR_RECORD_H_ -#define META_EMBEDDINGS_COOCUR_RECORD_H_ +#ifndef META_EMBEDDINGS_COOCCUR_RECORD_H_ +#define META_EMBEDDINGS_COOCCUR_RECORD_H_ #include @@ -20,38 +20,38 @@ namespace meta namespace embeddings { /** - * Represents an entry in the coocurrence matrix. Satisfies the Record + * Represents an entry in the cooccurrence matrix. Satisfies the Record * concept for multiway_merge support. */ -struct coocur_record +struct cooccur_record { uint64_t target; uint64_t context; double weight; - void merge_with(coocur_record&& other) + void merge_with(cooccur_record&& other) { weight += other.weight; } }; -bool operator==(const coocur_record& a, const coocur_record& b) +inline bool operator==(const cooccur_record& a, const cooccur_record& b) { return std::tie(a.target, a.context) == std::tie(b.target, b.context); } -bool operator!=(const coocur_record& a, const coocur_record& b) +inline bool operator!=(const cooccur_record& a, const cooccur_record& b) { return !(a == b); } -bool operator<(const coocur_record& a, const coocur_record& b) +inline bool operator<(const cooccur_record& a, const cooccur_record& b) { return std::tie(a.target, a.context) < std::tie(b.target, b.context); } template -uint64_t packed_write(OutputStream& os, const coocur_record& record) +uint64_t packed_write(OutputStream& os, const cooccur_record& record) { using io::packed::write; return write(os, record.target) + write(os, record.context) @@ -59,7 +59,7 @@ uint64_t packed_write(OutputStream& os, const coocur_record& record) } template -uint64_t packed_read(InputStream& is, coocur_record& record) +uint64_t packed_read(InputStream& is, cooccur_record& record) { using io::packed::read; return read(is, record.target) + read(is, record.context) diff --git a/include/meta/embeddings/cooccurrence_counter.h b/include/meta/embeddings/cooccurrence_counter.h new file mode 100644 index 000000000..ac629cf24 --- /dev/null +++ b/include/meta/embeddings/cooccurrence_counter.h @@ -0,0 +1,225 @@ +/** + * @file cooccurrence_counter.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_EMBEDDINGS_COOCCURRENCE_COUNTER_H_ +#define META_EMBEDDINGS_COOCCURRENCE_COUNTER_H_ + +#include + +#include "meta/config.h" + +#include "meta/analyzers/token_stream.h" +#include "meta/corpus/corpus.h" +#include "meta/embeddings/cooccur_record.h" +#include "meta/hashing/probe_map.h" +#include "meta/io/packed.h" +#include "meta/parallel/semaphore.h" + +namespace meta +{ +namespace embeddings +{ + +/** + * A (target, context) pair used as the key in a cooccurrence hash table. + */ +struct cooccurrence_key +{ + constexpr cooccurrence_key(uint64_t targ, uint64_t ctx) + : target{targ}, context{ctx} + { + // nothing + } + + uint64_t target; + uint64_t context; +}; + +inline bool operator==(const cooccurrence_key& a, const cooccurrence_key& b) +{ + return std::tie(a.target, a.context) == std::tie(b.target, b.context); +} + +inline bool operator<(const cooccurrence_key& a, const cooccurrence_key& b) +{ + return std::tie(a.target, a.context) < std::tie(b.target, b.context); +} + +template +uint64_t packed_write(OutputStream& os, const cooccurrence_key& key) +{ + auto bytes = io::packed::write(os, key.target); + return bytes + io::packed::write(os, key.context); +} +} + +namespace hashing +{ +template <> +struct key_traits +{ + static constexpr bool inlineable = true; + constexpr static embeddings::cooccurrence_key sentinel() + { + return {key_traits::sentinel(), + key_traits::sentinel()}; + } +}; + +template <> +struct is_contiguously_hashable +{ + const static constexpr bool value = true; +}; +} + +namespace embeddings +{ + +/** + * A chunk of cooccurrence records on disk. + */ +struct cooccurrence_chunk +{ + cooccurrence_chunk(const std::string& file, uint64_t bytes) + : path{file}, size{bytes} + { + // nothing + } + + std::string path; + uint64_t size; +}; + +inline bool operator<(const cooccurrence_chunk& a, const cooccurrence_chunk& b) +{ + // merge smaller chunks first + return a.size > b.size; +} + +/** + * An iterator adhering to the ChunkIterator concept for multiway_merge + * support on in-memory cooccurrence data. + */ +class memory_cooccur_iterator +{ + public: + using map_type = hashing::probe_map; + using memory_chunk_type = map_type::storage_type::vector_type; + using count_type = std::pair; + + memory_cooccur_iterator() = default; + + memory_cooccur_iterator(memory_chunk_type&& items) + : items_{std::move(items)}, idx_{0} + { + // nothing + } + + memory_cooccur_iterator& operator++() + { + ++idx_; + if (idx_ >= items_.size()) + { + items_.clear(); + idx_ = 0; + } + + return *this; + } + + cooccur_record operator*() const + { + const auto& item = items_[idx_]; + return {item.first.target, item.first.context, item.second}; + } + + uint64_t total_bytes() const + { + return sizeof(count_type) * items_.size(); + } + + uint64_t bytes_read() const + { + return sizeof(count_type) * idx_; + } + + bool operator==(const memory_cooccur_iterator& other) const + { + return items_.empty() && other.items_.empty(); + } + + bool operator!=(const memory_cooccur_iterator& other) const + { + return !(*this == other); + } + + private: + memory_chunk_type items_; + std::size_t idx_{0}; +}; + +/** + * Management class for cooccurrence counting. This class maintains the + * shared state across all threads used for parallel cooccurrence counting. + */ +class cooccurrence_counter +{ + public: + using memory_chunk_type = memory_cooccur_iterator::memory_chunk_type; + + struct configuration + { + std::string prefix; + std::size_t max_ram = 4096u * 1024u * 1024u; // 4GB + std::size_t merge_fanout = 8; + std::size_t window_size = 15; + bool break_on_tags = false; + }; + + cooccurrence_counter(configuration config, parallel::thread_pool& pool); + + ~cooccurrence_counter(); + + void count(corpus::corpus& docs, + const analyzers::token_stream& stream); + + private: + void flush_chunk(memory_chunk_type&& chunk); + void memory_merge_chunks(); + void maybe_merge(); + + friend class cooccurrence_buffer; + const std::string prefix_; + std::size_t max_ram_; + const std::size_t merge_fanout_; + const std::size_t window_size_; + const bool break_on_tags_; + const hashing::probe_map vocab_; + parallel::thread_pool& pool_; + std::size_t chunk_num_{0}; + std::atomic_size_t num_tokenizing_{0}; + std::size_t num_pending_{0}; + std::vector memory_chunks_; + std::priority_queue chunks_; + std::mutex chunk_mutex_; + std::condition_variable chunk_cond_; + std::mutex io_mutex_; +}; + +class cooccurrence_exception : public std::runtime_error +{ + public: + using std::runtime_error::runtime_error; +}; +} + + +} +#endif diff --git a/include/meta/embeddings/word_embeddings.h b/include/meta/embeddings/word_embeddings.h index a450680d3..4ae876af6 100644 --- a/include/meta/embeddings/word_embeddings.h +++ b/include/meta/embeddings/word_embeddings.h @@ -85,6 +85,11 @@ class word_embeddings std::vector top_k(util::array_view query, std::size_t k = 100) const; + /** + * @return the number of dimensions for each word + */ + std::size_t vector_size() const; + private: util::array_view vector(std::size_t tid); diff --git a/include/meta/features/feature_selector.h b/include/meta/features/feature_selector.h index f12b62e02..f1e890490 100644 --- a/include/meta/features/feature_selector.h +++ b/include/meta/features/feature_selector.h @@ -20,6 +20,7 @@ #include "meta/index/disk_index.h" #include "meta/io/filesystem.h" #include "meta/learn/instance.h" +#include "meta/parallel/algorithm.h" #include "meta/stats/multinomial.h" #include "meta/succinct/sarray.h" #include "meta/util/progress.h" @@ -222,30 +223,59 @@ class feature_selector template void calc_probs(const LabeledDatasetContainer& docs) { - uint64_t num_processed = 0; - - printing::progress prog{" > Calculating feature probs: ", docs.size()}; + using co_occur_t = decltype(co_occur_); + using term_prob_t = decltype(term_prob_); - for (const auto& instance : docs) + // local struct to encapsulate the reduced objects + struct prob_counts { - std::stringstream ss; - ss << docs.label(instance); - class_label lbl{ss.str()}; - - class_prob_.increment(lbl, 1); - - for (const auto& count : instance.weights) + prob_counts() = default; + prob_counts(const co_occur_t& p_co_occur, + const term_prob_t& p_term_prob) + : co_occur{p_co_occur}, term_prob{p_term_prob} { - term_id tid{count.first}; - - term_prob_.increment(tid, count.second); - co_occur_.increment(std::make_pair(lbl, tid), count.second); + // nothing } + prob_counts& operator+=(const prob_counts& other) + { + co_occur += other.co_occur; + term_prob += other.term_prob; + return *this; + } + co_occur_t co_occur; + term_prob_t term_prob; + }; - prog(++num_processed); - } + uint64_t num_processed = 0; + std::mutex prog_cls_mutex; + printing::progress prog{" > Calculating feature probs: ", docs.size()}; + auto counts = parallel::reduction( + docs.begin(), docs.end(), [&]() { return prob_counts{}; }, + [&](prob_counts& counts, + const typename LabeledDatasetContainer::instance_type& + instance) { + std::stringstream ss; + ss << docs.label(instance); + class_label lbl{ss.str()}; + for (const auto& w : instance.weights) + { + term_id tid{w.first}; + counts.term_prob.increment(tid, w.second); + counts.co_occur.increment(std::make_pair(lbl, tid), + w.second); + } + std::lock_guard lock{prog_cls_mutex}; + prog(++num_processed); + class_prob_.increment(lbl, 1); + }, + [&](prob_counts& result, const prob_counts& temp) { + result += temp; + }); prog.end(); + + term_prob_ = std::move(counts.term_prob); + co_occur_ = std::move(counts.co_occur); } /** diff --git a/include/meta/features/selector_factory.h b/include/meta/features/selector_factory.h index 36ffd01c9..7801fda32 100644 --- a/include/meta/features/selector_factory.h +++ b/include/meta/features/selector_factory.h @@ -103,8 +103,8 @@ make_selector(const cpptoml::table& config, const LabeledDatasetContainer& docs) throw selector_factory_exception{ "feature selection method required in [features] table"}; - auto features_per_class = static_cast( - table->get_as("features-per-class").value_or(20)); + auto features_per_class + = table->get_as("features-per-class").value_or(20); auto selector = selector_factory::get().create( *method, *table, docs.total_labels(), docs.total_features()); diff --git a/include/meta/hashing/hash_storage.h b/include/meta/hashing/hash_storage.h index 68794fa43..a901cdd9a 100644 --- a/include/meta/hashing/hash_storage.h +++ b/include/meta/hashing/hash_storage.h @@ -458,9 +458,10 @@ class storage_base * @param key The key to look for * @param hc The hash code for the key */ - uint64_t get_idx(const key_type& key, std::size_t hc) const + std::size_t get_idx(const key_type& key, + typename hash_type::result_type hc) const { - probing_strategy strategy{hc, as_derived().capacity()}; + probing_strategy strategy(hc, as_derived().capacity()); auto idx = strategy.probe(); while (as_derived().occupied(idx) && !as_derived().equal(idx, hc, key)) { diff --git a/include/meta/hashing/hashes/farm_hash.h b/include/meta/hashing/hashes/farm_hash.h index cd4c1544e..7f97c2a09 100644 --- a/include/meta/hashing/hashes/farm_hash.h +++ b/include/meta/hashing/hashes/farm_hash.h @@ -247,7 +247,7 @@ class farm_hash } public: - using result_type = std::size_t; + using result_type = uint64_t; farm_hash() : buf_pos_{reinterpret_cast(buffer_.data())}, mixed_{false} @@ -356,8 +356,7 @@ class farm_hash_seeded : public farm_hash inline explicit operator result_type() { - uint64_t result - = static_cast(static_cast(*this)); + auto result = static_cast(static_cast(*this)); return farm::hash_len_16(result - seed_.low, seed_.high); } }; diff --git a/include/meta/hashing/hashes/murmur_hash.h b/include/meta/hashing/hashes/murmur_hash.h index 461398ada..caa5472bc 100644 --- a/include/meta/hashing/hashes/murmur_hash.h +++ b/include/meta/hashing/hashes/murmur_hash.h @@ -91,10 +91,10 @@ class murmur_hash<4> } public: - using result_type = std::size_t; + using result_type = uint32_t; - murmur_hash(std::size_t seed) - : out_{static_cast(seed)}, buflen_{0}, total_length_{0} + murmur_hash(result_type seed) + : out_{seed}, buflen_{0}, total_length_{0} { } @@ -132,7 +132,7 @@ class murmur_hash<4> } } - explicit operator std::size_t() + explicit operator result_type() { uint32_t k1 = 0; switch (buflen_ & 3) @@ -197,7 +197,7 @@ class murmur_hash<8> } public: - using result_type = std::size_t; + using result_type = uint64_t; murmur_hash(uint64_t seed) : h1_{seed}, h2_{seed}, buflen_{0}, total_length_{0} @@ -239,7 +239,7 @@ class murmur_hash<8> } } - explicit operator std::size_t() + explicit operator result_type() { uint64_t k1 = 0; uint64_t k2 = 0; diff --git a/include/meta/hashing/perfect_hash.h b/include/meta/hashing/perfect_hash.h index 2fc57cfdf..40ca9b051 100644 --- a/include/meta/hashing/perfect_hash.h +++ b/include/meta/hashing/perfect_hash.h @@ -43,7 +43,7 @@ class perfect_hash using meta::hashing::hash_append; farm_hash_seeded hasher{bucket_seed_}; hash_append(hasher, key); - auto hash = static_cast(hasher); + auto hash = static_cast(hasher); auto bucket_id = hash % seeds_.size(); auto seed = seeds_[bucket_id]; auto pos = farm::hash_len_16(hash, seed) % num_bins_; diff --git a/include/meta/hashing/perfect_hash_builder.h b/include/meta/hashing/perfect_hash_builder.h index c15dbe8ca..5b50361f8 100644 --- a/include/meta/hashing/perfect_hash_builder.h +++ b/include/meta/hashing/perfect_hash_builder.h @@ -112,10 +112,10 @@ class perfect_hash_builder struct hashed_key { - std::size_t idx; + uint64_t idx; K key; - hashed_key(std::size_t index, const K& akey) : idx{index}, key{akey} + hashed_key(uint64_t index, const K& akey) : idx{index}, key{akey} { // nothing } diff --git a/include/meta/hashing/perfect_hash_builder.tcc b/include/meta/hashing/perfect_hash_builder.tcc index 6a8527f33..14f920512 100644 --- a/include/meta/hashing/perfect_hash_builder.tcc +++ b/include/meta/hashing/perfect_hash_builder.tcc @@ -34,7 +34,7 @@ namespace mph template struct bucket_record { - std::size_t idx; + uint64_t idx; std::vector keys; void merge_with(bucket_record&& other) @@ -84,12 +84,12 @@ template using chunk_iterator = util::chunk_iterator>; template -std::size_t hash(const K& key, uint64_t seed) +farm_hash_seeded::result_type hash(const K& key, uint64_t seed) { using meta::hashing::hash_append; farm_hash_seeded hasher{seed}; hash_append(hasher, key); - return static_cast(hasher); + return static_cast(hasher); } } @@ -308,10 +308,10 @@ void perfect_hash_builder::merge_chunks_by_bucket_size() namespace mph { template -std::vector hashes_for_bucket(const mph::bucket_record& bucket, - std::size_t seed) +std::vector hashes_for_bucket(const mph::bucket_record& bucket, + uint64_t seed) { - std::vector hashes(bucket.keys.size()); + std::vector hashes(bucket.keys.size()); std::transform(bucket.keys.begin(), bucket.keys.end(), hashes.begin(), [&](const K& key) { @@ -325,16 +325,16 @@ std::vector hashes_for_bucket(const mph::bucket_record& bucket, template void hashes_to_indices(ForwardIterator begin, ForwardIterator end, - OutputIterator output, std::size_t seed, std::size_t mod) + OutputIterator output, uint64_t seed, std::size_t mod) { - std::transform(begin, end, output, [&](const std::size_t& key) + std::transform(begin, end, output, [&](uint64_t key) { return farm::hash_len_16(key, seed) % mod; }); } -inline bool insert_bucket(std::vector& indices, - std::vector& occupied_slots, std::size_t idx, +inline bool insert_bucket(std::vector& indices, + std::vector& occupied_slots, uint64_t idx, uint16_t seed, util::disk_vector& seeds) { auto iit = indices.begin(); @@ -384,13 +384,13 @@ void perfect_hash_builder::construct_perfect_hash() auto hashes = mph::hashes_for_bucket(bucket, bucket_seed_); - std::vector indices(bucket.keys.size()); + std::vector indices(bucket.keys.size()); bool success = false; const uint16_t max_probes = std::numeric_limits::max(); for (uint16_t i = 0; i < max_probes && !success; ++i) { - auto seed = static_cast(i); + auto seed = static_cast(i); mph::hashes_to_indices(hashes.begin(), hashes.end(), indices.begin(), seed, num_bins); diff --git a/include/meta/hashing/probing.h b/include/meta/hashing/probing.h index 30898a84c..bcbda9996 100644 --- a/include/meta/hashing/probing.h +++ b/include/meta/hashing/probing.h @@ -27,7 +27,7 @@ namespace probing class linear { public: - linear(uint64_t hash, uint64_t capacity) : hash_{hash}, capacity_{capacity} + linear(std::size_t hash, std::size_t capacity) : hash_{hash}, capacity_{capacity} { hash_ %= capacity_; } @@ -35,20 +35,20 @@ class linear /** * @return the next index to probe in the table */ - uint64_t probe() + std::size_t probe() { return hash_++ % capacity_; } private: - uint64_t hash_; - uint64_t capacity_; + std::size_t hash_; + std::size_t capacity_; }; class linear_nomod { public: - linear_nomod(uint64_t hash, uint64_t capacity) + linear_nomod(std::size_t hash, std::size_t capacity) : hash_{hash}, max_{capacity - 1} { hash_ %= capacity; @@ -57,7 +57,7 @@ class linear_nomod /** * @return the next index to probe in the table */ - uint64_t probe() + std::size_t probe() { hash_++; if (hash_ > max_) @@ -66,14 +66,14 @@ class linear_nomod } private: - uint64_t hash_; - uint64_t max_; + std::size_t hash_; + std::size_t max_; }; class binary { public: - binary(uint64_t hash, uint64_t capacity) + binary(std::size_t hash, std::size_t capacity) : hash_{hash}, step_{0}, capacity_{capacity} { hash_ %= capacity; @@ -82,7 +82,7 @@ class binary /** * @return the next index to probe in the table */ - uint64_t probe() + std::size_t probe() { // discard hashes that fall off of the table for (; (hash_ ^ step_) >= capacity_; ++step_) @@ -91,9 +91,9 @@ class binary } private: - uint64_t hash_; - uint64_t step_; - uint64_t capacity_; + std::size_t hash_; + std::size_t step_; + std::size_t capacity_; }; template @@ -104,9 +104,9 @@ class binary_hybrid static_assert(Alignment > sizeof(probe_entry), "Alignment should be larger than sizeof(T)"); - const static uint64_t block_size = Alignment / sizeof(probe_entry); + const static std::size_t block_size = Alignment / sizeof(probe_entry); - binary_hybrid(uint64_t hash, uint64_t capacity) + binary_hybrid(std::size_t hash, std::size_t capacity) : hash_{hash}, step_{0}, max_{capacity - 1} { hash_ %= capacity; @@ -126,7 +126,7 @@ class binary_hybrid } } - uint64_t probe() + std::size_t probe() { if (META_LIKELY(step_ < block_size)) { @@ -141,10 +141,10 @@ class binary_hybrid } private: - uint64_t hash_; - uint64_t step_; - uint64_t idx_; - uint64_t max_; + std::size_t hash_; + std::size_t step_; + std::size_t idx_; + std::size_t max_; }; // http://stackoverflow.com/questions/2348187 @@ -152,7 +152,7 @@ class binary_hybrid class quadratic { public: - quadratic(uint64_t hash, uint64_t capacity) + quadratic(std::size_t hash, std::size_t capacity) : hash_{hash}, capacity_{capacity}, step_{0} { hash_ &= (capacity_ - 1); @@ -162,7 +162,7 @@ class quadratic * @note This strategy only will work for power-of-2 capacities! * @return the next index to probe in the table */ - uint64_t probe() + std::size_t probe() { auto next = (hash_ + (step_ * (step_ + 1)) / 2) & (capacity_ - 1); ++step_; @@ -170,9 +170,9 @@ class quadratic } private: - uint64_t hash_; - uint64_t capacity_; - uint64_t step_; + std::size_t hash_; + std::size_t capacity_; + std::size_t step_; }; } } diff --git a/include/meta/index/chunk_reader.h b/include/meta/index/chunk_reader.h index 33e0c4792..71ed2edae 100644 --- a/include/meta/index/chunk_reader.h +++ b/include/meta/index/chunk_reader.h @@ -98,10 +98,12 @@ class postings_record * Represents an on-disk chunk to be merged with multi-way merge sort. Each * chunk_reader stores the file it's reading from, the total bytes needed * to be read, and the current number of bytes read, as well as buffers in - * one postings_record. + * one postings_record. When it reaches the end its file, the file will be + * destroyed. */ template -using chunk_reader = util::chunk_iterator>; +using chunk_reader + = util::destructive_chunk_iterator>; /** * Performs a multi-way merge sort of all of the provided chunks, writing diff --git a/include/meta/index/disk_index.h b/include/meta/index/disk_index.h index 3dcf4986c..12b699997 100644 --- a/include/meta/index/disk_index.h +++ b/include/meta/index/disk_index.h @@ -72,12 +72,14 @@ class disk_index * @param d_id * @return the actual name of this document */ + META_DEPRECATED("use metadata() instead") std::string doc_name(doc_id d_id) const; /** * @param d_id * @return the path to the file containing this document */ + META_DEPRECATED("use metadata() instead") std::string doc_path(doc_id d_id) const; /** @@ -134,6 +136,17 @@ class disk_index */ corpus::metadata metadata(doc_id d_id) const; + /** + * @param d_id The document to fetch the metadata field for + * @param name The name of the metadata field to be returned + * @return the metadata field value, if it exists + */ + template + util::optional metadata(doc_id d_id, const std::string& name) const + { + return metadata(d_id).get(name); + } + /** * @param d_id * @return the number of unique terms in d_id diff --git a/include/meta/index/postings_buffer.h b/include/meta/index/postings_buffer.h index 5f5446dcf..a43005f16 100644 --- a/include/meta/index/postings_buffer.h +++ b/include/meta/index/postings_buffer.h @@ -299,7 +299,7 @@ template void hash_append(HashAlgorithm& h, const postings_buffer& pb) { - using util::hash_append; + using hashing::hash_append; hash_append(h, pb.primary_key()); } } diff --git a/include/meta/index/postings_data.h b/include/meta/index/postings_data.h index c7c78fad2..992d2e5cf 100644 --- a/include/meta/index/postings_data.h +++ b/include/meta/index/postings_data.h @@ -157,7 +157,8 @@ class postings_data * @param in The stream to read from * @return the number of bytes read in consuming this postings data */ - uint64_t read_packed(std::istream& in); + template + uint64_t read_packed(InputStream& in); /** * @return the term_id for this postings_data diff --git a/include/meta/index/postings_data.tcc b/include/meta/index/postings_data.tcc index 846be2b3f..5a915d65c 100644 --- a/include/meta/index/postings_data.tcc +++ b/include/meta/index/postings_data.tcc @@ -194,13 +194,15 @@ uint64_t length(const T& elem, } template +template uint64_t postings_data::read_packed( - std::istream& in) + InputStream& in) { - if (in.get() == EOF) + if (in.peek() == EOF) + { + in.get(); return 0; - else - in.unget(); + } auto bytes = io::packed::read(in, p_id_); diff --git a/include/meta/index/ranker/all.h b/include/meta/index/ranker/all.h index a08593c38..8a1fe0e04 100644 --- a/include/meta/index/ranker/all.h +++ b/include/meta/index/ranker/all.h @@ -5,3 +5,5 @@ #include "meta/index/ranker/lm_ranker.h" #include "meta/index/ranker/okapi_bm25.h" #include "meta/index/ranker/pivoted_length.h" +#include "meta/index/ranker/kl_divergence_prf.h" +#include "meta/index/ranker/rocchio.h" diff --git a/include/meta/index/ranker/kl_divergence_prf.h b/include/meta/index/ranker/kl_divergence_prf.h new file mode 100644 index 000000000..2a9e9f6c3 --- /dev/null +++ b/include/meta/index/ranker/kl_divergence_prf.h @@ -0,0 +1,101 @@ +/** + * @file kl_divergence_prf.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_INDEX_KL_DIVERGENCE_PRF_H_ +#define META_INDEX_KL_DIVERGENCE_PRF_H_ + +#include "meta/index/ranker/lm_ranker.h" +#include "meta/index/ranker/ranker_factory.h" + +namespace meta +{ +namespace index +{ + +/** + * Implements the two-component mixture model for pseudo-relevance + * feedback in the KL-divergence retrieval model. + * + * @see http://dl.acm.org/citation.cfm?id=502654 + * + * Required config parameters: + * ~~~toml + * [ranker] + * method = "kl-divergence-prf" + * ~~~ + * + * Optional config parameters: + * ~~~toml + * alpha = 0.5 # query interpolation parameter + * lambda = 0.5 # mixture model interpolation parameter + * k = 10 # number of feedback documents to retrieve + * max-terms = 50 # maximum number of feedback terms to use + * + * [ranker.feedback] + * method = "dirichlet-prior" # the initial model used to retrieve documents + * # other parameters for that initial retrieval method + * ~~~ + */ +class kl_divergence_prf : public ranker +{ + public: + /// Identifier for this ranker. + const static util::string_view id; + + /// Default value of alpha, the query interpolation parameter + const static constexpr float default_alpha = 0.5; + + /// Default value for lambda, the mixture model interpolation parameter + const static constexpr float default_lambda = 0.5; + + /// Default value for k, the number of feedback documents to retrieve + const static constexpr uint64_t default_k = 10; + + /** + * Default value for max_terms, the number of feedback terms to + * interpolate into the query model. + */ + const static constexpr uint64_t default_max_terms = 50; + + kl_divergence_prf(std::shared_ptr fwd); + + kl_divergence_prf(std::shared_ptr fwd, + std::unique_ptr&& initial_ranker, + float alpha = default_alpha, + float lambda = default_lambda, uint64_t k = default_k, + uint64_t max_terms = default_max_terms); + + kl_divergence_prf(std::istream& in); + + void save(std::ostream& out) const override; + + std::vector + rank(ranker_context& ctx, uint64_t num_results, + const filter_function_type& filter) override; + + private: + std::shared_ptr fwd_; + std::unique_ptr initial_ranker_; + const float alpha_; + const float lambda_; + const uint64_t k_; + const uint64_t max_terms_; +}; + +/** + * Specialization of the factory method used to create kl_divergence_prf + * rankers. + */ +template <> +std::unique_ptr +make_ranker(const cpptoml::table& global, + const cpptoml::table& local); +} +} +#endif diff --git a/include/meta/index/ranker/lm_ranker.h b/include/meta/index/ranker/lm_ranker.h index 89e9568bc..08a18f3bb 100644 --- a/include/meta/index/ranker/lm_ranker.h +++ b/include/meta/index/ranker/lm_ranker.h @@ -22,7 +22,7 @@ namespace index * scoring methods described in "A Study of Smoothing Methods for Language * Models Applied to Ad Hoc Information Retrieval" by Zhai and Lafferty, 2001. */ -class language_model_ranker : public ranker +class language_model_ranker : public ranking_function { public: /// The identifier for this ranker. diff --git a/include/meta/index/ranker/okapi_bm25.h b/include/meta/index/ranker/okapi_bm25.h index f5b7a846a..6cb33bd70 100644 --- a/include/meta/index/ranker/okapi_bm25.h +++ b/include/meta/index/ranker/okapi_bm25.h @@ -33,7 +33,7 @@ namespace index * k3 = 500.0 * ~~~ */ -class okapi_bm25 : public ranker +class okapi_bm25 : public ranking_function { public: /// The identifier for this ranker. diff --git a/include/meta/index/ranker/pivoted_length.h b/include/meta/index/ranker/pivoted_length.h index d5cf4b45b..b3c12bd55 100644 --- a/include/meta/index/ranker/pivoted_length.h +++ b/include/meta/index/ranker/pivoted_length.h @@ -33,7 +33,7 @@ namespace index * s = 0.2 * ~~~ */ -class pivoted_length : public ranker +class pivoted_length : public ranking_function { public: /// Identifier for this ranker. diff --git a/include/meta/index/ranker/ranker.h b/include/meta/index/ranker/ranker.h index fa2f4fea2..44fda16f7 100644 --- a/include/meta/index/ranker/ranker.h +++ b/include/meta/index/ranker/ranker.h @@ -12,8 +12,8 @@ #include #include -#include "meta/meta.h" #include "meta/index/inverted_index.h" +#include "meta/meta.h" namespace meta { @@ -78,6 +78,28 @@ struct postings_context } }; +inline term_id get_term_id(disk_index& inv, const std::string& term) +{ + return inv.get_term_id(term); +} + +inline term_id get_term_id(disk_index&, term_id tid) +{ + return tid; +} +} + +/** + * Stores a list of postings_stream and other relevant information for + * performing document-at-a-time ranking. You should not generally have to + * interact with this class unless implementing a new feedback method, in + * which case you should only have to construct it and pass it off to + * ranker::rank() directly afterward. + * + * ForwardIterator must dereference to a pair type (either std::pair or + * hashing::kv_pair) which has a key type of either std::string or term_id + * and a value type convertible to float. + */ struct ranker_context { template @@ -96,7 +118,7 @@ struct ranker_context typename std::decay::type>; query_length += kv_traits::value(count); - auto term = idx.get_term_id(kv_traits::key(count)); + auto term = detail::get_term_id(inv, kv_traits::key(count)); auto pstream = idx.stream_for(term); if (!pstream) continue; @@ -116,11 +138,10 @@ struct ranker_context } inverted_index& idx; - std::vector postings; + std::vector postings; float query_length; doc_id cur_doc; }; -} /** * Exception class for ranker interactions. @@ -159,7 +180,7 @@ class ranker score(inverted_index& idx, ForwardIterator begin, ForwardIterator end, uint64_t num_results = 10, Function&& filter = passthrough) { - detail::ranker_context ctx{idx, begin, end, filter}; + ranker_context ctx{idx, begin, end, filter}; return rank(ctx, num_results, filter); } @@ -170,15 +191,41 @@ class ranker * @param filter A filtering function to apply to each doc_id; returns * true if the document should be included in results */ - std::vector score(inverted_index& idx, - const corpus::document& query, - uint64_t num_results = 10, - const filter_function_type& filter - = [](doc_id) - { - return true; - }); + std::vector + score(inverted_index& idx, const corpus::document& query, + uint64_t num_results = 10, + const filter_function_type& filter = [](doc_id) { return true; }); + /** + * Default destructor. + */ + virtual ~ranker() = default; + + /** + * Saves the ranker to a stream. This should save the ranker's id, + * followed by any parameters needed for reconstruction. + */ + virtual void save(std::ostream& out) const = 0; + + /** + * Scores a query using a document-at-a-time strategy. You should not + * override this unless you desire a completely different ranking + * strategy than document-at-a-time, which might be the case if you are + * implementing a new pseudo-relevance feedback method. + * + * @param ctx The ranker_context holding the postings lists + * @param num_results The number of search results to return + * @param filter The filter function to be used + */ + virtual std::vector rank(ranker_context& ctx, + uint64_t num_results, + const filter_function_type& filter) + = 0; +}; + +class ranking_function : public ranker +{ + public: /** * Computes the contribution to the score of a document for a matched * query term. @@ -193,23 +240,10 @@ class ranker */ virtual float initial_score(const score_data& sd) const; - /** - * Default destructor. - */ - virtual ~ranker() = default; - - /** - * Saves the ranker to a stream. This should save the ranker's id, - * followed by any parameters needed for reconstruction. - */ - virtual void save(std::ostream& out) const = 0; - - private: - std::vector rank(detail::ranker_context& ctx, - uint64_t num_results, - const filter_function_type& filter); + virtual std::vector + rank(ranker_context& ctx, uint64_t num_results, + const filter_function_type& filter) override final; }; } } - #endif diff --git a/include/meta/index/ranker/ranker_factory.h b/include/meta/index/ranker/ranker_factory.h index 6dea704db..368d2cbfc 100644 --- a/include/meta/index/ranker/ranker_factory.h +++ b/include/meta/index/ranker/ranker_factory.h @@ -9,6 +9,7 @@ #ifndef META_RANKER_FACTORY_H_ #define META_RANKER_FACTORY_H_ +#include "meta/index/ranker/lm_ranker.h" #include "meta/index/ranker/ranker.h" #include "meta/util/factory.h" #include "meta/util/shim.h" @@ -29,11 +30,27 @@ namespace index * class directly to add their own rankers. */ class ranker_factory - : public util::factory + : public util::factory { + public: /// Friend the base ranker factory friend base_factory; + std::unique_ptr + create_lm(util::string_view identifier, const cpptoml::table& global, + const cpptoml::table& local) + { + auto rnk = base_factory::create(identifier, global, local); + if (auto der = dynamic_cast(rnk.get())) + { + rnk.release(); + return std::unique_ptr{der}; + } + throw std::invalid_argument{identifier.to_string() + + " is not a language_model_ranker"}; + } + private: /** * Constructor. @@ -52,10 +69,33 @@ class ranker_factory */ std::unique_ptr make_ranker(const cpptoml::table&); +/** + * Convenience method for creating a ranker using the factory. + * @param global The global configuration group (containing the index path) + * @param local The ranker configuration group itself + */ +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local); + +/** + * Convenience method for creating a language_model_ranker using the + * factory. + */ +std::unique_ptr make_lm_ranker(const cpptoml::table&); + +/** + * Convenience method for creating a language_model_ranker using the factory. + * @param global The global configuration group (containing the index path) + * @param local The ranker configuration group itself + */ +std::unique_ptr +make_lm_ranker(const cpptoml::table& global, const cpptoml::table& local); + /** * Factory method for creating a ranker. This should be specialized if * your given ranker requires special construction behavior (e.g., - * reading parameters). + * reading parameters) that requires only the ranker-specific configuration + * (this will be the case almost all of the time). */ template std::unique_ptr make_ranker(const cpptoml::table&) @@ -63,6 +103,20 @@ std::unique_ptr make_ranker(const cpptoml::table&) return make_unique(); } +/** + * Factory method for creating a ranker. This should be specialized if your + * given ranker requires special construction behavior that includes + * reading parameter values from the global configuration as well as the + * ranker-specific configuration. + */ +template +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local) +{ + (void)global; + return make_ranker(local); +} + /** * Factory that is responsible for loading rankers from streams. Clients * should use the register_ranker method instead of this class directly to @@ -70,8 +124,22 @@ std::unique_ptr make_ranker(const cpptoml::table&) */ class ranker_loader : public util::factory { + public: friend base_factory; + std::unique_ptr + create_lm(util::string_view identifier, std::istream& in) + { + auto rnk = base_factory::create(identifier, in); + if (auto lmr = dynamic_cast(rnk.get())) + { + rnk.release(); + return std::unique_ptr{lmr}; + } + throw std::invalid_argument{ + "loaded ranker is not a language_model_ranker"}; + } + private: /** * Constructor for setting up the singleton ranker_loader. @@ -90,6 +158,11 @@ class ranker_loader : public util::factory */ std::unique_ptr load_ranker(std::istream&); +/** + * Convenience method for loading a language_model_ranker using the factory. + */ +std::unique_ptr load_lm_ranker(std::istream&); + /** * Factory method for loading a ranker. This should be specialized if your * given ranker requires special construction behavior. Otherwise, it is @@ -108,7 +181,10 @@ std::unique_ptr load_ranker(std::istream& in) template void register_ranker() { - ranker_factory::get().add(Ranker::id, make_ranker); + ranker_factory::get().add(Ranker::id, [](const cpptoml::table& global, + const cpptoml::table& local) { + return make_ranker(global, local); + }); ranker_loader::get().add(Ranker::id, load_ranker); } } diff --git a/include/meta/index/ranker/rocchio.h b/include/meta/index/ranker/rocchio.h new file mode 100644 index 000000000..8d4894bf8 --- /dev/null +++ b/include/meta/index/ranker/rocchio.h @@ -0,0 +1,101 @@ +/** + * @file rocchio.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_INDEX_ROCCHIO_H_ +#define META_INDEX_ROCCHIO_H_ + +#include "meta/index/ranker/ranker_factory.h" + +namespace meta +{ +namespace index +{ + +/** + * Implements the Rocchio algorithm for pseudo-relevance feedback. This + * implementation considers only positive documents for feedback. The top + * `max_terms` from the centroid of the feedback set are selected according + * to their weights provided by the wrapped ranker's `score_one` function. + * These are then interpolated into the query in *count space*, and then + * the results from running the wrapped ranker on the new query are + * returned. + * + * Required config parameters: + * ~~~toml + * [ranker] + * method = "rocchio" + * ~~~ + * + * Optional config parameters: + * ~~~toml + * alpha = 1.0 # original query weight parameter + * beta = 1.0 # feedback document weight parameter + * k = 10 # number of feedback documents to retrieve + * max-terms = 50 # maximum number of feedback terms to use + * [ranker.feedback] + * method = # whatever ranker method you want to wrap + * # other parameters for that ranker + * ~~~ + * + * @see https://en.wikipedia.org/wiki/Rocchio_algorithm + */ +class rocchio : public ranker +{ + public: + /// Identifier for this ranker. + const static util::string_view id; + + /// Default value of alpha, the original query weight parameter + const static constexpr float default_alpha = 1.0f; + + /// Default value of beta, the positive document weight parameter + const static constexpr float default_beta = 0.8f; + + /// Default value for k, the number of feedback documents to retrieve + const static constexpr uint64_t default_k = 10; + + /** + * Default value for max_terms, the number of new terms to add to the + * new query. + */ + const static constexpr uint64_t default_max_terms = 50; + + rocchio(std::shared_ptr fwd); + + rocchio(std::shared_ptr fwd, + std::unique_ptr&& initial_ranker, + float alpha = default_alpha, float beta = default_beta, + uint64_t k = default_k, uint64_t max_terms = default_max_terms); + + rocchio(std::istream& in); + + void save(std::ostream& out) const override; + + std::vector + rank(ranker_context& ctx, uint64_t num_results, + const filter_function_type& filter) override; + + private: + std::shared_ptr fwd_; + std::unique_ptr initial_ranker_; + const float alpha_; + const float beta_; + const uint64_t k_; + const uint64_t max_terms_; +}; + +/** + * Specialization of the factory method used to create rocchio rankers. + */ +template <> +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local); +} +} +#endif diff --git a/include/meta/index/ranker/unigram_mixture.h b/include/meta/index/ranker/unigram_mixture.h new file mode 100644 index 000000000..29bc1e709 --- /dev/null +++ b/include/meta/index/ranker/unigram_mixture.h @@ -0,0 +1,111 @@ +/** + * @file unigram_mixture.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef UNIGRAM_MIXTURE_H_ +#define UNIGRAM_MIXTURE_H_ + +#include +#include +#include + +#include "meta/config.h" +#include "meta/learn/dataset_view.h" +#include "meta/stats/multinomial.h" + +namespace meta +{ +namespace index +{ +namespace feedback +{ + +/** + * @param dset A collection of documents to fit a language model to + * @return the maximum likelihood estimate for the language model + */ +stats::multinomial maximum_likelihood(const learn::dataset_view& dset) +{ + stats::multinomial model; + for (const auto& inst : dset) + { + for (const auto& weight : inst.weights) + { + model.increment(weight.first, weight.second); + } + } + return model; +} + +struct training_options +{ + /// The fixed probability of the background model + double lambda = 0.5; + /// The maximum number of iterations for running EM + uint64_t max_iter = 50; + /// The convergence threshold as the relative change in log likelihood + double delta = 1e-5; +}; + +/** + * Learns the feedback model component of a two-component unigram mixture + * model. The BackgroundModel is a unary function that returns the + * probability of a term. This is used as the first component of the + * mixture model, which has fixed probability options.lambda of being + * selected. This function used the EM algorithm to fit the second + * component language model and returns it. + * + * @param background The background language model + * @param dset The feedback documents to fit the feedback model to + * @param options The training options for the EM algorithm + * @return the feedback model + */ +template +stats::multinomial +unigram_mixture(BackgroundModel&& background, const learn::dataset_view& dset, + const training_options& options = {}) +{ + auto feedback = maximum_likelihood(dset); + auto old_ll = std::numeric_limits::lowest(); + auto relative_change = std::numeric_limits::max(); + + for (uint64_t i = 1; + i <= options.max_iter && relative_change >= options.delta; ++i) + { + stats::multinomial model; + double ll = 0; + + for (const auto& inst : dset) + { + for (const auto& weight : inst.weights) + { + auto p_wc = background(weight.first); + auto p_wf = feedback.probability(weight.first); + + auto numerator = options.lambda * p_wc; + auto denominator = numerator + (1.0 - options.lambda) * p_wf; + + auto p_zw = numerator / denominator; + + model.increment(weight.first, (1.0 - p_zw) * weight.second); + ll += weight.second * std::log(denominator); + } + } + + feedback = model; + assert(ll > old_ll); + relative_change = (old_ll - ll) / old_ll; + old_ll = ll; + } + + return feedback; +} +} +} +} +#endif diff --git a/include/meta/index/score_data.h b/include/meta/index/score_data.h index 2f9e18750..e7434e0bf 100644 --- a/include/meta/index/score_data.h +++ b/include/meta/index/score_data.h @@ -80,15 +80,16 @@ struct score_data * @param p_avg_dl The average doc length in the index * @param p_num_docs The number of docs in the index * @param p_total_terms The total number of terms in the index - * @param p_query The current query + * @param p_query_length The current query length (e.g. the total number of + * words in the query) */ score_data(inverted_index& p_idx, float p_avg_dl, uint64_t p_num_docs, - uint64_t p_total_terms, float p_length) + uint64_t p_total_terms, float p_query_length) : idx(p_idx), // gcc no non-const ref init from brace init list avg_dl{p_avg_dl}, num_docs{p_num_docs}, total_terms{p_total_terms}, - query_length{p_length} + query_length{p_query_length} { /* nothing */ } diff --git a/include/meta/io/mmap_file.h b/include/meta/io/mmap_file.h index 42b73828d..23a1b63c5 100644 --- a/include/meta/io/mmap_file.h +++ b/include/meta/io/mmap_file.h @@ -14,6 +14,7 @@ #include #include "meta/config.h" +#include "meta/util/optional.h" namespace meta { @@ -99,6 +100,27 @@ class mmap_file_exception : public std::runtime_error public: using std::runtime_error::runtime_error; }; + +/** + * A stream for use with io::packed that reads from a memory mapped file. + */ +class mmap_ifstream +{ + public: + mmap_ifstream() = default; + mmap_ifstream(mmap_ifstream&&) = default; + mmap_ifstream& operator=(mmap_ifstream&&) = default; + mmap_ifstream(const std::string& filename); + + bool is_open() const; + int peek() const; + int get(); + void close(); + + private: + util::optional file_; + std::size_t pos_; +}; } } diff --git a/include/meta/io/packed.h b/include/meta/io/packed.h index a91b03712..5c1c0f888 100644 --- a/include/meta/io/packed.h +++ b/include/meta/io/packed.h @@ -177,6 +177,19 @@ uint64_t packed_write(OutputStream& stream, return packed_write(stream, static_cast(value)); } +/** + * Writes a pair type in a packed representation. + * + * @param os The stream to write to + * @param value The value to write + * @return the number of bytes used to write out the value + */ +template +uint64_t packed_write(OutputSteam& os, const std::pair& pr) +{ + return packed_write(os, pr.first) + packed_write(os, pr.second); +} + /** * Writes a vector type in a packed representation. * @@ -342,6 +355,19 @@ uint64_t packed_read(InputStream& stream, util::identifier& value) return packed_read(stream, static_cast(value)); } +/** + * Reads a pair type from a packed representation. + * + * @param is The stream to read from + * @param value The value to write + * @return the number of bytes read + */ +template +uint64_t packed_read(InputStream& is, std::pair& pr) +{ + return packed_read(is, pr.first) + packed_read(is, pr.second); +} + /** * Reads a vector type from a packed representation. * diff --git a/include/meta/learn/dataset.h b/include/meta/learn/dataset.h index 0c18a3255..cef67f0b3 100644 --- a/include/meta/learn/dataset.h +++ b/include/meta/learn/dataset.h @@ -34,17 +34,18 @@ class dataset public: using instance_type = instance; using const_iterator = std::vector::const_iterator; - using iterator = const_iterator; + using iterator = std::vector::iterator; using size_type = std::vector::size_type; /** * Creates an in-memory dataset from a forward_index and a range of * doc_ids, represented as iterators. */ - template + template dataset(std::shared_ptr idx, ForwardIterator begin, - ForwardIterator end) - : total_features_{idx->unique_terms()} + ForwardIterator end, ProgressTrait = ProgressTrait{}) + : total_features_(idx->unique_terms()) { auto size = static_cast(std::distance(begin, end)); @@ -53,7 +54,8 @@ class dataset instances_.reserve(size); - printing::progress progress{" > Loading instances into memory: ", size}; + typename ProgressTrait::type progress{ + " > Loading instances into memory: ", size}; for (auto doc = 0_inst_id; begin != end; ++begin, ++doc) { progress(doc); @@ -70,15 +72,17 @@ class dataset * the knn classifier. The id field of the instance_types stored within * the dataset is a document_id. */ - template + template dataset(std::shared_ptr idx, ForwardIterator begin, - ForwardIterator end) - : total_features_{idx->unique_terms()} + ForwardIterator end, ProgressTrait = ProgressTrait{}) + : total_features_(idx->unique_terms()) { auto size = static_cast(std::distance(begin, end)); instances_.reserve(size); - printing::progress progress{" > Loading instances into memory: ", size}; + typename ProgressTrait::type progress{ + " > Loading instances into memory: ", size}; for (uint64_t pos = 0; begin != end; ++begin, ++pos) { progress(pos); @@ -120,7 +124,15 @@ class dataset /** * @return an iterator to the first instance */ - iterator begin() const + const_iterator begin() const + { + return instances_.begin(); + } + + /** + * @return an iterator to the first instance + */ + iterator begin() { return instances_.begin(); } @@ -128,7 +140,15 @@ class dataset /** * @return an iterator to one past the end of the dataset */ - iterator end() const + const_iterator end() const + { + return instances_.end(); + } + + /** + * @return an iterator to one past the end of the dataset + */ + iterator end() { return instances_.end(); } diff --git a/include/meta/learn/dataset_view.h b/include/meta/learn/dataset_view.h index 031296328..fad119143 100644 --- a/include/meta/learn/dataset_view.h +++ b/include/meta/learn/dataset_view.h @@ -38,6 +38,7 @@ class dataset_view using size_type = dataset::size_type; class iterator; + using const_iterator = iterator; dataset_view(const dataset& dset) : dataset_view{dset, std::mt19937_64{std::random_device{}()}} @@ -45,17 +46,36 @@ class dataset_view // nothing } + dataset_view(const dataset& dset, dataset::const_iterator begin, + dataset::const_iterator end) + : dataset_view{dset, begin, end, + std::mt19937_64{std::random_device{}()}} + { + // nothing + } + template dataset_view(const dataset& dset, RandomEngine&& rng) + : dataset_view{dset, dset.begin(), dset.end(), + std::forward(rng)} + { + // nothing + } + + template + dataset_view(const dataset& dset, dataset::const_iterator begin, + dataset::const_iterator end, RandomEngine&& rng) : dset_{&dset}, - indices_(dset.size()), + indices_(static_cast(std::distance(begin, end))), rng_(std::forward(rng)) { - std::iota(indices_.begin(), indices_.end(), 0); + std::iota(indices_.begin(), indices_.end(), + std::distance(dset.begin(), begin)); } // subset constructor - dataset_view(const dataset_view& dv, iterator first, iterator last) + dataset_view(const dataset_view& dv, const_iterator first, + const_iterator last) : dset_{dv.dset_}, rng_{dv.rng_} { assert(first <= last); @@ -175,7 +195,6 @@ class dataset_view const dataset* dset_; std::vector::const_iterator it_; }; - using const_iterator = iterator; iterator begin() const { diff --git a/include/meta/learn/instance.h b/include/meta/learn/instance.h index bf864b412..a11e8cfb8 100644 --- a/include/meta/learn/instance.h +++ b/include/meta/learn/instance.h @@ -60,7 +60,7 @@ struct instance /// the id within the dataset that contains this instance instance_id id; /// the weights of the features in this instance - const feature_vector weights; + feature_vector weights; }; } } diff --git a/include/meta/learn/transform.h b/include/meta/learn/transform.h new file mode 100644 index 000000000..be29fdb05 --- /dev/null +++ b/include/meta/learn/transform.h @@ -0,0 +1,161 @@ +/** + * @file dataset.h + * @author Chase Geigle + * + * All files in META are released under the MIT license. For more details, + * consult the file LICENSE in the root of the project. + */ + +#ifndef META_LEARN_TRANSFORM_H_ +#define META_LEARN_TRANSFORM_H_ + +#include "meta/index/ranker/ranker.h" +#include "meta/index/score_data.h" +#include "meta/learn/dataset.h" + +namespace meta +{ +namespace learn +{ + +/** + * Transformer for converting term frequency vectors into tf-idf weight + * vectors. This transformation is performed with respect to a specific + * index::inverted_index that defines the term statistics, and with respect + * to an index::ranking_function that defines the "tf-idf" weight (via its + * score_one() function). + * + * For example, one can construct a tfidf_transformer with an + * inverted index and an okapi_bm25 ranker to get tf-idf vectors using + * Okapi BM25's definitions of tf and idf. + * + * Some caveats to be aware of: + * + * 1. if your ranker uses extra information that isn't present in score_data + * (e.g. by using score_data.d_id and querying something), this will only + * work if your instance ids directly correspond to doc ids in the + * inverted index + * + * 2. tf-idf values are computed using statistics from the inverted_index. + * If this index contains your test set, the statistics are going to be + * computed including documents in your test set. If this is + * undesirable, create an inverted_index on just your training data and + * use that instead of one created on both the training and testing + * data. + * + * 3. This transformation only makes sense if your instances' weight + * vectors are actually term frequency vectors. If they aren't, the + * assumptions here that every entry in every weight vector can be + * safely converted to an integral value without rounding is violated. + */ +class tfidf_transformer +{ + public: + /** + * @param idx The index to use for term statistics + * @param r The ranker to use for defining the weights + */ + tfidf_transformer(index::inverted_index& idx, index::ranking_function& r) + : idx_(idx), + rnk_(r), + sdata_(idx, idx.avg_doc_length(), idx.num_docs(), + idx.total_corpus_terms(), 1) + { + sdata_.query_term_weight = 1.0f; + } + + /** + * @param inst The instance to transform + */ + void operator()(learn::instance& inst) + { + sdata_.d_id = doc_id{inst.id}; + sdata_.doc_size = static_cast(std::accumulate( + inst.weights.begin(), inst.weights.end(), 0.0, + [](double accum, const std::pair& val) { + return accum + val.second; + })); + sdata_.doc_unique_terms = inst.weights.size(); + for (auto& pr : inst.weights) + { + sdata_.t_id = term_id{pr.first}; + sdata_.doc_count = idx_.doc_freq(sdata_.t_id); + sdata_.corpus_term_count = idx_.total_num_occurences(sdata_.t_id); + sdata_.doc_term_count = static_cast(pr.second); + + pr.second = rnk_.score_one(sdata_); + } + } + + private: + index::inverted_index& idx_; + index::ranking_function& rnk_; + index::score_data sdata_; +}; + +/** + * Transformer to normalize all unit vectors to unit length. + */ +class l2norm_transformer +{ + public: + void operator()(learn::instance& inst) const + { + auto norm = std::sqrt(std::accumulate( + inst.weights.begin(), inst.weights.end(), 0.0, + [](double accum, const std::pair& val) { + return accum + val.second * val.second; + })); + for (auto& pr : inst.weights) + pr.second /= norm; + } +}; + +/** + * Transforms the feature vectors of a dataset **in place** using the given + * transformation function. TransformFunction must have an operator() that + * takes a learn::instance by mutable reference and changes its + * feature values in-place. For example, a simple TransformFunction might + * be one that normalizes all of the feature vectors to be unit length. + * + * @param dset The dataset to be transformed + * @param trans The transformation function to be applied to all + * feature_vectors in dset + */ +template +void transform(dataset& dset, TransformFunction&& trans) +{ + for (auto& inst : dset) + trans(inst); +} + +/** + * Transforms the feature vectors of a dataset **in place** to be tf-idf + * features using the given index for term statistics and ranker for + * tf-idf weight definitions. + * + * @param dset The dataset to be transformed + * @param idx The inverted_index to use for term statistics like df + * @param rnk The ranker to use to define tf-idf weights (via its + * score_one()) + */ +void tfidf_transform(dataset& dset, index::inverted_index& idx, + index::ranking_function& rnk) +{ + tfidf_transformer transformer{idx, rnk}; + transform(dset, transformer); +} + +/** + * Transforms the feature vectors of a dataset **in place** to be unit + * length according to their L2 norm. + * + * @param dset The dataset to be transformed + */ +void l2norm_transform(dataset& dset) +{ + return transform(dset, l2norm_transformer{}); +} +} +} +#endif diff --git a/include/meta/lm/static_probe_map.h b/include/meta/lm/static_probe_map.h index e93b02549..184a8f55c 100644 --- a/include/meta/lm/static_probe_map.h +++ b/include/meta/lm/static_probe_map.h @@ -68,7 +68,8 @@ class static_probe_map /** * Helper function to create hasher and hash a list of word ids */ - uint64_t hash(const std::vector& tokens) const; + hashing::murmur_hash<>::result_type + hash(const std::vector& tokens) const; /// Helper function to find a node given the hash value util::optional find_hash(uint64_t hashed) const; diff --git a/include/meta/parallel/algorithm.h b/include/meta/parallel/algorithm.h new file mode 100644 index 000000000..f3222b80d --- /dev/null +++ b/include/meta/parallel/algorithm.h @@ -0,0 +1,145 @@ +/** + * @file algorithm.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_PARALLEL_ALGORITHM_H_ +#define META_PARALLEL_ALGORITHM_H_ + +#include + +#include "meta/config.h" +#include "meta/parallel/parallel_for.h" +#include "meta/parallel/thread_pool.h" + +namespace meta +{ +namespace parallel +{ + +/** + * Performs a reduction across a set of mapped values in parallel. This + * algorithm has three distinct phases: + * + * 1. Initialization: each thread invokes the LocalStorage functor, which + * should return the local storage needed to perform the reduction + * across the set of values that will be assigned to a particular + * thread. This is done *within* the thread to ensure that memory + * allocations occur within the worker thread (so it can take advantage + * of thread-local heap structures in, for example, jemalloc). + * + * 2. Mapping: each thread invokes the MappingFunction functor, which is a + * *binary* operator that takes a mutable reference to the thread's + * local storage that was created using the LocalStorage functor as its + * first argument and the element in the iterator range (by const ref) + * as its second argument. It is *not* expected to return anything as + * the calculation results should be being placed in the thread's local + * storage. + * + * 3. Reduction: finally, the main thread will compute the final value of + * the reduction by applying ReductionFunction across the local storage + * for each of the threads. ReductionFunction is a *binary* functor that + * takes the return type of LocalStorage by *mutable reference* as the + * first argument and a *const reference* to an object of the same type + * as the second argument. It is *not* expected to return anything and + * instead should compute the reduction by modifying the first argument. + */ +template +typename std::result_of::type +reduction(Iterator begin, Iterator end, thread_pool& pool, LocalStorage&& ls_fn, + MappingFunction&& map_fn, ReductionFunction&& red_fn) +{ + using value_type = typename std::iterator_traits::value_type; + + auto futures + = for_each_block(begin, end, pool, [&](Iterator tbegin, Iterator tend) { + auto local_storage = ls_fn(); + std::for_each(tbegin, tend, [&](const value_type& val) { + map_fn(local_storage, val); + }); + return local_storage; + }); + + // reduction phase + auto local_storage = futures[0].get(); + for (auto it = ++futures.begin(); it != futures.end(); ++it) + { + red_fn(local_storage, it->get()); + } + return local_storage; +} + +template +typename std::result_of::type +reduction(Iterator begin, Iterator end, LocalStorage&& ls_fn, + MappingFunction&& map_fn, ReductionFunction&& red_fn) +{ + parallel::thread_pool pool; + return reduction(begin, end, pool, ls_fn, map_fn, red_fn); +} + +namespace detail +{ +template +void merge_sort(RandomIt begin, RandomIt end, thread_pool& pool, + std::size_t avail_threads, Compare comp) +{ + auto len = std::distance(begin, end); + if (avail_threads < 2 || len <= 1024) + { + std::sort(begin, end, comp); + return; + } + + auto mid = std::next(begin, len / 2); + auto t1 = pool.submit_task([&]() { + merge_sort(begin, mid, pool, avail_threads / 2 + avail_threads % 2, + comp); + }); + merge_sort(mid, end, pool, avail_threads / 2, comp); + t1.get(); + std::inplace_merge(begin, mid, end, comp); +} +} + +/** + * Runs a parallel merge sort, deferring to std::sort at small problem + * sizes. + * + * @param begin The beginning of the range + * @param end The end of the range + * @param pool The thread pool to use for running the sort + * @param comp The comparison function for the sort + */ +template +void sort(RandomIt begin, RandomIt end, thread_pool& pool, Compare comp) +{ + auto fut = pool.submit_task( + [&]() { detail::merge_sort(begin, end, pool, pool.size(), comp); }); + fut.get(); +} + +/** + * Runs a parallel merge sort, deferring to std::sort at small problem + * sizes. + * + * @param begin The beginning of the range + * @param end The end of the range + * @param pool The thread pool to use for running the sort + * @param comp The comparison function for the sort + */ +template +void sort(RandomIt begin, RandomIt end, thread_pool& pool) +{ + using value_type = typename std::iterator_traits::value_type; + return sort(begin, end, pool, std::less{}); +} +} +} +#endif diff --git a/include/meta/parallel/parallel_for.h b/include/meta/parallel/parallel_for.h index 17240eeaf..cf9a9dbdb 100644 --- a/include/meta/parallel/parallel_for.h +++ b/include/meta/parallel/parallel_for.h @@ -24,31 +24,22 @@ namespace parallel { /** - * Runs the given function on the range denoted by begin and end in parallel. - * @param begin The first element to operate on - * @param end One past the last element to operate on - * @param func The function to perform on each element - */ -template -void parallel_for(Iterator begin, Iterator end, Function func) -{ - thread_pool pool; - parallel_for(begin, end, pool, func); -} - -/** - * Runs the given function on the range denoted by begin and end in parallel. - * @param begin The first element to operate on - * @param end One past the last element to operate on - * @param pool The thread pool to use - * @param func The function to perform on each element + * Runs the given function on sub-ranges of [begin, end) in parallel. + * @param begin The beginning of the range + * @param end The ending of the range + * @param pool The thread_pool to run on + * @param fn The binary function that operates over iterator ranges */ template -void parallel_for(Iterator begin, Iterator end, thread_pool& pool, - Function func) +std::vector::type>> +for_each_block(Iterator begin, Iterator end, thread_pool& pool, Function&& fn) { using difference_type = typename std::iterator_traits::difference_type; + using result_type = + typename std::result_of::type; + auto pool_size = static_cast(pool.size()); auto block_size = std::distance(begin, end) / pool_size; @@ -63,19 +54,50 @@ void parallel_for(Iterator begin, Iterator end, thread_pool& pool, block_size = 1; } - std::vector> futures; + std::vector> futures; // first p - 1 groups for (; begin != last; std::advance(begin, block_size)) { futures.emplace_back(pool.submit_task([=]() { auto mylast = begin; std::advance(mylast, block_size); - std::for_each(begin, mylast, func); + return fn(begin, mylast); })); } // last group - futures.emplace_back( - pool.submit_task([=]() { std::for_each(begin, end, func); })); + futures.emplace_back(pool.submit_task([=]() { return fn(begin, end); })); + + return futures; +} + +/** + * Runs the given function on the range denoted by begin and end in parallel. + * @param begin The first element to operate on + * @param end One past the last element to operate on + * @param func The function to perform on each element + */ +template +void parallel_for(Iterator begin, Iterator end, Function func) +{ + thread_pool pool; + parallel_for(begin, end, pool, func); +} + +/** + * Runs the given function on the range denoted by begin and end in parallel. + * @param begin The first element to operate on + * @param end One past the last element to operate on + * @param pool The thread pool to use + * @param func The function to perform on each element + */ +template +void parallel_for(Iterator begin, Iterator end, thread_pool& pool, + Function func) +{ + auto futures + = for_each_block(begin, end, pool, [&](Iterator tbegin, Iterator tend) { + std::for_each(tbegin, tend, func); + }); for (auto& fut : futures) fut.get(); } diff --git a/include/meta/parallel/semaphore.h b/include/meta/parallel/semaphore.h index 7f6e2822a..7f2f7fb25 100644 --- a/include/meta/parallel/semaphore.h +++ b/include/meta/parallel/semaphore.h @@ -31,7 +31,7 @@ class semaphore /** * Constructs the semaphore to allow count number of threads at a time. */ - semaphore(unsigned count) : count_{count} + semaphore(std::size_t count) : count_{count} { // nothing } @@ -67,7 +67,7 @@ class semaphore friend wait_guard; private: - unsigned count_; + std::size_t count_; std::mutex mutex_; std::condition_variable cond_; }; diff --git a/include/meta/parser/sr_parser.h b/include/meta/parser/sr_parser.h index 5e3542042..47f7ad031 100644 --- a/include/meta/parser/sr_parser.h +++ b/include/meta/parser/sr_parser.h @@ -82,7 +82,7 @@ class sr_parser /** * How many threads to use for training. */ - uint64_t num_threads = std::thread::hardware_concurrency(); + std::size_t num_threads = std::thread::hardware_concurrency(); /** * The algorithm to use for training. Defaults to diff --git a/include/meta/sequence/hmm/discrete_observations.h b/include/meta/sequence/hmm/discrete_observations.h new file mode 100644 index 000000000..e141e5236 --- /dev/null +++ b/include/meta/sequence/hmm/discrete_observations.h @@ -0,0 +1,151 @@ +/** + * @file word_observations.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_SEQUENCE_HMM_WORD_OBS_H_ +#define META_SEQUENCE_HMM_WORD_OBS_H_ + +#include "meta/io/packed.h" +#include "meta/meta.h" +#include "meta/sequence/hmm/hmm.h" +#include "meta/stats/multinomial.h" +#include "meta/util/traits.h" + +namespace meta +{ +namespace sequence +{ +namespace hmm +{ + +/** + * A multinomial observation distribution for HMMs. + */ +template +class discrete_observations +{ + public: + using observation_type = ObservationType; + using conditional_distribution_type = stats::multinomial; + + /** + * E-step scratch space for computing expected counts. + */ + class expected_counts_type + { + public: + friend discrete_observations; + + expected_counts_type(uint64_t num_states, + stats::dirichlet prior) + : obs_dist_(num_states, prior) + { + // nothing + } + + void increment(const observation_type& obs, state_id s_i, double count) + { + obs_dist_[s_i].increment(obs, count); + } + + expected_counts_type& operator+=(const expected_counts_type& other) + { + for (state_id s_i{0}; s_i < obs_dist_.size(); ++s_i) + obs_dist_[s_i] += other.obs_dist_[s_i]; + return *this; + } + + private: + std::vector obs_dist_; + }; + + /** + * Initializes each multinomial distribution for each hidden state + * randomly by using the provided random number generator. + */ + template + discrete_observations(uint64_t num_states, uint64_t num_observations, + Generator&& rng, + stats::dirichlet&& prior) + : obs_dist_(num_states, prior) + { + for (auto& dist : obs_dist_) + { + for (observation_type obs{0}; obs < num_observations; ++obs) + { + auto rnd = random::bounded_rand(rng, 65536); + auto val = (rnd / 65536.0) / num_observations; + + dist.increment(obs, val); + } + } + } + + /** + * Re-estimates the multinomials given expected_counts. + */ + discrete_observations(expected_counts_type&& counts) + : obs_dist_(std::move(counts.obs_dist_)) + { + // nothing + } + + /** + * Loads a discrete observation distribution from an input stream. + */ + template > + discrete_observations(InputStream& is) + { + if (io::packed::read(is, obs_dist_) == 0) + throw hmm_exception{"failed to load hmm observation distribution"}; + } + + /** + * Obtains an expected_counts_type suitable for re-estimating this + * distribution. + */ + expected_counts_type expected_counts() const + { + return {num_states(), obs_dist_.front().prior()}; + } + + uint64_t num_states() const + { + return obs_dist_.size(); + } + + double probability(observation_type obs, state_id s_i) const + { + return obs_dist_[s_i].probability(obs); + } + + double log_probability(ObservationType obs, state_id s_i) const + { + return std::log(probability(obs, s_i)); + } + + const conditional_distribution_type& distribution(state_id s_i) const + { + return obs_dist_[s_i]; + } + + template + void save(OutputStream& os) const + { + io::packed::write(os, obs_dist_); + } + + private: + std::vector obs_dist_; +}; +} +} +} +#endif diff --git a/include/meta/sequence/hmm/forward_backward.h b/include/meta/sequence/hmm/forward_backward.h new file mode 100644 index 000000000..a5eff1f78 --- /dev/null +++ b/include/meta/sequence/hmm/forward_backward.h @@ -0,0 +1,388 @@ +/** + * @file forward_backward.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_SEQUENCE_HMM_FORWARD_BACKWARD_H_ +#define META_SEQUENCE_HMM_FORWARD_BACKWARD_H_ + +#include "meta/config.h" +#include "meta/sequence/markov_model.h" +#include "meta/sequence/trellis.h" + +namespace meta +{ +namespace sequence +{ +namespace hmm +{ + +/** + * Encapsulates the forward-backward algorithm using the scaling method + * from the original Rabiner paper. + * + * @see http://www.ece.ucsb.edu/Faculty/Rabiner/ece259/Reprints/tutorial%20on%20hmm%20and%20applications.pdf + * @see http://sifaka.cs.uiuc.edu/course/498cxz06s/hmm.pdf + */ +struct scaling_forward_backward +{ + template + static util::dense_matrix + output_probabilities(const HMM& hmm, const typename HMM::sequence_type& seq) + { + const auto& obs_dist = hmm.observation_distribution(); + util::dense_matrix output_probs{seq.size(), hmm.num_states()}; + + for (uint64_t t = 0; t < seq.size(); ++t) + { + for (state_id s_i{0}; s_i < hmm.num_states(); ++s_i) + { + output_probs(t, s_i) = obs_dist.probability(seq[t], s_i); + } + } + return output_probs; + } + + template + static forward_trellis + forward(const HMM& hmm, const typename HMM::sequence_type& seq, + const util::dense_matrix& output_probs) + { + forward_trellis fwd{seq.size(), hmm.num_states()}; + + // initialize the first column of the trellis + for (label_id l{0}; l < hmm.num_states(); ++l) + { + state_id s{l}; + fwd.probability(0, l, hmm.init_prob(s) * output_probs(0, s)); + } + // normalize to avoid underflow + fwd.normalize(0); + + // compute remaining columns using the recursive formulation + for (uint64_t t = 1; t < seq.size(); ++t) + { + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + double sum = 0; + for (label_id j{0}; j < hmm.num_states(); ++j) + { + state_id s_j{j}; + sum += fwd.probability(t - 1, j) * hmm.trans_prob(s_j, s_i); + } + fwd.probability(t, i, sum * output_probs(t, s_i)); + } + // normalize to avoid underflow + fwd.normalize(t); + } + + return fwd; + } + + template + static trellis backward(const HMM& hmm, + const typename HMM::sequence_type& seq, + const forward_trellis& fwd, + const util::dense_matrix& output_probs) + { + trellis bwd{seq.size(), hmm.num_states()}; + + // initialize the last column of the trellis + for (label_id i{0}; i < hmm.num_states(); ++i) + { + bwd.probability(seq.size() - 1, i, 1); + } + + // fill in the remaining columns of the trellis from back to front + for (uint64_t k = 1; k < seq.size(); ++k) + { + assert(seq.size() - 1 >= k); + uint64_t t = seq.size() - 1 - k; + + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + + double sum = 0; + for (label_id j{0}; j < hmm.num_states(); ++j) + { + state_id s_j{j}; + + sum += bwd.probability(t + 1, j) * hmm.trans_prob(s_i, s_j) + * output_probs(t + 1, s_j); + } + auto norm = fwd.normalizer(t + 1); + bwd.probability(t, i, norm * sum); + } + } + + return bwd; + } + + template + static util::dense_matrix + posterior_state_membership(const HMM& hmm, const forward_trellis& fwd, + const trellis& bwd) + { + util::dense_matrix gamma{fwd.size(), hmm.num_states()}; + for (uint64_t t = 0; t < fwd.size(); ++t) + { + double norm = 0; + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + gamma(t, s_i) = fwd.probability(t, i) * bwd.probability(t, i); + norm += gamma(t, s_i); + } + std::transform(gamma.begin(t), gamma.end(t), gamma.begin(t), + [&](double val) { return val / norm; }); + // gamma(t, ) = prob. dist over possible states at time t + } + return gamma; + } + + template + static void increment_counts(const HMM& hmm, ExpectedCounts& counts, + const typename HMM::sequence_type& seq, + const forward_trellis& fwd, const trellis& bwd, + const util::dense_matrix& gamma, + const util::dense_matrix& output_probs) + { + // add expected counts to the new parameters + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + + // add expected counts for initial state probabilities + counts.model_counts.increment_initial(s_i, gamma(0, s_i)); + + // add expected counts for transition probabilities + for (label_id j{0}; j < hmm.num_states(); ++j) + { + state_id s_j{j}; + + for (uint64_t t = 0; t < seq.size() - 1; ++t) + { + auto xi_tij + = (gamma(t, s_i) * hmm.trans_prob(s_i, s_j) + * output_probs(t + 1, s_j) * fwd.normalizer(t + 1) + * bwd.probability(t + 1, j)) + / bwd.probability(t, i); + + counts.model_counts.increment_transition(s_i, s_j, xi_tij); + } + } + + // add expected counts for observation probabilities + for (uint64_t t = 0; t < seq.size(); ++t) + { + counts.obs_counts.increment(seq[t], s_i, gamma(t, s_i)); + } + } + + // compute contribution to the log likelihood from the forward + // trellis scaling factors for this sequence + for (uint64_t t = 0; t < seq.size(); ++t) + { + // L = \prod_o \prod_t 1 / scale(t) + // log L = \sum_o \sum_t \log (1 / scale(t)) + // log L = \sum_o \sum_t - \log scale(t) + counts.log_likelihood += -std::log(fwd.normalizer(t)); + } + } +}; + +/** + * Encapsulates the forward-backward algorithm using calculations in log + * space. This is typically slower than the scaling method, but may be + * necessary in some cases (like for observations that themselves are + * sequences and have vanishingly small probabilities). + */ +struct logarithm_forward_backward +{ + template + static double log_sum_exp(ForwardIterator begin, ForwardIterator end) + { + auto max_it = std::max_element(begin, end); + + auto shifted_sum_exp + = std::accumulate(begin, end, 0.0, [=](double accum, double val) { + return accum + std::exp(val - *max_it); + }); + + return *max_it + std::log(shifted_sum_exp); + } + + template + static util::dense_matrix + output_probabilities(const HMM& hmm, const typename HMM::sequence_type& seq) + { + const auto& obs_dist = hmm.observation_distribution(); + util::dense_matrix output_probs{seq.size(), hmm.num_states()}; + + for (uint64_t t = 0; t < seq.size(); ++t) + { + for (state_id s_i{0}; s_i < hmm.num_states(); ++s_i) + { + output_probs(t, s_i) = obs_dist.log_probability(seq[t], s_i); + } + } + return output_probs; + } + + template + static trellis forward(const HMM& hmm, + const typename HMM::sequence_type& seq, + const util::dense_matrix& output_log_probs) + { + trellis fwd{seq.size(), hmm.num_states()}; + + // initialize the first column of the trellis + for (label_id l{0}; l < hmm.num_states(); ++l) + { + state_id s{l}; + fwd.probability(0, l, std::log(hmm.init_prob(s)) + + output_log_probs(0, s)); + } + + std::vector scratch(hmm.num_states()); + // compute remaining columns using the recursive formulation + for (uint64_t t = 1; t < seq.size(); ++t) + { + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + for (label_id j{0}; j < hmm.num_states(); ++j) + { + state_id s_j{j}; + scratch[j] = fwd.probability(t - 1, j) + + std::log(hmm.trans_prob(s_j, s_i)); + } + fwd.probability(t, i, + log_sum_exp(scratch.begin(), scratch.end()) + + output_log_probs(t, s_i)); + } + } + + return fwd; + } + + template + static trellis backward(const HMM& hmm, + const typename HMM::sequence_type& seq, + const trellis&, + const util::dense_matrix& output_log_probs) + { + trellis bwd{seq.size(), hmm.num_states()}; + + // initialize the last column of the trellis + for (label_id i{0}; i < hmm.num_states(); ++i) + { + bwd.probability(seq.size() - 1, i, 0); + } + + std::vector scratch(hmm.num_states()); + // fill in the remaining columns of the trellis from back to front + for (uint64_t k = 1; k < seq.size(); ++k) + { + assert(seq.size() - 1 >= k); + uint64_t t = seq.size() - 1 - k; + + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + for (label_id j{0}; j < hmm.num_states(); ++j) + { + state_id s_j{j}; + scratch[j] = bwd.probability(t + 1, j) + + std::log(hmm.trans_prob(s_i, s_j)) + + output_log_probs(t + 1, s_j); + } + bwd.probability(t, i, + log_sum_exp(scratch.begin(), scratch.end())); + } + } + + return bwd; + } + + template + static util::dense_matrix + posterior_state_membership(const HMM& hmm, const trellis& fwd, + const trellis& bwd) + { + util::dense_matrix gamma{fwd.size(), hmm.num_states()}; + std::vector scratch(hmm.num_states()); + for (uint64_t t = 0; t < fwd.size(); ++t) + { + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + gamma(t, s_i) = fwd.probability(t, i) + bwd.probability(t, i); + } + auto norm = log_sum_exp(gamma.begin(t), gamma.end(t)); + std::transform(gamma.begin(t), gamma.end(t), gamma.begin(t), + [=](double val) { return val - norm; }); + } + return gamma; + } + + template + static void + increment_counts(const HMM& hmm, ExpectedCounts& counts, + const typename HMM::sequence_type& seq, const trellis& fwd, + const trellis& bwd, + const util::dense_matrix& log_gamma, + const util::dense_matrix& output_log_probs) + { + for (label_id i{0}; i < hmm.num_states(); ++i) + { + state_id s_i{i}; + + // add expected counts for initial state probabilities + counts.model_counts.increment_initial(s_i, + std::exp(log_gamma(0, s_i))); + + // add expected counts for transition probabilities + for (label_id j{0}; j < hmm.num_states(); ++j) + { + state_id s_j{j}; + + for (uint64_t t = 0; t < seq.size() - 1; ++t) + { + auto log_xi_tij + = log_gamma(t, s_i) + std::log(hmm.trans_prob(s_i, s_j)) + + output_log_probs(t + 1, s_j) + + bwd.probability(t + 1, j) - bwd.probability(t, i); + + counts.model_counts.increment_transition( + s_i, s_j, std::exp(log_xi_tij)); + } + } + + // add expected counts for observation probabilities + for (uint64_t t = 0; t < seq.size(); ++t) + { + counts.obs_counts.increment(seq[t], s_i, + std::exp(log_gamma(t, s_i))); + } + } + + // compute contribution to the log likelihood + std::vector scratch(hmm.num_states()); + for (label_id i{0}; i < hmm.num_states(); ++i) + { + scratch[i] = fwd.probability(seq.size() - 1, i); + } + counts.log_likelihood += log_sum_exp(scratch.begin(), scratch.end()); + } +}; +} +} +} +#endif diff --git a/include/meta/sequence/hmm/hmm.h b/include/meta/sequence/hmm/hmm.h new file mode 100644 index 000000000..42cb97b1b --- /dev/null +++ b/include/meta/sequence/hmm/hmm.h @@ -0,0 +1,312 @@ +/** + * @file hmm.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_SEQUENCE_HMM_H_ +#define META_SEQUENCE_HMM_H_ + +#include + +#include "meta/config.h" +#include "meta/logging/logger.h" +#include "meta/parallel/algorithm.h" +#include "meta/sequence/hmm/forward_backward.h" +#include "meta/sequence/markov_model.h" +#include "meta/sequence/trellis.h" +#include "meta/stats/multinomial.h" +#include "meta/util/identifiers.h" +#include "meta/util/progress.h" +#include "meta/util/random.h" +#include "meta/util/time.h" +#include "meta/util/traits.h" + +namespace meta +{ +namespace sequence +{ +namespace hmm +{ + +class hmm_exception : public std::runtime_error +{ + public: + using std::runtime_error::runtime_error; +}; + +template +struct hmm_traits +{ + using observation_type = typename ObsDist::observation_type; + using sequence_type = std::vector; + using training_data_type = std::vector; + using forward_backward_type = scaling_forward_backward; +}; + +/** + * A generic Hidden Markov Model implementation for unsupervised sequence + * labeling tasks. + */ +template +class hidden_markov_model +{ + public: + using traits_type = hmm_traits; + using observation_type = typename traits_type::observation_type; + using sequence_type = typename traits_type::sequence_type; + using training_data_type = typename traits_type::training_data_type; + using forward_backward_type = typename traits_type::forward_backward_type; + + struct training_options + { + /** + * The convergence threshold. When the difference in log likelihood + * between iterations falls below this value, training will stop. + */ + double delta = 1e-5; + + /** + * The maximum number of iterations. If the difference in log + * likelihood has not reached the convergence threshold after this + * many iterations, stop training. + */ + uint64_t max_iters = std::numeric_limits::max(); + }; + + /** + * Constructs a new Hidden Markov Model with random initialization + * using the provided random number generator. The observation + * distribution must be provided and is not initialized by the + * constructor (so you should initialize it yourself using an + * appropriate constructor for it). + * + * @param num_states The number of hidden states in the HMM + * @param gen The random number generator to use for initialization + * @param obs_dist The observation distribution + * @param trans_prior The Dirichlet prior over the transitions + */ + template + hidden_markov_model(uint64_t num_states, Generator&& rng, + ObsDist&& obs_dist, + stats::dirichlet trans_prior) + : obs_dist_{std::move(obs_dist)}, model_{num_states, rng, trans_prior} + { + if (obs_dist_.num_states() != num_states) + throw hmm_exception{"The observation distribution and HMM have " + "differing numbers of hidden states"}; + } + + /** + * Constructs a new Hidden Markov Model with uniform initialization of + * initial state and transition distributions. The observation + * distribution must be provided and is not initialized by the + * constructor (so you should initialize it yourself using an + * appropriate constructor for it). The initialization of the + * observation distribution is quite important as this is the only + * distribution that distinguishes states from one another when this + * constructor is used, so it is recommended to use a random + * initialization for it if possible. + * + * @param num_states The number of hidden states in the HMM + * @param obs_dist The observation distribution + * @param trans_prior The Dirichlet prior over the transitions + */ + hidden_markov_model(uint64_t num_states, ObsDist&& obs_dist, + stats::dirichlet trans_prior) + : obs_dist_{std::move(obs_dist)}, model_{num_states, trans_prior} + { + if (obs_dist_.num_states() != num_states) + throw hmm_exception{"The observation distribution and HMM have " + "differing numbers of hidden states"}; + } + + /** + * Loads a hidden Markov model from an input stream. + */ + template > + hidden_markov_model(InputStream& is) : obs_dist_{is}, model_{is} + { + // nothing + } + + /** + * @param instances The training data to fit the model to + * @param options The training options + * @return the log likelihood of the data + */ + double fit(const training_data_type& instances, parallel::thread_pool& pool, + training_options options) + { + double old_ll = std::numeric_limits::lowest(); + for (uint64_t iter = 1; iter <= options.max_iters; ++iter) + { + double log_likelihood = 0; + + auto em_time = common::time([&]() { + printing::progress progress{"> Iteration " + + std::to_string(iter) + ": ", + instances.size()}; + log_likelihood + = expectation_maximization(instances, pool, progress); + }); + + auto relative_change = (old_ll - log_likelihood) / old_ll; + LOG(info) << "Took " << em_time.count() / 1000.0 << "s" << ENDLG; + + if (iter > 1) + { + LOG(info) << "Log likelihood: " << log_likelihood << " (+" + << relative_change << " relative change)" << ENDLG; + } + else + { + LOG(info) << "Log log_likelihood: " << log_likelihood << ENDLG; + } + + assert(old_ll <= log_likelihood); + + if (iter > 1 && relative_change < options.delta) + { + LOG(info) << "Converged! (" << relative_change << " < " + << options.delta << ")" << ENDLG; + return log_likelihood; + } + + old_ll = log_likelihood; + } + + return old_ll; + } + + uint64_t num_states() const + { + return model_.num_states(); + } + + double trans_prob(state_id from, state_id to) const + { + return model_.transition_probability(from, to); + } + + double init_prob(state_id s) const + { + return model_.initial_probability(s); + } + + const ObsDist& observation_distribution() const + { + return obs_dist_; + } + + const typename ObsDist::conditional_distribution_type& + observation_distribution(state_id s) const + { + return obs_dist_.distribution(s); + } + + template + void save(OutputStream& os) const + { + obs_dist_.save(os); + model_.save(os); + } + + /** + * Temporary storage for expected counts for the different model types, + * plus the data log likelihood computed during the forward-backward + * algorithm + */ + struct expected_counts + { + expected_counts(const hidden_markov_model& hmm) + : obs_counts{hmm.obs_dist_.expected_counts()}, + model_counts{hmm.model_.expected_counts()} + { + // nothing + } + + expected_counts& operator+=(const expected_counts& other) + { + obs_counts += other.obs_counts; + model_counts += other.model_counts; + log_likelihood += other.log_likelihood; + return *this; + } + + typename ObsDist::expected_counts_type obs_counts; + markov_model::expected_counts_type model_counts; + double log_likelihood = 0.0; + }; + + /** + * Computes expected counts using the forward-backward algorithm. + */ + expected_counts forward_backward(const sequence_type& seq) + { + expected_counts ec{*this}; + forward_backward(seq, ec); + return ec; + } + + private: + void forward_backward(const sequence_type& seq, expected_counts& counts) + { + using fwdbwd = forward_backward_type; + // cache b_i(o_t) since this could be computed with an + // arbitrarily complex model + auto output_probs = fwdbwd::output_probabilities(*this, seq); + + // run forward-backward + auto fwd = fwdbwd::forward(*this, seq, output_probs); + auto bwd = fwdbwd::backward(*this, seq, fwd, output_probs); + + // compute the probability of being in a given state at a given + // time from the trellises + auto gamma = fwdbwd::posterior_state_membership(*this, fwd, bwd); + + // increment expected counts + fwdbwd::increment_counts(*this, counts, seq, fwd, bwd, gamma, + output_probs); + } + + double expectation_maximization(const training_data_type& instances, + parallel::thread_pool& pool, + printing::progress& progress) + { + uint64_t seq_id = 0; + // compute expected counts across all instances in parallel + std::mutex progress_mutex; + auto counts = parallel::reduction( + instances.begin(), instances.end(), pool, + [&]() { return expected_counts{*this}; }, + [&](expected_counts& counts, const sequence_type& seq) { + { + std::lock_guard lock{progress_mutex}; + progress(seq_id++); + } + forward_backward(seq, counts); + }, + [&](expected_counts& result, const expected_counts& temp) { + result += temp; + }); + + // normalize and replace old parameters + obs_dist_ = ObsDist{std::move(counts.obs_counts)}; + model_ = markov_model{std::move(counts.model_counts)}; + + return counts.log_likelihood; + } + + ObsDist obs_dist_; + markov_model model_; +}; +} +} +} +#endif diff --git a/include/meta/sequence/hmm/sequence_observations.h b/include/meta/sequence/hmm/sequence_observations.h new file mode 100644 index 000000000..5ad2ad1c4 --- /dev/null +++ b/include/meta/sequence/hmm/sequence_observations.h @@ -0,0 +1,141 @@ +/** + * @file sequence_observations.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_SEQUENCE_HMM_SEQUENCE_OBS_H_ +#define META_SEQUENCE_HMM_SEQUENCE_OBS_H_ + +#include "meta/sequence/hmm/hmm.h" +#include "meta/sequence/markov_model.h" +#include "meta/stats/multinomial.h" +#include "meta/util/traits.h" + +namespace meta +{ +namespace sequence +{ +namespace hmm +{ + +/** + * A Markov Model observation distribution for HMMs. Each observation is + * assumed to be a sequence of states. Each *HMM* state is modeled via a + * separate Markov model. + */ +class sequence_observations +{ + public: + using observation_type = std::vector; + using conditional_distribution_type = markov_model; + + /** + * E-step scratch space for computing expected counts. + */ + class expected_counts_type + { + public: + friend sequence_observations; + + expected_counts_type(uint64_t num_hmm_states, + uint64_t num_markov_states, + stats::dirichlet prior); + + void increment(const observation_type& seq, state_id s_i, + double amount); + + expected_counts_type& operator+=(const expected_counts_type& other); + + private: + std::vector counts_; + }; + + /** + * Initializes each state's Markov model randomly using the provided + * random number generator. + */ + template + sequence_observations(uint64_t num_hmm_states, uint64_t num_markov_states, + Generator&& gen, stats::dirichlet prior) + { + models_.reserve(num_hmm_states); + for (uint64_t h = 0; h < num_hmm_states; ++h) + models_.emplace_back(num_markov_states, + std::forward(gen), prior); + } + + /** + * Default initializes each state's Markov model. This is only useful + * when setting values manually by using increment(). + */ + sequence_observations(uint64_t num_hmm_states, uint64_t num_markov_states, + stats::dirichlet prior); + + /** + * Re-estimates the Markov models given expected counts. + */ + sequence_observations(expected_counts_type&& counts); + + /** + * Loads a sequence observation distribution from an input stream. + */ + template > + sequence_observations(InputStream& is) + { + uint64_t size; + if (io::packed::read(is, size) == 0) + throw hmm_exception{ + "failed to load sequence_observations from stream"}; + + models_.reserve(size); + for (uint64_t i = 0; i < size; ++i) + models_.emplace_back(is); + } + + /** + * Obtains an expected_counts_type suitable for re-estimating this + * distribution. + */ + expected_counts_type expected_counts() const; + + uint64_t num_states() const; + + double probability(const observation_type& obs, state_id s_i) const; + + double log_probability(const observation_type& obs, state_id s_i) const; + + const markov_model& distribution(state_id s_i) const; + + /** + * Saves a sequence observation distribution to a stream. + */ + template + void save(OutputStream& os) const + { + io::packed::write(os, models_.size()); + for (const auto& model : models_) + model.save(os); + } + + private: + std::vector models_; +}; + +template <> +struct hmm_traits +{ + using observation_type = sequence_observations::observation_type; + using sequence_type = std::vector; + using training_data_type = std::vector; + using forward_backward_type = logarithm_forward_backward; +}; +} +} +} +#endif diff --git a/include/meta/sequence/markov_model.h b/include/meta/sequence/markov_model.h new file mode 100644 index 000000000..4e245080c --- /dev/null +++ b/include/meta/sequence/markov_model.h @@ -0,0 +1,179 @@ +/** + * @file markov_model.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_SEQUENCE_MARKOV_MODEL_H_ +#define META_SEQUENCE_MARKOV_MODEL_H_ + +#include "meta/stats/dirichlet.h" +#include "meta/util/dense_matrix.h" +#include "meta/util/identifiers.h" +#include "meta/util/random.h" +#include "meta/util/traits.h" + +namespace meta +{ +namespace sequence +{ + +MAKE_NUMERIC_IDENTIFIER(state_id, uint64_t) + +/** + * Represents a Markov model over a set of states. + */ +class markov_model +{ + public: + /** + * Represents expected counts for re-estimating a markov_model. + */ + class expected_counts_type + { + public: + friend markov_model; + + expected_counts_type(uint64_t num_states, + stats::dirichlet prior); + + void increment(const std::vector& seq, double amount); + void increment_initial(state_id s, double amount); + void increment_transition(state_id from, state_id to, double amount); + + expected_counts_type& operator+=(const expected_counts_type& other); + + private: + std::vector initial_count_; + util::dense_matrix trans_count_; + stats::dirichlet prior_; + }; + + /** + * Constructs a new Markov Model with random initialization using the + * provided random number generator. + */ + template + markov_model(uint64_t num_states, Generator&& rng, + stats::dirichlet prior) + : initial_prob_(num_states), + trans_prob_{num_states, num_states}, + prior_{std::move(prior)} + { + double inorm = 0; + for (state_id s_i{0}; s_i < num_states; ++s_i) + { + auto rnd = random::bounded_rand(rng, 65536); + auto val = (rnd / 65536.0) / num_states; + initial_prob_[s_i] = val; + inorm += val; + + double tnorm = 0; + for (state_id s_j{0}; s_j < num_states; ++s_j) + { + auto rnd = random::bounded_rand(rng, 65536); + auto val = (rnd / 65536.0) / num_states; + trans_prob_(s_i, s_j) = val; + tnorm += val; + } + for (state_id s_j{0}; s_j < num_states; ++s_j) + { + trans_prob_(s_i, s_j) + = (trans_prob_(s_i, s_j) + prior_.pseudo_counts(s_j)) + / (tnorm + prior_.pseudo_counts()); + } + } + + for (state_id s_i{0}; s_i < num_states; ++s_i) + { + initial_prob_[s_i] + = (initial_prob_[s_i] + prior_.pseudo_counts(s_i)) + / (inorm + prior_.pseudo_counts()); + } + } + + /** + * Loads a Markov model from a file. + */ + template > + markov_model(InputStream& is) + { + if (io::packed::read(is, initial_prob_) == 0) + throw std::runtime_error{"failed to read markov model from stream"}; + if (io::packed::read(is, trans_prob_) == 0) + throw std::runtime_error{"failed to read markov model from stream"}; + if (io::packed::read(is, prior_) == 0) + throw std::runtime_error{"failed to read markov model from stream"}; + } + + /** + * Constructs a new Markov model with uniform initialization of + * initial state and transition distibutions. + */ + markov_model(uint64_t num_states, stats::dirichlet prior); + + /** + * Constructs a new Markov model from a set of expected counts. + */ + markov_model(expected_counts_type&& counts); + + /** + * Obtains an expected_counts_type suitable for re-estimating this + * Markov model. + */ + expected_counts_type expected_counts() const; + + /** + * Obtains a reference to the prior used for the model. + */ + const stats::dirichlet& prior() const; + + /** + * @return the number of states in the Markov model + */ + uint64_t num_states() const; + + /** + * @return \f$\log P(\mathbf{s} \mid \theta)\f$ + */ + double log_probability(const std::vector& seq) const; + + /** + * @return \f$P(\mathbf{s} \mid \theta)\f$ + */ + double probability(const std::vector& seq) const; + + /** + * @return \f$P(s_{t} \mid s_{f}, \theta)\f$ + */ + double transition_probability(state_id from, state_id to) const; + + /** + * @return \f$P(s \mid \theta)\f$ + */ + double initial_probability(state_id s) const; + + /** + * Saves a Markov model to a stream. + */ + template + void save(OutputStream& os) const + { + io::packed::write(os, initial_prob_); + io::packed::write(os, trans_prob_); + io::packed::write(os, prior_); + } + + private: + std::vector initial_prob_; + util::dense_matrix trans_prob_; + stats::dirichlet prior_; +}; +} +} +#endif diff --git a/include/meta/stats/dirichlet.h b/include/meta/stats/dirichlet.h index 7ee11c436..1533bc71b 100644 --- a/include/meta/stats/dirichlet.h +++ b/include/meta/stats/dirichlet.h @@ -13,6 +13,7 @@ #include #include "meta/config.h" +#include "meta/io/packed.h" #include "meta/util/sparse_vector.h" namespace meta @@ -29,6 +30,11 @@ template class dirichlet { public: + /** + * Constructs an empty (0, 0) Dirichlet. + */ + dirichlet(); + /** * Constructs a symmetric Dirichlet with concentration parameter * \f$\alpha\f$ and dimension \f$n\f$. @@ -102,6 +108,60 @@ class dirichlet */ void load(std::istream& in); + template + friend uint64_t packed_write(OutputStream& os, const dirichlet& dist) + { + auto bytes = io::packed::write(os, static_cast(dist.type_)); + switch (dist.type_) + { + case type::SYMMETRIC: + { + bytes += io::packed::write(os, dist.params_.fixed_alpha_); + bytes += io::packed::write( + os, static_cast(dist.alpha_sum_ + / dist.params_.fixed_alpha_)); + break; + } + case type::ASYMMETRIC: + { + bytes += io::packed::write(os, dist.params_.sparse_alpha_); + break; + } + } + return bytes; + } + + template + friend uint64_t packed_read(InputStream& is, dirichlet& dist) + { + uint64_t typ; + auto bytes = io::packed::read(is, typ); + if (bytes == 0) + return 0; + + type read_type = static_cast(typ); + switch (read_type) + { + case type::SYMMETRIC: + { + double alpha; + bytes += io::packed::read(is, alpha); + uint64_t n; + bytes += io::packed::read(is, n); + dist = dirichlet{alpha, n}; + break; + } + case type::ASYMMETRIC: + { + std::vector> vec; + bytes += io::packed::read(is, vec); + dist = dirichlet{vec.begin(), vec.end()}; + break; + } + } + return bytes; + } + private: enum class type { diff --git a/include/meta/stats/dirichlet.tcc b/include/meta/stats/dirichlet.tcc index cb780c23d..c705470f6 100644 --- a/include/meta/stats/dirichlet.tcc +++ b/include/meta/stats/dirichlet.tcc @@ -3,16 +3,22 @@ * @author Chase Geigle */ +#include "meta/io/packed.h" #include "meta/stats/dirichlet.h" #include "meta/util/identifiers.h" #include "meta/util/shim.h" -#include "meta/io/packed.h" namespace meta { namespace stats { +template +dirichlet::dirichlet() : dirichlet{0.0, 0} +{ + // nothing +} + template dirichlet::dirichlet(double alpha, uint64_t n) : type_{type::SYMMETRIC}, params_{alpha}, alpha_sum_{n * alpha} @@ -26,11 +32,9 @@ dirichlet::dirichlet(Iter begin, Iter end) : type_{type::ASYMMETRIC}, params_{begin, end} { using pair_type = typename Iter::value_type; - alpha_sum_ - = std::accumulate(begin, end, 0.0, [](double accum, const pair_type& b) - { - return accum + b.second; - }); + alpha_sum_ = std::accumulate( + begin, end, 0.0, + [](double accum, const pair_type& b) { return accum + b.second; }); } template @@ -145,67 +149,13 @@ void dirichlet::swap(dirichlet& other) template void dirichlet::save(std::ostream& out) const { - io::packed::write(out, static_cast(type_)); - switch (type_) - { - case type::SYMMETRIC: - { - io::packed::write(out, params_.fixed_alpha_); - io::packed::write( - out, static_cast(alpha_sum_ / params_.fixed_alpha_)); - break; - } - case type::ASYMMETRIC: - { - io::packed::write(out, params_.sparse_alpha_.size()); - for (const auto& alpha : params_.sparse_alpha_) - { - io::packed::write(out, alpha.first); - io::packed::write(out, alpha.second); - } - break; - } - } + io::packed::write(out, *this); } template void dirichlet::load(std::istream& in) { - uint64_t typ; - auto bytes = io::packed::read(in, typ); - if (bytes == 0) - return; - - type read_type = static_cast(typ); - switch (read_type) - { - case type::SYMMETRIC: - { - double alpha; - io::packed::read(in, alpha); - uint64_t n; - io::packed::read(in, n); - *this = dirichlet{alpha, n}; - break; - } - case type::ASYMMETRIC: - { - uint64_t size; - io::packed::read(in, size); - std::vector> vec; - vec.reserve(size); - for (uint64_t i = 0; i < size; ++i) - { - T event; - io::packed::read(in, event); - double count; - io::packed::read(in, count); - vec.emplace_back(std::move(event), count); - } - *this = dirichlet{vec.begin(), vec.end()}; - break; - } - } + io::packed::read(in, *this); } } } diff --git a/include/meta/stats/multinomial.h b/include/meta/stats/multinomial.h index b59eb7c2b..836cd81d8 100644 --- a/include/meta/stats/multinomial.h +++ b/include/meta/stats/multinomial.h @@ -13,6 +13,7 @@ #include #include "meta/config.h" +#include "meta/io/packed.h" #include "meta/stats/dirichlet.h" #include "meta/util/sparse_vector.h" @@ -134,6 +135,34 @@ class multinomial */ void load(std::istream& in); + template + friend uint64_t packed_write(OutputStream& os, const multinomial& dist) + { + using io::packed::write; + return write(os, dist.total_counts_) + write(os, dist.counts_) + + write(os, dist.prior_); + } + + template + friend uint64_t packed_read(InputStream& is, multinomial& dist) + { + dist.clear(); + using io::packed::read; + auto bytes = io::packed::read(is, dist.total_counts_); + if (bytes == 0) + return 0; + + auto count_bytes = io::packed::read(is, dist.counts_); + if (count_bytes == 0) + return 0; + + auto prior_bytes = io::packed::read(is, dist.prior_); + if (prior_bytes == 0) + return 0; + + return bytes + count_bytes + prior_bytes; + } + private: util::sparse_vector counts_; double total_counts_; diff --git a/include/meta/stats/multinomial.tcc b/include/meta/stats/multinomial.tcc index 5cd7a298d..aa3d6a0a8 100644 --- a/include/meta/stats/multinomial.tcc +++ b/include/meta/stats/multinomial.tcc @@ -114,36 +114,13 @@ multinomial& multinomial::operator+=(const multinomial& rhs) template void multinomial::save(std::ostream& out) const { - io::packed::write(out, total_counts_); - io::packed::write(out, counts_.size()); - for (const auto& count : counts_) - { - io::packed::write(out, count.first); - io::packed::write(out, count.second); - } - prior_.save(out); + io::packed::write(out, *this); } template void multinomial::load(std::istream& in) { - clear(); - double total_counts; - auto bytes = io::packed::read(in, total_counts); - uint64_t size; - bytes += io::packed::read(in, size); - if (bytes == 0) - return; - - total_counts_ = total_counts; - counts_.reserve(size); - for (uint64_t i = 0; i < size; ++i) - { - T event; - io::packed::read(in, event); - io::packed::read(in, counts_[event]); - } - prior_.load(in); + io::packed::read(in, *this); } } } diff --git a/include/meta/stats/running_stats.h b/include/meta/stats/running_stats.h index fe7d53fc4..315b817cb 100644 --- a/include/meta/stats/running_stats.h +++ b/include/meta/stats/running_stats.h @@ -52,6 +52,11 @@ class running_stats */ double variance() const; + /** + * @return the total number of items seen thus far + */ + std::size_t size() const; + private: /// the current running mean double m_k_; diff --git a/include/meta/topics/lda_cvb.h b/include/meta/topics/lda_cvb.h index 6f303f9e8..bdb4f3acb 100644 --- a/include/meta/topics/lda_cvb.h +++ b/include/meta/topics/lda_cvb.h @@ -42,7 +42,7 @@ class lda_cvb : public lda_model * @param beta The hyperparameter for the Dirichlet prior over * \f$\theta\f$ */ - lda_cvb(std::shared_ptr idx, uint64_t num_topics, + lda_cvb(std::shared_ptr idx, std::size_t num_topics, double alpha, double beta); /** diff --git a/include/meta/topics/lda_gibbs.h b/include/meta/topics/lda_gibbs.h index c001ae914..5c5299ad3 100644 --- a/include/meta/topics/lda_gibbs.h +++ b/include/meta/topics/lda_gibbs.h @@ -43,7 +43,7 @@ class lda_gibbs : public lda_model * @param beta The hyperparameter for the Dirichlet prior over * \f$\theta\f$ */ - lda_gibbs(std::shared_ptr idx, uint64_t num_topics, + lda_gibbs(std::shared_ptr idx, std::size_t num_topics, double alpha, double beta); /** diff --git a/include/meta/topics/lda_model.h b/include/meta/topics/lda_model.h index af5a131f3..01218c88c 100644 --- a/include/meta/topics/lda_model.h +++ b/include/meta/topics/lda_model.h @@ -45,7 +45,8 @@ class lda_model * @param idx The index containing the documents to use for the model * @param num_topics The number of topics to find */ - lda_model(std::shared_ptr idx, uint64_t num_topics); + lda_model(std::shared_ptr idx, + std::size_t num_topics); /** * Destructor. Made virtual to allow for deletion through pointer to @@ -133,12 +134,12 @@ class lda_model /** * The number of topics. */ - size_t num_topics_; + std::size_t num_topics_; /** * The number of total unique words. */ - size_t num_words_; + std::size_t num_words_; }; } } diff --git a/include/meta/topics/lda_scvb.h b/include/meta/topics/lda_scvb.h index 552b2ab4d..f637663f3 100644 --- a/include/meta/topics/lda_scvb.h +++ b/include/meta/topics/lda_scvb.h @@ -44,7 +44,7 @@ class lda_scvb : public lda_model * @param minibatch_size The number of documents to consider in a * minibatch */ - lda_scvb(std::shared_ptr idx, uint64_t num_topics, + lda_scvb(std::shared_ptr idx, std::size_t num_topics, double alpha, double beta, uint64_t minibatch_size = 100); /** diff --git a/include/meta/util/dense_matrix.h b/include/meta/util/dense_matrix.h index 91fe98583..270f9da65 100644 --- a/include/meta/util/dense_matrix.h +++ b/include/meta/util/dense_matrix.h @@ -14,6 +14,7 @@ #include #include "meta/config.h" +#include "meta/io/packed.h" namespace meta { @@ -130,6 +131,20 @@ class dense_matrix */ uint64_t columns() const; + template + friend uint64_t packed_write(OutputStream& os, const dense_matrix& mat) + { + return io::packed::write(os, mat.storage_) + + io::packed::write(os, mat.columns_); + } + + template + friend uint64_t packed_read(InputStream& is, dense_matrix& mat) + { + return io::packed::read(is, mat.storage_) + + io::packed::read(is, mat.columns_); + } + private: /// the underlying storage for the matrix std::vector storage_; diff --git a/include/meta/util/fixed_heap.h b/include/meta/util/fixed_heap.h index d3324c14f..79d7a2c8c 100644 --- a/include/meta/util/fixed_heap.h +++ b/include/meta/util/fixed_heap.h @@ -91,6 +91,18 @@ class fixed_heap Comp comp_; std::vector pq_; }; + +/** + * Constructs a fixed_heap from a maximum size and binary comparison + * function. + */ +template +fixed_heap make_fixed_heap(uint64_t max_elems, + BinaryFunction&& bf) +{ + return fixed_heap(max_elems, + std::forward(bf)); +} } } diff --git a/include/meta/util/identifiers.h b/include/meta/util/identifiers.h index e66adb812..3814713e0 100644 --- a/include/meta/util/identifiers.h +++ b/include/meta/util/identifiers.h @@ -169,7 +169,7 @@ struct identifier : public comparable> template void hash_append(HashAlgorithm& h, const identifier& id) { - using util::hash_append; + using hashing::hash_append; hash_append(h, static_cast(id)); } @@ -346,39 +346,18 @@ struct hash> using ident_name \ = meta::util::numerical_identifier; -#if !defined NDEBUG && !defined NUSE_OPAQUE_IDENTIFIERS #define MAKE_IDENTIFIER(ident_name, base_type) \ MAKE_OPAQUE_IDENTIFIER(ident_name, base_type) -#else -#define MAKE_IDENTIFIER(ident_name, base_type) using ident_name = base_type; -#endif -#if !defined NDEBUG && !defined NUSE_OPAQUE_IDENTIFIERS #define MAKE_NUMERIC_IDENTIFIER(ident_name, base_type) \ MAKE_OPAQUE_NUMERIC_IDENTIFIER(ident_name, base_type) -#else -#define MAKE_NUMERIC_IDENTIFIER(ident_name, base_type) \ - using ident_name = base_type; -#endif -#if !defined NDEBUG && !defined NUSE_OPAQUE_IDENTIFIERS #define MAKE_IDENTIFIER_UDL(ident_name, base_type, suffix) \ MAKE_OPAQUE_IDENTIFIER(ident_name, base_type) \ MAKE_USER_DEFINED_LITERAL(ident_name, base_type, suffix) -#else -#define MAKE_IDENTIFIER_UDL(ident_name, base_type, suffix) \ - using ident_name = base_type; \ - MAKE_USER_DEFINED_LITERAL(ident_name, base_type, suffix) -#endif -#if !defined NDEBUG && !defined NUSE_OPAQUE_IDENTIFIERS #define MAKE_NUMERIC_IDENTIFIER_UDL(ident_name, base_type, suffix) \ MAKE_OPAQUE_NUMERIC_IDENTIFIER(ident_name, base_type) \ MAKE_USER_DEFINED_NUMERIC_LITERAL(ident_name, base_type, suffix) -#else -#define MAKE_NUMERIC_IDENTIFIER_UDL(ident_name, base_type, suffix) \ - using ident_name = base_type; \ - MAKE_USER_DEFINED_NUMERIC_LITERAL(ident_name, base_type, suffix) -#endif #endif diff --git a/include/meta/util/iterator.h b/include/meta/util/iterator.h new file mode 100644 index 000000000..1bf78dd58 --- /dev/null +++ b/include/meta/util/iterator.h @@ -0,0 +1,137 @@ +/** + * @file iterator.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_UTIL_ITERATOR_H_ +#define META_UTIL_ITERATOR_H_ + +#include +#include + +#include "meta/config.h" +#include "meta/util/comparable.h" + +namespace meta +{ +namespace util +{ + +/** + * Wrapper around an Iterator that, when dereferenced, returns f(*it) + * where `it` is the wrapped Iterator and `f` is a UnaryFunction. + */ +template +class transform_iterator + : public comparable> +{ + public: + using traits_type = std::iterator_traits; + using difference_type = typename traits_type::difference_type; + using value_type = typename std::result_of::type; + using pointer = typename std::add_pointer::type; + using reference = + typename std::add_lvalue_reference::type; + using iterator_category = typename traits_type::iterator_category; + + transform_iterator(Iterator it, UnaryFunction fun) : it_{it}, fun_(fun) + { + // nothing + } + + transform_iterator& operator++() + { + ++it_; + return *this; + } + + transform_iterator operator++(int) + { + auto tmp = *this; + ++it_; + return tmp; + } + + transform_iterator& operator--() + { + --it_; + return *this; + } + + transform_iterator operator--(int) + { + auto tmp = *this; + --it_; + return *tmp; + } + + transform_iterator& operator+=(difference_type diff) + { + it_ += diff; + return *this; + } + + transform_iterator operator+(difference_type diff) const + { + auto tmp = *this; + tmp += diff; + return tmp; + } + + transform_iterator& operator-=(difference_type diff) + { + it_ -= diff; + return *this; + } + + transform_iterator operator-(difference_type diff) const + { + auto tmp = *this; + tmp -= diff; + return tmp; + } + + difference_type operator-(transform_iterator other) const + { + return it_ - other.it_; + } + + reference operator[](difference_type diff) const + { + return fun_(it_[diff]); + } + + bool operator<(transform_iterator other) const + { + return it_ < other.it_; + } + + value_type operator*() const + { + return fun_(*it_); + } + + private: + Iterator it_; + UnaryFunction fun_; +}; + +/** + * Helper function to construct a transform_iterator from an Iterator and + * a UnaryFunction to transform the values of that Iterator. + */ +template +transform_iterator +make_transform_iterator(Iterator it, UnaryFunction&& fun) +{ + return transform_iterator( + it, std::forward(fun)); +} +} +} +#endif diff --git a/include/meta/util/multiway_merge.h b/include/meta/util/multiway_merge.h index 20b80967f..ee9b17ae4 100644 --- a/include/meta/util/multiway_merge.h +++ b/include/meta/util/multiway_merge.h @@ -18,6 +18,7 @@ #include "meta/config.h" #include "meta/io/filesystem.h" +#include "meta/io/mmap_file.h" #include "meta/io/moveable_stream.h" #include "meta/io/packed.h" #include "meta/util/progress.h" @@ -84,14 +85,21 @@ namespace util * A unary function that is called once per every unique Record after * merging. * + * - ProgressTrait: + * A traits class whose type indicates the progress reporting object to + * use. By default, this is meta::printing::default_progress_trait, but + * progress reporting can be silenced using + * meta::printing::no_progress_trait. + * * @return the total number of unique Records that were written to the * OutputStream */ template + class ShouldMerge, + class ProgressTrait = printing::default_progress_trait> uint64_t multiway_merge(ForwardIterator begin, ForwardIterator end, Compare&& record_comp, ShouldMerge&& should_merge, - RecordHandler&& output) + RecordHandler&& output, ProgressTrait = ProgressTrait{}) { using ChunkIterator = typename ForwardIterator::value_type; @@ -100,7 +108,7 @@ uint64_t multiway_merge(ForwardIterator begin, ForwardIterator end, return acc + chunk.total_bytes(); }); - printing::progress progress{" > Merging: ", to_read}; + typename ProgressTrait::type progress{" > Merging: ", to_read}; uint64_t total_read = std::accumulate( begin, end, 0ul, [](uint64_t acc, const ChunkIterator& chunk) { @@ -162,16 +170,17 @@ uint64_t multiway_merge(ForwardIterator begin, ForwardIterator end, * A simplified wrapper for multiway_merge that uses the default comparison * (operator<) and merge criteria (operator==). */ -template +template uint64_t multiway_merge(ForwardIterator begin, ForwardIterator end, - RecordHandler&& output) + RecordHandler&& output, ProgressTrait = ProgressTrait{}) { using Record = typename std::remove_reference::type; auto record_comp = [](const Record& a, const Record& b) { return a < b; }; auto record_equal = [](const Record& a, const Record& b) { return a == b; }; return multiway_merge(begin, end, record_comp, record_equal, - std::forward(output)); + std::forward(output), ProgressTrait{}); } /** @@ -190,7 +199,7 @@ class chunk_iterator * @param filename The file to read from */ chunk_iterator(const std::string& filename) - : input_{filename, std::ios::binary}, + : input_{filename}, bytes_read_{0}, total_bytes_{filesystem::file_size(filename)} { @@ -206,15 +215,15 @@ class chunk_iterator */ chunk_iterator& operator++() { - if (input_.stream().peek() == EOF) + if (input_.peek() == EOF) { - input_.stream().close(); + input_.close(); assert(*this == chunk_iterator{}); return *this; } - bytes_read_ += io::packed::read(input_.stream(), record_); + bytes_read_ += io::packed::read(input_, record_); return *this; } @@ -247,11 +256,11 @@ class chunk_iterator */ bool operator==(const chunk_iterator& other) const { - return !input_.stream().is_open() && !other.input_.stream().is_open(); + return !input_.is_open() && !other.input_.is_open(); } private: - io::mifstream input_; + io::mmap_ifstream input_; Record record_; uint64_t bytes_read_; uint64_t total_bytes_; @@ -263,6 +272,48 @@ bool operator!=(const chunk_iterator& a, { return !(a == b); } + +/** + * A simple implementation of the ChunkIterator concept that reads Records + * from a binary file using io::packed::read and deletes the underlying + * file when it reaches EOF. + */ +template +class destructive_chunk_iterator : public chunk_iterator +{ + public: + using base_iterator = chunk_iterator; + + destructive_chunk_iterator() = default; + + destructive_chunk_iterator(const std::string& filename) + : base_iterator(filename), filename_{filename} + { + // nothing + } + + destructive_chunk_iterator& operator++() + { + ++base(); + if (base() == base_iterator{}) + filesystem::delete_file(filename_); + + return *this; + } + + const std::string& filename() const + { + return filename_; + } + + private: + base_iterator& base() + { + return static_cast(*this); + } + + const std::string filename_; +}; } } #endif diff --git a/include/meta/util/progress.h b/include/meta/util/progress.h index e1fa0b597..909c3f53b 100644 --- a/include/meta/util/progress.h +++ b/include/meta/util/progress.h @@ -18,6 +18,7 @@ #include #include "meta/config.h" +#include "meta/util/string_view.h" namespace meta { @@ -74,7 +75,7 @@ class progress /** * Clears the last line the progress bar wrote. */ - void clear() const; + static void clear(); private: void print(); @@ -101,6 +102,36 @@ class progress /// Whether or not we should print an endline when done. bool endline_; }; + +/** + * Class adhering to the progress API that can be substituted for it when + * no progress output is desired. + */ +class null_progress +{ + public: + null_progress(util::string_view prefix, uint64_t length, int interval = 500) + { + (void)prefix; + (void)length; + (void)interval; + } + + void operator()(uint64_t iter) + { + (void)iter; + } +}; + +struct default_progress_trait +{ + using type = progress; +}; + +struct no_progress_trait +{ + using type = null_progress; +}; } } #endif diff --git a/include/meta/util/sparse_vector.h b/include/meta/util/sparse_vector.h index 39f9187da..c06b83c4c 100644 --- a/include/meta/util/sparse_vector.h +++ b/include/meta/util/sparse_vector.h @@ -15,6 +15,7 @@ #include #include "meta/config.h" +#include "meta/io/packed.h" namespace meta { @@ -185,6 +186,18 @@ class sparse_vector sparse_vector& operator+=(const sparse_vector& rhs); sparse_vector& operator-=(const sparse_vector& rhs); + template + friend uint64_t packed_write(OutputSteam& os, const sparse_vector& sv) + { + return io::packed::write(os, sv.storage_); + } + + template + friend uint64_t packed_read(InputStream& is, sparse_vector& sv) + { + return io::packed::read(is, sv.storage_); + } + private: /** * Internal storage for the sparse vector: a sorted vector of pairs. diff --git a/include/meta/util/traits.h b/include/meta/util/traits.h new file mode 100644 index 000000000..ae3cf1b1c --- /dev/null +++ b/include/meta/util/traits.h @@ -0,0 +1,25 @@ +/** + * @file traits.h + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_UTIL_TRAITS_H_ +#define META_UTIL_TRAITS_H_ + +#include + +namespace meta +{ +namespace util +{ +template +using disable_if_same_or_derived_t = typename std:: + enable_if::type>:: + value>::type; +} +} +#endif diff --git a/src/analyzers/filters/length_filter.cpp b/src/analyzers/filters/length_filter.cpp index 522c27c51..6a5a02b55 100644 --- a/src/analyzers/filters/length_filter.cpp +++ b/src/analyzers/filters/length_filter.cpp @@ -86,15 +86,13 @@ std::unique_ptr make_filter(std::unique_ptr src, const cpptoml::table& config) { - auto min = config.get_as("min"); + auto min = config.get_as("min"); if (!min) throw token_stream_exception{"min required for length filter config"}; - auto max = config.get_as("max"); + auto max = config.get_as("max"); if (!max) throw token_stream_exception{"max required for length filter config"}; - return make_unique(std::move(src), - static_cast(*min), - static_cast(*max)); + return make_unique(std::move(src), *min, *max); } } } diff --git a/src/analyzers/tokenizers/whitespace_tokenizer.cpp b/src/analyzers/tokenizers/whitespace_tokenizer.cpp index abd925cac..fa7b62c9a 100644 --- a/src/analyzers/tokenizers/whitespace_tokenizer.cpp +++ b/src/analyzers/tokenizers/whitespace_tokenizer.cpp @@ -19,14 +19,25 @@ namespace tokenizers const util::string_view whitespace_tokenizer::id = "whitespace-tokenizer"; -whitespace_tokenizer::whitespace_tokenizer() : idx_{0} +whitespace_tokenizer::whitespace_tokenizer(bool suppress_whitespace) + : suppress_whitespace_{suppress_whitespace} { + // nothing } void whitespace_tokenizer::set_content(std::string&& content) { content_ = std::move(content); - idx_ = 0; + it_ = content_.begin(); + if (suppress_whitespace_) + consume_adjacent_whitespace(); +} + +void whitespace_tokenizer::consume_adjacent_whitespace() +{ + it_ = std::find_if_not(it_, content_.cend(), [](char c) { + return std::isspace(c); + }); } std::string whitespace_tokenizer::next() @@ -34,26 +45,47 @@ std::string whitespace_tokenizer::next() if (!*this) throw token_stream_exception{"next() called with no tokens left"}; - std::string ret; - // all whitespace chars are their own token - if (std::isspace(content_[idx_])) + if (std::isspace(*it_)) { - ret.push_back(content_[idx_++]); - } - // otherwise, concatenate all non-whitespace chars together until we - // find a whitespace char - else - { - while (*this && !std::isspace(content_[idx_])) - ret.push_back(content_[idx_++]); + if (suppress_whitespace_) + { + consume_adjacent_whitespace(); + } + else + { + // all whitespace chars are their own token + return std::string(1, *it_++); + } } + + // otherwise, find the next whitespace character and emit the sequence + // of consecutive non-whitespace characters as a token + auto begin = it_; + it_ = std::find_if(it_, content_.cend(), [](char c) { + return std::isspace(c); + }); + std::string ret{begin, it_}; assert(!ret.empty()); + + if (suppress_whitespace_) + consume_adjacent_whitespace(); + return ret; } whitespace_tokenizer::operator bool() const { - return idx_ < content_.size(); + return !content_.empty() && it_ != content_.cend(); +} + +template <> +std::unique_ptr +make_tokenizer(const cpptoml::table& config) +{ + auto suppress_whitespace + = config.get_as("suppress-whitespace").value_or(true); + + return make_unique(suppress_whitespace); } } } diff --git a/src/classify/classifier/knn.cpp b/src/classify/classifier/knn.cpp index c0e976347..f7b5ebb94 100644 --- a/src/classify/classifier/knn.cpp +++ b/src/classify/classifier/knn.cpp @@ -74,7 +74,7 @@ class_label knn::classify(const feature_vector& instance) const "k must be smaller than the " "number of documents in the index (training documents)"}; - analyzers::feature_map query{instance.size()}; + analyzers::feature_map query(instance.size()); for (const auto& count : instance) query[inv_idx_->term_text(count.first)] += count.second; assert(query.size() > 0); diff --git a/src/classify/confusion_matrix.cpp b/src/classify/confusion_matrix.cpp index 78f81c79d..37a43dcf9 100644 --- a/src/classify/confusion_matrix.cpp +++ b/src/classify/confusion_matrix.cpp @@ -22,6 +22,16 @@ confusion_matrix::confusion_matrix() /* nothing */ } +void confusion_matrix::add_fold_accuracy(double acc) +{ + fold_acc_.push_back(acc); +} + +std::vector confusion_matrix::fold_accuracy() const +{ + return fold_acc_; +} + void confusion_matrix::add(const predicted_label& predicted, const class_label& actual, size_t times) { diff --git a/src/classify/tools/CMakeLists.txt b/src/classify/tools/CMakeLists.txt index db5016ba4..030dd940c 100644 --- a/src/classify/tools/CMakeLists.txt +++ b/src/classify/tools/CMakeLists.txt @@ -1,9 +1,11 @@ add_executable(classify classify.cpp) target_link_libraries(classify meta-classify meta-sequence-analyzers - meta-parser-analyzers) + meta-parser-analyzers + meta-embeddings-analyzers) add_executable(online-classify online_classify.cpp) target_link_libraries(online-classify meta-classify meta-sequence-analyzers - meta-parser-analyzers) + meta-parser-analyzers + meta-embeddings-analyzers) diff --git a/src/classify/tools/classify.cpp b/src/classify/tools/classify.cpp index 97bc66ecf..aae72dc14 100644 --- a/src/classify/tools/classify.cpp +++ b/src/classify/tools/classify.cpp @@ -11,6 +11,7 @@ #include "meta/caching/all.h" #include "meta/classify/classifier/all.h" +#include "meta/embeddings/analyzers/embedding_analyzer.h" #include "meta/index/forward_index.h" #include "meta/index/ranker/all.h" #include "meta/parser/analyzers/tree_analyzer.h" @@ -73,6 +74,7 @@ int main(int argc, char* argv[]) // Register additional analyzers parser::register_analyzers(); sequence::register_analyzers(); + embeddings::register_analyzers(); auto config = cpptoml::parse_file(argv[1]); auto class_config = config->get_table("classifier"); @@ -102,7 +104,6 @@ int main(int argc, char* argv[]) } else { - creator = [&](classify::multiclass_dataset_view fold) { return classify::make_classifier(*class_config, std::move(fold)); }; diff --git a/src/classify/tools/online_classify.cpp b/src/classify/tools/online_classify.cpp index a47cf69d7..5d8a3d685 100644 --- a/src/classify/tools/online_classify.cpp +++ b/src/classify/tools/online_classify.cpp @@ -6,8 +6,8 @@ #include #include "meta/classify/batch_training.h" -#include "meta/classify/classifier_factory.h" #include "meta/classify/classifier/online_classifier.h" +#include "meta/classify/classifier_factory.h" #include "meta/logging/logger.h" #include "meta/parser/analyzers/tree_analyzer.h" #include "meta/sequence/analyzers/ngram_pos_analyzer.h" @@ -38,14 +38,14 @@ int main(int argc, char* argv[]) return 1; } - auto batch_size = config->get_as("batch-size"); + auto batch_size = config->get_as("batch-size"); if (!batch_size) { std::cerr << "Missing batch-size in " << argv[1] << std::endl; return 1; } - auto test_start = config->get_as("test-start"); + auto test_start = config->get_as("test-start"); if (!test_start) { std::cerr << "Missing test-start in " << argv[1] << std::endl; @@ -54,7 +54,7 @@ int main(int argc, char* argv[]) auto f_idx = index::make_index(*config); - if (static_cast(*test_start) > f_idx->num_docs()) + if (*test_start > f_idx->num_docs()) { std::cerr << "The start of the test set is more than the number of " "docs in the index." @@ -81,24 +81,22 @@ int main(int argc, char* argv[]) } auto docs = f_idx->docs(); - auto test_begin = docs.begin() + *test_start; + auto test_begin = docs.begin() + static_cast(*test_start); std::vector training_set{docs.begin(), test_begin}; std::vector test_set{test_begin, docs.end()}; - auto dur = common::time( - [&]() - { - classify::batch_train(f_idx, *online_classifier, training_set, - static_cast(*batch_size)); + auto dur = common::time([&]() { + classify::batch_train(f_idx, *online_classifier, training_set, + *batch_size); - classify::multiclass_dataset test_data{f_idx, test_set.begin(), - test_set.end()}; + classify::multiclass_dataset test_data{f_idx, test_set.begin(), + test_set.end()}; - auto mtrx = classifier->test(test_data); - mtrx.print(); - mtrx.print_stats(); - }); + auto mtrx = classifier->test(test_data); + mtrx.print(); + mtrx.print_stats(); + }); std::cout << "Took " << dur.count() / 1000.0 << "s" << std::endl; diff --git a/src/corpus/gz_corpus.cpp b/src/corpus/gz_corpus.cpp index 02959f2f9..19d9cbf4c 100644 --- a/src/corpus/gz_corpus.cpp +++ b/src/corpus/gz_corpus.cpp @@ -63,7 +63,7 @@ std::unique_ptr make_corpus(util::string_view prefix, { auto encoding = config.get_as("encoding").value_or("utf-8"); - auto num_docs = config.get_as("num-docs"); + auto num_docs = config.get_as("num-docs"); if (!num_docs) throw corpus_exception{"num-docs config param required for gz_corpus"}; @@ -75,8 +75,7 @@ std::unique_ptr make_corpus(util::string_view prefix, filename.append(dataset.data(), dataset.size()); filename += ".dat"; - return make_unique(filename, encoding, - static_cast(*num_docs)); + return make_unique(filename, encoding, *num_docs); } } } diff --git a/src/corpus/libsvm_corpus.cpp b/src/corpus/libsvm_corpus.cpp index 1b4b53630..344548aa6 100644 --- a/src/corpus/libsvm_corpus.cpp +++ b/src/corpus/libsvm_corpus.cpp @@ -109,12 +109,11 @@ std::unique_ptr make_corpus(util::string_view prefix, } } - auto lines = config.get_as("num-docs"); + auto lines = config.get_as("num-docs"); if (!lines) return make_unique(filename, lbl_type); else - return make_unique(filename, lbl_type, - static_cast(*lines)); + return make_unique(filename, lbl_type, *lines); } } } diff --git a/src/corpus/line_corpus.cpp b/src/corpus/line_corpus.cpp index 57df4803d..5fca39660 100644 --- a/src/corpus/line_corpus.cpp +++ b/src/corpus/line_corpus.cpp @@ -84,12 +84,11 @@ std::unique_ptr make_corpus(util::string_view prefix, filename.append(dataset.data(), dataset.size()); filename += ".dat"; - auto lines = config.get_as("num-docs"); + auto lines = config.get_as("num-docs"); if (!lines) return make_unique(filename, encoding); else - return make_unique(filename, encoding, - static_cast(*lines)); + return make_unique(filename, encoding, *lines); } } } diff --git a/src/embeddings/CMakeLists.txt b/src/embeddings/CMakeLists.txt index 701f94402..25441f3be 100644 --- a/src/embeddings/CMakeLists.txt +++ b/src/embeddings/CMakeLists.txt @@ -1,9 +1,10 @@ project(meta-embeddings) add_subdirectory(tools) +add_subdirectory(analyzers) -add_library(meta-embeddings word_embeddings.cpp) -target_link_libraries(meta-embeddings cpptoml meta-util) +add_library(meta-embeddings cooccurrence_counter.cpp word_embeddings.cpp) +target_link_libraries(meta-embeddings cpptoml meta-analyzers meta-util meta-io) install(TARGETS meta-embeddings EXPORT meta-exports diff --git a/src/embeddings/analyzers/CMakeLists.txt b/src/embeddings/analyzers/CMakeLists.txt new file mode 100644 index 000000000..9c62f076a --- /dev/null +++ b/src/embeddings/analyzers/CMakeLists.txt @@ -0,0 +1,8 @@ +project(meta-embeddings-analyzers) + +add_library(meta-embeddings-analyzers embedding_analyzer.cpp) +target_link_libraries(meta-embeddings-analyzers meta-analyzers meta-embeddings) + +install(TARGETS meta-embeddings-analyzers + EXPORT meta-exports + DESTINATION lib) diff --git a/src/embeddings/analyzers/embedding_analyzer.cpp b/src/embeddings/analyzers/embedding_analyzer.cpp new file mode 100644 index 000000000..4b32cd55b --- /dev/null +++ b/src/embeddings/analyzers/embedding_analyzer.cpp @@ -0,0 +1,76 @@ +/** + * @file embedding_analyzer.cpp + * @author Sean Massung + */ + +#include "meta/analyzers/token_stream.h" +#include "meta/corpus/document.h" +#include "meta/embeddings/analyzers/embedding_analyzer.h" +#include "meta/math/vector.h" + +namespace meta +{ +namespace analyzers +{ + +const util::string_view embedding_analyzer::id = "embedding"; + +embedding_analyzer::embedding_analyzer(const cpptoml::table& config, + std::unique_ptr stream) + : stream_{std::move(stream)}, + embeddings_{std::make_shared( + embeddings::load_embeddings(config))}, + prefix_{*config.get_as("prefix")}, + features_(embeddings_->vector_size(), 0.0) +{ + // nothing +} + +embedding_analyzer::embedding_analyzer(const embedding_analyzer& other) + : stream_{other.stream_->clone()}, + embeddings_{other.embeddings_}, + prefix_{other.prefix_}, + features_{other.features_} +{ + // nothing +} + +void embedding_analyzer::tokenize(const corpus::document& doc, + featurizer& counts) +{ + using namespace math::operators; + stream_->set_content(get_content(doc)); + features_.assign(embeddings_->vector_size(), 0.0); + uint64_t num_seen = 0; + while (*stream_) + { + auto token = stream_->next(); + features_ = std::move(features_) + embeddings_->at(token).v; + ++num_seen; + } + + // average each feature and record it + uint64_t cur_dim = 0; + for (const auto& val : features_) + counts(prefix_ + std::to_string(cur_dim++), val / num_seen); +} + +template <> +std::unique_ptr +make_analyzer(const cpptoml::table& global, + const cpptoml::table& config) +{ + auto filts = load_filters(global, config); + return make_unique(config, std::move(filts)); +} +} + +namespace embeddings +{ +void register_analyzers() +{ + using namespace analyzers; + register_analyzer(); +} +} +} diff --git a/src/embeddings/cooccurrence_counter.cpp b/src/embeddings/cooccurrence_counter.cpp new file mode 100644 index 000000000..fbeb1cca6 --- /dev/null +++ b/src/embeddings/cooccurrence_counter.cpp @@ -0,0 +1,349 @@ +/** + * @file cooccurrence_counter.cpp + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#include "meta/embeddings/cooccurrence_counter.h" +#include "meta/analyzers/analyzer.h" +#include "meta/embeddings/cooccur_iterator.h" +#include "meta/logging/logger.h" +#include "meta/util/printing.h" +#include "meta/util/progress.h" + +namespace meta +{ +namespace embeddings +{ + +class cooccurrence_buffer +{ + public: + using count_t = std::pair; + using map_t = meta::hashing::probe_map; + + cooccurrence_buffer(cooccurrence_counter* counter, std::size_t max_bytes, + const analyzers::token_stream& stream) + : counter_{counter}, + max_bytes_{max_bytes}, + cooccurrences_{ + static_cast(max_bytes_ / sizeof(count_t))}, + stream_{stream.clone()} + { + // nothing + } + + cooccurrence_buffer(cooccurrence_buffer&&) = default; + cooccurrence_buffer& operator=(cooccurrence_buffer&&) = default; + + ~cooccurrence_buffer() + { + flush(); + --counter_->num_tokenizing_; + } + + void flush() + { + if (!cooccurrences_.empty()) + { + std::lock_guard lock{counter_->io_mutex_}; + printing::progress::clear(); + LOG(info) << "Flushing hash table of size: " + << printing::bytes_to_units(cooccurrences_.bytes_used()) + << " with " << cooccurrences_.size() << " unique pairs" + << ENDLG; + } + + { + auto items = std::move(cooccurrences_).extract(); + std::sort(items.begin(), items.end(), + [](const count_t& a, const count_t& b) { + return a.first < b.first; + }); + + counter_->flush_chunk(std::move(items)); + } + + cooccurrences_ + = map_t{static_cast(max_bytes_ / sizeof(count_t))}; + } + + void operator()(uint64_t target, uint64_t context, double weight) + { + cooccurrence_key key{target, context}; + auto it = cooccurrences_.find(key); + if (it == cooccurrences_.end()) + { + maybe_flush(); + cooccurrences_[key] = weight; + } + else + { + it->value() += weight; + } + } + + private: + void maybe_flush() + { + // check if inserting a new cooccurrence would cause a resize + if (cooccurrences_.next_load_factor() + >= cooccurrences_.max_load_factor()) + { + // see if the newly resized table would fit in ram + auto bytes_used + = cooccurrences_.bytes_used() * cooccurrences_.resize_ratio(); + + if (bytes_used >= max_bytes_) + { + flush(); + } + } + } + + friend class cooccurrence_counter; + + cooccurrence_counter* counter_; + const std::size_t max_bytes_; + map_t cooccurrences_; + std::unique_ptr stream_; +}; + +namespace +{ +hashing::probe_map +load_vocab(const std::string& filename) +{ + using map_type = hashing::probe_map; + + std::ifstream input{filename, std::ios::binary}; + auto size = io::packed::read(input); + auto reserve_size = static_cast( + std::ceil(size / map_type::default_max_load_factor())); + + printing::progress progress{" > Loading vocab: ", size}; + map_type vocab{reserve_size}; + for (uint64_t tid{0}; tid < size; ++tid) + { + progress(tid); + auto word = io::packed::read(input); + io::packed::read(input); + + vocab[word] = tid; + } + + return vocab; +} +} + +cooccurrence_counter::cooccurrence_counter(configuration config, + parallel::thread_pool& pool) + : prefix_{std::move(config.prefix)}, + max_ram_{config.max_ram}, + merge_fanout_{config.merge_fanout}, + window_size_{config.window_size}, + break_on_tags_{config.break_on_tags}, + vocab_{load_vocab(prefix_ + "/vocab.bin")}, + pool_(pool) +{ + LOG(info) << "Loaded vocabulary of size " << vocab_.size() << " occupying " + << printing::bytes_to_units(vocab_.bytes_used()) << ENDLG; + + if (vocab_.bytes_used() > max_ram_) + throw cooccurrence_exception{"RAM limit too restrictive"}; + max_ram_ -= vocab_.bytes_used(); +} + +void cooccurrence_counter::count(corpus::corpus& docs, + const analyzers::token_stream& stream) +{ + if (chunk_num_ != 0) + throw cooccurrence_exception{ + "cooccurrence_counters may not be re-used"}; + + num_tokenizing_ = pool_.size(); + printing::progress progress{" > Counting cooccurrences: ", docs.size()}; + corpus::parallel_consume( + docs, pool_, + [&]() { + return cooccurrence_buffer{this, max_ram_ / pool_.size(), stream}; + }, + [&](cooccurrence_buffer& buffer, const corpus::document& doc) { + { + std::lock_guard lock{io_mutex_}; + progress(doc.id()); + } + + buffer.stream_->set_content(analyzers::get_content(doc)); + + std::deque history; + while (*buffer.stream_) + { + auto tok = buffer.stream_->next(); + + if (tok == "" && break_on_tags_) + { + history.clear(); + } + else if (tok == "" && break_on_tags_) + { + continue; + } + else + { + // ignore out-of-vocabulary words + auto it = vocab_.find(tok); + if (it == vocab_.end()) + continue; + + auto tid = it->value(); + + // everything in history is a left-context of tid. + // Likewise, tid is a right-context of everything in + // history. + for (auto it = history.begin(), end = history.end(); + it != end; ++it) + { + auto dist = std::distance(it, end); + buffer(tid, *it, 1.0 / dist); + buffer(*it, tid, 1.0 / dist); + } + + history.push_back(tid); + if (history.size() > window_size_) + history.pop_front(); + } + } + }); +} + +void cooccurrence_counter::flush_chunk(memory_chunk_type&& chunk) +{ + std::unique_lock lock{chunk_mutex_}; + + if (!chunk.empty()) + memory_chunks_.emplace_back(std::move(chunk)); + + ++num_pending_; + + // If this thread added the last expected in-memory chunk, it performs + // the merging + if (num_pending_ == num_tokenizing_) + { + memory_merge_chunks(); + --num_pending_; + chunk_cond_.notify_all(); + lock.unlock(); + + // co-opt this thread to start merging on-disk chunks if there are + // enough to start the mergesort + maybe_merge(); + } + // otherwise, this thread will wait until the merger thread completes + else + { + chunk_cond_.wait(lock, [&]() { return memory_chunks_.empty(); }); + --num_pending_; + } +} + +void cooccurrence_counter::memory_merge_chunks() +{ + if (!memory_chunks_.empty()) + { + auto filename = prefix_ + "/chunk-" + std::to_string(chunk_num_++); + uint64_t total_bytes = 0; + { + std::ofstream output{filename, std::ios::binary}; + printing::progress::clear(); + LOG(info) << "Merging " << memory_chunks_.size() + << " in-memory chunks..." << ENDLG; + util::multiway_merge(memory_chunks_.begin(), memory_chunks_.end(), + [&](cooccur_record&& record) { + total_bytes + += io::packed::write(output, record); + }, + printing::no_progress_trait{}); + } + + if (total_bytes > 0) + { + chunks_.emplace(filename, total_bytes); + } + else + { + filesystem::delete_file(filename); + } + + memory_chunks_.clear(); + } +} + +void cooccurrence_counter::maybe_merge() +{ + std::unique_lock lock{chunk_mutex_}; + if (chunks_.size() < merge_fanout_) + return; + + --num_tokenizing_; + + std::vector chunks; + chunks.reserve(merge_fanout_); + for (std::size_t i = 0; i < merge_fanout_; ++i) + { + chunks.emplace_back(chunks_.top().path); + chunks_.pop(); + } + + auto filename = prefix_ + "/chunk-" + std::to_string(chunk_num_++); + uint64_t total_bytes = 0; + { + std::ofstream output{filename, std::ios::binary}; + printing::progress::clear(); + LOG(info) << "Merging " << chunks.size() << " on-disk chunks..." + << ENDLG; + lock.unlock(); + + util::multiway_merge(chunks.begin(), chunks.end(), + [&](cooccur_record&& record) { + total_bytes + += io::packed::write(output, record); + }, + printing::no_progress_trait{}); + } + + lock.lock(); + chunks_.emplace(filename, total_bytes); + ++num_tokenizing_; + printing::progress::clear(); + LOG(info) << "On-disk merge complete" << ENDLG; +} + +cooccurrence_counter::~cooccurrence_counter() +{ + std::vector chunks; + + chunks.reserve(chunks_.size()); + while (!chunks_.empty()) + { + chunks.emplace_back(chunks_.top().path); + chunks_.pop(); + } + + std::ofstream output{prefix_ + "/cooccur.bin", std::ios::binary}; + auto num_records = util::multiway_merge( + chunks.begin(), chunks.end(), + [&](cooccur_record&& record) { io::packed::write(output, record); }); + chunks.clear(); + + LOG(info) << "Cooccurrence matrix elements: " << num_records << ENDLG; + LOG(info) << "Cooccurrence matrix size: " + << printing::bytes_to_units( + filesystem::file_size(prefix_ + "/cooccur.bin")) + << ENDLG; +} +} +} diff --git a/src/embeddings/tools/CMakeLists.txt b/src/embeddings/tools/CMakeLists.txt index 3eda80486..df2b64ae1 100644 --- a/src/embeddings/tools/CMakeLists.txt +++ b/src/embeddings/tools/CMakeLists.txt @@ -1,8 +1,8 @@ add_executable(embedding-vocab embedding_vocab.cpp) target_link_libraries(embedding-vocab meta-analyzers meta-util meta-io) -add_executable(embedding-coocur embedding_coocur.cpp) -target_link_libraries(embedding-coocur meta-analyzers meta-util meta-io) +add_executable(embedding-cooccur embedding_cooccur.cpp) +target_link_libraries(embedding-cooccur meta-embeddings) add_executable(glove glove.cpp) target_link_libraries(glove meta-util diff --git a/src/embeddings/tools/embedding_cooccur.cpp b/src/embeddings/tools/embedding_cooccur.cpp new file mode 100644 index 000000000..debb45389 --- /dev/null +++ b/src/embeddings/tools/embedding_cooccur.cpp @@ -0,0 +1,81 @@ +/** + * @file embedding_cooccur.cpp + * @author Chase Geigle + * + * This tool builds the weighted cooccurrence matrix for the GloVe training + * method. + */ + +#include + +#include "cpptoml.h" +#include "meta/analyzers/analyzer.h" +#include "meta/corpus/corpus_factory.h" +#include "meta/embeddings/cooccurrence_counter.h" +#include "meta/io/filesystem.h" +#include "meta/logging/logger.h" + +using namespace meta; + +int main(int argc, char** argv) +{ + if (argc < 2) + { + std::cerr << "Usage: " << argv[0] << " config.toml" << std::endl; + return 1; + } + + logging::set_cerr_logging(); + + auto config = cpptoml::parse_file(argv[1]); + + // extract building parameters + auto embed_cfg = config->get_table("embeddings"); + auto prefix = *embed_cfg->get_as("prefix"); + auto vocab_filename = prefix + "/vocab.bin"; + auto window_size + = embed_cfg->get_as("window-size").value_or(15); + auto max_ram = embed_cfg->get_as("max-ram").value_or(4096) + * 1024 * 1024; + auto merge_fanout + = embed_cfg->get_as("merge-fanout").value_or(8); + auto break_on_tags + = embed_cfg->get_as("break-on-tags").value_or(false); + + if (!filesystem::file_exists(vocab_filename)) + { + LOG(fatal) << "Vocabulary file has not yet been generated, please do " + "this before building the cooccurrence table" + << ENDLG; + return 1; + } + + auto stream = analyzers::load_filters(*config, *embed_cfg); + if (!stream) + { + LOG(fatal) << "Failed to find an ngram-word analyzer configuration in " + << argv[1] << ENDLG; + return 1; + } + + auto num_threads + = embed_cfg->get_as("num-threads") + .value_or(std::max(1u, std::thread::hardware_concurrency())); + + { + embeddings::cooccurrence_counter::configuration cooccur_config; + cooccur_config.prefix = prefix; + cooccur_config.max_ram = max_ram; + cooccur_config.merge_fanout = merge_fanout; + cooccur_config.window_size = window_size; + cooccur_config.break_on_tags = break_on_tags; + + parallel::thread_pool pool{num_threads}; + embeddings::cooccurrence_counter counter{cooccur_config, pool}; + + auto docs = corpus::make_corpus(*config); + counter.count(*docs, *stream); + } + + return 0; +} diff --git a/src/embeddings/tools/embedding_coocur.cpp b/src/embeddings/tools/embedding_coocur.cpp deleted file mode 100644 index 028a32179..000000000 --- a/src/embeddings/tools/embedding_coocur.cpp +++ /dev/null @@ -1,298 +0,0 @@ -/** - * @file embedding_coocur.cpp - * @author Chase Geigle - * - * This tool builds the weighted coocurrence matrix for the GloVe training - * method. - */ - -#include - -#include "cpptoml.h" -#include "meta/analyzers/all.h" -#include "meta/analyzers/token_stream.h" -#include "meta/corpus/corpus_factory.h" -#include "meta/embeddings/coocur_iterator.h" -#include "meta/hashing/probe_map.h" -#include "meta/io/packed.h" -#include "meta/logging/logger.h" -#include "meta/util/multiway_merge.h" -#include "meta/util/progress.h" -#include "meta/util/printing.h" - -using namespace meta; - -namespace meta -{ -namespace hashing -{ -template -struct key_traits> -{ - static constexpr bool inlineable - = key_traits::inlineable && key_traits::inlineable; - - constexpr static std::pair sentinel() - { - return {key_traits::sentinel(), key_traits::sentinel()}; - } -}; -} -} - -class coocur_buffer -{ - public: - coocur_buffer(std::size_t max_ram, util::string_view prefix) - : max_bytes_{max_ram}, - prefix_{prefix.to_string()}, - coocur_{static_cast(max_bytes_ / sizeof(count_t))} - { - // nothing - } - - void flush() - { - LOG(info) << "\nFlushing buffer of size: " - << printing::bytes_to_units(coocur_.bytes_used()) << " with " - << coocur_.size() << " unique pairs" << ENDLG; - - { - auto items = std::move(coocur_).extract(); - std::sort(items.begin(), items.end(), - [](const count_t& a, const count_t& b) - { - return a.first < b.first; - }); - - std::ofstream output{prefix_ + "/chunk-" - + std::to_string(chunk_num_), - std::ios::binary}; - for (const auto& pr : items) - { - io::packed::write(output, pr.first.first); - io::packed::write(output, pr.first.second); - io::packed::write(output, pr.second); - } - } - - coocur_ = map_t{static_cast(max_bytes_ / sizeof(count_t))}; - ++chunk_num_; - } - - void operator()(uint64_t target, uint64_t context, double weight) - { - auto it = coocur_.find(std::make_pair(target, context)); - if (it == coocur_.end()) - { - maybe_flush(); - coocur_[std::make_pair(target, context)] = weight; - } - else - { - it->value() += weight; - } - } - - std::size_t num_chunks() const - { - return chunk_num_; - } - - uint64_t merge_chunks() - { - coocur_ = map_t{}; - std::vector chunks; - chunks.reserve(num_chunks()); - - for (std::size_t i = 0; i < num_chunks(); ++i) - chunks.emplace_back(prefix_ + "/chunk-" + std::to_string(i)); - - std::ofstream output{prefix_ + "/coocur.bin", std::ios::binary}; - auto num_records - = util::multiway_merge(chunks.begin(), chunks.end(), - [&](embeddings::coocur_record&& record) - { - io::packed::write(output, record); - }); - chunks.clear(); - - // clean up temporary files - for (std::size_t i = 0; i < num_chunks(); ++i) - { - filesystem::delete_file(prefix_ + "/chunk-" + std::to_string(i)); - } - - return num_records; - } - - private: - void maybe_flush() - { - // check if inserting a new coocurrence would cause a resize - if (coocur_.next_load_factor() >= coocur_.max_load_factor()) - { - // see if the newly resized table would fit in ram - auto bytes_used = coocur_.bytes_used() * coocur_.resize_ratio(); - - if (bytes_used >= max_bytes_) - { - flush(); - } - } - } - - using count_t = std::pair, double>; - using map_t - = meta::hashing::probe_map, double>; - const std::size_t max_bytes_; - const std::string prefix_; - map_t coocur_; - std::size_t chunk_num_ = 0; -}; - -hashing::probe_map -load_vocab(const std::string& filename) -{ - using map_type = hashing::probe_map; - - std::ifstream input{filename, std::ios::binary}; - auto size = io::packed::read(input); - auto reserve_size = static_cast( - std::ceil(size / map_type::default_max_load_factor())); - - printing::progress progress{" > Loading vocab: ", size}; - map_type vocab{reserve_size}; - for (uint64_t tid{0}; tid < size; ++tid) - { - progress(tid); - auto word = io::packed::read(input); - io::packed::read(input); - - vocab[word] = tid; - } - - return vocab; -} - -int main(int argc, char** argv) -{ - if (argc < 2) - { - std::cerr << "Usage: " << argv[0] << " config.toml" << std::endl; - return 1; - } - - logging::set_cerr_logging(); - - auto config = cpptoml::parse_file(argv[1]); - - // extract building parameters - auto embed_cfg = config->get_table("embeddings"); - auto prefix = *embed_cfg->get_as("prefix"); - auto vocab_filename = prefix + "/vocab.bin"; - auto window_size = static_cast( - embed_cfg->get_as("window-size").value_or(15)); - auto max_ram = static_cast( - embed_cfg->get_as("max-ram").value_or(4096)) - * 1024 * 1024; - - if (!filesystem::file_exists(vocab_filename)) - { - LOG(fatal) << "Vocabulary file has not yet been generated, please do " - "this before building the coocurrence table" - << ENDLG; - return 1; - } - - auto vocab = load_vocab(vocab_filename); - LOG(info) << "Loaded vocabulary of size " << vocab.size() << " occupying " - << printing::bytes_to_units(vocab.bytes_used()) << ENDLG; - - if (max_ram <= vocab.bytes_used()) - { - LOG(fatal) << "RAM limit too restrictive" << ENDLG; - return 1; - } - - max_ram -= vocab.bytes_used(); - if (max_ram < 1024 * 1024) - { - LOG(fatal) << "RAM limit too restrictive" << ENDLG; - return 1; - } - - auto stream = analyzers::load_filters(*config, *embed_cfg); - if (!stream) - { - LOG(fatal) << "Failed to find an ngram-word analyzer configuration in " - << argv[1] << ENDLG; - return 1; - } - - coocur_buffer coocur{max_ram, prefix}; - - { - auto docs = corpus::make_corpus(*config); - printing::progress progress{" > Counting coocurrences: ", docs->size()}; - for (uint64_t i = 0; docs->has_next(); ++i) - { - progress(i); - auto doc = docs->next(); - stream->set_content(analyzers::get_content(doc)); - - std::deque history; - while (*stream) - { - auto tok = stream->next(); - - if (tok == "") - { - history.clear(); - } - else if (tok == "") - { - continue; - } - else - { - // ignore out-of-vocabulary words - auto it = vocab.find(tok); - if (it == vocab.end()) - continue; - - auto tid = it->value(); - - // everything in history is a left-context of tid. - // Likewise, tid is a right-context of everything in - // history. - for (auto it = history.begin(), end = history.end(); - it != end; ++it) - { - auto dist = std::distance(it, end); - coocur(tid, *it, 1.0 / dist); - coocur(*it, tid, 1.0 / dist); - } - - history.push_back(tid); - if (history.size() > window_size) - history.pop_front(); - } - } - } - } - - // flush any remaining elements - coocur.flush(); - - // merge all on-disk chunks - auto uniq = coocur.merge_chunks(); - - LOG(info) << "Coocurrence matrix elements: " << uniq << ENDLG; - LOG(info) << "Coocurrence matrix size: " - << printing::bytes_to_units( - filesystem::file_size(prefix + "/coocur.bin")) - << ENDLG; - - return 0; -} diff --git a/src/embeddings/tools/glove.cpp b/src/embeddings/tools/glove.cpp index a112710e8..8fc044fa8 100644 --- a/src/embeddings/tools/glove.cpp +++ b/src/embeddings/tools/glove.cpp @@ -2,7 +2,7 @@ * @file glove.cpp * @author Chase Geigle * - * This tool builds word embedding vectors from a weighted coocurrence + * This tool builds word embedding vectors from a weighted cooccurrence * matrix using the GloVe model. * * @see http://nlp.stanford.edu/projects/glove/ @@ -11,15 +11,15 @@ #include #include "cpptoml.h" -#include "meta/embeddings/coocur_iterator.h" +#include "meta/embeddings/cooccur_iterator.h" #include "meta/io/filesystem.h" #include "meta/io/packed.h" #include "meta/logging/logger.h" #include "meta/parallel/thread_pool.h" #include "meta/util/aligned_allocator.h" #include "meta/util/array_view.h" -#include "meta/util/progress.h" #include "meta/util/printing.h" +#include "meta/util/progress.h" #include "meta/util/random.h" #include "meta/util/time.h" @@ -30,64 +30,61 @@ std::size_t shuffle_partition(const std::string& prefix, std::size_t max_ram, { using namespace embeddings; - using vec_type = std::vector; + using vec_type = std::vector; using diff_type = vec_type::iterator::difference_type; std::mt19937 engine{std::random_device{}()}; - vec_type records(max_ram / sizeof(coocur_record)); + vec_type records(max_ram / sizeof(cooccur_record)); // read in RAM sized chunks and shuffle in memory and write out to disk std::vector chunk_sizes; std::size_t total_records = 0; - coocur_iterator input{prefix + "/coocur.bin"}; + cooccur_iterator input{prefix + "/cooccur.bin"}; - auto elapsed = common::time( - [&]() + auto elapsed = common::time([&]() { + printing::progress progress{" > Shuffling (pass 1): ", + input.total_bytes()}; + while (input != cooccur_iterator{}) { - printing::progress progress{" > Shuffling (pass 1): ", - input.total_bytes()}; - while (input != coocur_iterator{}) + std::size_t i = 0; + for (; i < records.size() && input != cooccur_iterator{}; + ++i, ++input) { - std::size_t i = 0; - for (; i < records.size() && input != coocur_iterator{}; - ++i, ++input) - { - progress(input.bytes_read()); - records[i] = *input; - } + progress(input.bytes_read()); + records[i] = *input; + } - std::shuffle(records.begin(), - records.begin() + static_cast(i), - engine); + std::shuffle(records.begin(), + records.begin() + static_cast(i), engine); - std::ofstream output{prefix + "/coocur-shuf." - + std::to_string(chunk_sizes.size()) - + ".tmp", - std::ios::binary}; + std::ofstream output{prefix + "/cooccur-shuf." + + std::to_string(chunk_sizes.size()) + + ".tmp", + std::ios::binary}; - total_records += i; - chunk_sizes.push_back(i); - for (std::size_t j = 0; j < i; ++j) - io::packed::write(output, records[j]); - } - }); + total_records += i; + chunk_sizes.push_back(i); + for (std::size_t j = 0; j < i; ++j) + io::packed::write(output, records[j]); + } + }); LOG(info) << "Shuffling pass 1 took " << elapsed.count() / 1000.0 << " seconds" << ENDLG; - std::vector chunks; + std::vector chunks; chunks.reserve(chunk_sizes.size()); for (std::size_t i = 0; i < chunk_sizes.size(); ++i) { - chunks.emplace_back(prefix + "/coocur-shuf." + std::to_string(i) + chunks.emplace_back(prefix + "/cooccur-shuf." + std::to_string(i) + ".tmp"); } std::vector outputs(num_partitions); for (std::size_t i = 0; i < outputs.size(); ++i) { - outputs[i].open(prefix + "/coocur-shuf." + std::to_string(i) + ".bin", + outputs[i].open(prefix + "/cooccur-shuf." + std::to_string(i) + ".bin", std::ios::binary); } @@ -109,7 +106,7 @@ std::size_t shuffle_partition(const std::string& prefix, std::size_t max_ram, for (std::size_t n = 0; n < to_write; ++n) { - if (chunks[j] == coocur_iterator{} || i == records.size()) + if (chunks[j] == cooccur_iterator{} || i == records.size()) break; records[i] = *chunks[j]; ++chunks[j]; @@ -133,7 +130,7 @@ std::size_t shuffle_partition(const std::string& prefix, std::size_t max_ram, // delete temporary files for (std::size_t i = 0; i < chunk_sizes.size(); ++i) { - filesystem::delete_file(prefix + "/coocur-shuf." + std::to_string(i) + filesystem::delete_file(prefix + "/cooccur-shuf." + std::to_string(i) + ".tmp"); } @@ -153,26 +150,23 @@ class glove_trainer { // extract building parameters auto prefix = *embed_cfg.get_as("prefix"); - auto max_ram = static_cast( - embed_cfg.get_as("max-ram").value_or(4096)) + auto max_ram = embed_cfg.get_as("max-ram").value_or(4096) * 1024 * 1024; - vector_size_ = static_cast( - embed_cfg.get_as("vector-size").value_or(50)); + vector_size_ + = embed_cfg.get_as("vector-size").value_or(50); - auto num_threads = static_cast( - embed_cfg.get_as("num-threads") - .value_or(std::max(1u, std::thread::hardware_concurrency()))); + auto num_threads + = embed_cfg.get_as("num-threads") + .value_or(std::max(1u, std::thread::hardware_concurrency())); - auto iters = static_cast( - embed_cfg.get_as("max-iter").value_or(25)); + auto iters = embed_cfg.get_as("max-iter").value_or(25); learning_rate_ = embed_cfg.get_as("learning-rate").value_or(0.05); xmax_ = embed_cfg.get_as("xmax").value_or(100.0); scale_ = embed_cfg.get_as("scale").value_or(0.75); - auto num_rare = static_cast( - embed_cfg.get_as("unk-num-avg").value_or(100)); + auto num_rare = embed_cfg.get_as("unk-num-avg").value_or(100); if (!filesystem::file_exists(prefix + "/vocab.bin")) { @@ -182,15 +176,19 @@ class glove_trainer throw glove_exception{"no vocabulary file found in " + prefix}; } - if (!filesystem::file_exists(prefix + "/coocur.bin")) + if (!filesystem::file_exists(prefix + "/cooccur.bin")) { LOG(fatal) - << "Coocurrence matrix has not yet been generated, please " + << "Cooccurrence matrix has not yet been generated, please " "do this before learning word embeddings" << ENDLG; - throw glove_exception{"no coocurrence matrix found in " + prefix}; + throw glove_exception{"no cooccurrence matrix found in " + prefix}; } + // shuffle the data and partition it into equal parts for each + // thread + auto total_records = shuffle_partition(prefix, max_ram, num_threads); + std::size_t num_words = 0; { std::ifstream vocab{prefix + "/vocab.bin", std::ios::binary}; @@ -206,28 +204,23 @@ class glove_trainer // randomly initialize the word embeddings and biases { std::mt19937 engine{std::random_device{}()}; - std::generate(weights_.begin(), weights_.end(), [&]() - { - // use the word2vec style initialization - // I'm not entirely sure why, but this seems - // to do better than initializing the vectors - // to lie in the unit cube. Maybe scaling? - auto rnd = random::bounded_rand(engine, 65536); - return (rnd / 65536.0 - 0.5) / (vector_size_ + 1); - }); + std::generate(weights_.begin(), weights_.end(), [&]() { + // use the word2vec style initialization + // I'm not entirely sure why, but this seems + // to do better than initializing the vectors + // to lie in the unit cube. Maybe scaling? + auto rnd = random::bounded_rand(engine, 65536); + return (rnd / 65536.0 - 0.5) / (vector_size_ + 1); + }); } - // shuffle the data and partition it into equal parts for each - // thread - auto total_records = shuffle_partition(prefix, max_ram, num_threads); - // train using the specified number of threads train(prefix, num_threads, iters, total_records); - // delete the temporary shuffled coocurrence files + // delete the temporary shuffled cooccurrence files for (std::size_t i = 0; i < num_threads; ++i) - filesystem::delete_file(prefix + "/coocur-shuf." + std::to_string(i) - + ".bin"); + filesystem::delete_file(prefix + "/cooccur-shuf." + + std::to_string(i) + ".bin"); // save the target and context word embeddings save(prefix, num_words, num_rare); @@ -322,19 +315,16 @@ class glove_trainer futures.reserve(num_threads); for (std::size_t t = 0; t < num_threads; ++t) { - futures.emplace_back(pool.submit_task( - [&, t]() - { - return train_thread(prefix, t, progress, records); - })); + futures.emplace_back(pool.submit_task([&, t]() { + return train_thread(prefix, t, progress, records); + })); } double total_cost = 0.0; - auto elapsed = common::time([&]() - { - for (auto& fut : futures) - total_cost += fut.get(); - }); + auto elapsed = common::time([&]() { + for (auto& fut : futures) + total_cost += fut.get(); + }); progress.end(); LOG(progress) << "> Iteration " << i << "/" << iters @@ -344,9 +334,9 @@ class glove_trainer } } - double cost_weight(double coocur) const + double cost_weight(double cooccur) const { - return (coocur < xmax_) ? std::pow(coocur / xmax_, scale_) : 1.0; + return (cooccur < xmax_) ? std::pow(cooccur / xmax_, scale_) : 1.0; } void update_weight(double* weight, double* gradsq, double grad) @@ -362,12 +352,12 @@ class glove_trainer { using namespace embeddings; - coocur_iterator iter{prefix + "/coocur-shuf." - + std::to_string(thread_id) + ".bin"}; + cooccur_iterator iter{prefix + "/cooccur-shuf." + + std::to_string(thread_id) + ".bin"}; double cost = 0.0; - for (; iter != coocur_iterator{}; ++iter) + for (; iter != cooccur_iterator{}; ++iter) { progress(records++); auto record = *iter; @@ -426,10 +416,7 @@ class glove_trainer num_words}; io::packed::write(output, vector_size_); save_embeddings(output, num_words, num_rare, progress, - [&](uint64_t term) - { - return target_vector(term); - }); + [&](uint64_t term) { return target_vector(term); }); } // context embeddings @@ -440,11 +427,9 @@ class glove_trainer printing::progress progress{" > Saving context embeddings: ", num_words}; io::packed::write(output, vector_size_); - save_embeddings(output, num_words, num_rare, progress, - [&](uint64_t term) - { - return context_vector(term); - }); + save_embeddings( + output, num_words, num_rare, progress, + [&](uint64_t term) { return context_vector(term); }); } } @@ -467,8 +452,7 @@ class glove_trainer const auto& vec = vf(tid); std::transform(unk_vec.begin(), unk_vec.end(), vec.begin(), unk_vec.begin(), - [=](double unkweight, double vecweight) - { + [=](double unkweight, double vecweight) { return unkweight + vecweight / num_to_average; }); } @@ -479,15 +463,13 @@ class glove_trainer void write_normalized(ForwardIterator begin, ForwardIterator end, std::ofstream& output) const { - auto len = std::sqrt(std::accumulate(begin, end, 0.0, - [](double accum, double weight) - { - return accum + weight * weight; - })); - std::for_each(begin, end, [&](double weight) - { - io::packed::write(output, weight / len); - }); + auto len = std::sqrt( + std::accumulate(begin, end, 0.0, [](double accum, double weight) { + return accum + weight * weight; + })); + std::for_each(begin, end, [&](double weight) { + io::packed::write(output, weight / len); + }); } util::aligned_vector weights_; diff --git a/src/embeddings/tools/meta_to_glove.cpp b/src/embeddings/tools/meta_to_glove.cpp index 016297cd4..eca4adfca 100644 --- a/src/embeddings/tools/meta_to_glove.cpp +++ b/src/embeddings/tools/meta_to_glove.cpp @@ -1,15 +1,15 @@ /** - * @file embedding_coocur.cpp + * @file embedding_cooccur.cpp * @author Chase Geigle * - * This tool decompresses the MeTA vocabulary and coocurrence matrix files + * This tool decompresses the MeTA vocabulary and cooccurrence matrix files * to input that the original GloVe tool can read. * * (This is mainly for sanity checking.) */ #include "cpptoml.h" -#include "meta/embeddings/coocur_iterator.h" +#include "meta/embeddings/cooccur_iterator.h" #include "meta/io/binary.h" #include "meta/logging/logger.h" #include "meta/util/progress.h" @@ -52,11 +52,11 @@ int main(int argc, char** argv) } { - coocur_iterator iter{prefix + "/coocur.bin"}; - printing::progress progress{" > Decompressing coocurrence matrix: ", + cooccur_iterator iter{prefix + "/cooccur.bin"}; + printing::progress progress{" > Decompressing cooccurrence matrix: ", iter.total_bytes()}; - std::ofstream output{"coocur-glove.bin", std::ios::binary}; - for (; iter != coocur_iterator{}; ++iter) + std::ofstream output{"cooccur-glove.bin", std::ios::binary}; + for (; iter != cooccur_iterator{}; ++iter) { progress(iter.bytes_read()); auto record = *iter; diff --git a/src/embeddings/word_embeddings.cpp b/src/embeddings/word_embeddings.cpp index 23b08f62f..f73f3e1dc 100644 --- a/src/embeddings/word_embeddings.cpp +++ b/src/embeddings/word_embeddings.cpp @@ -8,6 +8,7 @@ */ #include "meta/embeddings/word_embeddings.h" +#include "meta/io/filesystem.h" #include "meta/io/packed.h" #include "meta/math/vector.h" #include "meta/util/fixed_heap.h" @@ -39,10 +40,8 @@ word_embeddings::word_embeddings(std::istream& vocab, std::istream& vectors) progress(tid); auto vec = vector(tid); - std::generate(vec.begin(), vec.end(), [&]() - { - return io::packed::read(vectors); - }); + std::generate(vec.begin(), vec.end(), + [&]() { return io::packed::read(vectors); }); } } @@ -73,16 +72,13 @@ word_embeddings::word_embeddings(std::istream& vocab, std::istream& first, progress(tid); auto vec = vector(tid); - std::generate(vec.begin(), vec.end(), [&]() - { - return (io::packed::read(first) - + io::packed::read(second)); - }); + std::generate(vec.begin(), vec.end(), [&]() { + return (io::packed::read(first) + + io::packed::read(second)); + }); auto len = math::operators::l2norm(vec); - std::transform(vec.begin(), vec.end(), vec.begin(), [=](double weight) - { - return weight / len; - }); + std::transform(vec.begin(), vec.end(), vec.begin(), + [=](double weight) { return weight / len; }); } } @@ -139,11 +135,10 @@ std::vector word_embeddings::top_k(util::array_view query, std::size_t k) const { - auto comp = [](const scored_embedding& a, const scored_embedding& b) - { - return a.score > b.score; - }; - util::fixed_heap results{k, comp}; + auto results = util::make_fixed_heap( + k, [](const scored_embedding& a, const scored_embedding& b) { + return a.score > b.score; + }); // +1 for for (std::size_t tid = 0; tid < id_to_term_.size() + 1; ++tid) @@ -159,6 +154,11 @@ word_embeddings::top_k(util::array_view query, return results.extract_top(); } +std::size_t word_embeddings::vector_size() const +{ + return vector_size_; +} + word_embeddings load_embeddings(const cpptoml::table& config) { auto prefix = config.get_as("prefix"); @@ -166,6 +166,10 @@ word_embeddings load_embeddings(const cpptoml::table& config) throw word_embeddings_exception{ "missing prefix key in configuration file"}; + if (!filesystem::exists(*prefix)) + throw word_embeddings_exception{"embeddings directory does not exist: " + + *prefix}; + std::ifstream vocab{*prefix + "/vocab.bin", std::ios::binary}; if (!vocab) throw word_embeddings_exception{"missing vocabulary file in: " diff --git a/src/features/odds_ratio.cpp b/src/features/odds_ratio.cpp index f2da9efca..da72f7e97 100644 --- a/src/features/odds_ratio.cpp +++ b/src/features/odds_ratio.cpp @@ -19,7 +19,7 @@ double odds_ratio::score(const class_label& lbl, term_id tid) const double denominator = (1.0 - p_tc) * p_tnc; // avoid divide by zero - if (denominator == 0.0) + if (denominator <= 1e-20) return 0.0; return std::log(numerator / denominator); diff --git a/src/index/disk_index.cpp b/src/index/disk_index.cpp index d59f34cb0..63b5d3a77 100644 --- a/src/index/disk_index.cpp +++ b/src/index/disk_index.cpp @@ -105,7 +105,7 @@ uint64_t disk_index::num_docs() const std::string disk_index::doc_name(doc_id d_id) const { - auto path = doc_path(d_id); + auto path = metadata(d_id, "path").value_or("[none]"); return path.substr(path.find_last_of("/") + 1); } diff --git a/src/index/eval/ir_eval.cpp b/src/index/eval/ir_eval.cpp index b4642b501..a93a8c410 100644 --- a/src/index/eval/ir_eval.cpp +++ b/src/index/eval/ir_eval.cpp @@ -145,7 +145,7 @@ double ir_eval::ndcg(const std::vector& results, query_id q_id, std::vector rels; for (const auto& s : ht->second) rels.push_back(s.second); - std::sort(rels.begin(), rels.end()); + std::sort(rels.begin(), rels.end(), std::greater{}); double idcg = 0.0; i = 1; diff --git a/src/index/forward_index.cpp b/src/index/forward_index.cpp index 17fab146d..5d7edd041 100644 --- a/src/index/forward_index.cpp +++ b/src/index/forward_index.cpp @@ -3,10 +3,7 @@ * @author Sean Massung */ -#include "cpptoml.h" #include "meta/analyzers/analyzer.h" -#include "meta/corpus/corpus.h" -#include "meta/corpus/corpus_factory.h" #include "meta/corpus/libsvm_corpus.h" #include "meta/hashing/probe_map.h" #include "meta/index/chunk_reader.h" @@ -17,19 +14,11 @@ #include "meta/index/postings_file.h" #include "meta/index/postings_file_writer.h" #include "meta/index/postings_inverter.h" -#include "meta/index/string_list.h" -#include "meta/index/string_list_writer.h" -#include "meta/index/vocabulary_map.h" #include "meta/index/vocabulary_map_writer.h" #include "meta/io/libsvm_parser.h" #include "meta/logging/logger.h" -#include "meta/parallel/thread_pool.h" -#include "meta/util/disk_vector.h" -#include "meta/util/mapping.h" #include "meta/util/pimpl.tcc" #include "meta/util/printing.h" -#include "meta/util/shim.h" -#include "meta/util/time.h" namespace meta { @@ -53,7 +42,7 @@ class forward_index::impl * merged. */ void tokenize_docs(corpus::corpus& corpus, metadata_writer& mdata_writer, - uint64_t ram_budget, uint64_t num_threads); + uint64_t ram_budget, std::size_t num_threads); /** * Merges together num_chunks number of intermediate chunks, using the @@ -224,8 +213,8 @@ void forward_index::create_index(const cpptoml::table& config, } else { - auto ram_budget = static_cast( - config.get_as("indexer-ram-budget").value_or(1024)); + auto ram_budget + = config.get_as("indexer-ram-budget").value_or(1024); if (config.get_as("uninvert").value_or(false)) { @@ -255,9 +244,8 @@ void forward_index::create_index(const cpptoml::table& config, impl_->load_labels(docs.size()); auto max_threads = std::thread::hardware_concurrency(); - auto num_threads = static_cast( - config.get_as("indexer-num-threads") - .value_or(max_threads)); + auto num_threads = config.get_as("indexer-num-threads") + .value_or(max_threads); if (num_threads > max_threads) { num_threads = max_threads; @@ -292,60 +280,69 @@ void forward_index::create_index(const cpptoml::table& config, LOG(info) << "Done creating index: " << index_name() << ENDLG; } +namespace +{ +struct local_storage +{ + local_storage(const std::string& chunk_path, + const std::unique_ptr& analyzer) + : chunk_{chunk_path, std::ios::binary}, analyzer_{analyzer->clone()} + { + // nothing + } + + io::mofstream chunk_; + std::unique_ptr analyzer_; +}; +} + void forward_index::impl::tokenize_docs(corpus::corpus& docs, metadata_writer& mdata_writer, uint64_t ram_budget, - uint64_t num_threads) + std::size_t num_threads) { std::mutex io_mutex; - std::mutex corpus_mutex; std::mutex vocab_mutex; printing::progress progress{" > Tokenizing Docs: ", docs.size()}; hashing::probe_map vocab; bool exceeded_budget = false; - auto task = [&](size_t chunk_id) - { - std::ofstream chunk{idx_->index_name() + "/chunk-" - + std::to_string(chunk_id), - std::ios::binary}; - auto analyzer = analyzer_->clone(); - while (true) - { - util::optional doc; - { - std::lock_guard lock{corpus_mutex}; - - if (!docs.has_next()) - return; + std::atomic_size_t chunk_id{0}; - doc = docs.next(); - } + parallel::thread_pool pool{num_threads}; + corpus::parallel_consume( + docs, pool, + [&]() { + auto cid = chunk_id.fetch_add(1); + return local_storage{idx_->index_name() + "/chunk-" + + std::to_string(cid), + analyzer_}; + }, + [&](local_storage& ls, const corpus::document& doc) { { std::lock_guard lock{io_mutex}; - progress(doc->id()); + progress(doc.id()); } - auto counts = analyzer->analyze(*doc); + auto counts = ls.analyzer_->analyze(doc); // warn if there is an empty document if (counts.empty()) { std::lock_guard lock{io_mutex}; LOG(progress) << '\n' << ENDLG; - LOG(warning) << "Empty document (id = " << doc->id() + LOG(warning) << "Empty document (id = " << doc.id() << ") generated!" << ENDLG; } auto length = std::accumulate( counts.begin(), counts.end(), 0ul, - [](uint64_t acc, const std::pair& count) - { + [](uint64_t acc, const std::pair& count) { return acc + std::round(count.second); }); - mdata_writer.write(doc->id(), length, counts.size(), doc->mdata()); - idx_->impl_->set_label(doc->id(), doc->label()); + mdata_writer.write(doc.id(), length, counts.size(), doc.mdata()); + idx_->impl_->set_label(doc.id(), doc.label()); forward_index::postings_data_type::count_t pd_counts; pd_counts.reserve(counts.size()); @@ -372,20 +369,10 @@ void forward_index::impl::tokenize_docs(corpus::corpus& docs, } } - forward_index::postings_data_type pdata{doc->id()}; + forward_index::postings_data_type pdata{doc.id()}; pdata.set_counts(std::move(pd_counts)); - pdata.write_packed(chunk); - } - }; - - parallel::thread_pool pool{num_threads}; - std::vector> futures; - futures.reserve(num_threads); - for (size_t i = 0; i < num_threads; ++i) - futures.emplace_back(pool.submit_task(std::bind(task, i))); - - for (auto& fut : futures) - fut.get(); + pdata.write_packed(ls.chunk_); + }); progress.end(); @@ -437,8 +424,7 @@ void forward_index::impl::merge_chunks( } util::multiway_merge(chunks.begin(), chunks.end(), - [&](forward_index::postings_data_type&& to_write) - { + [&](forward_index::postings_data_type&& to_write) { // renumber the postings forward_index::postings_data_type::count_t counts; counts.reserve(to_write.counts().size()); @@ -575,9 +561,12 @@ void forward_index::impl::uninvert(const inverted_index& inv_idx, { postings_inverter handler{idx_->index_name()}; { + printing::progress progress{" > Uninverting postings: ", + inv_idx.unique_terms()}; auto producer = handler.make_producer(ram_budget); for (term_id t_id{0}; t_id < inv_idx.unique_terms(); ++t_id) { + progress(t_id); auto pdata = inv_idx.search_primary(t_id); producer(pdata->primary_key(), pdata->counts()); } diff --git a/src/index/inverted_index.cpp b/src/index/inverted_index.cpp index d5b60d1cb..88caeb68f 100644 --- a/src/index/inverted_index.cpp +++ b/src/index/inverted_index.cpp @@ -4,25 +4,16 @@ * @author Chase Geigle */ -#include "meta/analyzers/analyzer.h" -#include "meta/corpus/corpus.h" -#include "meta/corpus/corpus_factory.h" -#include "meta/corpus/metadata_parser.h" #include "meta/index/disk_index_impl.h" #include "meta/index/inverted_index.h" #include "meta/index/metadata_writer.h" #include "meta/index/postings_file.h" #include "meta/index/postings_file_writer.h" #include "meta/index/postings_inverter.h" -#include "meta/index/vocabulary_map.h" #include "meta/index/vocabulary_map_writer.h" #include "meta/logging/logger.h" -#include "meta/parallel/thread_pool.h" -#include "meta/util/mapping.h" #include "meta/util/pimpl.tcc" #include "meta/util/printing.h" -#include "meta/util/progress.h" -#include "meta/util/shim.h" namespace meta { @@ -58,7 +49,7 @@ class inverted_index::impl void tokenize_docs(corpus::corpus& docs, postings_inverter& inverter, metadata_writer& mdata_writer, uint64_t ram_budget, - uint64_t num_threads); + std::size_t num_threads); /** * Compresses the large postings file. @@ -74,7 +65,8 @@ class inverted_index::impl std::unique_ptr analyzer_; util::optional> postings_; + inverted_index::secondary_key_type>> + postings_; /// the total number of term occurrences in the entire corpus uint64_t total_corpus_terms_; @@ -126,14 +118,14 @@ void inverted_index::create_index(const cpptoml::table& config, LOG(info) << "Creating index: " << index_name() << ENDLG; - auto ram_budget = static_cast( - config.get_as("indexer-ram-budget").value_or(1024)); - auto max_writers = static_cast( - config.get_as("indexer-max-writers").value_or(8)); + auto ram_budget + = config.get_as("indexer-ram-budget").value_or(1024); + auto max_writers + = config.get_as("indexer-max-writers").value_or(8); auto max_threads = std::thread::hardware_concurrency(); - auto num_threads = static_cast( - config.get_as("indexer-num-threads").value_or(max_threads)); + auto num_threads = config.get_as("indexer-num-threads") + .value_or(max_threads); if (num_threads > max_threads) { num_threads = max_threads; @@ -187,66 +179,69 @@ void inverted_index::load_index() inv_impl_->load_postings(); } +namespace +{ +struct local_storage +{ + local_storage(uint64_t ram_budget, + postings_inverter& inverter, + const std::unique_ptr& analyzer) + : producer_{inverter.make_producer(ram_budget)}, + analyzer_{analyzer->clone()} + { + // nothing + } + + postings_inverter::producer producer_; + std::unique_ptr analyzer_; +}; +} + void inverted_index::impl::tokenize_docs( corpus::corpus& docs, postings_inverter& inverter, - metadata_writer& mdata_writer, uint64_t ram_budget, uint64_t num_threads) + metadata_writer& mdata_writer, uint64_t ram_budget, std::size_t num_threads) { - std::mutex mutex; + std::mutex io_mutex; printing::progress progress{" > Tokenizing Docs: ", docs.size()}; + uint64_t local_budget = ram_budget / num_threads; - auto task = [&](uint64_t ram_budget) - { - auto producer = inverter.make_producer(ram_budget); - auto analyzer = analyzer_->clone(); - while (true) - { - util::optional doc; - { - std::lock_guard lock{mutex}; + parallel::thread_pool pool{num_threads}; - if (!docs.has_next()) - return; // destructor for producer will write - // any intermediate chunks - doc = docs.next(); - progress(doc->id()); + corpus::parallel_consume( + docs, pool, + [&]() { + return local_storage{local_budget, inverter, analyzer_}; + }, + [&](local_storage& ls, const corpus::document& doc) { + { + std::lock_guard lock{io_mutex}; + progress(doc.id()); } - auto counts = analyzer->analyze(*doc); + auto counts = ls.analyzer_->analyze(doc); // warn if there is an empty document if (counts.empty()) { - std::lock_guard lock{mutex}; + std::lock_guard lock{io_mutex}; LOG(progress) << '\n' << ENDLG; - LOG(warning) << "Empty document (id = " << doc->id() + LOG(warning) << "Empty document (id = " << doc.id() << ") generated!" << ENDLG; } auto length = std::accumulate( counts.begin(), counts.end(), 0ul, - [](uint64_t acc, const std::pair& count) - { + [](uint64_t acc, + const std::pair& count) { return acc + count.second; }); - mdata_writer.write(doc->id(), length, counts.size(), doc->mdata()); - idx_->impl_->set_label(doc->id(), doc->label()); + mdata_writer.write(doc.id(), length, counts.size(), doc.mdata()); + idx_->impl_->set_label(doc.id(), doc.label()); // update chunk - producer(doc->id(), counts); - } - }; - - parallel::thread_pool pool{num_threads}; - std::vector> futures; - for (size_t i = 0; i < num_threads; ++i) - { - futures.emplace_back( - pool.submit_task(std::bind(task, ram_budget / num_threads))); - } - - for (auto& fut : futures) - fut.get(); + ls.producer_(doc.id(), counts); + }); } void inverted_index::impl::compress(const std::string& filename, @@ -312,13 +307,7 @@ uint64_t inverted_index::total_corpus_terms() uint64_t inverted_index::total_num_occurences(term_id t_id) const { - auto pdata = search_primary(t_id); - - double sum = 0; - for (auto& c : pdata->counts()) - sum += c.second; - - return static_cast(sum); + return stream_for(t_id)->total_counts(); } float inverted_index::avg_doc_length() @@ -334,7 +323,7 @@ inverted_index::tokenize(const corpus::document& doc) uint64_t inverted_index::doc_freq(term_id t_id) const { - return search_primary(t_id)->counts().size(); + return stream_for(t_id)->size(); } auto inverted_index::search_primary(term_id t_id) const diff --git a/src/index/ranker/CMakeLists.txt b/src/index/ranker/CMakeLists.txt index 2ffc52915..20518f751 100644 --- a/src/index/ranker/CMakeLists.txt +++ b/src/index/ranker/CMakeLists.txt @@ -6,6 +6,8 @@ add_library(meta-ranker absolute_discount.cpp lm_ranker.cpp okapi_bm25.cpp pivoted_length.cpp + kl_divergence_prf.cpp + rocchio.cpp ranker.cpp ranker_factory.cpp) target_link_libraries(meta-ranker meta-index) diff --git a/src/index/ranker/kl_divergence_prf.cpp b/src/index/ranker/kl_divergence_prf.cpp new file mode 100644 index 000000000..5e2d6c4fa --- /dev/null +++ b/src/index/ranker/kl_divergence_prf.cpp @@ -0,0 +1,163 @@ +/** + * @file kl_divergence_prf.cpp + * @author Chase Geigle + */ + +#include + +#include "cpptoml.h" +#include "meta/index/ranker/dirichlet_prior.h" +#include "meta/index/ranker/kl_divergence_prf.h" +#include "meta/index/ranker/unigram_mixture.h" +#include "meta/index/score_data.h" +#include "meta/io/packed.h" +#include "meta/logging/logger.h" +#include "meta/util/fixed_heap.h" +#include "meta/util/iterator.h" +#include "meta/util/shim.h" + +namespace meta +{ +namespace index +{ + +const util::string_view kl_divergence_prf::id = "kl-divergence-prf"; +const constexpr float kl_divergence_prf::default_alpha; +const constexpr float kl_divergence_prf::default_lambda; +const constexpr uint64_t kl_divergence_prf::default_k; +const constexpr uint64_t kl_divergence_prf::default_max_terms; + +kl_divergence_prf::kl_divergence_prf(std::shared_ptr fwd) + : fwd_{std::move(fwd)}, + initial_ranker_{make_unique()}, + alpha_{default_alpha}, + lambda_{default_lambda}, + k_{default_k}, + max_terms_{default_max_terms} +{ + // nothing +} + +kl_divergence_prf::kl_divergence_prf( + std::shared_ptr fwd, + std::unique_ptr&& initial_ranker, float alpha, + float lambda, uint64_t k, uint64_t max_terms) + : fwd_{std::move(fwd)}, + initial_ranker_{std::move(initial_ranker)}, + alpha_{alpha}, + lambda_{lambda}, + k_{k}, + max_terms_{max_terms} +{ + // nothing +} + +kl_divergence_prf::kl_divergence_prf(std::istream& in) + : fwd_{[&]() { + auto path = io::packed::read(in); + auto cfg = cpptoml::parse_file(path + "/config.toml"); + return make_index(*cfg); + }()}, + initial_ranker_{load_lm_ranker(in)}, + alpha_{io::packed::read(in)}, + lambda_{io::packed::read(in)}, + k_{io::packed::read(in)}, + max_terms_{io::packed::read(in)} +{ + // nothing +} + +void kl_divergence_prf::save(std::ostream& out) const +{ + io::packed::write(out, id); + io::packed::write(out, fwd_->index_name()); + initial_ranker_->save(out); + io::packed::write(out, alpha_); + io::packed::write(out, lambda_); + io::packed::write(out, k_); + io::packed::write(out, max_terms_); +} + +std::vector +kl_divergence_prf::rank(ranker_context& ctx, uint64_t num_results, + const filter_function_type& filter) +{ + auto fb_docs = initial_ranker_->rank(ctx, k_, filter); + auto extract_docid = [](const search_result& sr) { return sr.d_id; }; + + // construct feedback document set + learn::dataset fb_dset{ + fwd_, util::make_transform_iterator(fb_docs.begin(), extract_docid), + util::make_transform_iterator(fb_docs.end(), extract_docid), + printing::no_progress_trait{}}; + + // learn the feedback model using the EM algorithm + feedback::training_options options; + options.lambda = lambda_; + auto fb_model = feedback::unigram_mixture( + [&](term_id tid) { + float term_count = ctx.idx.total_num_occurences(tid); + return term_count / ctx.idx.total_corpus_terms(); + }, + fb_dset, options); + + // extract only the top max_terms from the feedback model + using scored_term = std::pair; + auto heap = util::make_fixed_heap( + max_terms_, [&](const scored_term& a, const scored_term& b) { + return a.second > b.second; + }); + fb_model.each_seen_event( + [&](term_id tid) { heap.emplace(tid, fb_model.probability(tid)); }); + + // interpolate the old query with the top terms from the feedback model + hashing::probe_map new_query; + for (const auto& pr : heap.extract_top()) + { + new_query[pr.first] += alpha_ * pr.second; + } + for (const auto& postings_ctx : ctx.postings) + { + auto p_wq = postings_ctx.query_term_weight / ctx.query_length; + new_query[postings_ctx.t_id] += (1.0f - alpha_) * p_wq; + } + + // construct a new ranker_context from the new query + ranker_context new_ctx{ctx.idx, new_query.begin(), new_query.end(), filter}; + + // return ranking results based on the new query + return initial_ranker_->rank(new_ctx, num_results, filter); +} + +template <> +std::unique_ptr +make_ranker(const cpptoml::table& global, + const cpptoml::table& local) +{ + if (global.begin() == global.end()) + { + LOG(fatal) << "Global configuration group was empty in construction of " + "kl_divergence_prf ranker" + << ENDLG; + LOG(fatal) << "Did you mean to call index::make_ranker(global, local) " + "instead of index::make_ranker(local)?" + << ENDLG; + throw ranker_exception{"empty global configuration provided to " + "construction of kl_divergence_prf ranker"}; + } + + auto alpha = local.get_as("alpha").value_or( + kl_divergence_prf::default_alpha); + auto lambda = local.get_as("lambda").value_or( + kl_divergence_prf::default_lambda); + auto k = local.get_as("k").value_or(kl_divergence_prf::default_k); + auto max_terms = local.get_as("max-terms") + .value_or(kl_divergence_prf::default_max_terms); + auto init_cfg = local.get_table("feedback"); + auto f_idx = make_index(global); + return make_unique(std::move(f_idx), + make_lm_ranker(global, *init_cfg), + alpha, lambda, k, max_terms); +} +} +} diff --git a/src/index/ranker/ranker.cpp b/src/index/ranker/ranker.cpp index 6bff701b8..fae2cedda 100644 --- a/src/index/ranker/ranker.cpp +++ b/src/index/ranker/ranker.cpp @@ -4,7 +4,6 @@ * @author Chase Geigle */ -#include #include "meta/corpus/document.h" #include "meta/index/inverted_index.h" #include "meta/index/postings_data.h" @@ -18,27 +17,26 @@ namespace index { std::vector - ranker::score(inverted_index& idx, const corpus::document& query, - uint64_t num_results /* = 10 */, - const filter_function_type& filter /* return true */) +ranker::score(inverted_index& idx, const corpus::document& query, + uint64_t num_results /* = 10 */, + const filter_function_type& filter /* return true */) { auto counts = idx.tokenize(query); return score(idx, counts.begin(), counts.end(), num_results, filter); } -std::vector ranker::rank(detail::ranker_context& ctx, - uint64_t num_results, - const filter_function_type& filter) +std::vector +ranking_function::rank(ranker_context& ctx, uint64_t num_results, + const filter_function_type& filter) { score_data sd{ctx.idx, ctx.idx.avg_doc_length(), ctx.idx.num_docs(), ctx.idx.total_corpus_terms(), ctx.query_length}; - auto comp = [](const search_result& a, const search_result& b) - { - // comparison is reversed since we want a min-heap - return a.score > b.score; - }; - util::fixed_heap results{num_results, comp}; + auto results = util::make_fixed_heap( + num_results, [](const search_result& a, const search_result& b) { + // comparison is reversed since we want a min-heap + return a.score > b.score; + }); doc_id next_doc{ctx.idx.num_docs()}; while (ctx.cur_doc < ctx.idx.num_docs()) @@ -89,7 +87,7 @@ std::vector ranker::rank(detail::ranker_context& ctx, return results.extract_top(); } -float ranker::initial_score(const score_data&) const +float ranking_function::initial_score(const score_data&) const { return 0.0; } diff --git a/src/index/ranker/ranker_factory.cpp b/src/index/ranker/ranker_factory.cpp index ed2efb817..86c1069af 100644 --- a/src/index/ranker/ranker_factory.cpp +++ b/src/index/ranker/ranker_factory.cpp @@ -15,7 +15,10 @@ namespace index template void ranker_factory::reg() { - add(Ranker::id, make_ranker); + add(Ranker::id, + [](const cpptoml::table& global, const cpptoml::table& local) { + return make_ranker(global, local); + }); } ranker_factory::ranker_factory() @@ -26,15 +29,47 @@ ranker_factory::ranker_factory() reg(); reg(); reg(); + reg(); + reg(); } std::unique_ptr make_ranker(const cpptoml::table& config) { - auto function = config.get_as("method"); + // pass a blank configuration group as the first argument to the + // factory method + static auto blank = cpptoml::make_table(); + return make_ranker(*blank, config); +} + +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local) +{ + auto function = local.get_as("method"); + if (!function) + throw ranker_factory::exception{ + "method key required in [ranker] to construct a ranker"}; + + return ranker_factory::get().create(*function, global, local); +} + +std::unique_ptr +make_lm_ranker(const cpptoml::table& config) +{ + // pass a blank configuration group as the first argument to the + // factory method + static auto blank = cpptoml::make_table(); + return make_lm_ranker(*blank, config); +} + +std::unique_ptr +make_lm_ranker(const cpptoml::table& global, const cpptoml::table& local) +{ + auto function = local.get_as("method"); if (!function) throw ranker_factory::exception{ - "ranking-function required to construct a ranker"}; - return ranker_factory::get().create(*function, config); + "method key required in [ranker] to construct a ranker"}; + + return ranker_factory::get().create_lm(*function, global, local); } template @@ -51,6 +86,8 @@ ranker_loader::ranker_loader() reg(); reg(); reg(); + reg(); + reg(); } std::unique_ptr load_ranker(std::istream& in) @@ -59,5 +96,12 @@ std::unique_ptr load_ranker(std::istream& in) io::packed::read(in, method); return ranker_loader::get().create(method, in); } + +std::unique_ptr load_lm_ranker(std::istream& in) +{ + std::string method; + io::packed::read(in, method); + return ranker_loader::get().create_lm(method, in); +} } } diff --git a/src/index/ranker/rocchio.cpp b/src/index/ranker/rocchio.cpp new file mode 100644 index 000000000..2ef916022 --- /dev/null +++ b/src/index/ranker/rocchio.cpp @@ -0,0 +1,159 @@ +/** + * @file rocchio.cpp + * @author Chase Geigle + */ + +#include "cpptoml.h" + +#include "meta/hashing/probe_map.h" +#include "meta/index/forward_index.h" +#include "meta/index/ranker/okapi_bm25.h" +#include "meta/index/ranker/rocchio.h" +#include "meta/index/score_data.h" +#include "meta/io/packed.h" +#include "meta/logging/logger.h" +#include "meta/util/fixed_heap.h" +#include "meta/util/shim.h" + +namespace meta +{ +namespace index +{ + +const util::string_view rocchio::id = "rocchio"; +const constexpr float rocchio::default_alpha; +const constexpr float rocchio::default_beta; +const constexpr uint64_t rocchio::default_k; +const constexpr uint64_t rocchio::default_max_terms; + +rocchio::rocchio(std::shared_ptr fwd) + : fwd_{std::move(fwd)}, + initial_ranker_{make_unique()}, + alpha_{default_alpha}, + beta_{default_beta}, + k_{default_k}, + max_terms_{default_max_terms} +{ + // nothing +} + +rocchio::rocchio(std::shared_ptr fwd, + std::unique_ptr&& initial_ranker, float alpha, + float beta, uint64_t k, uint64_t max_terms) + : fwd_{std::move(fwd)}, + initial_ranker_{std::move(initial_ranker)}, + alpha_{alpha}, + beta_{beta}, + k_{k}, + max_terms_{max_terms} +{ + // nothing +} + +rocchio::rocchio(std::istream& in) + : fwd_{[&]() { + auto path = io::packed::read(in); + auto cfg = cpptoml::parse_file(path + "/config.toml"); + return make_index(*cfg); + }()}, + initial_ranker_{load_ranker(in)}, + alpha_{io::packed::read(in)}, + beta_{io::packed::read(in)}, + k_{io::packed::read(in)}, + max_terms_{io::packed::read(in)} +{ + // nothing +} + +void rocchio::save(std::ostream& out) const +{ + io::packed::write(out, id); + io::packed::write(out, fwd_->index_name()); + initial_ranker_->save(out); + io::packed::write(out, alpha_); + io::packed::write(out, beta_); + io::packed::write(out, k_); + io::packed::write(out, max_terms_); +} + +std::vector rocchio::rank(ranker_context& ctx, + uint64_t num_results, + const filter_function_type& filter) +{ + auto fb_docs = initial_ranker_->rank(ctx, k_, filter); + + // compute the centroid in both count-space and tf-idf space + hashing::probe_map term_scores; + hashing::probe_map centroid; + + score_data sd{ctx.idx, ctx.idx.avg_doc_length(), ctx.idx.num_docs(), + ctx.idx.total_corpus_terms(), 1.0f}; + sd.query_term_weight = 1.0f; + for (const auto& sr : fb_docs) + { + sd.d_id = sr.d_id; + sd.doc_size = ctx.idx.doc_size(sd.d_id); + sd.doc_unique_terms = ctx.idx.unique_terms(sd.d_id); + + auto stream = *fwd_->stream_for(sd.d_id); + for (const auto& weight : stream) + { + sd.t_id = weight.first; + sd.doc_count = ctx.idx.doc_freq(sd.t_id); + sd.corpus_term_count = ctx.idx.total_num_occurences(sd.t_id); + sd.doc_term_count = static_cast(weight.second); + + auto& rnk = dynamic_cast(*initial_ranker_); + term_scores[sd.t_id] += rnk.score_one(sd) / k_; + centroid[sd.t_id] += weight.second / k_; + } + } + + // extract the top max_terms_ feedback terms according to their scores + // in tf-idf space + using scored_term = std::pair; + auto heap = util::make_fixed_heap( + max_terms_, [](const scored_term& a, const scored_term& b) { + return a.second > b.second; + }); + for (const auto& pr : term_scores) + { + heap.emplace(pr.key(), pr.value()); + } + + // construct a new interpolated query in count-space from these top terms + hashing::probe_map new_query; + for (const auto& pr : heap.extract_top()) + { + new_query[pr.first] += beta_ * centroid[pr.first]; + } + for (const auto& postings_ctx : ctx.postings) + { + new_query[postings_ctx.t_id] += alpha_ * postings_ctx.query_term_weight; + } + + // construct a new ranker_context from the new query + ranker_context new_ctx{ctx.idx, new_query.begin(), new_query.end(), filter}; + + // return ranking results based on the new query + return initial_ranker_->rank(new_ctx, num_results, filter); +} + +template <> +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local) +{ + auto alpha = local.get_as("alpha").value_or(rocchio::default_alpha); + auto beta = local.get_as("beta").value_or(rocchio::default_beta); + auto k = local.get_as("k").value_or(rocchio::default_k); + auto max_terms = local.get_as("max-terms") + .value_or(rocchio::default_max_terms); + + auto init_cfg = local.get_table("feedback"); + auto f_idx = make_index(global); + return make_unique(std::move(f_idx), + make_ranker(global, *init_cfg), alpha, beta, + k, max_terms); +} +} +} diff --git a/src/index/tools/interactive_search.cpp b/src/index/tools/interactive_search.cpp index 3ba580bdc..362a97552 100644 --- a/src/index/tools/interactive_search.cpp +++ b/src/index/tools/interactive_search.cpp @@ -44,7 +44,7 @@ int main(int argc, char* argv[]) auto group = config->get_table("ranker"); if (!group) throw std::runtime_error{"\"ranker\" group needed in config file!"}; - auto ranker = index::make_ranker(*group); + auto ranker = index::make_ranker(*config, *group); // Find the path prefix to each document so we can print out the contents. std::string prefix = *config->get_as("prefix") + "/" @@ -66,10 +66,8 @@ int main(int argc, char* argv[]) // Use the ranker to score the query over the index. std::vector ranking; - auto time = common::time([&]() - { - ranking = ranker->score(*idx, query, 5); - }); + auto time + = common::time([&]() { ranking = ranker->score(*idx, query, 5); }); std::cout << "Showing top 5 results (" << time.count() << "ms)" << std::endl; @@ -77,13 +75,13 @@ int main(int argc, char* argv[]) uint64_t result_num = 1; for (auto& result : ranking) { - std::string path{idx->doc_path(result.d_id)}; + auto mdata = idx->metadata(result.d_id); + auto path = mdata.get("path").value_or("[none]"); auto output = printing::make_bold(std::to_string(result_num) + ". " + path) + " (score = " + std::to_string(result.score) + ", docid = " + std::to_string(result.d_id) + ")"; std::cout << output << std::endl; - auto mdata = idx->metadata(result.d_id); if (auto content = mdata.get("content")) { auto len diff --git a/src/index/tools/query_runner.cpp b/src/index/tools/query_runner.cpp index ee6358fe6..ec90fc601 100644 --- a/src/index/tools/query_runner.cpp +++ b/src/index/tools/query_runner.cpp @@ -27,12 +27,12 @@ template void print_results(const Index& idx, const SearchResult& result, uint64_t result_num) { - std::string path{idx->doc_path(result.d_id)}; + auto mdata = idx->metadata(result.d_id); + auto path = mdata.template get("path").value_or("[none]"); auto output = printing::make_bold(std::to_string(result_num) + ". " + path) + " (score = " + std::to_string(result.score) + ", docid = " + std::to_string(result.d_id) + ")"; std::cout << output << std::endl; - auto mdata = idx->metadata(result.d_id); if (auto content = mdata.template get("content")) { auto len = std::min(std::string::size_type{77}, content->size()); @@ -85,7 +85,7 @@ int main(int argc, char* argv[]) auto group = config->get_table("ranker"); if (!group) throw std::runtime_error{"\"ranker\" group needed in config"}; - auto ranker = index::make_ranker(*group); + auto ranker = index::make_ranker(*config, *group); // Get the config group with options specific to this executable. auto query_group = config->get_table("query-runner"); @@ -103,10 +103,9 @@ int main(int argc, char* argv[]) // Read the rest of the options for this executable. auto trec_format = query_group->get_as("trec-format").value_or(false); - auto max_results = static_cast( - query_group->get_as("max-results").value_or(10)); - auto q_id = static_cast( - query_group->get_as("query-id-start").value_or(1)); + auto max_results + = query_group->get_as("max-results").value_or(10); + auto q_id = query_group->get_as("query-id-start").value_or(1); // create the IR evaluation scorer if necessary std::unique_ptr eval; diff --git a/src/index/tools/search.cpp b/src/index/tools/search.cpp index 3ab4a037f..5fbb1f3d5 100644 --- a/src/index/tools/search.cpp +++ b/src/index/tools/search.cpp @@ -6,6 +6,7 @@ #include #include #include + #include "meta/analyzers/analyzer.h" #include "meta/caching/all.h" #include "meta/corpus/document.h" @@ -42,59 +43,58 @@ int main(int argc, char* argv[]) auto config = cpptoml::parse_file(argv[1]); auto idx = index::make_index(*config); - // Create a ranking class based on the config file. auto group = config->get_table("ranker"); if (!group) throw std::runtime_error{"\"ranker\" group needed in config file!"}; - auto ranker = index::make_ranker(*group); + auto ranker = index::make_ranker(*config, *group); // Use UTF-8 for the default encoding unless otherwise specified. auto encoding = config->get_as("encoding").value_or("utf-8"); // Time how long it takes to create the index. By default, common::time's // unit of measurement is milliseconds. - auto elapsed = common::time( - [&]() + auto elapsed = common::time([&]() { + // Get a std::vector of doc_ids that have been indexed. + auto docs = idx->docs(); + + // Search for up to the first 20 documents; we hope that the first + // result is the original document itself since we're querying with + // documents that are already indexed. + for (size_t i = 0; i < 20 && i < idx->num_docs(); ++i) { - // Get a std::vector of doc_ids that have been indexed. - auto docs = idx->docs(); - - // Search for up to the first 20 documents; we hope that the first - // result is the original document itself since we're querying with - // documents that are already indexed. - for (size_t i = 0; i < 20 && i < idx->num_docs(); ++i) + auto path = idx->metadata(docs[i], "path") + .value_or("[none]"); + // Create a document and specify its path; its content will be + // filled by the analyzer. + corpus::document query{doc_id{docs[i]}}; + query.content(filesystem::file_text(path), encoding); + + std::cout << "Ranking query " << (i + 1) << ": " << path + << std::endl; + + // Use the ranker to score the query over the index. By default, + // the + // ranker returns 10 documents, so we will display the "top 10 + // of + // 10" docs. + auto ranking = ranker->score(*idx, query); + std::cout << "Showing top 10 results." << std::endl; + + uint64_t result_num = 1; + for (auto& result : ranking) { - auto path = idx->doc_path(docs[i]); - // Create a document and specify its path; its content will be - // filled by the analyzer. - corpus::document query{doc_id{docs[i]}}; - query.content(filesystem::file_text(path), encoding); - - std::cout << "Ranking query " << (i + 1) << ": " << path - << std::endl; - - // Use the ranker to score the query over the index. By default, - // the - // ranker returns 10 documents, so we will display the "top 10 - // of - // 10" docs. - auto ranking = ranker->score(*idx, query); - std::cout << "Showing top 10 results." << std::endl; - - uint64_t result_num = 1; - for (auto& result : ranking) - { - std::cout << result_num << ". " - << idx->doc_name(result.d_id) << " " - << result.score << std::endl; - if (result_num++ == 10) - break; - } - - std::cout << std::endl; + std::cout << result_num << ". " + << idx->metadata(result.d_id, "name") + .value_or("[none]") + << " " << result.score << std::endl; + if (result_num++ == 10) + break; } - }); + + std::cout << std::endl; + } + }); std::cout << "Elapsed time: " << elapsed.count() / 1000.0 << " seconds" << std::endl; diff --git a/src/io/mmap_file.cpp b/src/io/mmap_file.cpp index 6c2a22830..1f50db5cf 100644 --- a/src/io/mmap_file.cpp +++ b/src/io/mmap_file.cpp @@ -9,8 +9,8 @@ #include "meta/io/mman-win32/mman.h" #endif -#include #include +#include #include #include "meta/io/filesystem.h" @@ -96,5 +96,35 @@ mmap_file::~mmap_file() close(file_descriptor_); } } + +mmap_ifstream::mmap_ifstream(const std::string& filename) + : file_(mmap_file(filename)), pos_{0} +{ + // nothing +} + +bool mmap_ifstream::is_open() const +{ + return static_cast(file_); +} + +int mmap_ifstream::peek() const +{ + if (!is_open() || pos_ >= file_->size()) + return EOF; + return static_cast((*file_)[pos_]); +} + +int mmap_ifstream::get() +{ + if (!is_open() || pos_ >= file_->size()) + return EOF; + return static_cast((*file_)[pos_++]); +} + +void mmap_ifstream::close() +{ + file_ = util::nullopt; +} } } diff --git a/src/lm/diff.cpp b/src/lm/diff.cpp index 1b3936a9e..8d3c1a9c6 100644 --- a/src/lm/diff.cpp +++ b/src/lm/diff.cpp @@ -21,15 +21,15 @@ diff::diff(const cpptoml::table& config) : lm_{config} if (!table) throw diff_exception{"missing [diff] table from config"}; - auto nval = table->get_as("n-value"); + auto nval = table->get_as("n-value"); if (!nval) throw diff_exception{"n-value not specified in config"}; - n_val_ = static_cast(*nval); + n_val_ = *nval; - auto edits = table->get_as("max-edits"); + auto edits = table->get_as("max-edits"); if (!edits) throw diff_exception{"max-edits not specified in config"}; - max_edits_ = static_cast(*edits); + max_edits_ = *edits; auto lambda = table->get_as("lambda"); lambda_ = lambda ? *lambda : 0.5; @@ -41,8 +41,7 @@ diff::diff(const cpptoml::table& config) : lm_{config} substitute_penalty_ = table->get_as("substitute-penalty").value_or(0.0); remove_penalty_ = table->get_as("remove-penalty").value_or(0.0); - max_cand_size_ = static_cast( - table->get_as("max-candidates").value_or(20)); + max_cand_size_ = table->get_as("max-candidates").value_or(20); lm_generate_ = table->get_as("lm-generate").value_or(false); set_stems(*table); @@ -54,12 +53,10 @@ diff::candidates(const sentence& sent, bool use_lm /* = false */) { use_lm_ = use_lm; using pair_t = std::pair; - auto comp = [](const pair_t& a, const pair_t& b) - { - return a.second < b.second; - }; - util::fixed_heap candidates{max_cand_size_, comp}; + auto candidates = util::make_fixed_heap( + max_cand_size_, + [](const pair_t& a, const pair_t& b) { return a.second < b.second; }); seen_.clear(); add(candidates, sent); step(sent, candidates, 0); diff --git a/src/lm/language_model.cpp b/src/lm/language_model.cpp index dfb4fb93c..d205530b4 100644 --- a/src/lm/language_model.cpp +++ b/src/lm/language_model.cpp @@ -95,9 +95,9 @@ language_model::top_k(const sentence& prev, size_t k) const { // this is horribly inefficient due to this LM's structure using pair_t = std::pair; - auto comp - = [](const pair_t& a, const pair_t& b) { return a.second > b.second; }; - util::fixed_heap candidates{k, comp}; + auto candidates = util::make_fixed_heap( + k, + [](const pair_t& a, const pair_t& b) { return a.second > b.second; }); token_list candidate{prev, vocabulary_}; candidate.push_back(0_tid); diff --git a/src/lm/static_probe_map.cpp b/src/lm/static_probe_map.cpp index 6b8809db2..b1954358d 100644 --- a/src/lm/static_probe_map.cpp +++ b/src/lm/static_probe_map.cpp @@ -3,8 +3,8 @@ * @author Sean Massung */ -#include "meta/hashing/hash.h" #include "meta/lm/static_probe_map.h" +#include "meta/hashing/hash.h" namespace meta { @@ -61,11 +61,13 @@ util::optional static_probe_map::find_hash(uint64_t hashed) const } } -uint64_t static_probe_map::hash(const std::vector& tokens) const +hashing::murmur_hash<>::result_type +static_probe_map::hash(const std::vector& tokens) const { - hashing::murmur_hash<> hasher{seed_}; + hashing::murmur_hash<> hasher{ + static_cast::result_type>(seed_)}; hash_append(hasher, tokens); - return static_cast(hasher); + return static_cast::result_type>(hasher); } } } diff --git a/src/parser/sr_parser.cpp b/src/parser/sr_parser.cpp index d022f1b5e..1613fced0 100644 --- a/src/parser/sr_parser.cpp +++ b/src/parser/sr_parser.cpp @@ -171,7 +171,7 @@ void sr_parser::train(std::vector& trees, training_options options) start += options.batch_size) { progress(start); - auto end = std::min(start + options.batch_size, + auto end = std::min(start + options.batch_size, data.size()); auto result diff --git a/src/sequence/CMakeLists.txt b/src/sequence/CMakeLists.txt index f803fbeb5..3c439d85e 100644 --- a/src/sequence/CMakeLists.txt +++ b/src/sequence/CMakeLists.txt @@ -2,12 +2,14 @@ project(meta-sequence) add_subdirectory(analyzers) add_subdirectory(crf) +add_subdirectory(hmm) add_subdirectory(tools) add_library(meta-sequence observation.cpp sequence.cpp sequence_analyzer.cpp trellis.cpp + markov_model.cpp io/ptb_parser.cpp) target_link_libraries(meta-sequence meta-io meta-utf) diff --git a/src/sequence/hmm/CMakeLists.txt b/src/sequence/hmm/CMakeLists.txt new file mode 100644 index 000000000..5f16043f9 --- /dev/null +++ b/src/sequence/hmm/CMakeLists.txt @@ -0,0 +1,10 @@ +project(meta-hmm) + +add_subdirectory(tools) + +add_library(meta-hmm sequence_observations.cpp) +target_link_libraries(meta-hmm meta-sequence) + +install(TARGETS meta-hmm + EXPORT meta-exports + DESTINATION lib) diff --git a/src/sequence/hmm/sequence_observations.cpp b/src/sequence/hmm/sequence_observations.cpp new file mode 100644 index 000000000..13e235148 --- /dev/null +++ b/src/sequence/hmm/sequence_observations.cpp @@ -0,0 +1,94 @@ +/** + * @file sequence_observations.cpp + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#include "meta/sequence/hmm/sequence_observations.h" + +namespace meta +{ +namespace sequence +{ +namespace hmm +{ + +sequence_observations::expected_counts_type::expected_counts_type( + uint64_t num_hmm_states, uint64_t num_markov_states, + stats::dirichlet prior) +{ + counts_.reserve(num_hmm_states); + for (state_id s_i{0}; s_i < num_hmm_states; ++s_i) + counts_.emplace_back(num_markov_states, prior); +} + +void sequence_observations::expected_counts_type::increment( + const observation_type& seq, state_id s_i, double amount) +{ + counts_[s_i].increment(seq, amount); +} + +auto sequence_observations::expected_counts_type:: +operator+=(const expected_counts_type& other) -> expected_counts_type& +{ + for (state_id s_i{0}; s_i < counts_.size(); ++s_i) + { + counts_[s_i] += other.counts_[s_i]; + } + return *this; +} + +sequence_observations::sequence_observations(uint64_t num_hmm_states, + uint64_t num_markov_states, + stats::dirichlet prior) +{ + models_.reserve(num_hmm_states); + for (uint64_t h = 0; h < num_hmm_states; ++h) + models_.emplace_back(num_markov_states, prior); +} + +sequence_observations::sequence_observations(expected_counts_type&& counts) + : models_{[&]() { + std::vector models; + models.reserve(counts.counts_.size()); + for (auto& ec : counts.counts_) + models.emplace_back(std::move(ec)); + return models; + }()} +{ + // nothing +} + +auto sequence_observations::expected_counts() const -> expected_counts_type +{ + return {num_states(), models_.front().num_states(), + models_.front().prior()}; +} + +uint64_t sequence_observations::num_states() const +{ + return models_.size(); +} + +double sequence_observations::probability(const observation_type& obs, + state_id s_i) const +{ + return models_[s_i].probability(obs); +} + +double sequence_observations::log_probability(const observation_type& obs, + state_id s_i) const +{ + return models_[s_i].log_probability(obs); +} + +const markov_model& sequence_observations::distribution(state_id s_i) const +{ + return models_[s_i]; +} +} +} +} diff --git a/src/sequence/hmm/tools/CMakeLists.txt b/src/sequence/hmm/tools/CMakeLists.txt new file mode 100644 index 000000000..6beb9a124 --- /dev/null +++ b/src/sequence/hmm/tools/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(hmm-train hmm_train.cpp) +target_link_libraries(hmm-train meta-hmm cpptoml) diff --git a/src/sequence/hmm/tools/hmm_train.cpp b/src/sequence/hmm/tools/hmm_train.cpp new file mode 100644 index 000000000..7464fe00b --- /dev/null +++ b/src/sequence/hmm/tools/hmm_train.cpp @@ -0,0 +1,166 @@ +/** + * @file hmm_train.cpp + * @author Chase Geigle + */ + +#include + +#include "cpptoml.h" +#include "meta/hashing/probe_map.h" +#include "meta/io/filesystem.h" +#include "meta/io/gzstream.h" +#include "meta/logging/logger.h" +#include "meta/sequence/hmm/discrete_observations.h" +#include "meta/sequence/hmm/hmm.h" +#include "meta/sequence/io/ptb_parser.h" +#include "meta/util/progress.h" + +using namespace meta; + +std::string two_digit(uint8_t num) +{ + std::stringstream ss; + ss << std::setw(2) << std::setfill('0') << static_cast(num); + return ss.str(); +} +/** + * Required config parameters: + * ~~~toml + * prefix = "global-data-prefix" + * + * [hmm] + * prefix = "path-to-model" + * treebank = "penn-treebank" # relative to data prefix + * corpus = "wsj" + * section-size = 99 + * train-sections = [0, 18] + * dev-sections = [19, 21] + * test-sections = [22, 24] + * ~~~ + * + * Optional config parameters: none + */ +int main(int argc, char** argv) +{ + if (argc < 2) + { + std::cerr << "Usage: " << argv[0] << " config.toml" << std::endl; + return 1; + } + + logging::set_cerr_logging(); + + auto config = cpptoml::parse_file(argv[1]); + + auto prefix = config->get_as("prefix"); + if (!prefix) + { + LOG(fatal) << "Global configuration must have a prefix key" << ENDLG; + return 1; + } + + auto seq_grp = config->get_table("hmm"); + if (!seq_grp) + { + LOG(fatal) << "Configuration must contain a [hmm] group" << ENDLG; + return 1; + } + + auto seq_prefix = seq_grp->get_as("prefix"); + if (!seq_prefix) + { + LOG(fatal) << "[hmm] group must contain a prefix to store model files" + << ENDLG; + return 1; + } + + auto treebank = seq_grp->get_as("treebank"); + if (!treebank) + { + LOG(fatal) << "[hmm] group must contain a treebank path" << ENDLG; + return 1; + } + + auto corpus = seq_grp->get_as("corpus"); + if (!corpus) + { + LOG(fatal) << "[hmm] group must contain a corpus" << ENDLG; + return 1; + } + + auto train_sections = seq_grp->get_array("train-sections"); + if (!train_sections) + { + LOG(fatal) << "[hmm] group must contain train-sections" << ENDLG; + return 1; + } + + auto section_size = seq_grp->get_as("section-size"); + if (!section_size) + { + LOG(fatal) << "[hmm] group must contain section-size" << ENDLG; + return 1; + } + + std::string path + = *prefix + "/" + *treebank + "/treebank-2/tagged/" + *corpus; + + hashing::probe_map vocab; + std::vector> training; + { + auto begin = train_sections->at(0)->as()->get(); + auto end = train_sections->at(1)->as()->get(); + printing::progress progress( + " > Reading training data: ", + static_cast((end - begin + 1) * *section_size)); + for (auto i = static_cast(begin); i <= end; ++i) + { + auto folder = two_digit(i); + for (uint8_t j = 0; j <= *section_size; ++j) + { + progress(static_cast(i - begin) * 99 + j); + auto file = *corpus + "_" + folder + two_digit(j) + ".pos"; + auto filename = path + "/" + folder + "/" + file; + auto sequences = sequence::extract_sequences(filename); + for (auto& seq : sequences) + { + std::vector instance; + instance.reserve(seq.size()); + for (const auto& obs : seq) + { + auto it = vocab.find(obs.symbol()); + if (it == vocab.end()) + it = vocab.insert(obs.symbol(), + term_id{vocab.size()}); + instance.push_back(it->value()); + } + training.emplace_back(std::move(instance)); + } + } + } + } + + using namespace sequence; + using namespace hmm; + + std::mt19937 rng{47}; + discrete_observations<> obs_dist{ + 30, vocab.size(), rng, stats::dirichlet{1e-6, vocab.size()}}; + + parallel::thread_pool pool; + hidden_markov_model> hmm{ + 30, rng, std::move(obs_dist), stats::dirichlet{1e-6, 30}}; + + decltype(hmm)::training_options options; + options.delta = 1e-5; + options.max_iters = 50; + hmm.fit(training, pool, options); + + filesystem::make_directories(*seq_prefix); + { + io::gzofstream file{*seq_prefix + "/model.gz"}; + hmm.save(file); + } + + return 0; +} diff --git a/src/sequence/markov_model.cpp b/src/sequence/markov_model.cpp new file mode 100644 index 000000000..b227294c3 --- /dev/null +++ b/src/sequence/markov_model.cpp @@ -0,0 +1,148 @@ +/** + * @file markov_model.cpp + * @author Chase Geigle + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#include "meta/sequence/markov_model.h" + +namespace meta +{ +namespace sequence +{ + +markov_model::expected_counts_type::expected_counts_type( + uint64_t num_states, stats::dirichlet prior) + : initial_count_(num_states), + trans_count_{num_states, num_states}, + prior_{std::move(prior)} +{ + // nothing +} + +void markov_model::expected_counts_type::increment( + const std::vector& seq, double amount) +{ + increment_initial(seq[0], amount); + for (uint64_t t = 1; t < seq.size(); ++t) + increment_transition(seq[t - 1], seq[t], amount); +} + +void markov_model::expected_counts_type::increment_initial(state_id s, + double amount) +{ + initial_count_[s] += amount; +} + +void markov_model::expected_counts_type::increment_transition(state_id from, + state_id to, + double amount) +{ + trans_count_(from, to) += amount; +} + +auto markov_model::expected_counts_type:: +operator+=(const expected_counts_type& other) -> expected_counts_type& +{ + std::transform(initial_count_.begin(), initial_count_.end(), + other.initial_count_.begin(), initial_count_.begin(), + [](double mic, double oic) { return mic + oic; }); + + for (state_id s_i{0}; s_i < trans_count_.rows(); ++s_i) + { + std::transform(trans_count_.begin(s_i), trans_count_.end(s_i), + other.trans_count_.begin(s_i), trans_count_.begin(s_i), + [](double mtc, double otc) { return mtc + otc; }); + } + + return *this; +} + +markov_model::markov_model(uint64_t num_states, + stats::dirichlet prior) + : initial_prob_(num_states), + trans_prob_{num_states, num_states}, + prior_{std::move(prior)} +{ + for (state_id s_i{0}; s_i < num_states; ++s_i) + { + initial_prob_[s_i] = (1.0 + prior_.pseudo_counts(s_i)) + / (num_states + prior_.pseudo_counts()); + + for (state_id s_j{0}; s_j < num_states; ++s_j) + { + trans_prob_(s_i, s_j) = (1.0 + prior_.pseudo_counts(s_j)) + / (num_states + prior.pseudo_counts()); + } + } +} + +markov_model::markov_model(expected_counts_type&& counts) + : initial_prob_{std::move(counts.initial_count_)}, + trans_prob_{std::move(counts.trans_count_)}, + prior_{std::move(counts.prior_)} +{ + // normalize probability estimates + auto inorm + = std::accumulate(initial_prob_.begin(), initial_prob_.end(), 0.0); + for (state_id s_i{0}; s_i < num_states(); ++s_i) + { + initial_prob_[s_i] = (initial_prob_[s_i] + prior_.pseudo_counts(s_i)) + / (inorm + prior_.pseudo_counts()); + + auto tnorm = std::accumulate(trans_prob_.begin(s_i), + trans_prob_.end(s_i), 0.0); + for (state_id s_j{0}; s_j < num_states(); ++s_j) + { + trans_prob_(s_i, s_j) + = (trans_prob_(s_i, s_j) + prior_.pseudo_counts(s_i)) + / (tnorm + prior_.pseudo_counts()); + } + } +} + +auto markov_model::expected_counts() const -> expected_counts_type +{ + return {num_states(), prior_}; +} + +const stats::dirichlet& markov_model::prior() const +{ + return prior_; +} + +uint64_t markov_model::num_states() const +{ + return initial_prob_.size(); +} + +double markov_model::log_probability(const std::vector& seq) const +{ + assert(seq.size() > 0); + double log_prob = std::log(initial_prob_[seq[0]]); + for (uint64_t t = 1; t < seq.size(); ++t) + { + log_prob += std::log(trans_prob_(seq[t - 1], seq[t])); + } + return log_prob; +} + +double markov_model::probability(const std::vector& seq) const +{ + return std::exp(log_probability(seq)); +} + +double markov_model::transition_probability(state_id from, state_id to) const +{ + return trans_prob_(from, to); +} + +double markov_model::initial_probability(state_id s) const +{ + return initial_prob_[s]; +} +} +} diff --git a/src/stats/running_stats.cpp b/src/stats/running_stats.cpp index 164965633..0adbfb264 100644 --- a/src/stats/running_stats.cpp +++ b/src/stats/running_stats.cpp @@ -39,5 +39,10 @@ double running_stats::variance() const { return s_k_ / (num_items_ - 1); } + +std::size_t running_stats::size() const +{ + return num_items_; +} } } diff --git a/src/tools/top_k.cpp b/src/tools/top_k.cpp index a9fa28c23..575f12c3c 100644 --- a/src/tools/top_k.cpp +++ b/src/tools/top_k.cpp @@ -3,19 +3,19 @@ * @author Sean Massung */ -#include -#include -#include -#include -#include #include "cpptoml.h" -#include "meta/corpus/corpus.h" -#include "meta/corpus/corpus_factory.h" #include "meta/analyzers/analyzer.h" #include "meta/analyzers/filters/all.h" -#include "meta/util/progress.h" -#include "meta/util/fixed_heap.h" +#include "meta/corpus/corpus.h" +#include "meta/corpus/corpus_factory.h" #include "meta/logging/logger.h" +#include "meta/util/fixed_heap.h" +#include "meta/util/progress.h" +#include +#include +#include +#include +#include using namespace meta; @@ -26,7 +26,8 @@ int main(int argc, char* argv[]) std::cerr << "Usage: " << argv[0] << " config.toml k" << std::endl; std::cerr << "Prints out the top k most frequent terms in the corpus " "according to the filter chain specified in the config " - "file." << std::endl; + "file." + << std::endl; return 1; } @@ -57,11 +58,9 @@ int main(int argc, char* argv[]) prog.end(); using pair_t = std::pair; - auto comp = [](const pair_t& a, const pair_t& b) - { - return a.second > b.second; - }; - util::fixed_heap terms{k, comp}; + auto terms = util::make_fixed_heap( + k, + [](const pair_t& a, const pair_t& b) { return a.second > b.second; }); for (auto& term : counts) terms.emplace(term); diff --git a/src/topics/lda_cvb.cpp b/src/topics/lda_cvb.cpp index 74d42b8a6..69ad39b22 100644 --- a/src/topics/lda_cvb.cpp +++ b/src/topics/lda_cvb.cpp @@ -3,19 +3,19 @@ * @author Chase Geigle */ -#include +#include "meta/topics/lda_cvb.h" #include "meta/index/postings_data.h" #include "meta/logging/logger.h" -#include "meta/topics/lda_cvb.h" #include "meta/util/progress.h" +#include namespace meta { namespace topics { -lda_cvb::lda_cvb(std::shared_ptr idx, uint64_t num_topics, - double alpha, double beta) +lda_cvb::lda_cvb(std::shared_ptr idx, + std::size_t num_topics, double alpha, double beta) : lda_model{std::move(idx), num_topics} { gamma_.resize(idx_->num_docs()); diff --git a/src/topics/lda_gibbs.cpp b/src/topics/lda_gibbs.cpp index 15be38b5e..8e2af5465 100644 --- a/src/topics/lda_gibbs.cpp +++ b/src/topics/lda_gibbs.cpp @@ -17,7 +17,7 @@ namespace topics { lda_gibbs::lda_gibbs(std::shared_ptr idx, - uint64_t num_topics, double alpha, double beta) + std::size_t num_topics, double alpha, double beta) : lda_model{std::move(idx), num_topics} { doc_word_topic_.resize(idx_->num_docs()); diff --git a/src/topics/lda_model.cpp b/src/topics/lda_model.cpp index 39e7a5551..45fe36f11 100644 --- a/src/topics/lda_model.cpp +++ b/src/topics/lda_model.cpp @@ -11,10 +11,10 @@ namespace topics { lda_model::lda_model(std::shared_ptr idx, - uint64_t num_topics) + std::size_t num_topics) : idx_{std::move(idx)}, num_topics_{num_topics}, - num_words_{idx_->unique_terms()} + num_words_(idx_->unique_terms()) { /* nothing */ } diff --git a/src/topics/lda_scvb.cpp b/src/topics/lda_scvb.cpp index dc81bafd7..719bb265d 100644 --- a/src/topics/lda_scvb.cpp +++ b/src/topics/lda_scvb.cpp @@ -14,7 +14,7 @@ namespace topics { lda_scvb::lda_scvb(std::shared_ptr idx, - uint64_t num_topics, double alpha, double beta, + std::size_t num_topics, double alpha, double beta, uint64_t minibatch_size) : lda_model{std::move(idx), num_topics}, alpha_{alpha}, diff --git a/src/topics/tools/lda.cpp b/src/topics/tools/lda.cpp index f73433e90..f1a5fa293 100644 --- a/src/topics/tools/lda.cpp +++ b/src/topics/tools/lda.cpp @@ -2,10 +2,10 @@ #include #include -#include "meta/topics/lda_gibbs.h" -#include "meta/topics/parallel_lda_gibbs.h" #include "meta/topics/lda_cvb.h" +#include "meta/topics/lda_gibbs.h" #include "meta/topics/lda_scvb.h" +#include "meta/topics/parallel_lda_gibbs.h" #include "cpptoml.h" @@ -16,7 +16,7 @@ using namespace meta; template -int run_lda(Index& idx, uint64_t num_iters, uint64_t topics, double alpha, +int run_lda(Index& idx, uint64_t num_iters, std::size_t topics, double alpha, double beta, const std::string& save_prefix) { Model model{idx, topics, alpha, beta}; @@ -60,11 +60,10 @@ int run_lda(const std::string& config_file) return 1; auto type = *lda_group->get_as("inference"); - auto iters - = static_cast(*lda_group->get_as("max-iters")); + auto iters = *lda_group->get_as("max-iters"); auto alpha = *lda_group->get_as("alpha"); auto beta = *lda_group->get_as("beta"); - auto topics = static_cast(*lda_group->get_as("topics")); + auto topics = *lda_group->get_as("topics"); auto save_prefix = *lda_group->get_as("model-prefix"); auto f_idx diff --git a/src/topics/tools/lda_topics.cpp b/src/topics/tools/lda_topics.cpp index 302b83fa7..aaefffeae 100644 --- a/src/topics/tools/lda_topics.cpp +++ b/src/topics/tools/lda_topics.cpp @@ -3,8 +3,8 @@ * @author Chase Geigle */ -#include #include +#include #include #include @@ -44,13 +44,11 @@ int print_topics(const std::string& config_file, const std::string& filename, std::cout << "Topic " << topic << ":" << std::endl; std::cout << "-----------------------" << std::endl; - auto comp = [](const std::pair& first, - const std::pair& second) - { - return first.second > second.second; - }; - util::fixed_heap, decltype(comp)> pairs{ - num_words, comp}; + using scored_term = std::pair; + auto pairs = util::make_fixed_heap( + num_words, [](const scored_term& a, const scored_term& b) { + return a.second > b.second; + }); while (stream) { diff --git a/src/utf/CMakeLists.txt b/src/utf/CMakeLists.txt index 6191d8f04..5935c7dc5 100644 --- a/src/utf/CMakeLists.txt +++ b/src/utf/CMakeLists.txt @@ -2,7 +2,11 @@ project(meta-utf) add_subdirectory(tools) -add_library(meta-utf segmenter.cpp transformer.cpp utf.cpp) +if (META_STATIC_UTF) + add_library(meta-utf STATIC segmenter.cpp transformer.cpp utf.cpp) +else() + add_library(meta-utf SHARED segmenter.cpp transformer.cpp utf.cpp) +endif() target_link_libraries(meta-utf PUBLIC meta-definitions) target_link_libraries(meta-utf PRIVATE ${ICU_LIBRARIES}) target_include_directories(meta-utf PRIVATE SYSTEM ${ICU_INCLUDE_DIRS}) diff --git a/src/util/progress.cpp b/src/util/progress.cpp index c6aeeb0d5..bc577a13e 100644 --- a/src/util/progress.cpp +++ b/src/util/progress.cpp @@ -56,6 +56,8 @@ void progress::print() auto end = it + static_cast(max_len * percent); std::fill(it, end, '='); *end = '>'; + if (end < barend) + std::fill(end + 1, barend, ' '); it = barend; *it++ = ']'; *it++ = ' '; @@ -101,7 +103,7 @@ void progress::end() } } -void progress::clear() const +void progress::clear() { LOG(progress) << '\r' << std::string(80, ' ') << '\r' << ENDLG; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c340954c0..94db25712 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -25,6 +25,15 @@ ExternalProject_Add(housing BUILD_COMMAND "" INSTALL_COMMAND "") +ExternalProject_Add(cranfield + SOURCE_DIR ${meta_BINARY_DIR}/../../data/cranfield + DOWNLOAD_DIR ${meta_BINARY_DIR}/../downloads + URL https://meta-toolkit.org/data/2016-11-10/cranfield.tar.gz + URL_HASH "SHA256=507b6f4f133bc1a65d140780cbd7060a3ca159410b772e5eb1e2c12b215d72b4" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "") + # Ignore sign warnings when expanding bandit's macros. file(GLOB BANDIT_SOURCE_FILES *.cpp) set_property(SOURCE ${BANDIT_SOURCE_FILES} APPEND PROPERTY COMPILE_FLAGS diff --git a/tests/dataset_transform_test.cpp b/tests/dataset_transform_test.cpp new file mode 100644 index 000000000..14ad3f1f8 --- /dev/null +++ b/tests/dataset_transform_test.cpp @@ -0,0 +1,76 @@ +/** + * @file dataset_transform_test.cpp + * @author Chase Geigle + */ + +#include "bandit/bandit.h" +#include "create_config.h" +#include "meta/classify/multiclass_dataset.h" +#include "meta/index/ranker/okapi_bm25.h" +#include "meta/learn/transform.h" + +using namespace bandit; +using namespace meta; + +go_bandit([]() { + describe("[learn] dataset l2 transformer", []() { + it("should normalize feature vectors to unit length", []() { + + std::vector vectors(2); + + vectors[0].emplace_back(0_tid, 12); + vectors[0].emplace_back(1_tid, 10); + vectors[0].emplace_back(2_tid, 5); + + vectors[1].emplace_back(1_tid, 1); + vectors[1].emplace_back(3_tid, 4); + vectors[1].emplace_back(5_tid, 9); + + learn::dataset dset{vectors.begin(), vectors.end(), 6}; + learn::l2norm_transform(dset); + + for (const auto& inst : dset) + { + auto norm = std::sqrt(std::accumulate( + inst.weights.begin(), inst.weights.end(), 0.0, + [](double accum, const std::pair& val) { + return accum + val.second * val.second; + })); + AssertThat(norm, EqualsWithDelta(1, 1e-12)); + } + }); + }); + + describe("[learn] dataset tf-idf transformer", []() { + it("should produce tf-idf vectors", []() { + auto config = tests::create_config("line"); + config->insert("uninvert", true); + filesystem::remove_all("ceeaus"); + + // make both indexes + auto inv = index::make_index(*config); + auto fwd = index::make_index(*config); + + // convert the data into a dataset + classify::multiclass_dataset dset{fwd}; + + // make tf-idf vectors + index::okapi_bm25 ranker; + learn::tfidf_transform(dset, *inv, ranker); + + // check that we get the same scores for a particular word + std::vector> query + = {{"charact", 1.0}}; + + auto ranking = ranker.score(*inv, query.begin(), query.end()); + + auto tid = inv->get_term_id("charact"); + for (const auto& result : ranking) + { + const auto& weights = dset(result.d_id).weights; + AssertThat(weights.at(tid), + EqualsWithDelta(result.score, 1e-5)); + } + }); + }); +}); diff --git a/tests/farm_hash_test.h b/tests/farm_hash_test.h index c1ce41ec5..972aa5b54 100644 --- a/tests/farm_hash_test.h +++ b/tests/farm_hash_test.h @@ -478,6 +478,7 @@ bool test(uint8_t data[], int offset, int len = 0) { using meta::hashing::farm_hash; using meta::hashing::farm_hash_seeded; + using result_type = farm_hash::result_type; static int index = 0; auto check = [&](uint32_t actual) @@ -493,21 +494,21 @@ bool test(uint8_t data[], int offset, int len = 0) farm_hash_seeded hasher{create_seed(offset, 0), create_seed(offset, 1)}; hasher(data, static_cast(len++)); - uint64_t h = static_cast(hasher); + auto h = static_cast(hasher); alive += (h >> 32) > 0; alive += ((h << 32) >> 32) > 0; } { farm_hash_seeded hasher{create_seed(offset, -1)}; hasher(data, static_cast(len++)); - uint64_t h = static_cast(hasher); + auto h = static_cast(hasher); alive += (h >> 32) > 0; alive += ((h << 32) >> 32) > 0; } { farm_hash hasher; hasher(data, static_cast(len++)); - uint64_t h = static_cast(hasher); + auto h = static_cast(hasher); alive += (h >> 32) > 0; alive += ((h << 32) >> 32) > 0; } @@ -517,21 +518,21 @@ bool test(uint8_t data[], int offset, int len = 0) { farm_hash_seeded hasher{create_seed(offset, 0), create_seed(offset, 1)}; hasher(data + offset, static_cast(len)); - uint64_t h = static_cast(hasher); + auto h = static_cast(hasher); check(h >> 32); check((h << 32) >> 32); } { farm_hash_seeded hasher{create_seed(offset, -1)}; hasher(data + offset, static_cast(len)); - uint64_t h = static_cast(hasher); + auto h = static_cast(hasher); check(h >> 32); check((h << 32) >> 32); } { farm_hash hasher; hasher(data + offset, static_cast(len)); - uint64_t h = static_cast(hasher); + auto h = static_cast(hasher); check(h >> 32); check((h << 32) >> 32); } diff --git a/tests/hashing_test.cpp b/tests/hashing_test.cpp index 0e52225e6..b8750bfcd 100644 --- a/tests/hashing_test.cpp +++ b/tests/hashing_test.cpp @@ -4,11 +4,11 @@ */ #include -#include #include +#include #include -#include #include +#include #include #include "bandit/bandit.h" @@ -28,9 +28,9 @@ namespace { * Checks that a probing strategy probes each element in a range exactly once. */ template -void check_range_at(uint64_t hash, uint64_t size) { - std::vector checker(size, 0); - const std::vector gold(size, 1); +void check_range_at(std::size_t hash, std::size_t size) { + std::vector checker(size, 0); + const std::vector gold(size, 1); Strategy strat{hash, size}; for (uint64_t i = 0; i < checker.size(); ++i) ++checker[strat.probe()]; @@ -40,8 +40,8 @@ void check_range_at(uint64_t hash, uint64_t size) { template void check_range() { - std::vector sizes = {2, 4, 8, 32, 64}; - std::vector weird_sizes = {3, 5, 7, 22, 100, 125}; + std::vector sizes = {2, 4, 8, 32, 64}; + std::vector weird_sizes = {3, 5, 7, 22, 100, 125}; if (!std::is_same::value) sizes.insert(sizes.end(), weird_sizes.begin(), weird_sizes.end()); @@ -110,20 +110,24 @@ void count(Map& map, const std::vector& tokens) { } template -void check_hash(uint64_t seed, util::string_view key, uint64_t expected) { +void check_hash(typename HashAlgorithm::result_type seed, util::string_view key, + typename HashAlgorithm::result_type expected) { HashAlgorithm hash{seed}; hash(key.data(), key.size()); - AssertThat(static_cast(hash), Equals(expected)); + AssertThat(static_cast(hash), + Equals(expected)); } template -void check_incremental_hash(uint64_t seed, util::string_view key, - uint64_t expected) { +void check_incremental_hash(typename HashAlgorithm::result_type seed, + util::string_view key, + typename HashAlgorithm::result_type expected) { HashAlgorithm hash{seed}; hash(key.data(), key.size() / 2); hash(key.data() + key.size() / 2, key.size() - key.size() / 2 - 1); hash(key.data() + key.size() - 1, 1); - AssertThat(static_cast(hash), Equals(expected)); + AssertThat(static_cast(hash), + Equals(expected)); } } @@ -252,9 +256,8 @@ go_bandit([]() { }); describe("[hashing] farm_hash x64", []() { - it("should match test vectors from FarmHash", []() { - farm_hash_self_test(); - }); + it("should match test vectors from FarmHash", + []() { farm_hash_self_test(); }); }); describe("[hashing] ints", []() { diff --git a/tests/ir_eval_test.cpp b/tests/ir_eval_test.cpp index 3d2c55cef..2de9f4bb8 100644 --- a/tests/ir_eval_test.cpp +++ b/tests/ir_eval_test.cpp @@ -47,7 +47,7 @@ go_bandit([]() { index::ir_eval eval{*file_cfg}; // sanity test bounds for (size_t i = 0; i < 5; ++i) { - auto path = idx->doc_path(doc_id{i}); + auto path = *idx->metadata(doc_id{i}, "path"); corpus::document query{doc_id{0}}; query.content(filesystem::file_text(path)); diff --git a/tests/language_model_test.cpp b/tests/language_model_test.cpp index e8a87cdc9..64ccd1660 100644 --- a/tests/language_model_test.cpp +++ b/tests/language_model_test.cpp @@ -28,7 +28,7 @@ void run_test(const cpptoml::table& line_cfg) { AssertThat(s4.size(), Equals(5ul)); // log_prob values calculated with KenLM - const double delta = 0.0000001; + const double delta = 1e-5; AssertThat(model.log_prob(s1), EqualsWithDelta(-5.0682507, delta)); AssertThat(model.log_prob(s2), EqualsWithDelta(-11.7275571, delta)); AssertThat(model.log_prob(s3), EqualsWithDelta(-11.07649517, delta)); diff --git a/tests/ranker_regression_test.cpp b/tests/ranker_regression_test.cpp new file mode 100644 index 000000000..ae58c173f --- /dev/null +++ b/tests/ranker_regression_test.cpp @@ -0,0 +1,183 @@ +/** + * @file ranker_regression_test.cpp + * @author Chase Geigle + */ + +#include "bandit/bandit.h" +#include "create_config.h" +#include "meta/corpus/document.h" +#include "meta/index/eval/ir_eval.h" +#include "meta/index/forward_index.h" +#include "meta/index/ranker/all.h" + +using namespace bandit; +using namespace meta; + +namespace { +struct ret_perf { + double map; + double avg_ndcg; +}; + +ret_perf retrieval_performance(index::ranker& r, index::inverted_index& idx, + const cpptoml::table& cfg) { + index::ir_eval eval{cfg}; + + std::ifstream queries{*cfg.get_as("query-path")}; + std::string line; + + double cumulative_ndcg = 0.0; + uint64_t num_queries = 0; + for (query_id qid{1}; std::getline(queries, line); ++qid, ++num_queries) { + corpus::document query; + query.content(line); + auto results = r.score(idx, query, 1); + eval.avg_p(results, qid, results.size()); + cumulative_ndcg += eval.ndcg(results, qid, results.size()); + } + + ret_perf perf; + perf.map = eval.map(); + perf.avg_ndcg = cumulative_ndcg / num_queries; + return perf; +} +} + +go_bandit([]() { + + describe("[ranker regression]", []() { + auto cfg = tests::create_config("line"); + cfg->insert("dataset", "cranfield"); + cfg->insert("query-judgements", + "../data/cranfield/cranfield-qrels.txt"); + cfg->insert("index", "cranfield-idx"); + cfg->insert("query-path", "../data/cranfield/cranfield-queries.txt"); + + auto anas = cfg->get_table_array("analyzers"); + auto ana = anas->get()[0]; + ana->insert("filter", "default-unigram-chain"); + + filesystem::remove_all("cranfield-idx"); + auto idx = index::make_index(*cfg); + + it("should obtain expected performance with absolute discounting", + [&]() { + index::absolute_discount r; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.34)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.22)); + }); + + it("should obtain expected performance with Dirichlet prior", [&]() { + index::dirichlet_prior r; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.30)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.21)); + }); + + it("should obtain expected performance with Jelinek-Mercer", [&]() { + index::jelinek_mercer r; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.34)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.23)); + }); + + it("should obtain expected performance with Okapi BM25", [&]() { + index::okapi_bm25 r; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.33)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.22)); + }); + + it("should obtain expected performance with pivoted length", [&]() { + index::pivoted_length r; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.32)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.21)); + }); + + it("should obtain expected performance with KL-divergence PRF", [&]() { + index::kl_divergence_prf r{ + index::make_index(*cfg)}; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.33)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.22)); + }); + + it("should obtain expected performance with Rocchio", [&]() { + index::rocchio r{index::make_index(*cfg)}; + auto perf = retrieval_performance(r, *idx, *cfg); + AssertThat(perf.map, IsGreaterThan(0.34)); + AssertThat(perf.avg_ndcg, IsGreaterThan(0.23)); + }); + + it("should get better performance than Dirichlet prior when using " + "KL-divergence PRF", + [&]() { + index::kl_divergence_prf kl_div{ + index::make_index(*cfg)}; + auto kl_perf = retrieval_performance(kl_div, *idx, *cfg); + + index::dirichlet_prior dp; + auto dp_perf = retrieval_performance(dp, *idx, *cfg); + + AssertThat(kl_perf.map, IsGreaterThanOrEqualTo(dp_perf.map)); + AssertThat(kl_perf.avg_ndcg, + IsGreaterThanOrEqualTo(dp_perf.avg_ndcg)); + }); + + it("should get better performance than Jelinek-Mercer when using " + "KL-divergence PRF", + [&]() { + index::kl_divergence_prf kl_div{ + index::make_index(*cfg), + make_unique()}; + auto kl_perf = retrieval_performance(kl_div, *idx, *cfg); + + index::jelinek_mercer jm; + auto jm_perf = retrieval_performance(jm, *idx, *cfg); + + AssertThat(kl_perf.map, IsGreaterThanOrEqualTo(jm_perf.map)); + AssertThat(kl_perf.avg_ndcg, + IsGreaterThanOrEqualTo(jm_perf.avg_ndcg)); + }); + + it("should get better performance than Okapi BM25 when using Rocchio", + [&]() { + index::rocchio rocchio{ + index::make_index(*cfg), + make_unique()}; + + auto rocchio_perf = retrieval_performance(rocchio, *idx, *cfg); + + index::okapi_bm25 bm25; + auto bm25_perf = retrieval_performance(bm25, *idx, *cfg); + + AssertThat(rocchio_perf.map, + IsGreaterThanOrEqualTo(bm25_perf.map)); + AssertThat(rocchio_perf.avg_ndcg, + IsGreaterThanOrEqualTo(bm25_perf.avg_ndcg)); + }); + + it("should get better performance than pivoted length when using " + "Rocchio", + [&]() { + index::rocchio rocchio{ + index::make_index(*cfg), + make_unique()}; + + auto rocchio_perf = retrieval_performance(rocchio, *idx, *cfg); + + index::pivoted_length pl; + auto pl_perf = retrieval_performance(pl, *idx, *cfg); + + AssertThat(rocchio_perf.map, + IsGreaterThanOrEqualTo(pl_perf.map)); + AssertThat(rocchio_perf.avg_ndcg, + IsGreaterThanOrEqualTo(pl_perf.avg_ndcg)); + }); + + idx = nullptr; + filesystem::remove_all("cranfield-idx"); + }); +}); diff --git a/tests/ranker_test.cpp b/tests/ranker_test.cpp index 1ac39511e..579c8be40 100644 --- a/tests/ranker_test.cpp +++ b/tests/ranker_test.cpp @@ -7,18 +7,22 @@ #include "create_config.h" #include "meta/corpus/document.h" #include "meta/index/ranker/all.h" +#include "meta/index/forward_index.h" using namespace bandit; using namespace meta; -namespace { +namespace +{ template -void test_rank(Ranker& r, Index& idx, const std::string& encoding) { +void test_rank(Ranker& r, Index& idx, const std::string& encoding) +{ // exhaustive search for each document - for (size_t i = 0; i < idx.num_docs(); ++i) { + for (size_t i = 0; i < idx.num_docs(); ++i) + { auto d_id = idx.docs()[i]; - auto path = idx.doc_path(d_id); + auto path = *idx.template metadata(d_id, "path"); corpus::document query{doc_id{i}}; query.content(filesystem::file_text(path), encoding); @@ -28,7 +32,8 @@ void test_rank(Ranker& r, Index& idx, const std::string& encoding) { // since we're searching for a document already in the index, the same // document should be ranked first, but there are a few duplicate // documents...... - if (ranking[0].d_id != i) { + if (ranking[0].d_id != i) + { AssertThat(ranking[1].d_id, Equals(i)); AssertThat(ranking[0].score, EqualsWithDelta(ranking[1].score, 0.0001)); @@ -44,7 +49,8 @@ void test_rank(Ranker& r, Index& idx, const std::string& encoding) { AssertThat(ranking[0].score, Is().GreaterThan(ranking.back().score)); // check for sorted-ness of ranking - for (uint64_t i = 1; i < ranking.size(); ++i) { + for (uint64_t i = 1; i < ranking.size(); ++i) + { AssertThat(ranking[i - 1].score, Is().GreaterThanOrEqualTo(ranking[i].score)); } @@ -87,6 +93,14 @@ go_bandit([]() { test_rank(r, *idx, encoding); }); + it("should be able to rank with KL-divergence pseudo-relevance " + "feedback", + [&]() { + index::kl_divergence_prf r{ + index::make_index(*config)}; + test_rank(r, *idx, encoding); + }); + idx = nullptr; filesystem::remove_all("ceeaus"); }); diff --git a/tests/tokenizer_filter_test.cpp b/tests/tokenizer_filter_test.cpp index 92323aab7..fa2c948bd 100644 --- a/tests/tokenizer_filter_test.cpp +++ b/tests/tokenizer_filter_test.cpp @@ -6,22 +6,24 @@ #include -#include "meta/analyzers/tokenizers/whitespace_tokenizer.h" -#include "meta/analyzers/tokenizers/icu_tokenizer.h" -#include "meta/analyzers/tokenizers/character_tokenizer.h" -#include "meta/analyzers/filters/all.h" #include "bandit/bandit.h" -#include "meta/corpus/document.h" #include "create_config.h" +#include "meta/analyzers/filters/all.h" +#include "meta/analyzers/tokenizers/character_tokenizer.h" +#include "meta/analyzers/tokenizers/icu_tokenizer.h" +#include "meta/analyzers/tokenizers/whitespace_tokenizer.h" +#include "meta/corpus/document.h" #include "meta/util/shim.h" using namespace bandit; using namespace meta; -namespace { +namespace +{ void check_expected(analyzers::token_stream& filter, - std::vector& expected) { + std::vector& expected) +{ AssertThat(static_cast(filter), IsTrue()); for (const auto& s : expected) AssertThat(filter.next(), Equals(s)); @@ -59,8 +61,8 @@ go_bandit([]() { it("should work on easy sentences", [&]() { norm->set_content("\"This \t\n\f\ris a quote,'' said Dr. Smith."); std::vector expected - = {"``", "This", " ", "is", " ", "a", " ", "quote", ",", - "''", " ", "said", " ", "Dr", ".", " ", "Smith", "."}; + = {"``", "This", "is", "a", "quote", ",", + "''", "said", "Dr", ".", "Smith", "."}; check_expected(*norm, expected); }); @@ -69,11 +71,9 @@ go_bandit([]() { "What about when we don't want to knee-jerk? We'll " "have to do something."); std::vector expected - = {"What", " ", "about", " ", "when", " ", - "we", " ", "don", "'t", " ", "want", - " ", "to", " ", "knee-jerk", "?", " ", - "We", "'ll", " ", "have", " ", "to", - " ", "do", " ", "something", "."}; + = {"What", "about", "when", "we", "don", "'t", + "want", "to", "knee-jerk", "?", "We", "'ll", + "have", "to", "do", "something", "."}; check_expected(*norm, expected); }); }); @@ -85,7 +85,7 @@ go_bandit([]() { auto norm = make_unique(std::move(tok), "Katakana-Latin"); norm->set_content("キャンパス ハロ"); - std::vector expected = {"kyanpasu", " ", "haro"}; + std::vector expected = {"kyanpasu", "haro"}; check_expected(*norm, expected); }); @@ -95,8 +95,7 @@ go_bandit([]() { "Greek-Latin"); norm->set_content("τί φῄς γραφὴν σέ τις ὡς ἔοικε"); std::vector expected - = {"tí", " ", "phḗis", " ", "graphḕn", " ", "sé", - " ", "tis", " ", "hōs", " ", "éoike"}; + = {"tí", "phḗis", "graphḕn", "sé", "tis", "hōs", "éoike"}; check_expected(*norm, expected); }); @@ -148,8 +147,7 @@ go_bandit([]() { filters::list_filter::type::REJECT); norm->set_content("supposedly i am the octopus of the big house"); std::vector expected - = {"supposedly", " ", " ", " ", " ", "octopus", - " ", " ", " ", "big", " ", "house"}; + = {"supposedly", "octopus", "big", "house"}; check_expected(*norm, expected); }); }); @@ -161,8 +159,7 @@ go_bandit([]() { auto norm = make_unique(std::move(tok)); norm->set_content("A\tweIrd Punctuation casE IS HERE!"); std::vector expected - = {"a", "\t", "weird", " ", "punctuation", " ", - "case", " ", "is", " ", "here!"}; + = {"a", "weird", "punctuation", "case", "is", "here!"}; check_expected(*norm, expected); }); }); @@ -177,9 +174,9 @@ go_bandit([]() { // note that the comma on retrieval prevents the word // form being // stemmed - std::vector expected = { - "In", " ", "linguist", " ", "morpholog", " ", "and", " ", - "inform", " ", "retrieval,", " ", "stem"}; + std::vector expected + = {"In", "linguist", "morpholog", "and", + "inform", "retrieval,", "stem"}; check_expected(*norm, expected); }); }); @@ -207,7 +204,7 @@ go_bandit([]() { describe("[tokenizer-filter] sentence_boundary", [&]() { std::unique_ptr stream; - stream = make_unique(); + stream = make_unique(false); stream = make_unique(std::move(stream)); stream = make_unique(std::move(stream)); @@ -227,7 +224,7 @@ go_bandit([]() { auto stopwords_file = *config->get_as("stop-words"); std::unique_ptr stream; - stream = make_unique(); + stream = make_unique(false); stream = make_unique(std::move(stream)); stream = make_unique(std::move(stream)); stream = make_unique(std::move(stream)); diff --git a/travis/install_libcxx.sh b/travis/install_libcxx.sh index 779035b20..b1dbf288c 100755 --- a/travis/install_libcxx.sh +++ b/travis/install_libcxx.sh @@ -13,7 +13,8 @@ cd ../ mkdir build cd build -cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=$HOME ../ +cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=$HOME \ + $LIBCXX_EXTRA_CMAKE_FLAGS ../ make cxx make install-libcxx install-libcxxabi diff --git a/travis/install_linux.sh b/travis/install_linux.sh index 3352774f7..de9631481 100755 --- a/travis/install_linux.sh +++ b/travis/install_linux.sh @@ -5,8 +5,13 @@ mkdir $HOME/bin export PATH=$HOME/bin:$PATH mkdir $HOME/include export CPLUS_INCLUDE_PATH=$HOME/include:$CPLUS_INCLUDE_PATH -wget --no-check-certificate http://www.cmake.org/files/v3.2/cmake-3.2.2-Linux-x86_64.sh -sh cmake-3.2.2-Linux-x86_64.sh --prefix=$HOME --exclude-subdir + +CMAKE_VERSION="${CMAKE_VERSION:-3.2.3}" +CMAKE_VERSION_PARTS=( ${CMAKE_VERSION//./ } ) +CMAKE_MAJOR_MINOR="${CMAKE_VERSION_PARTS[0]}.${CMAKE_VERSION_PARTS[1]}" + +wget --no-check-certificate https://www.cmake.org/files/v$CMAKE_MAJOR_MINOR/cmake-$CMAKE_VERSION-Linux-x86_64.sh +sh cmake-$CMAKE_VERSION-Linux-x86_64.sh --prefix=$HOME --exclude-subdir # we have to manually set CC and CXX since travis 'helpfully' clobbers them if [ "$COMPILER" = "gcc" ]; then