Skip to content

Commit

Permalink
conv axes changes fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 6, 2023
1 parent d21db87 commit b3a0104
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .travis/cli-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ $TRACT_RUN $MODELS/hey_snips_v4_model17.pb \

$TRACT_RUN $MODELS/hey_snips_v4_model17.pb -i S,20,f32 \
dump -q \
--assert-op-count AddAxis 0
--assert-op-count AddAxis 1

$TRACT_RUN $MODELS/en_libri_real/model.onnx \
--output-node output \
Expand Down
13 changes: 12 additions & 1 deletion core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -943,9 +943,20 @@ impl TypedOp for ConvUnary {
&self,
model: &TypedModel,
node: &TypedNode,
_io: InOut,
io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
if io == InOut::In(2) {
if let &AxisOp::Rm(_) = change {
return Ok(Some(AxisChangeConsequence {
substitute_op: Some(Box::new(self.clone())),
wire_changes: tvec!(),
}));
}
}
if io != InOut::In(0) && io != InOut::Out(0) {
return Ok(None);
}
let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
// remove n
Expand Down

0 comments on commit b3a0104

Please sign in to comment.