Skip to content

Commit

Permalink
fix: update ort (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkeenan38 authored Apr 1, 2024
1 parent 93f47a9 commit 0c2d9d5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 55 deletions.
55 changes: 4 additions & 51 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions src/vad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl<const N: usize> VoiceActivityDetector<N> {
.unwrap()
.with_inter_threads(1)
.unwrap()
.with_model_from_memory(MODEL)
.commit_from_memory(MODEL)
.unwrap();

Ok(Self::with_session(session, sample_rate))
Expand Down Expand Up @@ -84,8 +84,16 @@ impl<const N: usize> VoiceActivityDetector<N> {
let outputs = self.session.run(inputs).unwrap();

// Update h and c recursively.
let hn = outputs.get("hn").unwrap().extract_tensor::<f32>().unwrap();
let cn = outputs.get("cn").unwrap().extract_tensor::<f32>().unwrap();
let hn = outputs
.get("hn")
.unwrap()
.try_extract_tensor::<f32>()
.unwrap();
let cn = outputs
.get("cn")
.unwrap()
.try_extract_tensor::<f32>()
.unwrap();

self.h.assign(&hn.view());
self.c.assign(&cn.view());
Expand All @@ -94,7 +102,7 @@ impl<const N: usize> VoiceActivityDetector<N> {
let output = outputs
.get("output")
.unwrap()
.extract_tensor::<f32>()
.try_extract_tensor::<f32>()
.unwrap();
let probability = output.view()[[0, 0]];

Expand Down

0 comments on commit 0c2d9d5

Please sign in to comment.