Skip to content

Commit

Permalink
Merge pull request #812 from i-colbert/fix/min_acc_bw
Browse files Browse the repository at this point in the history
Fix: clean up MinimizeAccumulatorWidth logic for MVAU and VVAU
  • Loading branch information
auphelia authored Aug 4, 2023
2 parents db02614 + f713ab0 commit 99f61fc
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 88 deletions.
80 changes: 38 additions & 42 deletions src/finn/custom_op/fpgadataflow/matrixvectoractivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,14 @@ def minimize_accumulator_width(self, model):
# for the bipolar case they need to be converted to bipolar
if self.get_nodeattr("binaryXnorMode"):
weights = 2 * weights - 1

thresholds = None
if len(self.onnx_node.input) > 2:
thresholds = model.get_initializer(self.onnx_node.input[2])
else:
thresholds = None

idt = self.get_input_datatype()

(acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)
# if runtime-writeable weights, then the values of the weights can
# change and we need to use the worst-case values from the datatypes
if self.get_nodeattr("runtime_writeable_weights"):
Expand All @@ -604,11 +607,7 @@ def minimize_accumulator_width(self, model):
upper_range = calculate_matvec_accumulator_range(upper_worst, idt)
acc_min = min(min(lower_range), min(upper_range))
acc_max = max(max(upper_range), max(upper_range))
# if not runtime-writeable weights, then we can calculate the min
# and max values of the accumulation range using knowledge of the
# weights and input data types since they are fixed
else:
(acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)

# if the thresholds can be used to determine range, then adjust the range
# according to the known values of the thresholds
if thresholds is not None:
Expand All @@ -617,53 +616,50 @@ def minimize_accumulator_width(self, model):
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# clip threshold values
clip_upper = None
clip_lower = None
if max_threshold > acc_max + 1:
clip_upper = acc_max + 1
if min_threshold < acc_min:
clip_lower = acc_min
if (clip_lower is not None) or (clip_upper is not None):
if max_threshold > acc_max or min_threshold < acc_min:
warnings.warn("Clipping some thresholds in %s" % self.onnx_node.name)
thresholds = np.clip(thresholds, clip_lower, clip_upper)
thresholds = np.clip(thresholds, acc_min, acc_max)
model.set_initializer(self.onnx_node.input[2], thresholds)
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# get range required by threshold values
tdt_min = min(acc_min, min_threshold)
tdt_max = max(acc_max, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(-tdt_max - 1)
else:
tdt = DataType.get_smallest_possible(tdt_max)
assert np.vectorize(tdt.allowed)(
acc_min = min(min_threshold, acc_min)
acc_max = max(max_threshold, acc_max)

# if the acc_range is always greater than 0, then acc_max <= 2^P - 1
if acc_min >= 0:
acc_bit_width = np.log2(acc_max + 1)
acc_bit_width = math.ceil(acc_bit_width)
adt = DataType[f"UINT{acc_bit_width}"]
# if the acc_range is signed, then acc_min >= -2^{P-1} and acc_max <=
# 2^{P - 1} - 1, which means 2^{P - 1} >= max(-acc_min, 1 + acc_max)
else:
_acc_max = max(-acc_min, 1 + acc_max)
acc_bit_width = np.log2(_acc_max) + 1
acc_bit_width = math.ceil(acc_bit_width)
adt = DataType[f"INT{acc_bit_width}"]

# if activation, assert that the thresholds can be expressed with adt
if thresholds is not None:
assert np.vectorize(adt.allowed)(
threshold_tensor
).all(), "Thresholds in %s can't be expressed with type %s" % (
self.onnx_node.name,
str(tdt),
str(adt),
)
adt = tdt # Set activation datatype to the threshold datatype
else:
if acc_min < 0:
if abs(acc_min) > acc_max:
adt = DataType.get_smallest_possible(acc_min)
else:
adt = DataType.get_smallest_possible(-acc_max - 1)
else:
adt = DataType.get_smallest_possible(acc_max)
# if this is the last node in the graph, then ensure the datatype is
# divisibly by 8 bits
if model.find_direct_successors(self.onnx_node) is None:
bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]

# if no activation, output and accumulator datatypes are the same
if self.get_nodeattr("noActivation"):
# if this is the last node in the graph, then ensure the datatype is
# divisibly by 8 bits
if model.find_direct_successors(self.onnx_node) is None:
bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]
# for no-activation nodes, output dt = acc dt
self.set_nodeattr("outputDataType", adt.name)
self.set_nodeattr("accDataType", adt.name)

return DataType[self.get_nodeattr("accDataType")]

def minimize_weight_bit_width(self, model):
Expand Down
75 changes: 35 additions & 40 deletions src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def minimize_accumulator_width(self, model):
else:
thresholds = None
idt = self.get_input_datatype()

(acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)
# if runtime-writeable weights, then the values of the weights can
# change and we need to use the worst-case values from the datatypes
if self.get_nodeattr("runtime_writeable_weights"):
Expand All @@ -131,11 +133,7 @@ def minimize_accumulator_width(self, model):
upper_range = calculate_matvec_accumulator_range(upper_worst, idt)
acc_min = min(min(lower_range), min(upper_range))
acc_max = max(max(upper_range), max(upper_range))
# if not runtime-writeable weights, then we can calculate the min
# and max values of the accumulation range using knowledge of the
# weights and input data types since they are fixed
else:
(acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)

# if the thresholds can be used to determine range, then adjust the range
# according to the known values of the thresholds
if thresholds is not None:
Expand All @@ -144,53 +142,50 @@ def minimize_accumulator_width(self, model):
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# clip threshold values
clip_upper = None
clip_lower = None
if max_threshold > acc_max + 1:
clip_upper = acc_max + 1
if min_threshold < acc_min:
clip_lower = acc_min
if (clip_lower is not None) or (clip_upper is not None):
if max_threshold > acc_max or min_threshold < acc_min:
warnings.warn("Clipping some thresholds in %s" % self.onnx_node.name)
thresholds = np.clip(thresholds, clip_lower, clip_upper)
thresholds = np.clip(thresholds, acc_min, acc_max)
model.set_initializer(self.onnx_node.input[2], thresholds)
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# get range required by threshold values
tdt_min = min(acc_min, min_threshold)
tdt_max = max(acc_max, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(-tdt_max - 1)
else:
tdt = DataType.get_smallest_possible(tdt_max)
assert np.vectorize(tdt.allowed)(
acc_min = min(min_threshold, acc_min)
acc_max = max(max_threshold, acc_max)

# if the acc_range is always greater than 0, then acc_max <= 2^P - 1
if acc_min >= 0:
acc_bit_width = np.log2(acc_max + 1)
acc_bit_width = math.ceil(acc_bit_width)
adt = DataType[f"UINT{acc_bit_width}"]
# if the acc_range is signed, then acc_min >= -2^{P-1} and acc_max <=
# 2^{P - 1} - 1, which means 2^{P - 1} >= max(-acc_min, 1 + acc_max)
else:
_acc_max = max(-acc_min, 1 + acc_max)
acc_bit_width = np.log2(_acc_max) + 1
acc_bit_width = math.ceil(acc_bit_width)
adt = DataType[f"INT{acc_bit_width}"]

# if activation, assert that the thresholds can be expressed with adt
if thresholds is not None:
assert np.vectorize(adt.allowed)(
threshold_tensor
).all(), "Thresholds in %s can't be expressed with type %s" % (
self.onnx_node.name,
str(tdt),
str(adt),
)
adt = tdt # Set activation datatype to the threshold datatype
else:
if acc_min < 0:
if abs(acc_min) > acc_max:
adt = DataType.get_smallest_possible(acc_min)
else:
adt = DataType.get_smallest_possible(-acc_max - 1)
else:
adt = DataType.get_smallest_possible(acc_max)
# if this is the last node in the graph, then ensure the datatype is
# divisibly by 8 bits
if model.find_direct_successors(self.onnx_node) is None:
bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]

# if no activation, output and accumulator datatypes are the same
if self.get_nodeattr("noActivation"):
# if this is the last node in the graph, then ensure the datatype is
# divisibly by 8 bits
if model.find_direct_successors(self.onnx_node) is None:
bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]
# for no-activation nodes, output dt = acc dt
self.set_nodeattr("outputDataType", adt.name)
self.set_nodeattr("accDataType", adt.name)

return DataType[self.get_nodeattr("accDataType")]

def minimize_weight_bit_width(self, model):
Expand Down
14 changes: 8 additions & 6 deletions tests/fpgadataflow/test_minimize_bit_width.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,14 @@ def test_minimize_accumulator_width(wdt: DataType, idt: DataType, tdt: DataType,
# less than or equal to this calculation
exp_adt = calculate_accumulator_bit_width(inst, model)
assert cur_adt.bitwidth() <= exp_adt.bitwidth(), "Mismatched accumulation data types"
if model.find_direct_successors(inst.onnx_node) is None:
assert (
cur_adt.bitwidth() % 8
) == 0, "bit width of last node needs to be divisible by 8"

# if there is no activation, outputDataType = accDataType and if it is the last node
# it needs to be divisible by 8
if inst.get_nodeattr("noActivation"):
assert (
cur_adt.bitwidth() == cur_odt.bitwidth()
), "outputDataType and accDataType should be equal"
else:
assert cur_odt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed"
if model.find_direct_successors(inst.onnx_node) is None:
assert (
cur_adt.bitwidth() % 8
) == 0, "bit width of last node needs to be divisible by 8"

0 comments on commit 99f61fc

Please sign in to comment.