diff --git a/Project.toml b/Project.toml index 5572fb7..08ab0bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AssignTaxonomy" uuid = "941572e2-79a6-4ec3-8380-7ad5c8caa571" authors = ["Arthur Newbury"] -version = "0.1.2" +version = "0.1.3" [deps] BioSequences = "7e6ae17a-c86d-528c-b3b9-7f778a29fe59" diff --git a/src/classifier.jl b/src/classifier.jl index 65e2975..a26716b 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -55,17 +55,25 @@ Tables.columnnames(m::ClassificationResult) = names(m) #### Underlying algorithm function naieve_bayes(seqs::Vector,refs::Vector,k, n_bootstrap,lp,genera) -t = time() N = length(refs) n = length(seqs) + gens = sort(unique(genera)) + g = length(gens) + counts = zeros(g) + gen2ind = Dict(gens .=> 1:g) + seq2gen = Dict(1:N .=> [gen2ind[gen] for gen in genera]) assignments = Vector{Int64}(undef,n) confs = Vector{Float64}(undef,n) if lp == false - log_probs = [zeros(Float32,4^k) for _ in 1:N] + log_probs = [zeros(Float32,4^k) for _ in 1:g] priors, a =count_mers(refs) word_priors!(priors,N) @batch for i in 1:N - @fastmath log_probs[i] .= log.(conditional_prob.(a[i],priors,1)) + @fastmath log_probs[seq2gen[i]] .+= conditional_prob.(a[i],priors,1) + counts[seq2gen[i]] += 1 + end + for i in 1:g + @fastmath log_probs[i] .= log.(log_probs[i] ./ counts[i]) end else log_probs = lp @@ -75,14 +83,16 @@ t = time() assignment = assign(eachindex(kmer_array)[kmer_array],log_probs) assignments[i] =assignment sample_size = sum(kmer_array) รท k - confs[i] = bootstrap(vec(kmer_array),log_probs,genera[assignment],sample_size,n_bootstrap,genera) + confs[i] = bootstrap(vec(kmer_array),log_probs,gens[assignment],sample_size,n_bootstrap,gens) end - return assignments, confs, log_probs + return assignments, confs, log_probs, gens end function naieve_bayes(seqs::Vector,refs::Vector,taxa ::Array,k, n_bootstrap,lp=false) - a,c,l = naieve_bayes(seqs,refs,k, n_bootstrap,lp,taxa[:,end]) - t = taxa[a,:] + a,c,l,g = naieve_bayes(seqs,refs,k, n_bootstrap,lp,taxa[:,end]) + gens = g[a] + inds = [findfirst(x -> x== gen,taxa[:,end]) for gen in gens] + t = taxa[inds,:] return hcat(t,c),l end