Skip to content

Commit

Permalink
Merge branch 'ar/fix-oversampling' into 'master'
Browse files Browse the repository at this point in the history
Fix oversampling

See merge request machine-learning/modkit!209
  • Loading branch information
ArtRand committed Sep 7, 2024
2 parents 5ed80b7 + 69d0449 commit d235ba0
Show file tree
Hide file tree
Showing 9 changed files with 410 additions and 115 deletions.
2 changes: 1 addition & 1 deletion src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ impl ModSummarize {
} else {
// calculate the filter thresholds at the requested percentile
let pct = (self.filter_percentile * 100f32).floor();
info!("calculating threshold at {pct}% percentile");
info!("calculating threshold at {pct}(th) percentile");
calc_thresholds_per_base(
&read_ids_to_base_mod_calls,
self.filter_percentile,
Expand Down
23 changes: 12 additions & 11 deletions src/entropy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,8 @@ impl DescriptiveStats {
) -> String {
use crate::util::TAB;

format!("\
format!(
"\
{chrom}{TAB}\
{start}{TAB}\
{end}{TAB}\
Expand All @@ -1270,16 +1271,16 @@ impl DescriptiveStats {
{}{TAB}\
{}{TAB}\
{}\n",
self.mean_entropy,
strand.to_char(),
self.median_entropy,
self.min_entropy,
self.max_entropy,
self.mean_num_reads,
self.min_num_reads,
self.max_num_reads,
self.successful_count,
self.failed_count
self.mean_entropy,
strand.to_char(),
self.median_entropy,
self.min_entropy,
self.max_entropy,
self.mean_num_reads,
self.min_num_reads,
self.max_num_reads,
self.successful_count,
self.failed_count
)
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/interval_chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,20 @@ impl ChromCoordinates {
pub(crate) fn len(&self) -> u32 {
self.end_pos.checked_sub(self.start_pos).unwrap_or(0u32)
}

pub(crate) fn merge(self, other: Self) -> Self {
match (&self.focus_positions, &other.focus_positions) {
(FocusPositions::AllPositions, FocusPositions::AllPositions) => {}
_ => todo!("must be 'AllPositions' to merge"),
}
assert_eq!(self.chrom_tid, other.chrom_tid);
Self {
chrom_tid: self.chrom_tid,
start_pos: std::cmp::min(self.start_pos, other.start_pos),
end_pos: std::cmp::max(self.end_pos, other.end_pos),
focus_positions: FocusPositions::AllPositions,
}
}
}

#[derive(new)]
Expand Down
27 changes: 27 additions & 0 deletions src/monoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,33 @@ where
}
}

impl<A> Moniod for FxHashMap<A, usize>
where
A: Eq + Hash,
{
fn zero() -> Self {
FxHashMap::default()
}

fn op(self, other: Self) -> Self {
let mut agg = self;
for (k, v) in other {
*agg.entry(k).or_insert(0usize) += v;
}
agg
}

fn op_mut(&mut self, other: Self) {
for (k, v) in other {
*self.entry(k).or_insert(0usize) += v;
}
}

fn len(&self) -> usize {
self.len()
}
}

impl<A, B> Moniod for FxHashMap<A, B>
where
A: Eq + Hash,
Expand Down
13 changes: 3 additions & 10 deletions src/read_ids_to_base_mod_probs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,9 @@ impl Moniod for ReadIdsToBaseModProbs {
}

fn op(self, other: Self) -> Self {
let mut acc = self.inner;
for (read_id, base_mod_calls) in other.inner {
if acc.contains_key(&read_id) {
continue;
} else {
acc.insert(read_id, base_mod_calls);
}
}

Self { inner: acc }
let mut this = self;
this.op_mut(other);
this
}

fn op_mut(&mut self, other: Self) {
Expand Down
154 changes: 101 additions & 53 deletions src/reads_sampler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
pub(crate) mod record_sampler;
pub(crate) mod sampling_schedule;
use std::path::PathBuf;

use anyhow::anyhow;
use indicatif::{MultiProgress, ProgressBar};
use itertools::Itertools;
use log::debug;
use prettytable::row;
use rayon::prelude::*;
use rust_htslib::bam::{self, Read};
use rustc_hash::FxHashMap;

use record_sampler::RecordSampler;

use crate::interval_chunks::{MultiChromCoordinates, ReferenceIntervalsFeeder};
use crate::interval_chunks::{ChromCoordinates, ReferenceIntervalsFeeder};
use crate::mod_bam::{CollapseMethod, EdgeFilter};
use crate::monoid::Moniod;
use crate::position_filter::StrandedPositionFilter;
use crate::reads_sampler::sampling_schedule::SamplingSchedule;
use crate::reads_sampler::sampling_schedule::{
CountOrSample, SamplingSchedule,
};
use crate::record_processor::{RecordProcessor, WithRecords};
use crate::util::{
get_master_progress_bar, get_subroutine_progress_bar, get_targets,
get_ticker, ReferenceRecord, Region,
get_master_progress_bar, get_targets, get_ticker, ReferenceRecord, Region,
};
use anyhow::anyhow;
use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar};
use log::debug;
use rayon::prelude::*;
use record_sampler::RecordSampler;
use rust_htslib::bam::{self, Read};
use std::path::PathBuf;

pub(crate) mod record_sampler;
pub(crate) mod sampling_schedule;

pub(crate) fn get_sampled_read_ids_to_base_mod_probs<P: RecordProcessor>(
bam_fp: &PathBuf,
Expand Down Expand Up @@ -177,6 +184,11 @@ where
})
.collect::<Vec<ReferenceRecord>>();

let contig_sizes = contigs
.iter()
.map(|rec| (rec.tid, rec.length))
.collect::<FxHashMap<u32, u32>>();

