diff --git a/tests/fpgadataflow/test_minimize_bit_width.py b/tests/fpgadataflow/test_minimize_bit_width.py index 4be0a260b7..0e704230e7 100644 --- a/tests/fpgadataflow/test_minimize_bit_width.py +++ b/tests/fpgadataflow/test_minimize_bit_width.py @@ -296,8 +296,13 @@ def test_minimize_accumulator_width(wdt: DataType, idt: DataType, tdt: DataType, exp_adt = calculate_accumulator_bit_width(inst, model) assert cur_adt.bitwidth() <= exp_adt.bitwidth(), "Mismatched accumulation data types" - # if there is no activation, outputDataType = accDataType + # if there is no activation, outputDataType = accDataType and if it is the last node + # it needs to be divisible by 8 if inst.get_nodeattr("noActivation"): assert ( cur_adt.bitwidth() == cur_odt.bitwidth() ), "outputDataType and accDataType should be equal" + if model.find_direct_successors(inst.onnx_node) is None: + assert ( + cur_adt.bitwidth() % 8 + ) == 0, "bit width of last node needs to be divisible by 8"