Skip to content

Commit

Permalink
Add mutext to ensure thread-safe gemm calls
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Jan 27, 2018
1 parent 53b5140 commit f1fb88e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ GemmGeometry CreateMIOpenGemmGeometry(int M,

GemmGeometry GetGemmGeometry(std::string algorithm_name, std::string network_config)
{
auto guard = get_gemm_geo_map_lock();
auto gemm_iterator = gemm_geo_map().find(std::make_pair(algorithm_name, network_config));
if(gemm_iterator != gemm_geo_map().end())
{
Expand Down
8 changes: 7 additions & 1 deletion src/gemm_geometry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ std::unordered_map<GemmKey, GemmGeometry, SimpleHash>& gemm_geo_map()
return data;
}

std::unique_lock<std::mutex> get_gemm_geo_map_lock()
{
static std::mutex m{};
return std::unique_lock<std::mutex>{m};
}

void GemmGeometry::EnableBetaKernel(bool enable) { beta_kern_req = enable; }

void GemmGeometry::FindSolution(
Expand Down Expand Up @@ -135,7 +141,7 @@ void GemmGeometry::FindSolution(
vgd,
"");
}

auto guard = get_gemm_geo_map_lock();
gemm_geo_map()[std::make_pair(algorithm_name, network_config)] = *this;
}

Expand Down
3 changes: 3 additions & 0 deletions src/include/miopen/gemm_geometry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <miopen/tensor.hpp>
#include <miopengemm/miogemm.hpp>

#include <mutex>

namespace miopen {

struct GemmGeometry
Expand Down Expand Up @@ -80,6 +82,7 @@ struct GemmGeometry

using GemmKey = std::pair<std::string, std::string>;
std::unordered_map<GemmKey, GemmGeometry, SimpleHash>& gemm_geo_map();
std::unique_lock<std::mutex> get_gemm_geo_map_lock();

} // namespace miopen

Expand Down

0 comments on commit f1fb88e

Please sign in to comment.