From 61126a83c6f827295acd29261a6bddf902c3fe91 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sat, 3 Dec 2022 13:42:05 +0800 Subject: [PATCH 1/5] Fix default context (cpu) --- CMakeLists.txt | 12 ++++++------ k2/CMakeLists.txt | 2 +- k2/csrc/CMakeLists.txt | 3 +-- k2/csrc/default_context.cu | 34 +++++++++++++++++++++++++++++++--- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3faca3339..289688ff1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,11 +91,11 @@ if(NOT K2_WITH_CUDA) set(K2_ENABLE_NVTX OFF CACHE BOOL "" FORCE) endif() -if(NOT K2_USE_PYTORCH) - message(FATAL_ERROR "\ - Please set K2_USE_PYTORCH to ON. - Support for other frameworks will be added later") -endif() +#if(NOT K2_USE_PYTORCH) + #message(FATAL_ERROR "\ + #Please set K2_USE_PYTORCH to ON. + #Support for other frameworks will be added later") +#endif() set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -286,8 +286,8 @@ endif() list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) -include(pybind11) if(K2_USE_PYTORCH) + include(pybind11) add_definitions(-DK2_USE_PYTORCH) include(torch) configure_file( diff --git a/k2/CMakeLists.txt b/k2/CMakeLists.txt index 6a7839d0e..937078956 100644 --- a/k2/CMakeLists.txt +++ b/k2/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(csrc) -add_subdirectory(python) if(K2_USE_PYTORCH) + add_subdirectory(python) # We use K2_TORCH_VERSION instead of TORCH_VERSION # since TORCH_VERSION may contain something like "+cpu", "+cu113" if(K2_TORCH_VERSION VERSION_GREATER_EQUAL 1.8 OR NOT K2_WITH_CUDA) diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 736668e9b..a57191888 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -79,14 +79,13 @@ set(context_srcs thread_pool.cu timer.cu top_sort.cu - torch_util.cu utils.cu nbest.cu ) if(K2_USE_PYTORCH) - list(APPEND context_srcs pytorch_context.cu) + list(APPEND context_srcs pytorch_context.cu torch_util.cu) else() list(APPEND context_srcs default_context.cu) endif() diff --git a/k2/csrc/default_context.cu b/k2/csrc/default_context.cu index d606a0d94..012c0e1b6 100644 --- a/k2/csrc/default_context.cu +++ b/k2/csrc/default_context.cu @@ -18,6 +18,7 @@ */ #include +#include #include // NOLINT #include "k2/csrc/context.h" @@ -28,11 +29,9 @@ namespace k2 { static constexpr std::size_t kAlignment = 64; -// TODO(haowen): most of implementations below should be updated later. class CpuContext : public Context { public: CpuContext() = default; - ContextPtr GetCpuContext() override { return shared_from_this(); } DeviceType GetDeviceType() const override { return kCpu; } void *Allocate(std::size_t bytes, void **deleter_context) override { @@ -52,11 +51,19 @@ class CpuContext : public Context { void Deallocate(void *data, void * /*deleter_context*/) override { free(data); } + + void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context, + void *dst) override { + DeviceType device_type = dst_context->GetDeviceType(); + K2_CHECK_EQ(device_type, kCpu); + memcpy(dst, src, num_bytes); + }; }; class CudaContext : public Context { public: explicit CudaContext(int32_t gpu_id) : gpu_id_(gpu_id) { +#ifdef K2_WITH_CUDA if (gpu_id_ != -1) { auto ret = cudaSetDevice(gpu_id_); K2_CHECK_CUDA_ERROR(ret); @@ -65,42 +72,59 @@ class CudaContext : public Context { // and handle GPU ids from multiple machines. auto ret = cudaStreamCreate(&stream_); K2_CHECK_CUDA_ERROR(ret); +#else + K2_LOG(FATAL) << "Unreachable code."; +#endif } - ContextPtr GetCpuContext() override { return k2::GetCpuContext(); } DeviceType GetDeviceType() const override { return kCuda; } int32_t GetDeviceId() const override { return gpu_id_; } void *Allocate(std::size_t bytes, void **deleter_context) override { void *p = nullptr; +#ifdef K2_WITH_CUDA if (bytes) { auto ret = cudaMalloc(&p, bytes); K2_CHECK_CUDA_ERROR(ret); } if (deleter_context != nullptr) *deleter_context = nullptr; +#endif return p; } + void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context, + void *dst) override{}; + bool IsCompatible(const Context &other) const override { return other.GetDeviceType() == kCuda && other.GetDeviceId() == gpu_id_; } void Deallocate(void *data, void * /*deleter_context*/) override { +#ifdef K2_WITH_CUDA auto ret = cudaFree(data); K2_CHECK_CUDA_ERROR(ret); +#endif } cudaStream_t GetCudaStream() const override { +#ifdef K2_WITH_CUDA return g_stream_override.OverrideStream(stream_); +#else + return cudaStream_t{}; +#endif } void Sync() const override { +#ifdef K2_WITH_CUDA auto ret = cudaStreamSynchronize(stream_); K2_CHECK_CUDA_ERROR(ret); +#endif } ~CudaContext() { +#ifdef K2_WITH_CUDA auto ret = cudaStreamDestroy(stream_); K2_CHECK_CUDA_ERROR(ret); +#endif } private: @@ -111,6 +135,7 @@ class CudaContext : public Context { ContextPtr GetCpuContext() { return std::make_shared(); } ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) { +#ifdef K2_WITH_CUDA static std::once_flag has_cuda_init_flag; static bool has_cuda = false; std::call_once(has_cuda_init_flag, []() { @@ -125,6 +150,9 @@ ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) { if (has_cuda) return std::make_shared(gpu_id); return GetCpuContext(); +#else + return GetCpuContext(); +#endif } } // namespace k2 From c76fa64fecf7171aced99bb6e9329ec85e34d530 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 5 Dec 2022 08:19:20 +0800 Subject: [PATCH 2/5] Recover default context --- CMakeLists.txt | 6 ---- k2/csrc/default_context.cu | 62 +++++++++++++++++++++++++++++++------- k2/csrc/pinned_context.cu | 1 + 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 289688ff1..8d3bb93f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,12 +91,6 @@ if(NOT K2_WITH_CUDA) set(K2_ENABLE_NVTX OFF CACHE BOOL "" FORCE) endif() -#if(NOT K2_USE_PYTORCH) - #message(FATAL_ERROR "\ - #Please set K2_USE_PYTORCH to ON. - #Support for other frameworks will be added later") -#endif() - set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") diff --git a/k2/csrc/default_context.cu b/k2/csrc/default_context.cu index 012c0e1b6..4d6330516 100644 --- a/k2/csrc/default_context.cu +++ b/k2/csrc/default_context.cu @@ -22,6 +22,8 @@ #include // NOLINT #include "k2/csrc/context.h" +#include "k2/csrc/cub.h" +#include "k2/csrc/device_guard.h" #include "k2/csrc/log.h" #include "k2/csrc/nvtx.h" @@ -55,21 +57,40 @@ class CpuContext : public Context { void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context, void *dst) override { DeviceType device_type = dst_context->GetDeviceType(); - K2_CHECK_EQ(device_type, kCpu); - memcpy(dst, src, num_bytes); - }; + switch (device_type) { + case kCpu: + memcpy(dst, src, num_bytes); + break; + case kCuda: { + // CPU -> CUDA + DeviceGuard guard(dst_context); + ContextPtr pinned_context = GetPinnedContext(); + auto region = NewRegion(pinned_context, num_bytes); + memcpy(region->data, src, num_bytes); + pinned_context->CopyDataTo(num_bytes, region->data, dst_context, dst); + break; + } + default: + K2_LOG(FATAL) << "Unsupported device type: " << device_type; + break; + } + } }; class CudaContext : public Context { public: explicit CudaContext(int32_t gpu_id) : gpu_id_(gpu_id) { #ifdef K2_WITH_CUDA - if (gpu_id_ != -1) { + if (gpu_id != -1) { auto ret = cudaSetDevice(gpu_id_); K2_CHECK_CUDA_ERROR(ret); + } else { + int current_gpu_id; + auto ret = cudaGetDevice(¤t_gpu_id); + K2_CHECK_CUDA_ERROR(ret); + gpu_id_ = current_gpu_id; } - // TODO(haowen): choose one from available GPUs if gpu_id == -1? - // and handle GPU ids from multiple machines. + auto ret = cudaStreamCreate(&stream_); K2_CHECK_CUDA_ERROR(ret); #else @@ -92,7 +113,27 @@ class CudaContext : public Context { } void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context, - void *dst) override{}; + void *dst) override{ + DeviceType device_type = dst_context->GetDeviceType(); + switch (device_type) { + case kCpu: { + cudaError_t ret = + cudaMemcpy(dst, src, num_bytes, cudaMemcpyDeviceToHost); + K2_CHECK_CUDA_ERROR(ret); + break; + } + case kCuda: { + cudaError_t ret = + cudaMemcpyAsync(dst, src, num_bytes, cudaMemcpyDeviceToDevice, + dst_context->GetCudaStream()); + K2_CHECK_CUDA_ERROR(ret); + break; + } + default: + K2_LOG(FATAL) << "Unsupported device type: " << device_type; + break; + } + }; bool IsCompatible(const Context &other) const override { return other.GetDeviceType() == kCuda && other.GetDeviceId() == gpu_id_; @@ -100,6 +141,7 @@ class CudaContext : public Context { void Deallocate(void *data, void * /*deleter_context*/) override { #ifdef K2_WITH_CUDA + DeviceGuard guard(gpu_id_); auto ret = cudaFree(data); K2_CHECK_CUDA_ERROR(ret); #endif @@ -114,17 +156,14 @@ class CudaContext : public Context { } void Sync() const override { -#ifdef K2_WITH_CUDA + DeviceGuard guard(gpu_id_); auto ret = cudaStreamSynchronize(stream_); K2_CHECK_CUDA_ERROR(ret); -#endif } ~CudaContext() { -#ifdef K2_WITH_CUDA auto ret = cudaStreamDestroy(stream_); K2_CHECK_CUDA_ERROR(ret); -#endif } private: @@ -147,6 +186,7 @@ ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) { K2_LOG(WARNING) << "CUDA is not available. Return a CPU context."; }); + DeviceGuard guard(gpu_id); if (has_cuda) return std::make_shared(gpu_id); return GetCpuContext(); diff --git a/k2/csrc/pinned_context.cu b/k2/csrc/pinned_context.cu index b46270a87..92be8c494 100644 --- a/k2/csrc/pinned_context.cu +++ b/k2/csrc/pinned_context.cu @@ -26,6 +26,7 @@ #include #include "k2/csrc/context.h" +#include "k2/csrc/device_guard.h" #include "k2/csrc/log.h" #include "k2/csrc/macros.h" #include "k2/csrc/nvtx.h" From a53baa0e7bfef1c8c4fe29d35e31b3fb90c47e47 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 7 Dec 2022 10:40:07 +0800 Subject: [PATCH 3/5] Fix zero bytes allocation --- k2/csrc/CMakeLists.txt | 1 + k2/csrc/default_context.cu | 26 ++++----- k2/csrc/default_context_test.cu | 96 +++++++++++++++++++++++++++++++++ k2/csrc/pytorch_context.cu | 2 +- 4 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 k2/csrc/default_context_test.cu diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index a57191888..4f0955394 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -165,6 +165,7 @@ if(K2_ENABLE_TESTS) array_ops_test.cu array_test.cu connect_test.cu + default_context_test.cu dtype_test.cu fsa_algo_test.cu fsa_test.cu diff --git a/k2/csrc/default_context.cu b/k2/csrc/default_context.cu index 4d6330516..60ec577fa 100644 --- a/k2/csrc/default_context.cu +++ b/k2/csrc/default_context.cu @@ -90,9 +90,7 @@ class CudaContext : public Context { K2_CHECK_CUDA_ERROR(ret); gpu_id_ = current_gpu_id; } - - auto ret = cudaStreamCreate(&stream_); - K2_CHECK_CUDA_ERROR(ret); + allocator_ = new cub::CachingDeviceAllocator(); #else K2_LOG(FATAL) << "Unreachable code."; #endif @@ -103,10 +101,9 @@ class CudaContext : public Context { void *Allocate(std::size_t bytes, void **deleter_context) override { void *p = nullptr; #ifdef K2_WITH_CUDA - if (bytes) { - auto ret = cudaMalloc(&p, bytes); - K2_CHECK_CUDA_ERROR(ret); - } + DeviceGuard guard(gpu_id_); + auto ret = allocator_->DeviceAllocate(&p, bytes); // the default stream is 0 + K2_CHECK_CUDA_ERROR(ret); if (deleter_context != nullptr) *deleter_context = nullptr; #endif return p; @@ -133,7 +130,7 @@ class CudaContext : public Context { K2_LOG(FATAL) << "Unsupported device type: " << device_type; break; } - }; + } bool IsCompatible(const Context &other) const override { return other.GetDeviceType() == kCuda && other.GetDeviceId() == gpu_id_; @@ -142,33 +139,32 @@ class CudaContext : public Context { void Deallocate(void *data, void * /*deleter_context*/) override { #ifdef K2_WITH_CUDA DeviceGuard guard(gpu_id_); - auto ret = cudaFree(data); + auto ret = allocator_->DeviceFree(data); K2_CHECK_CUDA_ERROR(ret); #endif } cudaStream_t GetCudaStream() const override { #ifdef K2_WITH_CUDA - return g_stream_override.OverrideStream(stream_); + return g_stream_override.OverrideStream(0); #else - return cudaStream_t{}; + return kCudaStreamInvalid; #endif } void Sync() const override { DeviceGuard guard(gpu_id_); - auto ret = cudaStreamSynchronize(stream_); + auto ret = cudaStreamSynchronize(GetCudaStream()); K2_CHECK_CUDA_ERROR(ret); } ~CudaContext() { - auto ret = cudaStreamDestroy(stream_); - K2_CHECK_CUDA_ERROR(ret); + delete allocator_; } private: int32_t gpu_id_; - cudaStream_t stream_; + cub::CachingDeviceAllocator* allocator_; }; ContextPtr GetCpuContext() { return std::make_shared(); } diff --git a/k2/csrc/default_context_test.cu b/k2/csrc/default_context_test.cu new file mode 100644 index 000000000..47d4820dd --- /dev/null +++ b/k2/csrc/default_context_test.cu @@ -0,0 +1,96 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "k2/csrc/test_utils.h" +// +#include "k2/csrc/array.h" +#include "k2/csrc/device_guard.h" +#include "k2/csrc/context.h" + +namespace k2 { + +// Use a separate function because there is a lambda function inside K2_EVAL(). +static void TestImpl() { + int num_devices; + auto ret = cudaGetDeviceCount(&num_devices); + K2_LOG(INFO) << "Number of devices: " << num_devices; + + // Set the default device to 1 + ret = cudaSetDevice(1); + K2_CHECK_CUDA_ERROR(ret); + + int current_device; + ret = cudaGetDevice(¤t_device); + K2_CHECK_CUDA_ERROR(ret); + EXPECT_EQ(current_device, 1); + + ContextPtr c = GetCudaContext(0); + EXPECT_EQ(c->GetDeviceId(), 0); + + { + std::vector data; + Array1 src(c, data); + EXPECT_EQ(src.Dim(), 0); + } + + // the default device should still be 1 + ret = cudaGetDevice(¤t_device); + K2_CHECK_CUDA_ERROR(ret); + EXPECT_EQ(current_device, 1); + + Array1 a(c, "[1 2]"); + EXPECT_EQ(a.Context()->GetDeviceId(), 0); + + // b uses the default device, which is 1 + Array1 b(GetCudaContext(), "[10 20]"); + EXPECT_EQ(b.Context()->GetDeviceId(), 1); + + int32_t *a_data = a.Data(); + int32_t *b_data = b.Data(); + + { + DeviceGuard guard(0); + // a is on device 0 + K2_EVAL( + a.Context(), 2, set_a, (int32_t i)->void { a_data[i] += 1; }); + CheckArrayData(a, {2, 3}); + } + + { + DeviceGuard guard(1); + // b is on device 1 + K2_EVAL( + b.Context(), 2, set_b, (int32_t i)->void { b_data[i] += 2; }); + + CheckArrayData(b, {12, 22}); + } + +} + + +TEST(DefaultContext, GetCudaContext) { + // skip this test is CUDA is not available + int n; + auto ret = cudaGetDeviceCount(&n); + if (ret == cudaSuccess && n > 1) { + TestImpl(); + } +} + +} // namespace k2 diff --git a/k2/csrc/pytorch_context.cu b/k2/csrc/pytorch_context.cu index 14cbdc6d9..c533b38a7 100644 --- a/k2/csrc/pytorch_context.cu +++ b/k2/csrc/pytorch_context.cu @@ -171,7 +171,7 @@ class PytorchCudaContext : public Context { return g_stream_override.OverrideStream( c10::cuda::getCurrentCUDAStream(gpu_id_)); #else - return cudaStream_t{}; + return kCudaStreamInvalid; #endif } From 63be046554aca96cb7ecb74fb15e58fd5759dc38 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 7 Dec 2022 11:10:11 +0800 Subject: [PATCH 4/5] Minor fixes --- CMakeLists.txt | 3 +++ k2/csrc/default_context.cu | 1 + k2/csrc/default_context_test.cu | 1 + k2/csrc/pinned_context.cu | 1 - 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d3bb93f4..206f52cd7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,7 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) if(K2_USE_PYTORCH) + message(STATUS "Build k2 with Pytorch.") include(pybind11) add_definitions(-DK2_USE_PYTORCH) include(torch) @@ -289,6 +290,8 @@ if(K2_USE_PYTORCH) ${PROJECT_SOURCE_DIR}/k2/python/k2/torch_version.py @ONLY ) message(STATUS "Generated ${PROJECT_BINARY_DIR}/torch_version.py") +else() + message(STATUS "Build k2 without Pytorch.") endif() if(K2_WITH_CUDA) diff --git a/k2/csrc/default_context.cu b/k2/csrc/default_context.cu index 60ec577fa..ffd1fa7a3 100644 --- a/k2/csrc/default_context.cu +++ b/k2/csrc/default_context.cu @@ -120,6 +120,7 @@ class CudaContext : public Context { break; } case kCuda: { + DeviceGuard guard(dst_context); cudaError_t ret = cudaMemcpyAsync(dst, src, num_bytes, cudaMemcpyDeviceToDevice, dst_context->GetCudaStream()); diff --git a/k2/csrc/default_context_test.cu b/k2/csrc/default_context_test.cu index 47d4820dd..1f7fab3d9 100644 --- a/k2/csrc/default_context_test.cu +++ b/k2/csrc/default_context_test.cu @@ -43,6 +43,7 @@ static void TestImpl() { ContextPtr c = GetCudaContext(0); EXPECT_EQ(c->GetDeviceId(), 0); + // Test zero byte allocation. { std::vector data; Array1 src(c, data); diff --git a/k2/csrc/pinned_context.cu b/k2/csrc/pinned_context.cu index 92be8c494..b46270a87 100644 --- a/k2/csrc/pinned_context.cu +++ b/k2/csrc/pinned_context.cu @@ -26,7 +26,6 @@ #include #include "k2/csrc/context.h" -#include "k2/csrc/device_guard.h" #include "k2/csrc/log.h" #include "k2/csrc/macros.h" #include "k2/csrc/nvtx.h" From a77e41b1c016b63cfdca1f7c7201b56313376b1a Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 7 Dec 2022 11:44:40 +0800 Subject: [PATCH 5/5] Fix style --- k2/csrc/default_context.cu | 5 +++-- k2/csrc/default_context_test.cu | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/k2/csrc/default_context.cu b/k2/csrc/default_context.cu index ffd1fa7a3..85d97be31 100644 --- a/k2/csrc/default_context.cu +++ b/k2/csrc/default_context.cu @@ -102,7 +102,8 @@ class CudaContext : public Context { void *p = nullptr; #ifdef K2_WITH_CUDA DeviceGuard guard(gpu_id_); - auto ret = allocator_->DeviceAllocate(&p, bytes); // the default stream is 0 + // the default stream is 0 + auto ret = allocator_->DeviceAllocate(&p, bytes); K2_CHECK_CUDA_ERROR(ret); if (deleter_context != nullptr) *deleter_context = nullptr; #endif @@ -110,7 +111,7 @@ class CudaContext : public Context { } void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context, - void *dst) override{ + void *dst) override { DeviceType device_type = dst_context->GetDeviceType(); switch (device_type) { case kCpu: { diff --git a/k2/csrc/default_context_test.cu b/k2/csrc/default_context_test.cu index 1f7fab3d9..80b034dad 100644 --- a/k2/csrc/default_context_test.cu +++ b/k2/csrc/default_context_test.cu @@ -81,7 +81,6 @@ static void TestImpl() { CheckArrayData(b, {12, 22}); } - }