Skip to content

Commit

Permalink
[InsertFIFO] Preserve onnx tensor dtype when inserting FIFOs
Browse files Browse the repository at this point in the history
  • Loading branch information
auphelia committed Sep 19, 2024
1 parent d575f4c commit ec5613c
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/finn/transformation/fpgadataflow/insert_fifo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import numpy as np
import warnings
from onnx import TensorProto
from onnx import helper as oh
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
Expand Down Expand Up @@ -114,6 +113,8 @@ def apply(self, model):
# determine fifo node attributes
fld_shape = n0.get_folded_output_shape()
dtype = n0.get_output_datatype()
n0_otensor = model.get_tensor_valueinfo(output_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type

# check if folded_shape of output of first node and
# input of the second node is equal
Expand Down Expand Up @@ -145,7 +146,7 @@ def apply(self, model):
# or unless create_shallow_fifos is specified
fifo_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
n0.get_normal_output_shape(),
)
graph.value_info.append(fifo_output_tensor)
Expand Down Expand Up @@ -196,13 +197,15 @@ def apply(self, model):
fld_shape = n0.get_folded_input_shape(inp_ind)
n_shape = n0.get_normal_input_shape(inp_ind)
dtype = n0.get_input_datatype(inp_ind)
n0_itensor = model.get_tensor_valueinfo(graph_in_name)
n0_tensor_dtype = n0_itensor.type.tensor_type.elem_type
fifo_depth = n0.get_nodeattr("inFIFODepths")[inp_ind]

if fifo_depth > 2 or self.create_shallow_fifos:
# create fifo node
fifo_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
n0.get_normal_input_shape(inp_ind),
)
graph.value_info.append(fifo_output_tensor)
Expand Down Expand Up @@ -256,13 +259,15 @@ def apply(self, model):
fld_shape = n0.get_folded_output_shape(out_ind)
n_shape = n0.get_normal_output_shape(out_ind)
dtype = n0.get_output_datatype(out_ind)
n0_otensor = model.get_tensor_valueinfo(graph_out_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type
fifo_depth = n0.get_nodeattr("outFIFODepths")[out_ind]

if fifo_depth > 2 or self.create_shallow_fifos:
# create fifo node
fifo_input_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
n0.get_normal_output_shape(),
)
graph.value_info.append(fifo_input_tensor)
Expand Down

0 comments on commit ec5613c

Please sign in to comment.