Skip to content

Commit

Permalink
make DTM type generic (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen authored Mar 14, 2020
1 parent 09b998c commit e83a6e3
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/dtm.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mutable struct DocumentTermMatrix
mutable struct DocumentTermMatrix{T}
dtm::SparseMatrixCSC{Int, Int}
terms::Vector{String}
column_indices::Dict{String, Int}
terms::Vector{T}
column_indices::Dict{T, Int}
end


Expand All @@ -10,8 +10,8 @@ end
Creates a column index lookup dictionary from a vector of terms.
"""
function columnindices(terms::Vector{String})
column_indices = Dict{String, Int}()
function columnindices(terms::Vector{T}) where T
column_indices = Dict{T, Int}()
for i in 1:length(terms)
term = terms[i]
column_indices[term] = i
Expand Down Expand Up @@ -54,7 +54,7 @@ julia> m.dtm
[2, 6] = 1
```
"""
function DocumentTermMatrix(crps::Corpus, terms::Vector{String})
function DocumentTermMatrix(crps::Corpus, terms::Vector{T}) where T
column_indices = columnindices(terms)

m = length(crps)
Expand Down Expand Up @@ -87,7 +87,7 @@ DocumentTermMatrix(crps::Corpus) = DocumentTermMatrix(crps, lexicon(crps))

DocumentTermMatrix(crps::Corpus, lex::AbstractDict) = DocumentTermMatrix(crps, sort(collect(keys(lex))))

DocumentTermMatrix(dtm::SparseMatrixCSC{Int, Int},terms::Vector{String}) = DocumentTermMatrix(dtm, terms, columnindices(terms))
DocumentTermMatrix(dtm::SparseMatrixCSC{Int, Int},terms::Vector{T}) where T = DocumentTermMatrix{T}(dtm, terms, columnindices(terms))

"""
dtm(crps::Corpus)
Expand Down Expand Up @@ -152,7 +152,7 @@ tdm(crps::Corpus) = dtm(crps)' #'
#
##############################################################################

function dtm_entries(d::AbstractDocument, lex::Dict{String, Int})
function dtm_entries(d::AbstractDocument, lex::Dict{T, Int}) where T
ngs = ngrams(d)
indices = Array{Int}(undef, 0)
values = Array{Int}(undef, 0)
Expand Down Expand Up @@ -183,7 +183,7 @@ julia> dtv(crps[1], lexicon(crps))
1 2 0 1 1 1
```
"""
function dtv(d::AbstractDocument, lex::Dict{String, Int})
function dtv(d::AbstractDocument, lex::Dict{T, Int}) where T
p = length(keys(lex))
row = zeros(Int, 1, p)
indices, values = dtm_entries(d, lex)
Expand Down

0 comments on commit e83a6e3

Please sign in to comment.