diff --git a/lib/classifier-reborn/bayes.rb b/lib/classifier-reborn/bayes.rb index 428d0b9..fd64787 100644 --- a/lib/classifier-reborn/bayes.rb +++ b/lib/classifier-reborn/bayes.rb @@ -60,6 +60,8 @@ def initialize(*args) # b.train "that", "That text" # b.train "The other", "The other text" def train(category, text) + word_hash = Hasher.word_hash(text, @language, @enable_stemmer) + return if word_hash.empty? category = CategoryNamer.prepare_name(category) # Add the category dynamically or raise an error @@ -73,7 +75,7 @@ def train(category, text) @backend.update_category_training_count(category, 1) @backend.update_total_trainings(1) - Hasher.word_hash(text, @language, @enable_stemmer).each do |word, count| + word_hash.each do |word, count| @backend.update_category_word_frequency(category, word, count) @backend.update_category_word_count(category, count) @backend.update_total_words(count) @@ -88,10 +90,12 @@ def train(category, text) # b.train :this, "This text" # b.untrain :this, "This text" def untrain(category, text) + word_hash = Hasher.word_hash(text, @language, @enable_stemmer) + return if word_hash.empty? category = CategoryNamer.prepare_name(category) @backend.update_category_training_count(category, -1) @backend.update_total_trainings(-1) - Hasher.word_hash(text, @language, @enable_stemmer).each do |word, count| + word_hash.each do |word, count| next if @backend.total_words < 0 orig = @backend.category_word_frequency(category, word) || 0 @backend.update_category_word_frequency(category, word, -count) @@ -112,6 +116,12 @@ def untrain(category, text) def classifications(text) score = {} word_hash = Hasher.word_hash(text, @language, @enable_stemmer) + if word_hash.empty? + category_keys.each do |category| + score[category.to_s] = Float::INFINITY + end + return score + end category_keys.each do |category| score[category.to_s] = 0 total = (@backend.category_word_count(category) || 1).to_f diff --git a/lib/classifier-reborn/extensions/hasher.rb b/lib/classifier-reborn/extensions/hasher.rb index 398a46f..c1bf1de 100644 --- a/lib/classifier-reborn/extensions/hasher.rb +++ b/lib/classifier-reborn/extensions/hasher.rb @@ -21,7 +21,7 @@ def word_hash(str, language = 'en', enable_stemmer = true) # Return a word hash without extra punctuation or short symbols, just stemmed words def clean_word_hash(str, language = 'en', enable_stemmer = true) - word_hash_for_words str.gsub(/[^\p{WORD}\s]/, '').downcase.split, language, enable_stemmer + word_hash_for_words(str.gsub(/[^\p{WORD}\s]/, '').downcase.split, language, enable_stemmer) end def word_hash_for_words(words, language = 'en', enable_stemmer = true) diff --git a/test/bayes/bayesian_common_tests.rb b/test/bayes/bayesian_common_tests.rb index aea9468..38f26f6 100644 --- a/test/bayes/bayesian_common_tests.rb +++ b/test/bayes/bayesian_common_tests.rb @@ -71,7 +71,7 @@ def test_classification end def test_classification_with_threshold - b = threshold_classifier('Digit') + b = threshold_classifier('Number') assert_equal 1, b.categories.size refute b.threshold_enabled? @@ -79,15 +79,15 @@ def test_classification_with_threshold assert b.threshold_enabled? assert_equal 0.0, b.threshold # default - b.threshold = -7.0 + b.threshold = -4.0 - 10.times do |a_number| - b.train_digit(a_number.to_s) - b.train_digit(a_number.to_s) + ['one', 'two', 'three', 'four', 'five'].each do |a_number| + b.train_number(a_number) + b.train_number(a_number) end - 10.times do |a_number| - assert_equal 'Digit', b.classify(a_number.to_s) + ['one', 'two', 'three', 'four', 'five'].each do |a_number| + assert_equal 'Number', b.classify(a_number) end refute b.classify('xyzzy') diff --git a/test/bayes/bayesian_integration_test.rb b/test/bayes/bayesian_integration_test.rb index 8996ab6..a7664b8 100644 --- a/test/bayes/bayesian_integration_test.rb +++ b/test/bayes/bayesian_integration_test.rb @@ -57,7 +57,7 @@ def classification_scores(classifier) @testing_set.collect do |line| parts = line.strip.split("\t") result, score = classifier.classify_with_score(parts.last) - "#{result}:#{score}" + score.infinite? ? "irrelevant" : "#{result}:#{score}" end end end