From 9106d9a0a4e58d7326cb2ecafef9d2af9d1cc364 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 18 Aug 2024 07:59:48 +0000 Subject: [PATCH 1/3] Use `enabled=showprogress` --- src/lda.jl | 87 +++++++++++++++++++++++++----------------------------- 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/src/lda.jl b/src/lda.jl index 2497b9c..ad23df3 100644 --- a/src/lda.jl +++ b/src/lda.jl @@ -51,60 +51,55 @@ function lda( for wordid in 1:number_of_words nzdocs_idxs = nzrange(dtm.dtm, wordid) - for docid in dtm.dtm.rowval[nzdocs_idxs] - for _ in 1:dtm.dtm[docid, wordid] - topicid = rand(1:ntopics) - update_target_topic = topics[topicid] - update_target_topic.count += 1 - update_target_topic.wordcount[wordid] = get(update_target_topic.wordcount, wordid, 0) + 1 - topics[topicid] = update_target_topic - topic_base_document = docs[docid] - push!(topic_base_document.topic, topicid) - push!(topic_base_document.text, wordid) - topic_base_document.topicidcount[topicid] += 1 - end + for docid in dtm.dtm.rowval[nzdocs_idxs], + _ in 1:dtm.dtm[docid, wordid] + topicid = rand(1:ntopics) + update_target_topic = topics[topicid] + update_target_topic.count += 1 + update_target_topic.wordcount[wordid] = get(update_target_topic.wordcount, wordid, 0) + 1 + topics[topicid] = update_target_topic + topic_base_document = docs[docid] + push!(topic_base_document.topic, topicid) + push!(topic_base_document.text, wordid) + topic_base_document.topicidcount[topicid] += 1 end end probs = Vector{Float64}(undef, ntopics) - wait_time = showprogress ? 1.0 : Inf - # Gibbs sampling - @showprogress dt = wait_time for _ in 1:iteration - for doc in docs - for (i, word) in enumerate(doc.text) - topicid_current = doc.topic[i] - doc.topicidcount[topicid_current] -= 1 - topics[topicid_current].count -= 1 - topics[topicid_current].wordcount[word] -= 1 - document_lenth = length(doc.text) - 1 - - for target_topicid in 1:ntopics - topicprob = (doc.topicidcount[target_topicid] + beta) / (document_lenth + beta * ntopics) - topic = topics[target_topicid] - wordprob = (get(topic.wordcount, word, 0) + alpha) / (topic.count + alpha * number_of_words) - probs[target_topicid] = topicprob * wordprob - end - normalize_probs = sum(probs) - - # select new topic - select = rand() - sum_of_prob = 0.0 - new_topicid = 1 - for (selected_topicid, prob) in enumerate(probs) - sum_of_prob += prob / normalize_probs - if select < sum_of_prob - new_topicid = selected_topicid - break - end - end - doc.topic[i] = new_topicid - doc.topicidcount[new_topicid] = get(doc.topicidcount, new_topicid, 0) + 1 - topics[new_topicid].count += 1 - topics[new_topicid].wordcount[word] = get(topics[new_topicid].wordcount, word, 0) + 1 + @showprogress enabled=showprogress for _ in 1:iteration, + doc in docs, + (i, word) in enumerate(doc.text) + topicid_current = doc.topic[i] + doc.topicidcount[topicid_current] -= 1 + topics[topicid_current].count -= 1 + topics[topicid_current].wordcount[word] -= 1 + document_lenth = length(doc.text) - 1 + + for target_topicid in 1:ntopics + topicprob = (doc.topicidcount[target_topicid] + beta) / (document_lenth + beta * ntopics) + topic = topics[target_topicid] + wordprob = (get(topic.wordcount, word, 0) + alpha) / (topic.count + alpha * number_of_words) + probs[target_topicid] = topicprob * wordprob + end + normalize_probs = sum(probs) + + # select new topic + select = rand() + sum_of_prob = 0.0 + new_topicid = 1 + for (selected_topicid, prob) in enumerate(probs) + sum_of_prob += prob / normalize_probs + if select < sum_of_prob + new_topicid = selected_topicid + break end end + doc.topic[i] = new_topicid + doc.topicidcount[new_topicid] = get(doc.topicidcount, new_topicid, 0) + 1 + topics[new_topicid].count += 1 + topics[new_topicid].wordcount[word] = get(topics[new_topicid].wordcount, word, 0) + 1 end # ϕ From a33cefc0d1dfbde52553f55895de376866abe5f6 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 6 Sep 2024 18:27:53 +0330 Subject: [PATCH 2/3] revert changed loops --- src/lda.jl | 85 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/src/lda.jl b/src/lda.jl index ad23df3..ab07398 100644 --- a/src/lda.jl +++ b/src/lda.jl @@ -51,55 +51,58 @@ function lda( for wordid in 1:number_of_words nzdocs_idxs = nzrange(dtm.dtm, wordid) - for docid in dtm.dtm.rowval[nzdocs_idxs], - _ in 1:dtm.dtm[docid, wordid] - topicid = rand(1:ntopics) - update_target_topic = topics[topicid] - update_target_topic.count += 1 - update_target_topic.wordcount[wordid] = get(update_target_topic.wordcount, wordid, 0) + 1 - topics[topicid] = update_target_topic - topic_base_document = docs[docid] - push!(topic_base_document.topic, topicid) - push!(topic_base_document.text, wordid) - topic_base_document.topicidcount[topicid] += 1 + for docid in dtm.dtm.rowval[nzdocs_idxs] + for _ in 1:dtm.dtm[docid, wordid] + topicid = rand(1:ntopics) + update_target_topic = topics[topicid] + update_target_topic.count += 1 + update_target_topic.wordcount[wordid] = get(update_target_topic.wordcount, wordid, 0) + 1 + topics[topicid] = update_target_topic + topic_base_document = docs[docid] + push!(topic_base_document.topic, topicid) + push!(topic_base_document.text, wordid) + topic_base_document.topicidcount[topicid] += 1 + end end end probs = Vector{Float64}(undef, ntopics) # Gibbs sampling - @showprogress enabled=showprogress for _ in 1:iteration, - doc in docs, - (i, word) in enumerate(doc.text) - topicid_current = doc.topic[i] - doc.topicidcount[topicid_current] -= 1 - topics[topicid_current].count -= 1 - topics[topicid_current].wordcount[word] -= 1 - document_lenth = length(doc.text) - 1 - - for target_topicid in 1:ntopics - topicprob = (doc.topicidcount[target_topicid] + beta) / (document_lenth + beta * ntopics) - topic = topics[target_topicid] - wordprob = (get(topic.wordcount, word, 0) + alpha) / (topic.count + alpha * number_of_words) - probs[target_topicid] = topicprob * wordprob - end - normalize_probs = sum(probs) - - # select new topic - select = rand() - sum_of_prob = 0.0 - new_topicid = 1 - for (selected_topicid, prob) in enumerate(probs) - sum_of_prob += prob / normalize_probs - if select < sum_of_prob - new_topicid = selected_topicid - break + @showprogress enabled = showprogress for _ in 1:iteration + for doc in docs + for (i, word) in enumerate(doc.text) + topicid_current = doc.topic[i] + doc.topicidcount[topicid_current] -= 1 + topics[topicid_current].count -= 1 + topics[topicid_current].wordcount[word] -= 1 + document_lenth = length(doc.text) - 1 + + for target_topicid in 1:ntopics + topicprob = (doc.topicidcount[target_topicid] + beta) / (document_lenth + beta * ntopics) + topic = topics[target_topicid] + wordprob = (get(topic.wordcount, word, 0) + alpha) / (topic.count + alpha * number_of_words) + probs[target_topicid] = topicprob * wordprob + end + normalize_probs = sum(probs) + + # select new topic + select = rand() + sum_of_prob = 0.0 + new_topicid = 1 + for (selected_topicid, prob) in enumerate(probs) + sum_of_prob += prob / normalize_probs + if select < sum_of_prob + new_topicid = selected_topicid + break + end + end + doc.topic[i] = new_topicid + doc.topicidcount[new_topicid] = get(doc.topicidcount, new_topicid, 0) + 1 + topics[new_topicid].count += 1 + topics[new_topicid].wordcount[word] = get(topics[new_topicid].wordcount, word, 0) + 1 end end - doc.topic[i] = new_topicid - doc.topicidcount[new_topicid] = get(doc.topicidcount, new_topicid, 0) + 1 - topics[new_topicid].count += 1 - topics[new_topicid].wordcount[word] = get(topics[new_topicid].wordcount, word, 0) + 1 end # ϕ From e07a66f991315f00654f972e641a69136bf5b7ba Mon Sep 17 00:00:00 2001 From: rssdev10 Date: Fri, 6 Sep 2024 19:02:02 +0300 Subject: [PATCH 3/3] Project.toml: bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a930550..f7760af 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "TextAnalysis" uuid = "a2db99b7-8b79-58f8-94bf-bbc811eef33d" license = "MIT" desc = "Julia package for text analysis" -version = "0.8.1" +version = "0.8.2" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"