Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relu merge optimizer pass #586

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
13 changes: 13 additions & 0 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
static const unsigned n_out = {n_out};
static const unsigned reuse_factor = {reuse};
static const unsigned strategy = nnet::{strategy};
static const bool merged_relu = {merged_relu};
typedef {accum_t.name} accum_t;
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef {out_t} out_t;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""
Expand Down Expand Up @@ -66,6 +68,8 @@ def format(self, node):
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false"
mult_params['out_t'] = node.get_output_variable().type.name
mult_config = self.mult_template.format(**mult_params)

return mult_config + '\n' + conv_config
Expand Down Expand Up @@ -139,6 +143,15 @@ def format(self, node):
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false"
print(f"My out_t Class = {type(node.intermediate_op.type)}")
# TODO: Need to figure out when to append ::value_type (when
# node.intermediate_op's type is nnet::array but how to get that from a
# layer class?) and when not to Try: I think only io_stream IOType uses
# PackedType (io_parallel does not). Could grab IOType from layer
# class?? Turns out this isn't all that's needed--unclear what else.
# Also might need to add relu merge into dense_latency.h
mult_params['out_t'] = node.intermediate_op.type.name + '::value_type' if node.model.config.get_config_value('IOType') == 'io_stream' else node.intermediate_op.type.name
mult_config = self.mult_template.format(**mult_params)

return mult_config + '\n' + conv_config
Expand Down
4 changes: 4 additions & 0 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const unsigned n_nonzeros = {nonzeros};
static const bool merged_relu = {merged_relu};
static const bool store_weights_in_bram = false;
typedef {accum_t.name} accum_t;
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef {index_t.name} index_t;
typedef {out_t} out_t;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""
Expand All @@ -36,6 +38,8 @@ def format(self, node):
params['nzeros'] = node.get_weights('weight').nzeros
params['nonzeros'] = node.get_weights('weight').nonzeros
params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
params['merged_relu'] = "true" if node.get_merged_relu() else "false"
params['out_t'] = node.get_output_variable().type.name

return self.template.format(**params)

Expand Down
11 changes: 11 additions & 0 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, model, name, attributes, inputs, outputs=None):
accum_t = NamedType(*reversed(self.model.config.get_precision(self, 'accum')))
self.set_attr('accum_t', accum_t)

self.merged_relu = False

layer_config = self.model.config.get_layer_config(self)
for config_key, config_value in layer_config.items():
if config_key in self.attributes:
Expand Down Expand Up @@ -234,6 +236,7 @@ def _default_config_params(self):
params.update(self.attributes)
params['iotype'] = self.model.config.get_config_value('IOType')
params['reuse'] = self.get_attr('reuse_factor')
params['merged_relu'] = "true" if self.get_merged_relu() else "false"

return params

Expand All @@ -243,6 +246,12 @@ def get_layer_precision(self):
precision[data_type.name] = data_type
return precision

def get_merged_relu(self):
return self.merged_relu

def set_merged_relu(self, merged_relu):
self.merged_relu = merged_relu # Bool flag to set merged_relu

def get_numbers_cpp(self):
numbers = ''
for k, v in self.get_output_variable().get_shape():
Expand Down Expand Up @@ -300,6 +309,7 @@ def initialize(self):
else:
dims = ['N_LAYER_{}'.format(self.index)]
self.add_output_variable(shape, dims)
self.intermediate_op = self.get_output_variable()
self.add_weights(quantizer=self.get_attr('weight_quantizer'), compression=self.model.config.get_compression(self))
self.add_bias(quantizer=self.get_attr('bias_quantizer'))

