Skip to content

Commit

Permalink
Merge pull request #392 from wjdanswjddl/feature/dnnsp-mask_negcharge
Browse files Browse the repository at this point in the history
Feature/dnnsp mask negcharge
  • Loading branch information
HaiwangYu authored Mar 6, 2025
2 parents 5053f40 + f4f5bab commit be78c0a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
7 changes: 5 additions & 2 deletions pytorch/inc/WireCellPytorch/DNNROIFinding.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace WireCell {
// probability-like value. It is tuned to balance efficiency and
// noise reduction and strictly is best optimized for the given
// model.
double mask_thresh{0.7};
double mask_thresh{0.5};

// The IForward service to use
std::string forward{"TorchService"};
Expand Down Expand Up @@ -87,6 +87,9 @@ namespace WireCell {
std::string outtag{""};

int nchunks{1};

// if true, save the negative parts of the charge traces
bool save_negative_charge{false};
};

class DNNROIFinding : public Aux::Logger,
Expand Down Expand Up @@ -133,7 +136,7 @@ namespace WireCell {
IFrame::trace_summary_t get_summary_e(const IFrame::pointer& inframe, const std::string &tag) const;

// Convert dense array to (dense) traces
ITrace::shared_vector eigen_to_traces(const Array::array_xxf& arr);
ITrace::shared_vector eigen_to_traces(const Array::array_xxf& arr, bool save_negative_charge);

int m_save_count; // count frames saved
};
Expand Down
9 changes: 7 additions & 2 deletions pytorch/src/DNNROIFinding.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void Pytorch::DNNROIFinding::configure(const WireCell::Configuration& cfg)
m_cfg.outtag = get(cfg, "outtag", m_cfg.outtag);
m_cfg.debugfile = get(cfg, "debugfile", m_cfg.debugfile);
m_cfg.nchunks = get(cfg, "nchunks", m_cfg.nchunks);
m_cfg.save_negative_charge = get(cfg, "save_negative_charge", m_cfg.save_negative_charge);

m_nrows = m_chlist.size();
m_ncols = m_cfg.nticks;
Expand Down Expand Up @@ -163,6 +164,7 @@ WireCell::Configuration Pytorch::DNNROIFinding::default_configuration() const
cfg["outtag"] = m_cfg.outtag;
cfg["debugfile"] = m_cfg.debugfile;
cfg["nchunks"] = m_cfg.nchunks;
cfg["save_negative_charge"] = m_cfg.save_negative_charge;
return cfg;
}

Expand Down Expand Up @@ -208,14 +210,17 @@ IFrame::trace_summary_t Pytorch::DNNROIFinding::get_summary_e(const IFrame::poin
return summary_e;
}

ITrace::shared_vector Pytorch::DNNROIFinding::eigen_to_traces(const Array::array_xxf& arr)
ITrace::shared_vector Pytorch::DNNROIFinding::eigen_to_traces(const Array::array_xxf& arr, bool save_negative_charge)
{
ITrace::vector traces;
ITrace::ChargeSequence charge(m_ncols, 0.0);
for (size_t irow = 0; irow < m_nrows; ++irow) {
auto wave = arr.row(irow);
for (size_t icol=0; icol<m_ncols; ++icol) {
charge[icol] = wave(icol);
if (!save_negative_charge) { // set negative charge to zero
charge[icol] = charge[icol] < 0 ? 0 : charge[icol];
}
}
const auto ch = m_chlist[irow];
traces.push_back(std::make_shared<Aux::SimpleTrace>(ch, 0, charge));
Expand Down Expand Up @@ -322,7 +327,7 @@ bool Pytorch::DNNROIFinding::operator()(const IFrame::pointer& inframe, IFrame::
#endif

// eigen to frame
auto traces = eigen_to_traces(sp_charge);
auto traces = eigen_to_traces(sp_charge, m_cfg.save_negative_charge);
Aux::SimpleFrame* sframe = new Aux::SimpleFrame(
inframe->ident(), inframe->time(),
traces,
Expand Down

0 comments on commit be78c0a

Please sign in to comment.