Skip to content

Commit

Permalink
Refactor (#28)
Browse files Browse the repository at this point in the history
* implement models

* add stream, gv, etc.

* complete model for pstream

* use in pstream

* fix gv_switch

* refactor sstream

* refactor p/g stream

* change names

* rename gstream to speech

* remove ModelSet

* refactor DurationEstimator and engine

* fix examples

* derive for Voice

* refactor labels

* move sequence to mask

* move metadata fetcher to VoiceSet

* move around model code

* export inner modules

* refactor parser tree

* rename file to model

* add more tests
  • Loading branch information
femshima authored Mar 24, 2024
1 parent 6553d2c commit 76709f8
Show file tree
Hide file tree
Showing 27 changed files with 1,376 additions and 1,295 deletions.
6 changes: 3 additions & 3 deletions benches/bonsais.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn bonsai(bencher: &mut Bencher) {
let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap();

bencher.iter(|| {
engine.synthesize_from_strings(&lines);
engine.synthesize_from_strings(&lines).unwrap();
});
}

Expand Down Expand Up @@ -63,7 +63,7 @@ fn is_bonsai(bencher: &mut Bencher) {
let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap();

bencher.iter(|| {
engine.synthesize_from_strings(&lines);
engine.synthesize_from_strings(&lines).unwrap();
});
}

Expand Down Expand Up @@ -135,6 +135,6 @@ fn bonsai_letter(bencher: &mut Bencher) {
let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap();

bencher.iter(|| {
engine.synthesize_from_strings(&lines);
engine.synthesize_from_strings(&lines).unwrap();
});
}
5 changes: 2 additions & 3 deletions examples/genji/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
iw.set_parameter(1, Weights::new(&[0.5, 0.5])?)?;
iw.set_parameter(2, Weights::new(&[1.0, 0.0])?)?;

let gstream = engine.synthesize_from_strings(&lines);
let speech = gstream.get_speech();
let speech = engine.synthesize_from_strings(&lines)?;

println!(
"The synthesized voice has {} samples in total.",
Expand All @@ -31,7 +30,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// sample_format: hound::SampleFormat::Int,
// },
// )?;
// for &value in speech {
// for value in speech {
// let clamped = value.min(i16::MAX as f64).max(i16::MIN as f64);
// writer.write_sample(clamped as i16)?;
// }
Expand Down
5 changes: 2 additions & 3 deletions examples/is-bonsai/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let engine = Engine::load(&vec![
"models/hts_voice_nitech_jp_atr503_m001-1.05/nitech_jp_atr503_m001.htsvoice".to_string(),
])?;
let gstream = engine.synthesize_from_strings(&lines);
let speech = gstream.get_speech();
let speech = engine.synthesize_from_strings(&lines)?;

println!(
"The synthesized voice has {} samples in total.",
Expand All @@ -44,7 +43,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// sample_format: hound::SampleFormat::Int,
// },
// )?;
// for &value in speech {
// for value in speech {
// let clamped = value.min(i16::MAX as f64).max(i16::MIN as f64);
// writer.write_sample(clamped as i16)?;
// }
Expand Down
162 changes: 162 additions & 0 deletions src/duration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::model::Models;

pub struct DurationEstimator;

impl DurationEstimator {
pub fn create(&self, models: &Models, speed: f64) -> Vec<usize> {
let duration_params = models.duration();

// determine frame length
let mut duration = Self::estimate_duration(&duration_params, 0.0);
if speed != 1.0 {
let length: usize = duration.iter().sum();
duration =
Self::estimate_duration_with_frame_length(&duration_params, length as f64 / speed);
}

duration
}

pub fn create_with_alignment(&self, models: &Models, times: &[(f64, f64)]) -> Vec<usize> {
let duration_params = models.duration();

// determine state duration
let mut duration = vec![];
// use duration set by user
let mut frame_count = 0;
let mut next_state = 0;
let mut state = 0;
for (i, (_start_frame, end_frame)) in times.iter().enumerate() {
if *end_frame >= 0.0 {
let curr_duration = Self::estimate_duration_with_frame_length(
&duration_params[next_state..state + models.nstate()],
end_frame - frame_count as f64,
);
frame_count += curr_duration.iter().sum::<usize>();
next_state = state + models.nstate();
duration.extend_from_slice(&curr_duration);
} else if i + 1 == times.len() {
eprintln!("HTS_SStreamSet_create: The time of final label is not specified.");
Self::estimate_duration(&duration_params[next_state..state + models.nstate()], 0.0);
}
state += models.nstate();
}

duration
}

/// Estimate state duration
fn estimate_duration(duration_params: &[(f64, f64)], rho: f64) -> Vec<usize> {
duration_params
.iter()
.map(|(mean, vari)| (mean + rho * vari).round().max(1.0) as usize)
.collect()
}
/// Estimate duration from state duration probability distribution and specified frame length
fn estimate_duration_with_frame_length(
duration_params: &[(f64, f64)],
frame_length: f64,
) -> Vec<usize> {
let size = duration_params.len();

// get the target frame length
let target_length: usize = frame_length.round().max(1.0) as usize;

// check the specified duration
if target_length <= size {
return vec![1; size];
}

// RHO calculation
let (mean, vari) = duration_params
.iter()
.fold((0.0, 0.0), |(mean, vari), curr| {
(mean + curr.0, vari + curr.1)
});
let rho = (target_length as f64 - mean) / vari;

let mut duration = Self::estimate_duration(duration_params, rho);

// loop estimation
let mut sum: usize = duration.iter().sum();
let calculate_cost =
|d: usize, (mean, vari): (f64, f64)| (rho - (d as f64 - mean) / vari).abs();
while target_length != sum {
// search flexible state and modify its duration
if target_length > sum {
let (found_duration, _) = duration
.iter_mut()
.zip(duration_params.iter())
.min_by(|(ad, ap), (bd, bp)| {
calculate_cost(**ad + 1, **ap).total_cmp(&calculate_cost(**bd + 1, **bp))
})
.unwrap();
*found_duration += 1;
sum += 1;
} else {
let (found_duration, _) = duration
.iter_mut()
.zip(duration_params.iter())
.filter(|(duration, _)| **duration > 1)
.min_by(|(ad, ap), (bd, bp)| {
calculate_cost(**ad - 1, **ap).total_cmp(&calculate_cost(**bd - 1, **bp))
})
.unwrap();
*found_duration -= 1;
sum -= 1;
}
}

duration
}
}

#[cfg(test)]
mod tests {
use crate::model::tests::load_models;

use super::DurationEstimator;

#[test]
fn without_alignment() {
let models = load_models();
assert_eq!(
DurationEstimator.create(&models, 1.0),
[
8, 17, 14, 25, 15, 3, 4, 2, 2, 2, 2, 3, 3, 3, 3, 4, 3, 2, 2, 2, 3, 3, 6, 3, 2, 3,
3, 3, 3, 2, 2, 1, 3, 2, 14, 22, 14, 26, 38, 5
]
);
assert_eq!(
DurationEstimator.create(&models, 1.2),
[
6, 12, 11, 19, 14, 3, 4, 2, 2, 2, 2, 3, 3, 3, 3, 4, 3, 2, 2, 2, 3, 3, 6, 3, 2, 3,
3, 3, 3, 2, 2, 1, 3, 2, 14, 18, 11, 16, 27, 4
]
);
}

#[test]
fn with_alignment() {
let models = load_models();
assert_eq!(
DurationEstimator.create_with_alignment(
&models,
&[
(0.0, 298.5),
(298.5, 334.5),
(334.5, 350.5),
(350.5, 362.5),
(362.5, 394.5),
(394.5, 416.5),
(416.5, 454.5),
(454.5, 606.5)
]
),
[
36, 86, 48, 102, 27, 7, 11, 6, 6, 6, 2, 4, 3, 4, 3, 3, 3, 2, 2, 2, 3, 6, 14, 6, 3,
4, 5, 6, 4, 3, 3, 1, 4, 4, 26, 28, 19, 42, 55, 8
]
);
}
}
Loading

0 comments on commit 76709f8

Please sign in to comment.