let feeder = ReferenceIntervalsFeeder::new(
contigs,
batch_size,
Expand All @@ -194,50 +206,57 @@ where
}
let tid_progress =
master_progress.add(get_master_progress_bar(feeder.total_length()));
tid_progress.set_message("genome positions");
tid_progress.set_message("total bp");

let sampled_items = master_progress.add(get_ticker());
sampled_items.set_message("base mod calls sampled");
// end prog bar stuff

let mut aggregator = <P::Output as Moniod>::zero();
let mut reads_sampled_per_chr = FxHashMap::default();
for super_batch in feeder {
let total_batch_length =
super_batch.iter().map(|c| c.total_length()).sum::<u64>();
let batch_progress =
master_progress.add(get_subroutine_progress_bar(super_batch.len()));
debug!("batch has total length {total_batch_length}");
batch_progress.set_message("interval batches in progress");
let super_batch_result = super_batch
.into_par_iter()
.progress_with(batch_progress)
.map(|multi_coords| {
run_batch::<P>(
bam_fp,
multi_coords,
total_batch_length as u32,
sampling_schedule,
collapse_method,
edge_filter,
position_filter,
only_mapped,
false,
None,
&sampled_items,
)
})
.flatten()
.reduce(|| <P::Output as Moniod>::zero(), |a, b| a.op(b));
let super_batch_with_counts = sampling_schedule
.accumulate_sample_counts(
super_batch,
&contig_sizes,
&reads_sampled_per_chr,
batch_size,
);
let (super_batch_result, chrom_counts_for_batch) =
super_batch_with_counts
.into_par_iter()
.map(|multi_coords| {
run_batch::<P>(
bam_fp,
multi_coords,
sampling_schedule,
collapse_method,
edge_filter,
position_filter,
only_mapped,
false,
None,
&sampled_items,
)
})
.reduce(
|| (<P::Output as Moniod>::zero(), FxHashMap::zero()),
|(a, x), (b, y)| (a.op(b), x.op(y)),
);
tid_progress.inc(total_batch_length);
aggregator.op_mut(super_batch_result);
reads_sampled_per_chr.op_mut(chrom_counts_for_batch);
}
log_sampled_reads(&reads_sampled_per_chr);

Ok(aggregator)
}

fn run_batch<P: RecordProcessor>(
bam_fp: &PathBuf,
batch: MultiChromCoordinates,
total_batch_length: u32,
batch: Vec<(ChromCoordinates, CountOrSample)>,
sampling_schedule: &SamplingSchedule,
collapse_method: Option<&CollapseMethod>,
edge_filter: Option<&EdgeFilter>,
Expand All @@ -246,12 +265,14 @@ fn run_batch<P: RecordProcessor>(
allow_non_primary: bool,
kmer_size: Option<usize>,
sampled_items_counter: &ProgressBar,
) -> Vec<P::Output> {
) -> (P::Output, FxHashMap<u32, usize>)
where
P::Output: Moniod,
{
batch
.0
.into_par_iter()
.filter(|cc| sampling_schedule.chrom_has_reads(cc.chrom_tid))
.filter(|cc| {
.filter(|(cc, _)| sampling_schedule.chrom_has_reads(cc.chrom_tid))
.filter(|(cc, _)| {
position_filter
.map(|pf| {
pf.overlaps_not_stranded(
Expand All @@ -262,13 +283,15 @@ fn run_batch<P: RecordProcessor>(
})
.unwrap_or(true)
})
.filter_map(|cc| {
let record_sampler = sampling_schedule.get_record_sampler(
cc.chrom_tid,
total_batch_length,
cc.start_pos,
cc.end_pos,
);
.filter_map(|(cc, counts_or_sample)| {
let record_sampler = match counts_or_sample {
CountOrSample::Count(x) => RecordSampler::new_num_reads(x),
CountOrSample::Sample(x) => {
RecordSampler::new_sample_frac(x as f64, None)
}
CountOrSample::All => RecordSampler::new_passthrough(),
};

match sample_reads_from_interval::<P>(
bam_fp,
cc.chrom_tid,
Expand All @@ -285,7 +308,7 @@ fn run_batch<P: RecordProcessor>(
) {
Ok(res) => {
sampled_items_counter.inc(res.size());
Some(res)
Some((res, cc.chrom_tid))
}
Err(e) => {
debug!(
Expand All @@ -299,7 +322,17 @@ fn run_batch<P: RecordProcessor>(
}
}
})
.collect()
.fold(
|| (<P::Output as Moniod>::zero(), FxHashMap::zero()),
|(agg, mut counter), (out, chrom_tid)| {
*counter.entry(chrom_tid).or_insert(0usize) += out.len();
(agg.op(out), counter)
},
)
.reduce(
|| (<P::Output as Moniod>::zero(), FxHashMap::zero()),
|(a, x), (b, y)| (a.op(b), x.op(y)),
)
}

pub(crate) fn sample_reads_from_interval<P: RecordProcessor>(
Expand Down Expand Up @@ -339,3 +372,18 @@ where
kmer_size,
)
}

fn log_sampled_reads(sampled_reads_per_chr: &FxHashMap<u32, usize>) {
let mut tab = prettytable::Table::new();
tab.set_format(*prettytable::format::consts::FORMAT_CLEAN);
let mut total = 0usize;
sampled_reads_per_chr.iter().sorted_by(|(_, x), (_, y)| y.cmp(x)).for_each(
|(chr, count)| {
tab.add_row(row![chr, count]);
total += *count;
},
);

tab.add_row(row!["total", total]);
debug!("final mapped reads sampled:\n{tab}");
}
Loading

0 comments on commit d235ba0

Please sign in to comment.