Skip to content

Commit

Permalink
Batch Norm network config collision fix (#1859)
Browse files Browse the repository at this point in the history
* netconfig fix.
  • Loading branch information
Daniel Lowell committed Jun 27, 2019
1 parent 4298bd1 commit 7a8f787
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ void BatchNormForwardTraining(Handle& handle,
std::to_string(ygridsize) + "lx" + std::to_string(xlocalsize) + "ly" +
std::to_string(ylocalsize) + "rs" + std::to_string(static_cast<int>(resultsave)) +
"rr" + std::to_string(static_cast<int>(resultrunning)) + "segment" +
std::to_string(segment) + "n" + std::to_string(n) + "hw" + std::to_string(in_cstride);
std::to_string(segment) + "n" + std::to_string(n) + "c" + std::to_string(c) + "hw" +
std::to_string(in_cstride);

auto&& kernels = handle.GetKernels(algo_name, network_config);

Expand Down Expand Up @@ -761,8 +762,8 @@ void BatchNormForwardInference(Handle& handle,

std::string algo_name = "miopenBatchNormalizationForwardInference";
std::string network_config =
"n" + std::to_string(n) + "hw" + std::to_string(in_cstride) + "chw" +
std::to_string(in_nstride) + "segment" + std::to_string(segment) + "gx" +
"n" + std::to_string(n) + +"c" + std::to_string(c) + "hw" + std::to_string(in_cstride) +
"chw" + std::to_string(in_nstride) + "segment" + std::to_string(segment) + "gx" +
std::to_string(xgridsize) + "gy" + std::to_string(ygridsize) + "lx" +
std::to_string(xlocalsize) + "ly" + std::to_string(ylocalsize) + "fp16" +
std::to_string(static_cast<int>(bfp16parm)) + "fp32" +
Expand Down Expand Up @@ -1021,14 +1022,13 @@ void BatchNormBackward(Handle& handle,
}
std::string algo_name = "miopenBatchNormBackwardPropSpatial";
std::string network_config =
"variant" + std::to_string(variant) + "gx" + std::to_string(xgridsize) + "hw" +
std::to_string(in_cstride) + "gy" + std::to_string(ygridsize) + "lx" +
std::to_string(xlocalsize) + "ly" + std::to_string(ylocalsize) + "us" +
std::to_string(static_cast<int>(useSaved)) + "fp16" +
std::to_string(static_cast<int>(bfp16parm)) + "fp32" +
"variant" + std::to_string(variant) + "gx" + std::to_string(xgridsize) + "n" +
std::to_string(n) + "c" + std::to_string(c) + "hw" + std::to_string(in_cstride) + "gy" +
std::to_string(ygridsize) + "lx" + std::to_string(xlocalsize) + "ly" +
std::to_string(ylocalsize) + "us" + std::to_string(static_cast<int>(useSaved)) +
"fp16" + std::to_string(static_cast<int>(bfp16parm)) + "fp32" +
std::to_string(static_cast<int>(bfp32parm)) + "single" +
std::to_string(static_cast<int>(single)) + "c" + std::to_string(c) + "gcn" +
std::to_string(ldsgcn);
std::to_string(static_cast<int>(single)) + "gcn" + std::to_string(ldsgcn);

auto&& kernels = handle.GetKernels(algo_name, network_config);

Expand Down Expand Up @@ -1323,11 +1323,12 @@ void BatchNormBackward(Handle& handle,

std::string algo_name = "miopenBatchNormBackwardPropPerActivation";
std::string network_config =
std::to_string(xDesc.GetType()) + std::to_string(xgridsize) +
std::to_string(ygridsize) + std::to_string(xlocalsize) + std::to_string(ylocalsize) +
std::to_string(static_cast<int>(useSaved)) +
std::to_string(static_cast<int>(bfp16parm)) +
std::to_string(static_cast<int>(bfp32parm)) + std::to_string(in_nhw);
"gx" + std::to_string(xgridsize) + "gy" + std::to_string(ygridsize) + "lx" +
std::to_string(xlocalsize) + "ly" + std::to_string(ylocalsize) + "n" +
std::to_string(n) + "c" + std::to_string(c) + "hw" + std::to_string(in_cstride) + "u" +
std::to_string(static_cast<int>(useSaved)) + "fp16" +
std::to_string(static_cast<int>(bfp16parm)) + "fp32" +
std::to_string(static_cast<int>(bfp32parm)) + "nhw" + std::to_string(in_nhw);

auto&& kernels = handle.GetKernels(algo_name, network_config);

Expand Down

0 comments on commit 7a8f787

Please sign in to comment.