Expand Down Expand Up @@ -416,6 +426,7 @@ def initialize(self):
shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']]
dims = ['N_FILT_{}'.format(self.index), 'OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index)]
self.add_output_variable(shape, dims)
self.intermediate_op = self.get_output_variable()
self.add_weights(quantizer=self.get_attr('weight_quantizer'))
self.add_bias(quantizer=self.get_attr('bias_quantizer'))

Expand Down
4 changes: 2 additions & 2 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
try:
import qkeras
register_flow('convert', ['fuse_bias_add', 'remove_useless_transpose', 'output_rounding_saturation_mode', 'qkeras_factorize_alpha', 'extract_ternary_threshold', 'fuse_consecutive_batch_normalization']) # TODO Maybe not all QKeras optmizers belong here?
register_flow('optimize', ['eliminate_linear_activation', 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat'], requires=['convert'])
register_flow('optimize', ['eliminate_linear_activation', 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat', 'merge_relu'], requires=['convert'])
except:
register_flow('convert', ['fuse_bias_add', 'remove_useless_transpose'])
register_flow('optimize', ['eliminate_linear_activation', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat'], requires=['convert'])
register_flow('optimize', ['eliminate_linear_activation', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat', 'merge_relu'], requires=['convert'])

del opt_path
del module_path
Expand Down
58 changes: 58 additions & 0 deletions hls4ml/model/optimizer/passes/merge_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.layers import Activation, Dense, Conv2D, Conv2DBatchnorm

class MergeRelu(OptimizerPass):
def match(self, node):
backends = ['VivadoAccelerator', 'Vivado']
supported_layers = ['Dense', 'Conv2D', 'Conv2DBatchNorm']
# By the time this optimization pass runs, the Layer nodes' class names
# have been prepended with the name of the backend, e.g., a Conv2D
# layer is renamed VivadoAcceleratorConv2D. Thus, we strip the backend
# name for more streamlined matching.
input_node_class = node.get_input_node().__class__.__name__
curr_node_class = node.__class__.__name__
for b in backends:
input_node_class = input_node_class.replace(b, '')
curr_node_class = curr_node_class.replace(b, '')

is_match = input_node_class in supported_layers
# ReLU layers are of class Activation
is_match = is_match and (curr_node_class == 'Activation')
return is_match

def transform(self, model, node):
# Merge ReLU and Convolution/Dense layer
previous_node = node.get_input_node()
previous_node.set_merged_relu(True) # Turn on merged_relu flag for this Conv/Dense layer
if 'Conv2D' in previous_node.__class__.__name__:
if previous_node.get_attr('data_format') == 'channels_last':
shape = [previous_node.attributes['out_height'], previous_node.attributes['out_width'], previous_node.attributes['n_filt']]
dims = ['OUT_HEIGHT_{}'.format(previous_node.index), 'OUT_WIDTH_{}'.format(previous_node.index), 'N_FILT_{}'.format(previous_node.index)]
else:
shape = [previous_node.attributes['n_filt'], previous_node.attributes['out_height'], previous_node.attributes['out_width']]
dims = ['N_FILT_{}'.format(previous_node.index), 'OUT_HEIGHT_{}'.format(previous_node.index), 'OUT_WIDTH_{}'.format(previous_node.index)]
activation_precision, _ = model.config.get_precision(node, var='result')
previous_node.add_output_variable(shape, dims, precision=activation_precision)
if not node.get_output_nodes():
print("WARNING: {} is the output layer! No rewiring performed.".format(node.name))
model.remove_node(node, rewire=False)
else:
model.remove_node(node, rewire=True)
return True
elif 'Dense' in previous_node.__class__.__name__:
shape = previous_node.get_input_variable().shape[:]
shape[-1] = previous_node.attributes['n_out']
if len(shape) > 1:
dims = ['N_LAYER_{}_{}'.format(i, previous_node.index) for i in range(1, len(shape) + 1)]
else:
dims = ['N_LAYER_{}'.format(previous_node.index)]
print('shape: {}'.format(shape))
print('dims: {}'.format(dims))
activation_precision, _ = model.config.get_precision(node, var='result')
previous_node.add_output_variable(shape, dims, precision=activation_precision)
if not node.get_output_nodes():
print("WARNING: {} is the output layer! No rewiring performed.".format(node.name))
model.remove_node(node, rewire=False)
else:
model.remove_node(node, rewire=True)
return True
Loading