Skip to content

Commit

Permalink
Merge pull request #61 from GiuseppeDiGuglielmo/main
Browse files Browse the repository at this point in the history
Fix pooling layers (io_stream)
  • Loading branch information
dgburnette authored Dec 9, 2024
2 parents 7e752d4 + 9069c75 commit 4a35ef4
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions hls4ml/templates/catapult/nnet_utils/nnet_pooling_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ void compute_pool_encoded_2d(

const unsigned sh_idx = pool_table_height[h_idx] * CONFIG_T::pool_width;
const unsigned wp_idx = w_idx * (data_T::size / CONFIG_T::n_filt);
#pragma hls_unroll
PixelLoop:
for (unsigned p = 0; p < data_T::size / CONFIG_T::n_filt; p++) {
//#pragma HLS PIPELINE
Expand All @@ -86,6 +87,7 @@ void compute_pool_encoded_2d(
if ((h_idx < nH) && (wp_idx + p < nW)) {
filt_mask = sh_idx + pool_table_width[wp_idx + p] + 1;
}
// #pragma hls_unroll
CopyDataFilt:
for (unsigned c = 0; c < CONFIG_T::n_filt; c++) {
if (filt_mask > 0)
Expand All @@ -94,6 +96,7 @@ void compute_pool_encoded_2d(
}

if (filt_mask == CONFIG_T::pool_height * CONFIG_T::pool_width) {
#pragma hls_unroll
FiltLoop:
for (unsigned c = 0; c < CONFIG_T::n_filt; c++) {
PoolLoop:
Expand Down Expand Up @@ -145,6 +148,7 @@ void pooling2d_encoded_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {

constexpr int pack_factor = (data_T::size / CONFIG_T::n_filt) * (res_T::size / CONFIG_T::n_filt == 1);
(void)pack_factor;
#pragma hls_pipeline_init_interval pack_factor
ReadInputHeight:
for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) {
ReadInputWidth:
Expand Down Expand Up @@ -175,7 +179,7 @@ void compute_pool_buffer_2d(const data_T &in_elem,
static int sX = 0; // stride X
static int sY = 0; // stride Y

typename data_T::value_type pool_window[CONFIG_T::pool_height * CONFIG_T::pool_width];
typename CONFIG_T::accum_t pool_window[CONFIG_T::pool_height * CONFIG_T::pool_width];
//#pragma HLS ARRAY_PARTITION variable=pool_window complete

static typename data_T::value_type kernel_data[CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt];
Expand All @@ -187,12 +191,14 @@ void compute_pool_buffer_2d(const data_T &in_elem,
// Add pixel into line buffer, return pooling kernels
nnet::shift_line_buffer<data_T, CONFIG_T>(in_elem, line_buffer, kernel_data);

#pragma hls_unroll
// Can compute pooling output
if ((sX - lShiftX) == 0 && (sY - lShiftY) == 0 && pY > lShiftY - 1 && pX > lShiftX - 1) {
FiltLoop:
for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) {
//#pragma HLS PIPELINE

#pragma hls_unroll
// Retrieve data for current channel
PoolLoop:
for (unsigned i_ihw = 0; i_ihw < CONFIG_T::pool_height * CONFIG_T::pool_width; i_ihw++) {
Expand All @@ -201,7 +207,7 @@ void compute_pool_buffer_2d(const data_T &in_elem,

// Compute Pooling
res_pack[i_ic] =
reduce_pool<typename data_T::value_type, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T>(
reduce_pool<typename CONFIG_T::accum_t, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T>(
pool_window);
}

Expand Down Expand Up @@ -239,6 +245,7 @@ void pooling2d_buffer_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
[CONFIG_T::n_filt];
//#pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2

#pragma hls_pipeline_init_interval 1
ReadInputHeight:
for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) {
ReadInputWidth:
Expand All @@ -251,6 +258,7 @@ void pooling2d_buffer_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
}
}

#pragma hls_design block
template <class data_T, class res_T, typename CONFIG_T> void pooling2d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
//#pragma HLS inline region
switch (CONFIG_T::implementation) {
Expand Down Expand Up @@ -410,7 +418,7 @@ void compute_pool_buffer_1d(const data_T &in_elem, ac_channel<res_T> &res) {
}

// Compute Pooling
res_pack[i_ic] = reduce_pool<typename data_T::value_type, CONFIG_T::pool_width, CONFIG_T>(pool_window);
res_pack[i_ic] = reduce_pool<typename CONFIG_T::accum_t, CONFIG_T::pool_width, CONFIG_T>(pool_window);
}

// Write to output
Expand Down Expand Up @@ -441,6 +449,7 @@ void pooling1d_buffer_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
}
}

#pragma hls_design block
template <class data_T, class res_T, typename CONFIG_T> void pooling1d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
//#pragma HLS inline region
switch (CONFIG_T::implementation) {
Expand Down Expand Up @@ -474,19 +483,22 @@ template <class data_T, class res_T, typename CONFIG_T>
void compute_global_pool(const data_T &in_elem, typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt]) {
PoolFilt:
for (unsigned c = 0; c < CONFIG_T::n_filt; c++) {
#pragma hls_unroll

typename CONFIG_T::accum_t data_pack[data_T::size / CONFIG_T::n_filt];
//#pragma HLS ARRAY_PARTITION variable=data_pack complete dim=0

PixelLoop:
for (unsigned p = 0; p < data_T::size / CONFIG_T::n_filt; p++) {
#pragma hls_unroll
data_pack[p] = in_elem[p * CONFIG_T::n_filt + c];
}
data_window[c] = reduce_global_pool<typename CONFIG_T::accum_t, data_T::size / CONFIG_T::n_filt, CONFIG_T>(
data_window[c], data_pack);
}
}

#pragma hls_design block
template <class data_T, class res_T, typename CONFIG_T>
void global_pooling2d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0);
Expand All @@ -503,6 +515,7 @@ void global_pooling2d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {

PoolInitLoop:
for (unsigned i_init = 0; i_init < CONFIG_T::n_filt; i_init++) {
#pragma hls_unroll
data_window[i_init] = init;
}

Expand All @@ -524,6 +537,7 @@ void global_pooling2d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
//#pragma HLS DATA_PACK variable=res_pack
MaxPoolPack:
for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
#pragma hls_unroll
res_pack[i_pack] = data_window[i_pack];
}
res.write(res_pack);
Expand All @@ -537,13 +551,15 @@ void global_pooling2d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
//#pragma HLS DATA_PACK variable=res_pack
AvgPoolPack:
for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
#pragma hls_unroll
res_pack[i_pack] = data_window[i_pack] / (CONFIG_T::in_height * CONFIG_T::in_width);
}
res.write(res_pack);
}
}
}

#pragma hls_design block
template <class data_T, class res_T, typename CONFIG_T>
void global_pooling1d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0);
Expand All @@ -560,6 +576,7 @@ void global_pooling1d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {

PoolInitLoop:
for (unsigned i_init = 0; i_init < CONFIG_T::n_filt; i_init++) {
#pragma hls_unroll
data_window[i_init] = init;
}

Expand All @@ -578,6 +595,7 @@ void global_pooling1d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
//#pragma HLS DATA_PACK variable=res_pack
MaxPoolPack:
for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
#pragma hls_unroll
res_pack[i_pack] = data_window[i_pack];
}
res.write(res_pack);
Expand All @@ -591,6 +609,7 @@ void global_pooling1d_cl(ac_channel<data_T> &data, ac_channel<res_T> &res) {
//#pragma HLS DATA_PACK variable=res_pack
AvgPoolPack:
for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
#pragma hls_unroll
res_pack[i_pack] = data_window[i_pack] / CONFIG_T::n_in;
}
res.write(res_pack);
Expand Down

0 comments on commit 4a35ef4

Please sign in to comment.