diff --git a/benches/bonsais.rs b/benches/bonsais.rs index ef4dfcd..270ee20 100644 --- a/benches/bonsais.rs +++ b/benches/bonsais.rs @@ -25,7 +25,7 @@ fn bonsai(bencher: &mut Bencher) { let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); bencher.iter(|| { - engine.synthesize_from_strings(&lines); + engine.synthesize_from_strings(&lines).unwrap(); }); } @@ -63,7 +63,7 @@ fn is_bonsai(bencher: &mut Bencher) { let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); bencher.iter(|| { - engine.synthesize_from_strings(&lines); + engine.synthesize_from_strings(&lines).unwrap(); }); } @@ -135,6 +135,6 @@ fn bonsai_letter(bencher: &mut Bencher) { let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); bencher.iter(|| { - engine.synthesize_from_strings(&lines); + engine.synthesize_from_strings(&lines).unwrap(); }); } diff --git a/examples/genji/main.rs b/examples/genji/main.rs index b4a46b8..1079647 100644 --- a/examples/genji/main.rs +++ b/examples/genji/main.rs @@ -14,8 +14,7 @@ fn main() -> Result<(), Box> { iw.set_parameter(1, Weights::new(&[0.5, 0.5])?)?; iw.set_parameter(2, Weights::new(&[1.0, 0.0])?)?; - let gstream = engine.synthesize_from_strings(&lines); - let speech = gstream.get_speech(); + let speech = engine.synthesize_from_strings(&lines)?; println!( "The synthesized voice has {} samples in total.", @@ -31,7 +30,7 @@ fn main() -> Result<(), Box> { // sample_format: hound::SampleFormat::Int, // }, // )?; - // for &value in speech { + // for value in speech { // let clamped = value.min(i16::MAX as f64).max(i16::MIN as f64); // writer.write_sample(clamped as i16)?; // } diff --git a/examples/is-bonsai/main.rs b/examples/is-bonsai/main.rs index 41c273e..eadeaf7 100644 --- a/examples/is-bonsai/main.rs +++ b/examples/is-bonsai/main.rs @@ -27,8 +27,7 @@ fn main() -> Result<(), Box> { let engine = Engine::load(&vec![ "models/hts_voice_nitech_jp_atr503_m001-1.05/nitech_jp_atr503_m001.htsvoice".to_string(), ])?; - let gstream = engine.synthesize_from_strings(&lines); - let speech = gstream.get_speech(); + let speech = engine.synthesize_from_strings(&lines)?; println!( "The synthesized voice has {} samples in total.", @@ -44,7 +43,7 @@ fn main() -> Result<(), Box> { // sample_format: hound::SampleFormat::Int, // }, // )?; - // for &value in speech { + // for value in speech { // let clamped = value.min(i16::MAX as f64).max(i16::MIN as f64); // writer.write_sample(clamped as i16)?; // } diff --git a/src/duration.rs b/src/duration.rs new file mode 100644 index 0000000..8496d25 --- /dev/null +++ b/src/duration.rs @@ -0,0 +1,162 @@ +use crate::model::Models; + +pub struct DurationEstimator; + +impl DurationEstimator { + pub fn create(&self, models: &Models, speed: f64) -> Vec { + let duration_params = models.duration(); + + // determine frame length + let mut duration = Self::estimate_duration(&duration_params, 0.0); + if speed != 1.0 { + let length: usize = duration.iter().sum(); + duration = + Self::estimate_duration_with_frame_length(&duration_params, length as f64 / speed); + } + + duration + } + + pub fn create_with_alignment(&self, models: &Models, times: &[(f64, f64)]) -> Vec { + let duration_params = models.duration(); + + // determine state duration + let mut duration = vec![]; + // use duration set by user + let mut frame_count = 0; + let mut next_state = 0; + let mut state = 0; + for (i, (_start_frame, end_frame)) in times.iter().enumerate() { + if *end_frame >= 0.0 { + let curr_duration = Self::estimate_duration_with_frame_length( + &duration_params[next_state..state + models.nstate()], + end_frame - frame_count as f64, + ); + frame_count += curr_duration.iter().sum::(); + next_state = state + models.nstate(); + duration.extend_from_slice(&curr_duration); + } else if i + 1 == times.len() { + eprintln!("HTS_SStreamSet_create: The time of final label is not specified."); + Self::estimate_duration(&duration_params[next_state..state + models.nstate()], 0.0); + } + state += models.nstate(); + } + + duration + } + + /// Estimate state duration + fn estimate_duration(duration_params: &[(f64, f64)], rho: f64) -> Vec { + duration_params + .iter() + .map(|(mean, vari)| (mean + rho * vari).round().max(1.0) as usize) + .collect() + } + /// Estimate duration from state duration probability distribution and specified frame length + fn estimate_duration_with_frame_length( + duration_params: &[(f64, f64)], + frame_length: f64, + ) -> Vec { + let size = duration_params.len(); + + // get the target frame length + let target_length: usize = frame_length.round().max(1.0) as usize; + + // check the specified duration + if target_length <= size { + return vec![1; size]; + } + + // RHO calculation + let (mean, vari) = duration_params + .iter() + .fold((0.0, 0.0), |(mean, vari), curr| { + (mean + curr.0, vari + curr.1) + }); + let rho = (target_length as f64 - mean) / vari; + + let mut duration = Self::estimate_duration(duration_params, rho); + + // loop estimation + let mut sum: usize = duration.iter().sum(); + let calculate_cost = + |d: usize, (mean, vari): (f64, f64)| (rho - (d as f64 - mean) / vari).abs(); + while target_length != sum { + // search flexible state and modify its duration + if target_length > sum { + let (found_duration, _) = duration + .iter_mut() + .zip(duration_params.iter()) + .min_by(|(ad, ap), (bd, bp)| { + calculate_cost(**ad + 1, **ap).total_cmp(&calculate_cost(**bd + 1, **bp)) + }) + .unwrap(); + *found_duration += 1; + sum += 1; + } else { + let (found_duration, _) = duration + .iter_mut() + .zip(duration_params.iter()) + .filter(|(duration, _)| **duration > 1) + .min_by(|(ad, ap), (bd, bp)| { + calculate_cost(**ad - 1, **ap).total_cmp(&calculate_cost(**bd - 1, **bp)) + }) + .unwrap(); + *found_duration -= 1; + sum -= 1; + } + } + + duration + } +} + +#[cfg(test)] +mod tests { + use crate::model::tests::load_models; + + use super::DurationEstimator; + + #[test] + fn without_alignment() { + let models = load_models(); + assert_eq!( + DurationEstimator.create(&models, 1.0), + [ + 8, 17, 14, 25, 15, 3, 4, 2, 2, 2, 2, 3, 3, 3, 3, 4, 3, 2, 2, 2, 3, 3, 6, 3, 2, 3, + 3, 3, 3, 2, 2, 1, 3, 2, 14, 22, 14, 26, 38, 5 + ] + ); + assert_eq!( + DurationEstimator.create(&models, 1.2), + [ + 6, 12, 11, 19, 14, 3, 4, 2, 2, 2, 2, 3, 3, 3, 3, 4, 3, 2, 2, 2, 3, 3, 6, 3, 2, 3, + 3, 3, 3, 2, 2, 1, 3, 2, 14, 18, 11, 16, 27, 4 + ] + ); + } + + #[test] + fn with_alignment() { + let models = load_models(); + assert_eq!( + DurationEstimator.create_with_alignment( + &models, + &[ + (0.0, 298.5), + (298.5, 334.5), + (334.5, 350.5), + (350.5, 362.5), + (362.5, 394.5), + (394.5, 416.5), + (416.5, 454.5), + (454.5, 606.5) + ] + ), + [ + 36, 86, 48, 102, 27, 7, 11, 6, 6, 6, 2, 4, 3, 4, 3, 3, 3, 2, 2, 2, 3, 6, 14, 6, 3, + 4, 5, 6, 4, 3, 3, 1, 4, 4, 26, 28, 19, 42, 55, 8 + ] + ); + } +} diff --git a/src/engine.rs b/src/engine.rs index 11f75df..69f9d47 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,21 +1,24 @@ use std::path::Path; use std::sync::Arc; -use crate::constants::{DB, HALF_TONE, MAX_LF0, MIN_LF0}; -use crate::gstream::GenerateSpeechStreamSet; -use crate::label::Label; +use crate::constants::DB; +use crate::duration::DurationEstimator; +use crate::label::{LabelError, Labels}; +use crate::mlpg_adjust::MlpgAdjust; use crate::model::interporation_weight::InterporationWeight; -use crate::model::{ModelError, ModelSet}; -use crate::pstream::ParameterStreamSet; -use crate::sstream::StateStreamSet; +use crate::model::{apply_additional_half_tone, ModelError, Models, VoiceSet}; +use crate::speech::SpeechGenerator; use crate::vocoder::Vocoder; #[derive(Debug, thiserror::Error)] pub enum EngineError { - #[error("Model error")] + #[error("Model error: {0}")] ModelError(#[from] ModelError), #[error("Failed to parse option {0}")] ParseOptionError(String), + + #[error("Label error: {0}")] + LabelError(#[from] LabelError), } #[derive(Debug, Clone)] @@ -70,18 +73,20 @@ impl Default for Condition { } impl Condition { - pub fn load_model(&mut self, ms: &ModelSet) -> Result<(), EngineError> { - let nvoices = ms.get_nvoices(); - let nstream = ms.get_nstream(); + pub fn load_model(&mut self, voices: &VoiceSet) -> Result<(), EngineError> { + let first = voices.first(); + let metadata = &first.metadata; + + let nstream = metadata.num_streams; /* global */ - self.sampling_frequency = ms.get_sampling_frequency(); - self.fperiod = ms.get_fperiod(); + self.sampling_frequency = metadata.sampling_frequency; + self.fperiod = metadata.frame_period; self.msd_threshold = [0.5].repeat(nstream); self.gv_weight = [1.0].repeat(nstream); /* spectrum */ - for option in ms.get_option(0).unwrap_or(&[]) { + for option in &first.stream_models[0].metadata.option { let Some((key, value)) = option.split_once('=') else { eprintln!("Skipped unrecognized option {}.", option); continue; @@ -103,7 +108,7 @@ impl Condition { } /* interpolation weights */ - self.interporation_weight = InterporationWeight::new(nvoices, nstream); + self.interporation_weight = InterporationWeight::new(voices.len(), nstream); Ok(()) } @@ -215,85 +220,83 @@ impl Condition { pub struct Engine { pub condition: Condition, - pub ms: Arc, + pub voices: VoiceSet, } impl Engine { - pub fn load>(voices: &[P]) -> Result { - let ms = ModelSet::load_htsvoice_files(voices).unwrap(); + #[cfg(feature = "htsvoice")] + pub fn load>(voices: &[P]) -> Result { + use crate::model::load_htsvoice_file; + + let voices = voices + .iter() + .map(|path| Ok(Arc::new(load_htsvoice_file(path)?))) + .collect::, ModelError>>()?; + let voiceset = VoiceSet::new(voices)?; + let mut condition = Condition::default(); - condition.load_model(&ms)?; - Ok(Self::new(Arc::new(ms), condition)) - } - pub fn new(ms: Arc, condition: Condition) -> Engine { - Engine { condition, ms } - } + condition.load_model(&voiceset)?; - pub fn synthesize_from_strings(&self, lines: &[String]) -> GenerateSpeechStreamSet { - let labels = self.load_labels(lines); - let state_sequence = self.generate_state_sequence(&labels); - let parameter_sequence = self.generate_parameter_sequence(&state_sequence); - self.generate_sample_sequence(¶meter_sequence) + Ok(Self::new(voiceset, condition)) + } + pub fn new(voices: VoiceSet, condition: Condition) -> Self { + Engine { voices, condition } } - fn load_labels(&self, lines: &[String]) -> Label { - Label::load_from_strings( + pub fn synthesize_from_strings>( + &self, + lines: &[S], + ) -> Result, EngineError> { + let labels = Labels::load_from_strings( self.condition.sampling_frequency, self.condition.fperiod, lines, - ) + )?; + Ok(self.generate_speech(&labels)) } - fn generate_state_sequence(&self, label: &Label) -> StateStreamSet { - let mut sss = StateStreamSet::create( - self.ms.clone(), - label, - self.condition.phoneme_alignment_flag, - self.condition.speed, + pub fn generate_speech(&self, labels: &Labels) -> Vec { + let models = Models::new( + labels.labels().to_vec(), + &self.voices, &self.condition.interporation_weight, ); - self.apply_additional_half_tone(&mut sss); - sss - } - fn apply_additional_half_tone(&self, sss: &mut StateStreamSet) { - if self.condition.additional_half_tone == 0.0 { - return; - } - for i in 0..sss.get_total_state() { - let mut f = sss.get_mean(1, i, 0); - f += self.condition.additional_half_tone * HALF_TONE; - f = f.max(MIN_LF0).min(MAX_LF0); - sss.set_mean(1, i, 0, f); - } - } + let durations = if self.condition.phoneme_alignment_flag { + DurationEstimator.create_with_alignment(&models, labels.times()) + } else { + DurationEstimator.create(&models, self.condition.speed) + }; - fn generate_parameter_sequence(&self, state_sequence: &StateStreamSet) -> ParameterStreamSet { - ParameterStreamSet::create( - state_sequence, - &self.condition.msd_threshold, - &self.condition.gv_weight, - ) - } + let initialize = |stream_index: usize| { + MlpgAdjust::new( + stream_index, + self.condition.gv_weight[stream_index], + self.condition.msd_threshold[stream_index], + ) + }; + + let spectrum = initialize(0).create(models.stream(0), &models, &durations); + let lf0 = { + let mut lf0_params = models.stream(1); + apply_additional_half_tone(&mut lf0_params, self.condition.additional_half_tone); + initialize(1).create(lf0_params, &models, &durations) + }; + let lpf = initialize(2).create(models.stream(2), &models, &durations); - fn generate_sample_sequence( - &self, - parameter_sequence: &ParameterStreamSet, - ) -> GenerateSpeechStreamSet { let vocoder = Vocoder::new( - self.ms.get_vector_length(0) - 1, + models.vector_length(0) - 1, self.condition.stage, self.condition.use_log_gain, self.condition.sampling_frequency, self.condition.fperiod, ); - GenerateSpeechStreamSet::create( - parameter_sequence, - vocoder, + let generator = SpeechGenerator::new( self.condition.fperiod, self.condition.alpha, self.condition.beta, self.condition.volume, - ) + ); + generator.synthesize(vocoder, spectrum, lf0, Some(lpf)) } } diff --git a/src/gstream.rs b/src/gstream.rs deleted file mode 100644 index c1067ed..0000000 --- a/src/gstream.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::{pstream::ParameterStreamSet, vocoder::Vocoder}; - -pub struct GenerateSpeechStreamSet { - speech: Vec, -} - -impl GenerateSpeechStreamSet { - /// Generate speech - pub fn create( - pss: &ParameterStreamSet, - mut v: Vocoder, - fperiod: usize, - alpha: f64, - beta: f64, - volume: f64, - ) -> Self { - // check - if pss.get_nstream() != 2 && pss.get_nstream() != 3 { - panic!("The number of streams must be 2 or 3."); - } - if pss.get_vector_length(1) != 1 { - panic!("The size of lf0 static vector must be 1."); - } - if pss.get_nstream() >= 3 && pss.get_vector_length(2) % 2 == 0 { - panic!("The number of low-pass filter coefficient must be odd numbers."); - } - - // create speech buffer - let total_frame = pss.get_total_frame(); - let mut speech = vec![0.0; total_frame * fperiod]; - - // synthesize speech waveform - for i in 0..total_frame { - let lpf = if pss.get_nstream() >= 3 { - (0..pss.get_vector_length(2)) - .map(|vector_index| pss.get_parameter(2, i, vector_index)) - .collect() - } else { - vec![] - }; - let spectrum: Vec = (0..pss.get_vector_length(0)) - .map(|vector_index| pss.get_parameter(0, i, vector_index)) - .collect(); - - v.synthesize( - pss.get_parameter(1, i, 0), - &spectrum, - &lpf, - alpha, - beta, - volume, - &mut speech[i * fperiod..(i + 1) * fperiod], - ); - } - - GenerateSpeechStreamSet { speech } - } - - /// Get synthesized speech waveform - pub fn get_speech(&self) -> &[f64] { - &self.speech - } -} diff --git a/src/label.rs b/src/label.rs index b60b7a9..fe39d5f 100644 --- a/src/label.rs +++ b/src/label.rs @@ -1,82 +1,146 @@ -use std::str::FromStr; +#[derive(Debug, thiserror::Error)] +pub enum LabelError { + #[error("jlabel failed to parse fullcontext-label: {0}")] + JLabelParse(#[from] jlabel::ParseError), + #[error("Expected a fullcontext-label in {0}")] + MissingLabel(String), + #[error("Failed to parse as floating-point number")] + FloatParse(#[from] std::num::ParseFloatError), -struct LabelString { - content: jlabel::Label, - start: f64, - end: f64, + #[error("The length of `times` and `labels` must be the same")] + LengthMismatch, } -impl LabelString { - fn parse(s: &str, rate: f64) -> Self { - Self::parse_digit_string(s, rate).unwrap_or(Self { - // TODO: remove this unwrap - content: jlabel::Label::from_str(s).unwrap(), - start: -1.0, - end: -1.0, - }) - } - fn parse_digit_string(s: &str, rate: f64) -> Option { - let mut iter = s.splitn(3, ' '); - let start: f64 = iter.next().and_then(|s| s.parse().ok())?; - let end: f64 = iter.next().and_then(|s| s.parse().ok())?; - let content = iter.next()?.parse().ok()?; - Some(Self { - content, - start: rate * start, - end: rate * end, - }) - } +pub struct Labels { + labels: Vec, + times: Vec<(f64, f64)>, } -pub struct Label { - strings: Vec, -} +impl Labels { + pub fn load_from_strings>( + sampling_rate: usize, + fperiod: usize, + lines: &[S], + ) -> Result { + let mut labels = Vec::with_capacity(lines.len()); + let mut times = Vec::with_capacity(lines.len()); -impl Label { - pub fn load_from_strings(sampling_rate: usize, fperiod: usize, lines: &[String]) -> Self { + // start/end times are multiplied with 1e+7 let rate = sampling_rate as f64 / (fperiod as f64 * 1e+7); - let mut strings = Vec::with_capacity(lines.len()); for line in lines { - let Some(first_char) = line.chars().next() else { - break; - }; - if !first_char.is_ascii_graphic() { - break; - } + let line = line.as_ref(); - strings.push(LabelString::parse(line, rate)); - } + let mut split = line.splitn(3, ' '); + let first = split + .next() + .expect("`splitn` is expected to always have at least one element."); - for i in 0..strings.len() { - if i + 1 < strings.len() { - if strings[i].end < 0.0 && strings[i + 1].start >= 0.0 { - strings[i].end = strings[i + 1].start; - } else if strings[i].end >= 0.0 && strings[i + 1].start < 0.0 { - strings[i + 1].start = strings[i].end; - } - } - if strings[i].start < 0.0 { - strings[i].start = -1.0; - } - if strings[i].end < 0.0 { - strings[i].end = -1.0; + if let Some(second) = split.next() { + let third = split + .next() + .ok_or_else(|| LabelError::MissingLabel(line.to_string()))?; + + let mut start: f64 = first.parse()?; + let mut end: f64 = second.parse()?; + + start *= rate; + end *= rate; + + let label = third.parse()?; + + times.push((start, end)); + labels.push(label); + } else if first.is_empty() { + continue; + } else { + let label = first.parse()?; + times.push((-1.0, -1.0)); + labels.push(label); } } - Self { strings } + Self::new(labels, Some(times)) } - pub fn get_size(&self) -> usize { - self.strings.len() + pub fn new( + labels: Vec, + times: Option>, + ) -> Result { + if let Some(mut times) = times { + if labels.len() != times.len() { + return Err(LabelError::LengthMismatch); + } + + for i in 0..times.len() { + if i + 1 < times.len() { + if times[i].1 < 0.0 && times[i + 1].0 >= 0.0 { + times[i].1 = times[i + 1].0; + } else if times[i].1 >= 0.0 && times[i + 1].0 < 0.0 { + times[i + 1].0 = times[i].1; + } + } + + if times[i].0 < 0.0 { + times[i].0 = -1.0; + } + if times[i].1 < 0.0 { + times[i].1 = -1.0; + } + } + + Ok(Self { times, labels }) + } else { + Ok(Self { + times: vec![(-1.0, -1.0); labels.len()], + labels, + }) + } } - pub fn get_label(&self, index: usize) -> &jlabel::Label { - &self.strings[index].content + + pub fn labels(&self) -> &[jlabel::Label] { + &self.labels } - pub fn get_start_frame(&self, index: usize) -> f64 { - self.strings[index].start + pub fn times(&self) -> &[(f64, f64)] { + &self.times } - pub fn get_end_frame(&self, index: usize) -> f64 { - self.strings[index].end +} + +#[cfg(test)] +mod tests { + use super::Labels; + + #[test] + fn with_alignment() { + let lines = [ + "0 14925000 xx^xx-sil+b=o/A:xx+xx+xx/B:xx-xx_xx/C:xx_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:xx_xx#xx_xx@xx_xx|xx_xx/G:4_4%0_xx_xx/H:xx_xx/I:xx-xx@xx+xx&xx-xx|xx+xx/J:1_4/K:1+1-4", + "14925000 16725000 xx^sil-b+o=N/A:-3+1+4/B:xx-xx_xx/C:02_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:4_4#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:xx_xx/I:1-4@1+1&1-1|1+4/J:xx_xx/K:1+1-4", + "16725000 17525000 sil^b-o+N=s/A:-3+1+4/B:xx-xx_xx/C:02_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:4_4#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:xx_xx/I:1-4@1+1&1-1|1+4/J:xx_xx/K:1+1-4", + "17525000 18125000 b^o-N+s=a/A:-2+2+3/B:xx-xx_xx/C:02_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:4_4#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:xx_xx/I:1-4@1+1&1-1|1+4/J:xx_xx/K:1+1-4", + "18125000 19725000 o^N-s+a=i/A:-1+3+2/B:xx-xx_xx/C:02_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:4_4#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:xx_xx/I:1-4@1+1&1-1|1+4/J:xx_xx/K:1+1-4", + "19725000 20825000 N^s-a+i=sil/A:-1+3+2/B:xx-xx_xx/C:02_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:4_4#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:xx_xx/I:1-4@1+1&1-1|1+4/J:xx_xx/K:1+1-4", + "20825000 22725000 s^a-i+sil=xx/A:0+4+1/B:xx-xx_xx/C:02_xx+xx/D:xx+xx_xx/E:xx_xx!xx_xx-xx/F:4_4#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:xx_xx/I:1-4@1+1&1-1|1+4/J:xx_xx/K:1+1-4", + "22725000 30325000 a^i-sil+xx=xx/A:xx+xx+xx/B:xx-xx_xx/C:xx_xx+xx/D:xx+xx_xx/E:4_4!0_xx-xx/F:xx_xx#xx_xx@xx_xx|xx_xx/G:xx_xx%xx_xx_xx/H:1_4/I:xx-xx@xx+xx&xx-xx|xx+xx/J:xx_xx/K:1+1-4", + ]; + let labels = Labels::load_from_strings(48000, 240, &lines).unwrap(); + let times = labels.times(); + + let answer = [ + (0.0, 298.5), + (298.5, 334.5), + (334.5, 350.5), + (350.5, 362.5), + (362.5, 394.5), + (394.5, 416.5), + (416.5, 454.5), + (454.5, 606.5), + ]; + + assert_eq!(times.len(), answer.len()); + + for (time, ans) in times.iter().zip(answer) { + approx::assert_ulps_eq!(time.0, ans.0); + approx::assert_ulps_eq!(time.1, ans.1); + } } } diff --git a/src/lib.rs b/src/lib.rs index 21c39a3..3010ef4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,13 @@ mod constants; +pub mod duration; pub mod engine; -pub mod gstream; pub mod label; +pub mod mlpg_adjust; pub mod model; -pub mod pstream; -pub mod sstream; +pub mod speech; pub mod vocoder; -pub mod sequence; - #[cfg(test)] mod tests { use crate::{engine::Engine, model::interporation_weight::Weights}; @@ -37,8 +35,7 @@ mod tests { let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); - let speech_stream = engine.synthesize_from_strings(&lines); - let speech = speech_stream.get_speech(); + let speech = engine.synthesize_from_strings(&lines).unwrap(); assert_eq!(speech.len(), 66480); approx::assert_abs_diff_eq!(speech[2000], 19.35141137623778, epsilon = 1.0e-10); @@ -59,8 +56,7 @@ mod tests { iw.set_parameter(2, Weights::new(&[1.0, 0.0]).unwrap()) .unwrap(); - let speech_stream = engine.synthesize_from_strings(&lines); - let speech = speech_stream.get_speech(); + let speech = engine.synthesize_from_strings(&lines).unwrap(); assert_eq!(speech.len(), 74880); approx::assert_abs_diff_eq!(speech[2000], 2.3158134981607754e-5, epsilon = 1.0e-10); @@ -102,8 +98,7 @@ mod tests { let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); - let speech_stream = engine.synthesize_from_strings(&lines); - let speech = speech_stream.get_speech(); + let speech = engine.synthesize_from_strings(&lines).unwrap(); assert_eq!(speech.len(), 100800); approx::assert_abs_diff_eq!(speech[2000], 17.15977345625943, epsilon = 1.0e-10); @@ -119,8 +114,7 @@ mod tests { let mut engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); engine.condition.set_speed(1.4); - let speech_stream = engine.synthesize_from_strings(&lines); - let speech = speech_stream.get_speech(); + let speech = engine.synthesize_from_strings(&lines).unwrap(); assert_eq!(speech.len(), 72000); approx::assert_abs_diff_eq!(speech[2000], 15.0481014871396, epsilon = 1.0e-10); @@ -132,8 +126,7 @@ mod tests { #[test] fn empty() { let engine = Engine::load(&[MODEL_NITECH_ATR503.to_string()]).unwrap(); - let speech_stream = engine.synthesize_from_strings(&[]); - let speech = speech_stream.get_speech(); + let speech = engine.synthesize_from_strings::(&[]).unwrap(); assert_eq!(speech.len(), 0); } } diff --git a/src/sequence.rs b/src/mlpg_adjust/mask.rs similarity index 97% rename from src/sequence.rs rename to src/mlpg_adjust/mask.rs index e479a7c..29052b6 100644 --- a/src/sequence.rs +++ b/src/mlpg_adjust/mask.rs @@ -2,7 +2,7 @@ pub struct Mask(Vec); impl FromIterator for Mask { fn from_iter>(iter: I) -> Self { - Self(iter.into_iter().collect()) + Self::new(iter.into_iter().collect()) } } @@ -63,7 +63,7 @@ impl Mask { #[cfg(test)] mod tests { - use crate::sequence::Mask; + use super::Mask; #[test] fn fill() { diff --git a/src/pstream/mlpg.rs b/src/mlpg_adjust/mlpg.rs similarity index 99% rename from src/pstream/mlpg.rs rename to src/mlpg_adjust/mlpg.rs index 9642b1a..25b3d56 100644 --- a/src/pstream/mlpg.rs +++ b/src/mlpg_adjust/mlpg.rs @@ -1,4 +1,4 @@ -use crate::model::window::Windows; +use crate::model::Windows; const W1: f64 = 1.0; const W2: f64 = 1.0; diff --git a/src/mlpg_adjust/mod.rs b/src/mlpg_adjust/mod.rs new file mode 100644 index 0000000..869fc01 --- /dev/null +++ b/src/mlpg_adjust/mod.rs @@ -0,0 +1,121 @@ +use crate::{ + constants::NODATA, + model::{Models, StreamParameter}, +}; + +mod mask; +mod mlpg; + +use self::{ + mask::Mask, + mlpg::{MlpgGlobalVariance, MlpgMatrix}, +}; + +pub struct MlpgAdjust { + stream_index: usize, + gv_weight: f64, + msd_threshold: f64, +} + +impl MlpgAdjust { + pub fn new(stream_index: usize, gv_weight: f64, msd_threshold: f64) -> Self { + Self { + stream_index, + gv_weight, + msd_threshold, + } + } + /// Parameter generation using GV weight + pub fn create( + &self, + stream: StreamParameter, + models: &Models, + durations: &[usize], + ) -> Vec> { + let vector_length = models.vector_length(self.stream_index); + + let msd_flag: Mask = stream + .iter() + .zip(durations) + .flat_map(|((_, msd), duration)| { + let flag = *msd > self.msd_threshold; + [flag].repeat(*duration) + }) + .collect(); + + let msd_boundaries = msd_flag.boundary_distances(); + + let mut pars = Vec::with_capacity(vector_length); + for vector_index in 0..vector_length { + let parameters: Vec> = models + .windows(self.stream_index) + .iter() + .enumerate() + .map(|(window_index, window)| { + let m = vector_length * window_index + vector_index; + + let mut iter = msd_flag.mask().iter(); + stream + .iter() + .zip(durations) + // get mean and ivar, and spread it to its duration + .flat_map(|((curr_stream, _), duration)| { + let (mean, vari) = curr_stream[m]; + let ivar = { + if vari.abs() > 1e19 { + 0.0 + } else if vari.abs() < 1e-19 { + 1e38 + } else { + 1.0 / vari + } + }; + [(mean, ivar)].repeat(*duration) + }) + .zip(&msd_boundaries) + .map(|((mean, ivar), (left, right))| { + let is_left_msd_boundary = *left < window.left_width(); + let is_right_msd_boundary = *right < window.right_width(); + + // If the window includes non-msd frames, set the ivar to 0.0 + if (is_left_msd_boundary || is_right_msd_boundary) && window_index != 0 + { + (mean, 0.0) + } else { + (mean, ivar) + } + }) + .filter(|_| iter.next() == Some(&true)) + .collect() + }) + .collect(); + + let mut mtx = MlpgMatrix::new(); + mtx.calc_wuw_and_wum(models.windows(self.stream_index), parameters); + + let par = if let Some((gv_param, gv_switch)) = models.gv(self.stream_index) { + let mtx_before = mtx.clone(); + let par = mtx.solve(); + + let gv_mean = gv_param[vector_index].0 * self.gv_weight; + let gv_vari = gv_param[vector_index].1; + + let mut iter = msd_flag.mask().iter(); + let gv_switch: Vec = gv_switch + .iter() + .zip(durations) + .flat_map(|(switch, duration)| [*switch].repeat(*duration)) + .filter(|_| iter.next() == Some(&true)) + .collect(); + + MlpgGlobalVariance::new(mtx_before, par, &gv_switch).apply_gv(gv_mean, gv_vari) + } else { + mtx.solve() + }; + + pars.push(msd_flag.fill(par, NODATA)); + } + + pars + } +} diff --git a/src/model/mod.rs b/src/model/mod.rs index 0baf901..2e64919 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,16 +1,16 @@ -use std::{fmt::Display, path::Path}; +use std::{borrow::Cow, sync::Arc}; -use self::{ - interporation_weight::Weights, - stream::{Model, ModelParameter, StreamModels}, - window::Windows, +use self::voice::model::ModelParameter; + +pub use self::{ + interporation_weight::InterporationWeight, + voice::{window::Windows, GlobalModelMetadata, StreamModelMetadata, Voice}, }; + use jlabel::Label; pub mod interporation_weight; -pub mod question; -pub mod stream; -pub mod window; +pub mod voice; #[cfg(feature = "htsvoice")] mod parser; @@ -23,350 +23,428 @@ pub enum ModelError { MetadataError, #[error("Io failed: {0}")] Io(#[from] std::io::Error), + #[cfg(feature = "htsvoice")] #[error("Parser returned error:{0}")] ParserError(#[from] parser::ModelParseError), } -pub struct ModelSet { - metadata: GlobalModelMetadata, - /// ensured to have at least one element - voices: Vec, -} +pub type StreamParameter = Vec<(Vec<(f64, f64)>, f64)>; +pub type GvParameter = (Vec<(f64, f64)>, Vec); -impl Display for ModelSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "{}", self.metadata)?; - for (i, voice) in self.voices.iter().enumerate() { - writeln!(f, "Voice #{}:\n{}", i, voice)?; - } - Ok(()) - } -} +pub struct Models<'a> { + labels: Vec