From 878cff100d792427b5f990e137f79bddb6750d47 Mon Sep 17 00:00:00 2001 From: Dylan Lim Date: Fri, 15 Nov 2024 17:09:37 -0800 Subject: [PATCH] code formatting and refactor --- lib/kernels/include/kernels/accessor.h | 103 +++++++++++++---- .../include/kernels/copy_tensor_accessor.h | 19 ++++ .../include/kernels/managed_ff_stream.h | 2 + .../kernels/managed_per_device_ff_handle.h | 2 + lib/kernels/src/accessor.cc | 104 +----------------- lib/kernels/src/copy_tensor_accessor.cc | 48 ++++++++ lib/kernels/src/cpu/replicate_kernels.cc | 4 +- lib/kernels/src/cpu/reverse_kernels.cc | 12 +- lib/kernels/src/cuda/ops/linear_kernels.cu | 42 +++---- lib/kernels/src/managed_ff_stream.cc | 9 +- .../src/managed_per_device_ff_handle.cc | 13 +-- lib/kernels/test/src/test_attention_kernel.cc | 4 +- .../test/src/test_batch_matmul_kernel.cc | 4 +- .../test/src/test_batch_norm_kernel.cc | 13 ++- lib/kernels/test/src/test_combine_kernel.cc | 4 +- lib/kernels/test/src/test_concat_kernel.cc | 6 +- lib/kernels/test/src/test_dropout.cc | 4 +- lib/kernels/test/src/test_flat_kernel.cc | 13 ++- lib/kernels/test/src/test_gather_kernels.cc | 4 +- .../test/src/test_layer_norm_kernels.cc | 11 +- .../test/src/test_managed_ff_stream.cc | 24 ++-- .../src/test_managed_per_device_ff_handle.cc | 26 +++-- lib/kernels/test/src/test_partition_kernel.cc | 15 ++- lib/kernels/test/src/test_pool_2d_kernels.cc | 7 +- lib/kernels/test/src/test_reduction_kernel.cc | 7 +- lib/kernels/test/src/test_replicate_kernel.cc | 8 +- lib/kernels/test/src/test_reshape_kernel.cc | 4 +- lib/kernels/test/src/test_reverse_kernels.cc | 11 +- lib/kernels/test/src/test_softmax_kernel.cc | 4 +- lib/kernels/test/src/test_split_kernel.cc | 9 +- lib/kernels/test/src/test_transpose_kernel.cc | 6 +- lib/kernels/test/src/test_utils.cc | 24 ++-- lib/kernels/test/src/test_utils.h | 3 +- lib/local-execution/src/ops/pool_2d.cc | 16 +-- lib/local-execution/src/ops/reverse.cc | 12 +- .../test/src/test_local_cost_estimator.cc | 6 +- .../op-attrs/dim_ordered/dim_ordered.h | 5 +- .../include/op-attrs/make_datatype_value.h | 16 +++ .../src/op-attrs/make_datatype_value.cc | 25 +++++ lib/op-attrs/src/op-attrs/ops/attention.cc | 6 +- .../src/op-attrs/parallel_tensor_shape.cc | 2 +- lib/pcg/src/pcg/computation_graph_builder.cc | 25 +++-- .../parallel_computation_graph_builder.cc | 9 +- lib/runtime/src/ops/embedding.cc | 2 +- 44 files changed, 417 insertions(+), 276 deletions(-) create mode 100644 lib/kernels/include/kernels/copy_tensor_accessor.h create mode 100644 lib/kernels/src/copy_tensor_accessor.cc create mode 100644 lib/op-attrs/include/op-attrs/make_datatype_value.h create mode 100644 lib/op-attrs/src/op-attrs/make_datatype_value.cc diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 653c8db42d..487bc1f8f0 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -11,8 +11,6 @@ namespace FlexFlow { -struct Allocator; - class GenericTensorAccessorR { public: template @@ -42,7 +40,7 @@ class GenericTensorAccessorR { bool operator!=(GenericTensorAccessorR const &) const; template - real_type_t
const &at(std::vector const &indices) const { + real_type_t
const &at(std::vector const &indices) const { if (this->device_type != DeviceType::CPU) { throw mk_runtime_error("Calling at() on non-CPU allocated tensor"); } @@ -50,11 +48,31 @@ class GenericTensorAccessorR { throw mk_runtime_error(fmt::format( "Invalid access data type ({} != {})", this->data_type, DT)); } + if (indices.size() != this->shape.num_dims()) { + throw mk_runtime_error(fmt::format("Number of indices ({}) does not " + "match the number of dimensions ({}).", + indices.size(), + this->shape.num_dims())); + } using T = real_type_t
; - T const *data_ptr = static_cast(this->ptr); - size_t offset = calculate_index_offset(indices); + + int offset = 0; + int multiplier = 1; + for (int i = 0; i < this->shape.num_dims(); i++) { + if (indices.at(i) >= this->shape.at(legion_dim_t{i})) { + throw mk_runtime_error( + fmt::format("In {} dimension, attempting to access index {} " + "when only {} indexes exist", + i, + indices.at(i), + this->shape.at(legion_dim_t{i}))); + } + + offset += indices.at(i) * multiplier; + multiplier *= this->shape.at(legion_dim_t{i}); + } return data_ptr[offset]; } @@ -71,8 +89,6 @@ class GenericTensorAccessorR { decltype(ptr) const &, decltype(device_type) const &> tie() const; - - size_t calculate_index_offset(std::vector const &indices) const; }; std::string format_as(GenericTensorAccessorR const &); @@ -109,7 +125,7 @@ class GenericTensorAccessorW { operator GenericTensorAccessorR() const; template - real_type_t
&at(std::vector const &indices) { + real_type_t
&at(std::vector const &indices) { if (this->device_type != DeviceType::CPU) { throw mk_runtime_error("Calling at() on non-CPU allocated tensor"); } @@ -117,17 +133,37 @@ class GenericTensorAccessorW { throw mk_runtime_error(fmt::format( "Invalid access data type ({} != {})", this->data_type, DT)); } + if (indices.size() != this->shape.num_dims()) { + throw mk_runtime_error(fmt::format("Number of indices ({}) does not " + "match the number of dimensions ({}).", + indices.size(), + this->shape.num_dims())); + } using T = real_type_t
; T *data_ptr = static_cast(this->ptr); - size_t offset = calculate_index_offset(indices); + int offset = 0; + int multiplier = 1; + for (int i = 0; i < this->shape.num_dims(); i++) { + if (indices.at(i) >= this->shape.at(legion_dim_t{i})) { + throw mk_runtime_error( + fmt::format("In {} dimension, attempting to access index {} " + "when only {} indexes exist", + i, + indices.at(i), + this->shape.at(legion_dim_t{i}))); + } + + offset += indices.at(i) * multiplier; + multiplier *= this->shape.at(legion_dim_t{i}); + } return data_ptr[offset]; } template - real_type_t
&at(std::vector const &indices) const { + real_type_t
&at(std::vector const &indices) const { if (this->device_type != DeviceType::CPU) { throw mk_runtime_error("Calling at() on non-CPU allocated tensor"); } @@ -135,11 +171,31 @@ class GenericTensorAccessorW { throw mk_runtime_error(fmt::format( "Invalid access data type ({} != {})", this->data_type, DT)); } + if (indices.size() != this->shape.num_dims()) { + throw mk_runtime_error(fmt::format("Number of indices ({}) does not " + "match the number of dimensions ({}).", + indices.size(), + this->shape.num_dims())); + } using T = real_type_t
; T const *data_ptr = static_cast(this->ptr); - size_t offset = calculate_index_offset(indices); + int offset = 0; + int multiplier = 1; + for (int i = 0; i < this->shape.num_dims(); i++) { + if (indices.at(i) >= this->shape.at(legion_dim_t{i})) { + throw mk_runtime_error( + fmt::format("In {} dimension, attempting to access index {} " + "when only {} indexes exist", + i, + indices.at(i), + this->shape.at(legion_dim_t{i}))); + } + + offset += indices.at(i) * multiplier; + multiplier *= this->shape.at(legion_dim_t{i}); + } return data_ptr[offset]; } @@ -156,8 +212,6 @@ class GenericTensorAccessorW { decltype(ptr) const &, decltype(device_type) const &> tie() const; - - size_t calculate_index_offset(std::vector const &indices) const; }; std::string format_as(GenericTensorAccessorW const &); @@ -213,6 +267,21 @@ std::vector std::vector get_half_ptrs(std::vector const &); +int32_t *get_int32_ptr(GenericTensorAccessorW const &); +int64_t *get_int64_ptr(GenericTensorAccessorW const &); +float *get_float_ptr(GenericTensorAccessorW const &); +double *get_double_ptr(GenericTensorAccessorW const &); +half *get_half_ptr(GenericTensorAccessorW const &); +std::vector + get_int32_ptrs(std::vector const &); +std::vector + get_int64_ptrs(std::vector const &); +std::vector + get_float_ptrs(std::vector const &); +std::vector + get_double_ptrs(std::vector const &); +std::vector get_half_ptrs(std::vector const &); + template std::vector const *> get(std::vector const &accs) { @@ -239,14 +308,6 @@ std::pair void copy_accessor_data_to_l_from_r(GenericTensorAccessorW &dst_accessor, GenericTensorAccessorR const &src_accessor); -GenericTensorAccessorR - copy_tensor_accessor_r(GenericTensorAccessorR const &src_accessor, - Allocator &allocator); - -GenericTensorAccessorW - copy_tensor_accessor_w(GenericTensorAccessorW const &src_accessor, - Allocator &allocator); - } // namespace FlexFlow namespace FlexFlow { diff --git a/lib/kernels/include/kernels/copy_tensor_accessor.h b/lib/kernels/include/kernels/copy_tensor_accessor.h new file mode 100644 index 0000000000..da8af71e4f --- /dev/null +++ b/lib/kernels/include/kernels/copy_tensor_accessor.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_KERNELS_COPY_TENSOR_ACCESSOR_H +#define _FLEXFLOW_KERNELS_COPY_TENSOR_ACCESSOR_H + +#include "kernels/accessor.h" +#include "kernels/allocation.h" + +namespace FlexFlow { + +GenericTensorAccessorR + copy_tensor_accessor_r(GenericTensorAccessorR const &src_accessor, + Allocator &allocator); + +GenericTensorAccessorW + copy_tensor_accessor_w(GenericTensorAccessorW const &src_accessor, + Allocator &allocator); + +} // namespace FlexFlow + +#endif diff --git a/lib/kernels/include/kernels/managed_ff_stream.h b/lib/kernels/include/kernels/managed_ff_stream.h index 2f690b2eb3..26d5fb4911 100644 --- a/lib/kernels/include/kernels/managed_ff_stream.h +++ b/lib/kernels/include/kernels/managed_ff_stream.h @@ -19,6 +19,8 @@ struct ManagedFFStream { ffStream_t const &raw_stream() const; + void cleanup(); + private: ffStream_t *stream; }; diff --git a/lib/kernels/include/kernels/managed_per_device_ff_handle.h b/lib/kernels/include/kernels/managed_per_device_ff_handle.h index f9f944c6ff..035ea574de 100644 --- a/lib/kernels/include/kernels/managed_per_device_ff_handle.h +++ b/lib/kernels/include/kernels/managed_per_device_ff_handle.h @@ -24,6 +24,8 @@ struct ManagedPerDeviceFFHandle { PerDeviceFFHandle const &raw_handle() const; + void cleanup(); + private: PerDeviceFFHandle *handle; }; diff --git a/lib/kernels/src/accessor.cc b/lib/kernels/src/accessor.cc index 4cb5bd83a2..e56bded737 100644 --- a/lib/kernels/src/accessor.cc +++ b/lib/kernels/src/accessor.cc @@ -26,7 +26,7 @@ void copy_accessor_data_to_l_from_r( dst_accessor.ptr, src_accessor.ptr, num_bytes, cudaMemcpyDeviceToHost)); } else { assert(src_device_type == DeviceType::GPU); - assert(src_device_type == DeviceType::CPU); + assert(dst_device_type == DeviceType::GPU); checkCUDA(cudaMemcpy(dst_accessor.ptr, src_accessor.ptr, num_bytes, @@ -53,36 +53,6 @@ std::tupledata_type, this->shape, this->ptr, this->device_type); } -size_t GenericTensorAccessorW::calculate_index_offset( - std::vector const &indices) const { - - if (indices.size() != this->shape.num_dims()) { - throw mk_runtime_error(fmt::format( - "Number of indices ({}) does not match the number of dimensions ({}).", - indices.size(), - this->shape.num_dims())); - } - - size_t offset = 0; - size_t multiplier = 1; - - for (size_t i = 0; i < this->shape.num_dims(); i++) { - if (indices[i] >= this->shape.at(legion_dim_t(i))) { - throw mk_runtime_error( - fmt::format("In {} dimension, attempting to access index {} " - "when only {} indexes exist", - i, - indices[i], - this->shape.at(legion_dim_t(i)))); - } - - offset += indices[i] * multiplier; - multiplier *= this->shape.at(legion_dim_t(i)); - } - - return offset; -} - bool GenericTensorAccessorW::operator==( GenericTensorAccessorW const &other) const { return this->tie() == other.tie(); @@ -139,36 +109,6 @@ std::tupledata_type, this->shape, this->ptr, this->device_type); } -size_t GenericTensorAccessorR::calculate_index_offset( - std::vector const &indices) const { - - if (indices.size() != this->shape.num_dims()) { - throw mk_runtime_error(fmt::format( - "Number of indices ({}) does not match the number of dimensions ({}).", - indices.size(), - this->shape.num_dims())); - } - - ssize_t offset = 0; - size_t multiplier = 1; - - for (size_t i = 0; i < this->shape.num_dims(); i++) { - if (indices[i] >= this->shape.at(legion_dim_t(i))) { - throw mk_runtime_error( - fmt::format("In {} dimension, attempting to access index {} " - "when only {} indexes exist", - i, - indices[i], - this->shape.at(legion_dim_t(i)))); - } - - offset += indices[i] * multiplier; - multiplier *= this->shape.at(legion_dim_t(i)); - } - - return offset; -} - bool GenericTensorAccessorR::operator==( GenericTensorAccessorR const &other) const { return this->tie() == other.tie(); @@ -280,46 +220,4 @@ std::pair return std::make_pair(accessor.shape, accessor.data_type); } -template -struct CopyTensorAccessorW { - GenericTensorAccessorW operator()(GenericTensorAccessorW const &src_accessor, - Allocator &allocator) { - TensorShape shape = - get_tensor_shape(src_accessor.shape, src_accessor.data_type); - GenericTensorAccessorW dst_accessor = allocator.allocate_tensor(shape); - - copy_accessor_data_to_l_from_r(dst_accessor, src_accessor); - - return dst_accessor; - } -}; - -GenericTensorAccessorW - copy_tensor_accessor_w(GenericTensorAccessorW const &src_accessor, - Allocator &allocator) { - return DataTypeDispatch1{}( - src_accessor.data_type, src_accessor, allocator); -} - -template -struct CopyTensorAccessorR { - GenericTensorAccessorR operator()(GenericTensorAccessorR const &src_accessor, - Allocator &allocator) { - TensorShape shape = - get_tensor_shape(src_accessor.shape, src_accessor.data_type); - GenericTensorAccessorW dst_accessor = allocator.allocate_tensor(shape); - - copy_accessor_data_to_l_from_r(dst_accessor, src_accessor); - - return read_only_accessor_from_write_accessor(dst_accessor); - } -}; - -GenericTensorAccessorR - copy_tensor_accessor_r(GenericTensorAccessorR const &src_accessor, - Allocator &allocator) { - return DataTypeDispatch1{}( - src_accessor.data_type, src_accessor, allocator); -} - } // namespace FlexFlow diff --git a/lib/kernels/src/copy_tensor_accessor.cc b/lib/kernels/src/copy_tensor_accessor.cc new file mode 100644 index 0000000000..6a3ad8033a --- /dev/null +++ b/lib/kernels/src/copy_tensor_accessor.cc @@ -0,0 +1,48 @@ +#include "kernels/copy_tensor_accessor.h" +#include "kernels/datatype_dispatch.h" + +namespace FlexFlow { + +template +struct CopyTensorAccessorW { + GenericTensorAccessorW operator()(GenericTensorAccessorW const &src_accessor, + Allocator &allocator) { + TensorShape shape = + get_tensor_shape(src_accessor.shape, src_accessor.data_type); + GenericTensorAccessorW dst_accessor = allocator.allocate_tensor(shape); + + copy_accessor_data_to_l_from_r(dst_accessor, src_accessor); + + return dst_accessor; + } +}; + +GenericTensorAccessorW + copy_tensor_accessor_w(GenericTensorAccessorW const &src_accessor, + Allocator &allocator) { + return DataTypeDispatch1{}( + src_accessor.data_type, src_accessor, allocator); +} + +template +struct CopyTensorAccessorR { + GenericTensorAccessorR operator()(GenericTensorAccessorR const &src_accessor, + Allocator &allocator) { + TensorShape shape = + get_tensor_shape(src_accessor.shape, src_accessor.data_type); + GenericTensorAccessorW dst_accessor = allocator.allocate_tensor(shape); + + copy_accessor_data_to_l_from_r(dst_accessor, src_accessor); + + return read_only_accessor_from_write_accessor(dst_accessor); + } +}; + +GenericTensorAccessorR + copy_tensor_accessor_r(GenericTensorAccessorR const &src_accessor, + Allocator &allocator) { + return DataTypeDispatch1{}( + src_accessor.data_type, src_accessor, allocator); +} + +} // namespace FlexFlow diff --git a/lib/kernels/src/cpu/replicate_kernels.cc b/lib/kernels/src/cpu/replicate_kernels.cc index 25693b374d..cfcb44dac5 100644 --- a/lib/kernels/src/cpu/replicate_kernels.cc +++ b/lib/kernels/src/cpu/replicate_kernels.cc @@ -19,9 +19,9 @@ struct CPUBackwardKernel { GenericTensorAccessorW &input, size_t num_replicas) { using T = real_type_t
; - for (size_t i = 0; i < input.shape.num_elements(); i++) { + for (int i = 0; i < input.shape.num_elements(); i++) { T cur_sum = 0; - for (size_t j = 0; j < num_replicas; j++) { + for (int j = 0; j < num_replicas; j++) { cur_sum += output.at
({i, j}); } input.at
({i}) = cur_sum; diff --git a/lib/kernels/src/cpu/reverse_kernels.cc b/lib/kernels/src/cpu/reverse_kernels.cc index e5b3719d74..bc73c80e9e 100644 --- a/lib/kernels/src/cpu/reverse_kernels.cc +++ b/lib/kernels/src/cpu/reverse_kernels.cc @@ -11,13 +11,13 @@ struct CPUReverseForwardKernel { GenericTensorAccessorW &output) { assert(input.data_type == DT && output.data_type == DT); - size_t num_out_blocks = input.shape.at(legion_dim_t(0)); - size_t reverse_dim_size = input.shape.at(legion_dim_t(1)); - size_t in_block_size = input.shape.at(legion_dim_t(2)); + int num_out_blocks = input.shape.at(legion_dim_t(0)); + int reverse_dim_size = input.shape.at(legion_dim_t(1)); + int in_block_size = input.shape.at(legion_dim_t(2)); - for (size_t block_idx = 0; block_idx < num_out_blocks; block_idx++) { - for (size_t rev_idx = 0; rev_idx < reverse_dim_size; rev_idx++) { - for (size_t i = 0; i < in_block_size; i++) { + for (int block_idx = 0; block_idx < num_out_blocks; block_idx++) { + for (int rev_idx = 0; rev_idx < reverse_dim_size; rev_idx++) { + for (int i = 0; i < in_block_size; i++) { output.at
({block_idx, rev_idx, i}) = input.at
({num_out_blocks - 1 - block_idx, reverse_dim_size - 1 - rev_idx, diff --git a/lib/kernels/src/cuda/ops/linear_kernels.cu b/lib/kernels/src/cuda/ops/linear_kernels.cu index f13ebee67e..6b069218fa 100644 --- a/lib/kernels/src/cuda/ops/linear_kernels.cu +++ b/lib/kernels/src/cuda/ops/linear_kernels.cu @@ -135,14 +135,14 @@ void forward_kernel(cudaStream_t stream, batch_size, in_dim, &alpha, - reinterpret_cast(weight_ptr), + static_cast(weight_ptr), weight_type, in_dim, - reinterpret_cast(input_ptr), + static_cast(input_ptr), input_type, in_dim, &beta, - reinterpret_cast(output_ptr), + static_cast(output_ptr), output_type, out_dim, compute_type, @@ -156,14 +156,14 @@ void forward_kernel(cudaStream_t stream, batch_size, 1, &alpha, - reinterpret_cast(bias_ptr), + static_cast(bias_ptr), weight_type, 1, - reinterpret_cast(m.one_ptr), + static_cast(m.one_ptr), CUDA_R_32F, 1, &alpha, - reinterpret_cast(output_ptr), + static_cast(output_ptr), output_type, out_dim, compute_type, @@ -174,10 +174,10 @@ void forward_kernel(cudaStream_t stream, m.actiDesc, &alpha, m.outputTensor, - reinterpret_cast(output_ptr), + static_cast(output_ptr), &beta, m.outputTensor, - reinterpret_cast(output_ptr))); + static_cast(output_ptr))); } else if (m.activation == Activation::GELU) { size_t elements = size_t_from_int(out_dim) * size_t_from_int(batch_size); constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI) @@ -217,14 +217,14 @@ void backward_kernel(cudaStream_t stream, if (m.activation.has_value()) { if (m.activation == Activation::RELU) { relu_backward_kernel(m.output_type, - reinterpret_cast(output_grad_ptr), - reinterpret_cast(output_ptr), + static_cast(output_grad_ptr), + static_cast(output_ptr), output_size, stream); } else if (m.activation == Activation::SIGMOID) { sigmoid_backward_kernel(m.output_type, - reinterpret_cast(output_grad_ptr), - reinterpret_cast(output_ptr), + static_cast(output_grad_ptr), + static_cast(output_ptr), output_size, stream); } else { @@ -241,14 +241,14 @@ void backward_kernel(cudaStream_t stream, out_dim, batch_size, &alpha, - reinterpret_cast(input_ptr), + static_cast(input_ptr), input_type, in_dim, - reinterpret_cast(output_grad_ptr), + static_cast(output_grad_ptr), output_type, out_dim, &alpha, - reinterpret_cast(kernel_grad_ptr), + static_cast(kernel_grad_ptr), weight_type, in_dim, compute_type, @@ -290,14 +290,14 @@ void backward_kernel(cudaStream_t stream, out_dim, batch_size, &alpha, - reinterpret_cast(m.one_ptr), + static_cast(m.one_ptr), CUDA_R_32F, 1, - reinterpret_cast(output_grad_ptr), + static_cast(output_grad_ptr), output_type, out_dim, &alpha, - reinterpret_cast(bias_grad_ptr), + static_cast(bias_grad_ptr), weight_type, 1, compute_type, @@ -313,14 +313,14 @@ void backward_kernel(cudaStream_t stream, batch_size, out_dim, &alpha, - reinterpret_cast(kernel_ptr), + static_cast(kernel_ptr), weight_type, in_dim, - reinterpret_cast(output_grad_ptr), + static_cast(output_grad_ptr), output_type, out_dim, &alpha, - reinterpret_cast(input_grad_ptr), + static_cast(input_grad_ptr), input_type, in_dim, compute_type, diff --git a/lib/kernels/src/managed_ff_stream.cc b/lib/kernels/src/managed_ff_stream.cc index a8b44dc1d3..f0348aa91c 100644 --- a/lib/kernels/src/managed_ff_stream.cc +++ b/lib/kernels/src/managed_ff_stream.cc @@ -12,16 +12,17 @@ ManagedFFStream::ManagedFFStream(ManagedFFStream &&other) noexcept ManagedFFStream &ManagedFFStream::operator=(ManagedFFStream &&other) noexcept { if (this != &other) { - if (this->stream != nullptr) { - checkCUDA(cudaStreamDestroy(*this->stream)); - delete stream; - } + this->cleanup(); this->stream = std::exchange(other.stream, nullptr); } return *this; } ManagedFFStream::~ManagedFFStream() { + this->cleanup(); +} + +void ManagedFFStream::cleanup() { if (this->stream != nullptr) { checkCUDA(cudaStreamDestroy(*this->stream)); delete this->stream; diff --git a/lib/kernels/src/managed_per_device_ff_handle.cc b/lib/kernels/src/managed_per_device_ff_handle.cc index 5bd49dc26f..9f1737240e 100644 --- a/lib/kernels/src/managed_per_device_ff_handle.cc +++ b/lib/kernels/src/managed_per_device_ff_handle.cc @@ -5,7 +5,7 @@ namespace FlexFlow { ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle( size_t workSpaceSize, bool allowTensorOpMathConversion) { - this->handle = new PerDeviceFFHandle; + this->handle = new PerDeviceFFHandle{}; this->handle->workSpaceSize = workSpaceSize; this->handle->allowTensorOpMathConversion = allowTensorOpMathConversion; @@ -21,18 +21,17 @@ ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle( ManagedPerDeviceFFHandle &ManagedPerDeviceFFHandle::operator=( ManagedPerDeviceFFHandle &&other) noexcept { if (this != &other) { - if (this->handle != nullptr) { - checkCUDNN(cudnnDestroy(this->handle->dnn)); - checkCUBLAS(cublasDestroy(this->handle->blas)); - checkCUDA(cudaFree(this->handle->workSpace)); - delete this->handle; - } + this->cleanup(); this->handle = std::exchange(other.handle, nullptr); } return *this; } ManagedPerDeviceFFHandle::~ManagedPerDeviceFFHandle() { + this->cleanup(); +} + +void ManagedPerDeviceFFHandle::cleanup() { if (this->handle != nullptr) { checkCUDNN(cudnnDestroy(this->handle->dnn)); checkCUBLAS(cublasDestroy(this->handle->blas)); diff --git a/lib/kernels/test/src/test_attention_kernel.cc b/lib/kernels/test/src/test_attention_kernel.cc index aae3676107..023233ecb0 100644 --- a/lib/kernels/test/src/test_attention_kernel.cc +++ b/lib/kernels/test/src/test_attention_kernel.cc @@ -13,7 +13,9 @@ TEST_SUITE(FF_TEST_SUITE) { size_t qoSeqLength = 20, kvSeqLength = 20; ManagedFFStream managed_stream{}; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_batch_matmul_kernel.cc b/lib/kernels/test/src/test_batch_matmul_kernel.cc index b87f3978b5..8a11a069f5 100644 --- a/lib/kernels/test/src/test_batch_matmul_kernel.cc +++ b/lib/kernels/test/src/test_batch_matmul_kernel.cc @@ -15,7 +15,9 @@ TEST_SUITE(FF_TEST_SUITE) { size_t seq_length = -1; ManagedFFStream managed_stream{}; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_batch_norm_kernel.cc b/lib/kernels/test/src/test_batch_norm_kernel.cc index a258a27a34..611069ac93 100644 --- a/lib/kernels/test/src/test_batch_norm_kernel.cc +++ b/lib/kernels/test/src/test_batch_norm_kernel.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" #include "kernels/batch_norm_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -9,7 +10,9 @@ TEST_SUITE(FF_TEST_SUITE) { size_t output_n = 1, output_c = 10, output_h = 10, output_w = 10; ManagedFFStream managed_stream{}; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -36,12 +39,12 @@ TEST_SUITE(FF_TEST_SUITE) { create_random_filled_accessor_w(input_shape, allocator); GenericTensorAccessorW output_accessor = create_random_filled_accessor_w(output_shape, allocator); - GenericTensorAccessorW scale_accessor = - create_filled_accessor_w(scale_shape, allocator, DataTypeValue(1.0f)); + GenericTensorAccessorW scale_accessor = create_filled_accessor_w( + scale_shape, allocator, make_float_data_type_value(1)); SUBCASE("forward_kernel") { - GenericTensorAccessorW bias_accessor = - create_filled_accessor_w(bias_shape, allocator, DataTypeValue(0.0f)); + GenericTensorAccessorW bias_accessor = create_filled_accessor_w( + bias_shape, allocator, make_float_data_type_value(0)); Kernels::BatchNorm::forward_kernel(managed_stream.raw_stream(), state, diff --git a/lib/kernels/test/src/test_combine_kernel.cc b/lib/kernels/test/src/test_combine_kernel.cc index 60179ee75b..a4688a1030 100644 --- a/lib/kernels/test/src/test_combine_kernel.cc +++ b/lib/kernels/test/src/test_combine_kernel.cc @@ -6,7 +6,9 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Call Combine Forward and Backward Kernels") { - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_concat_kernel.cc b/lib/kernels/test/src/test_concat_kernel.cc index 841d53133c..b299f5dea8 100644 --- a/lib/kernels/test/src/test_concat_kernel.cc +++ b/lib/kernels/test/src/test_concat_kernel.cc @@ -8,9 +8,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test concat kernel forward and backward") { size_t num_inputs = 2; size_t size_per_input = 10; - ff_dim_t concat_axis = ff_dim_t(1); + ff_dim_t concat_axis = ff_dim_t{1}; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; TensorShape input_shape = diff --git a/lib/kernels/test/src/test_dropout.cc b/lib/kernels/test/src/test_dropout.cc index bee00d990d..4be2bdf7bb 100644 --- a/lib/kernels/test/src/test_dropout.cc +++ b/lib/kernels/test/src/test_dropout.cc @@ -18,7 +18,9 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape output_shape = input_shape; ManagedFFStream managed_stream{}; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_flat_kernel.cc b/lib/kernels/test/src/test_flat_kernel.cc index 9febf4bcc4..b8f128b761 100644 --- a/lib/kernels/test/src/test_flat_kernel.cc +++ b/lib/kernels/test/src/test_flat_kernel.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" #include "kernels/flat_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -7,7 +8,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Flat Kernel") { Allocator allocator = create_local_cuda_memory_allocator(); - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; TensorShape input_shape = @@ -16,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { GenericTensorAccessorR input_accessor = read_only_accessor_from_write_accessor(create_filled_accessor_w( - input_shape, allocator, DataTypeValue(2.0f))); + input_shape, allocator, make_float_data_type_value(2))); SUBCASE("forward_kernel") { GenericTensorAccessorW output_accessor = @@ -31,9 +34,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("backward_kernel") { GenericTensorAccessorW output_grad_accessor = create_filled_accessor_w( - output_shape, allocator, DataTypeValue(0.0f)); - GenericTensorAccessorW input_grad_accessor = - create_filled_accessor_w(input_shape, allocator, DataTypeValue(1.0f)); + output_shape, allocator, make_float_data_type_value(0)); + GenericTensorAccessorW input_grad_accessor = create_filled_accessor_w( + input_shape, allocator, make_float_data_type_value(1)); Kernels::Flat::backward_kernel(managed_stream.raw_stream(), input_accessor, diff --git a/lib/kernels/test/src/test_gather_kernels.cc b/lib/kernels/test/src/test_gather_kernels.cc index 4f9fa02a1a..7f97563217 100644 --- a/lib/kernels/test/src/test_gather_kernels.cc +++ b/lib/kernels/test/src/test_gather_kernels.cc @@ -5,7 +5,9 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Gather Forward and Backward Kernel") { - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_layer_norm_kernels.cc b/lib/kernels/test/src/test_layer_norm_kernels.cc index 87fc88f081..7d7298f83d 100644 --- a/lib/kernels/test/src/test_layer_norm_kernels.cc +++ b/lib/kernels/test/src/test_layer_norm_kernels.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" #include "kernels/layer_norm_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -17,7 +18,9 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape feature_shape = make_tensor_shape_from_legion_dims({feature_size}, DataType::FLOAT); - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -32,14 +35,14 @@ TEST_SUITE(FF_TEST_SUITE) { GenericTensorAccessorR input_accessor = create_random_filled_accessor_r(input_shape, allocator); - GenericTensorAccessorW gamma_accessor = - create_filled_accessor_w(feature_shape, allocator, DataTypeValue(1.0f)); + GenericTensorAccessorW gamma_accessor = create_filled_accessor_w( + feature_shape, allocator, make_float_data_type_value(1)); SUBCASE("forward_kernel") { GenericTensorAccessorW output_accessor = allocator.allocate_tensor(output_shape); GenericTensorAccessorW beta_accessor = create_filled_accessor_w( - feature_shape, allocator, DataTypeValue(0.0f)); + feature_shape, allocator, make_float_data_type_value(0)); Kernels::LayerNorm::forward_kernel(managed_stream.raw_stream(), state, diff --git a/lib/kernels/test/src/test_managed_ff_stream.cc b/lib/kernels/test/src/test_managed_ff_stream.cc index ce8a808454..605aa6ffa1 100644 --- a/lib/kernels/test/src/test_managed_ff_stream.cc +++ b/lib/kernels/test/src/test_managed_ff_stream.cc @@ -4,26 +4,28 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Test Managed FF Stream") { + TEST_CASE("ManagedFFStream") { ManagedFFStream base_stream{}; ffStream_t const *base_stream_ptr = &base_stream.raw_stream(); - SUBCASE("Test ManagedFFStream Move Constructor") { + SUBCASE("move constructor") { ManagedFFStream new_stream(std::move(base_stream)); CHECK(&base_stream.raw_stream() == nullptr); CHECK(&new_stream.raw_stream() == base_stream_ptr); } - SUBCASE("Test ManagedFFStream Assignment Operator") { - ManagedFFStream new_stream{}; - new_stream = std::move(base_stream); - CHECK(&base_stream.raw_stream() == nullptr); - CHECK(&new_stream.raw_stream() == base_stream_ptr); - } + SUBCASE("move assignment operator") { + SUBCASE("move assign to other") { + ManagedFFStream new_stream{}; + new_stream = std::move(base_stream); + CHECK(&base_stream.raw_stream() == nullptr); + CHECK(&new_stream.raw_stream() == base_stream_ptr); + } - SUBCASE("Test Self-Assignment") { - base_stream = std::move(base_stream); - CHECK(&base_stream.raw_stream() == base_stream_ptr); + SUBCASE("move assign to self") { + base_stream = std::move(base_stream); + CHECK(&base_stream.raw_stream() == base_stream_ptr); + } } } } diff --git a/lib/kernels/test/src/test_managed_per_device_ff_handle.cc b/lib/kernels/test/src/test_managed_per_device_ff_handle.cc index d39da03ba9..de3e5b72b1 100644 --- a/lib/kernels/test/src/test_managed_per_device_ff_handle.cc +++ b/lib/kernels/test/src/test_managed_per_device_ff_handle.cc @@ -4,33 +4,35 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Test Managed Per Device FF Handle") { + TEST_CASE("ManagedPerDeviceFFHandle") { ManagedPerDeviceFFHandle base_handle{1024 * 1024, true}; PerDeviceFFHandle const *base_handle_ptr = &base_handle.raw_handle(); - SUBCASE("Test ManagedPerDeviceFFHandle Constructor") { + SUBCASE("constructor") { CHECK(base_handle.raw_handle().workSpaceSize == 1024 * 1024); CHECK(base_handle.raw_handle().allowTensorOpMathConversion == true); } - SUBCASE("Test ManagedPerDeviceFFHandle Move Constructor") { + SUBCASE("move constructor") { ManagedPerDeviceFFHandle new_handle(std::move(base_handle)); CHECK(&base_handle.raw_handle() == nullptr); CHECK(&new_handle.raw_handle() == base_handle_ptr); } - SUBCASE("Test ManagedPerDeviceFFHandle Assignment Operator") { - ManagedPerDeviceFFHandle new_handle{1024 * 1024, true}; - new_handle = std::move(base_handle); + SUBCASE("move assignment operator") { + SUBCASE("move assign to other") { + ManagedPerDeviceFFHandle new_handle{1024 * 1024, true}; + new_handle = std::move(base_handle); - CHECK(&base_handle.raw_handle() == nullptr); - CHECK(&new_handle.raw_handle() == base_handle_ptr); - } + CHECK(&base_handle.raw_handle() == nullptr); + CHECK(&new_handle.raw_handle() == base_handle_ptr); + } - SUBCASE("Test Self-Assignment") { - base_handle = std::move(base_handle); - CHECK(&base_handle.raw_handle() == base_handle_ptr); + SUBCASE("move assign to self") { + base_handle = std::move(base_handle); + CHECK(&base_handle.raw_handle() == base_handle_ptr); + } } } } diff --git a/lib/kernels/test/src/test_partition_kernel.cc b/lib/kernels/test/src/test_partition_kernel.cc index 079af64a8c..4beae62553 100644 --- a/lib/kernels/test/src/test_partition_kernel.cc +++ b/lib/kernels/test/src/test_partition_kernel.cc @@ -1,12 +1,15 @@ #include "doctest/doctest.h" #include "kernels/partition_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Partition Forward and Backward") { - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -19,8 +22,8 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape output_shape = input_shape; SUBCASE("forward_kernel") { - GenericTensorAccessorR input_accessor = - create_filled_accessor_r(input_shape, allocator, DataTypeValue(1.0f)); + GenericTensorAccessorR input_accessor = create_filled_accessor_r( + input_shape, allocator, make_float_data_type_value(1)); GenericTensorAccessorW output_accessor = allocator.allocate_tensor(output_shape); @@ -32,9 +35,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("backward_kernel") { GenericTensorAccessorR output_grad_accessor = create_filled_accessor_r( - output_shape, allocator, DataTypeValue(1.0f)); - GenericTensorAccessorW input_grad_accessor = - create_filled_accessor_w(input_shape, allocator, DataTypeValue(2.0f)); + output_shape, allocator, make_float_data_type_value(1)); + GenericTensorAccessorW input_grad_accessor = create_filled_accessor_w( + input_shape, allocator, make_float_data_type_value(2)); Kernels::Repartition::backward_kernel(managed_stream.raw_stream(), state, diff --git a/lib/kernels/test/src/test_pool_2d_kernels.cc b/lib/kernels/test/src/test_pool_2d_kernels.cc index 76b966ea15..2a4d3caf9a 100644 --- a/lib/kernels/test/src/test_pool_2d_kernels.cc +++ b/lib/kernels/test/src/test_pool_2d_kernels.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" #include "kernels/pool_2d_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -12,7 +13,9 @@ TEST_SUITE(FF_TEST_SUITE) { PoolOp pool_type = PoolOp::MAX; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -57,7 +60,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("backward_kernel") { GenericTensorAccessorW output_grad_accessor = create_filled_accessor_w( - output_shape, allocator, DataTypeValue(1.0f)); + output_shape, allocator, make_float_data_type_value(1)); GenericTensorAccessorW input_grad_accessor = allocator.allocate_tensor(input_shape); diff --git a/lib/kernels/test/src/test_reduction_kernel.cc b/lib/kernels/test/src/test_reduction_kernel.cc index ddbe826c70..3c3e828049 100644 --- a/lib/kernels/test/src/test_reduction_kernel.cc +++ b/lib/kernels/test/src/test_reduction_kernel.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" #include "kernels/reduction_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -10,7 +11,9 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape input_shape = make_tensor_shape_from_legion_dims( {10, 10, 10, 10, 10}, DataType::FLOAT); - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -36,7 +39,7 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape output_shape = input_shape; GenericTensorAccessorR output_grad_accessor = create_filled_accessor_r( - output_shape, allocator, DataTypeValue(1.0f)); + output_shape, allocator, make_float_data_type_value(1)); GenericTensorAccessorW input_grad_accessor = allocator.allocate_tensor(input_shape); diff --git a/lib/kernels/test/src/test_replicate_kernel.cc b/lib/kernels/test/src/test_replicate_kernel.cc index 1d9e0677b7..27223cc7b5 100644 --- a/lib/kernels/test/src/test_replicate_kernel.cc +++ b/lib/kernels/test/src/test_replicate_kernel.cc @@ -13,7 +13,9 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape output_shape = make_tensor_shape_from_legion_dims({100}, DataType::FLOAT); - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -53,7 +55,9 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape output_shape = make_tensor_shape_from_legion_dims({5, num_replicas}, DataType::FLOAT); - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator gpu_allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_reshape_kernel.cc b/lib/kernels/test/src/test_reshape_kernel.cc index 41aaac9c3e..55797aeff6 100644 --- a/lib/kernels/test/src/test_reshape_kernel.cc +++ b/lib/kernels/test/src/test_reshape_kernel.cc @@ -5,7 +5,9 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Reshape Forward and Backward") { - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_reverse_kernels.cc b/lib/kernels/test/src/test_reverse_kernels.cc index 436b788a99..4adf79847a 100644 --- a/lib/kernels/test/src/test_reverse_kernels.cc +++ b/lib/kernels/test/src/test_reverse_kernels.cc @@ -1,6 +1,7 @@ #include "doctest/doctest.h" #include "kernels/reverse_kernels.h" #include "kernels/reverse_kernels_cpu.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -14,7 +15,9 @@ TEST_SUITE(FF_TEST_SUITE) { {num_out_blks, reverse_dim_size, in_blk_size}, DataType::FLOAT); TensorShape output_shape = input_shape; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -22,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("forward_kernel") { GenericTensorAccessorR input_accessor = read_only_accessor_from_write_accessor(create_filled_accessor_w( - input_shape, allocator, DataTypeValue(1.0f))); + input_shape, allocator, make_float_data_type_value(1))); GenericTensorAccessorW output_accessor = allocator.allocate_tensor(output_shape); @@ -65,7 +68,9 @@ TEST_SUITE(FF_TEST_SUITE) { {num_out_blks, reverse_dim_size, in_blk_size}, DataType::FLOAT); TensorShape output_shape = input_shape; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator gpu_allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_softmax_kernel.cc b/lib/kernels/test/src/test_softmax_kernel.cc index b293d1ce75..bb6bcb949b 100644 --- a/lib/kernels/test/src/test_softmax_kernel.cc +++ b/lib/kernels/test/src/test_softmax_kernel.cc @@ -8,7 +8,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Softmax Kernel Operations") { int input_n = 1, input_c = 1, input_h = 1, input_w = 100, channels = 100; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_split_kernel.cc b/lib/kernels/test/src/test_split_kernel.cc index 114077d6ec..34993fa151 100644 --- a/lib/kernels/test/src/test_split_kernel.cc +++ b/lib/kernels/test/src/test_split_kernel.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" #include "kernels/split_kernels.h" +#include "op-attrs/make_datatype_value.h" #include "test_utils.h" #include "utils/containers/repeat.h" @@ -12,7 +13,9 @@ TEST_SUITE(FF_TEST_SUITE) { coord_t in_blk_size = 100; coord_t num_blks = 1; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); @@ -49,8 +52,8 @@ TEST_SUITE(FF_TEST_SUITE) { output_grad_ptrs[i] = output_grad_accessor.get_float_ptr(); } - GenericTensorAccessorW input_grad_accessor = - create_filled_accessor_w(input_shape, allocator, DataTypeValue(0.0f)); + GenericTensorAccessorW input_grad_accessor = create_filled_accessor_w( + input_shape, allocator, make_float_data_type_value(0)); Kernels::Split::backward_kernel(managed_stream.raw_stream(), input_grad_accessor.get_float_ptr(), diff --git a/lib/kernels/test/src/test_transpose_kernel.cc b/lib/kernels/test/src/test_transpose_kernel.cc index 5c5e9b31f8..b9ef82a764 100644 --- a/lib/kernels/test/src/test_transpose_kernel.cc +++ b/lib/kernels/test/src/test_transpose_kernel.cc @@ -7,9 +7,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Transpose Kernel Operations") { std::size_t num_dims = 2; - std::vector perm = {ff_dim_t(0), ff_dim_t(1)}; + std::vector perm = {ff_dim_t{0}, ff_dim_t{1}}; - ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); + ManagedPerDeviceFFHandle managed_handle{ + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true}; ManagedFFStream managed_stream{}; Allocator allocator = create_local_cuda_memory_allocator(); diff --git a/lib/kernels/test/src/test_utils.cc b/lib/kernels/test/src/test_utils.cc index a59747b376..bfed1241ba 100644 --- a/lib/kernels/test/src/test_utils.cc +++ b/lib/kernels/test/src/test_utils.cc @@ -137,25 +137,35 @@ GenericTensorAccessorW } template -struct PrintCPUAccessorR { +struct Print2DCPUAccessorR { void operator()(GenericTensorAccessorR const &accessor, std::ostream &stream) { using T = real_type_t
; T const *data_ptr = accessor.get
(); - for (size_t i = 0; i < accessor.shape.num_elements(); i++) { - stream << data_ptr[i] << " "; + int rows = accessor.shape.at(legion_dim_t{0}); + int cols = accessor.shape.at(legion_dim_t{1}); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + stream << data_ptr[i * cols + j]; + + if (j < cols - 1) { + stream << " "; + } + } + stream << std::endl; } - stream << "\n"; } }; -void print_tensor_accessor_contents(GenericTensorAccessorR const &accessor, - std::ostream &stream) { +void print_2d_tensor_accessor_contents(GenericTensorAccessorR const &accessor, + std::ostream &stream) { Allocator cpu_allocator = create_local_cpu_memory_allocator(); GenericTensorAccessorR cpu_accessor = copy_accessor_r_to_cpu_if_necessary(accessor, cpu_allocator); - DataTypeDispatch1{}(accessor.data_type, accessor, stream); + DataTypeDispatch1{}( + accessor.data_type, accessor, stream); } template diff --git a/lib/kernels/test/src/test_utils.h b/lib/kernels/test/src/test_utils.h index efbbc90e08..d23b936cb0 100644 --- a/lib/kernels/test/src/test_utils.h +++ b/lib/kernels/test/src/test_utils.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_KERNELS_TEST_UTILS #define _FLEXFLOW_KERNELS_TEST_UTILS +#include "kernels/copy_tensor_accessor.h" #include "kernels/datatype_dispatch.h" #include "kernels/device.h" #include "kernels/local_cpu_allocator.h" @@ -37,7 +38,7 @@ GenericTensorAccessorR copy_accessor_r_to_cpu_if_necessary(GenericTensorAccessorR const &accessor, Allocator &allocator); -void print_tensor_accessor_contents(GenericTensorAccessorR const &accessor); +void print_2d_tensor_accessor_contents(GenericTensorAccessorR const &accessor); bool accessors_are_equal(GenericTensorAccessorR const &accessor_a, GenericTensorAccessorR const &accessor_b); diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index 33d62b713c..be51ea9526 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -30,14 +30,14 @@ static DeviceSpecificDeviceStates auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - int input_w = input.shape.at(ff_dim_t(0)) + 1; - int input_h = input.shape.at(ff_dim_t(1)) + 1; - int input_c = input.shape.at(ff_dim_t(2)) + 1; - int input_n = input.shape.at(ff_dim_t(3)) + 1; - int output_w = output.shape.at(ff_dim_t(0)) + 1; - int output_h = output.shape.at(ff_dim_t(1)) + 1; - int output_c = output.shape.at(ff_dim_t(2)) + 1; - int output_n = output.shape.at(ff_dim_t(3)) + 1; + int input_w = input.shape.at(ff_dim_t{0}) + 1; + int input_h = input.shape.at(ff_dim_t{1}) + 1; + int input_c = input.shape.at(ff_dim_t{2}) + 1; + int input_n = input.shape.at(ff_dim_t{3}) + 1; + int output_w = output.shape.at(ff_dim_t{0}) + 1; + int output_h = output.shape.at(ff_dim_t{1}) + 1; + int output_c = output.shape.at(ff_dim_t{2}) + 1; + int output_n = output.shape.at(ff_dim_t{3}) + 1; printf("init pool (input): n(%d) c(%d) h(%d) " "w(%d)\n", diff --git a/lib/local-execution/src/ops/reverse.cc b/lib/local-execution/src/ops/reverse.cc index 366a579bea..bb1b802edd 100644 --- a/lib/local-execution/src/ops/reverse.cc +++ b/lib/local-execution/src/ops/reverse.cc @@ -53,11 +53,11 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1; for (int i = 0; i < output.shape.get_dim(); i++) { if (i < axis.value) { - in_blk_size *= output.shape.at(ff_dim_t(i)); + in_blk_size *= output.shape.at(ff_dim_t{i}); } else if (i == axis.value) { - reverse_dim_size = output.shape.at(ff_dim_t(i)); + reverse_dim_size = output.shape.at(ff_dim_t{i}); } else { - num_out_blks *= output.shape.at(ff_dim_t(i)); + num_out_blks *= output.shape.at(ff_dim_t{i}); } } @@ -83,11 +83,11 @@ static std::optional coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1; for (int i = 0; i < input_grad.shape.get_dim(); i++) { if (i < axis) { - in_blk_size *= input_grad.shape.at(ff_dim_t(i)); + in_blk_size *= input_grad.shape.at(ff_dim_t{i}); } else if (i == axis) { - reverse_dim_size = input_grad.shape.at(ff_dim_t(i)); + reverse_dim_size = input_grad.shape.at(ff_dim_t{i}); } else { - num_out_blks *= input_grad.shape.at(ff_dim_t(i)); + num_out_blks *= input_grad.shape.at(ff_dim_t{i}); } } diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 788ab52a7a..512c1ef33b 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -12,7 +12,11 @@ // TEST_SUITE(FF_CUDA_TEST_SUITE) { // TEST_CASE("Local Cost Estimator") { // // local backing initialization -// ManagedPerDeviceFFHandle managed_handle(1024 * 1024, true); +// ManagedPerDeviceFFHandle managed_handle{ +/*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true +} +; // RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ // DeviceSpecific::create(managed_handle.raw_handle()), diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 6aa23d40fc..19a6e62178 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -175,8 +175,9 @@ auto inner_to_outer(FFOrdered const &ff_ordered) template std::vector inner_to_outer_idxs(FFOrdered const &ff_ordered) { std::vector idxs; - for (size_t i = 0; i < ff_ordered.size(); i++) { - idxs.push_back(ff_dim_t(ff_ordered.size() - i - 1)); + int size = static_cast(ff_ordered.size()); + for (int i = 0; i < ff_ordered.size(); i++) { + idxs.push_back(ff_dim_t{size - i - 1}); } return idxs; } diff --git a/lib/op-attrs/include/op-attrs/make_datatype_value.h b/lib/op-attrs/include/op-attrs/make_datatype_value.h new file mode 100644 index 0000000000..c3289c6309 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/make_datatype_value.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_MAKE_DATATYPE_VALUE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_MAKE_DATATYPE_VALUE_H + +#include "op-attrs/datatype_value.dtg.h" + +namespace FlexFlow { + +DataTypeValue make_float_data_type_value(float value); +DataTypeValue make_double_data_type_value(double value); +DataTypeValue make_int32_data_type_value(int32_t value); +DataTypeValue make_int64_data_type_value(int64_t value); +DataTypeValue make_bool_data_type_value(bool value); + +} + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_MAKE_DATATYPE_VALUE_H diff --git a/lib/op-attrs/src/op-attrs/make_datatype_value.cc b/lib/op-attrs/src/op-attrs/make_datatype_value.cc new file mode 100644 index 0000000000..bc402c433c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/make_datatype_value.cc @@ -0,0 +1,25 @@ +#include "op-attrs/make_datatype_value.h" + +namespace FlexFlow { + +DataTypeValue make_float_data_type_value(float value) { + return DataTypeValue{value}; +} + +DataTypeValue make_double_data_type_value(double value) { + return DataTypeValue{value}; +} + +DataTypeValue make_int32_data_type_value(int32_t value) { + return DataTypeValue{value}; +} + +DataTypeValue make_int64_data_type_value(int64_t value) { + return DataTypeValue{value}; +} + +DataTypeValue make_bool_data_type_value(bool value) { + return DataTypeValue{value}; +} + +} diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 483d832fee..8a806bcf9f 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -33,15 +33,15 @@ int get_oProjSize(MultiHeadAttentionAttrs const &attrs) { } int get_qSize(TensorShape const &query_shape) { - return dim_at_idx(query_shape, ff_dim_t(0)); + return dim_at_idx(query_shape, ff_dim_t{0}); } int get_kSize(TensorShape const &key_shape) { - return dim_at_idx(key_shape, ff_dim_t(0)); + return dim_at_idx(key_shape, ff_dim_t{0}); } int get_vSize(TensorShape const &value_shape) { - return dim_at_idx(value_shape, ff_dim_t(0)); + return dim_at_idx(value_shape, ff_dim_t{0}); } int get_qSize(MultiHeadAttentionParallelInputs const &inputs) { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index dcc567e0ca..6ea29b1855 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -138,7 +138,7 @@ std::unordered_set get_parallel_tensor_dim_indices(ParallelTensorShape const &shape) { std::unordered_set indices; extend(indices, transform(range(num_shard_dims(shape.dims)), [](int idx) { - return parallel_tensor_dim_idx_t(ff_dim_t(idx)); + return parallel_tensor_dim_idx_t(ff_dim_t{idx}); })); indices.insert(parallel_tensor_dim_idx_t(ReplicaType::SUM)); indices.insert(parallel_tensor_dim_idx_t(ReplicaType::DISCARD_COPY)); diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index dff647f5a1..65ef214669 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -3,6 +3,7 @@ #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/make_datatype_value.h" #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/broadcast.h" @@ -609,14 +610,14 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); - InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + InitializerAttrs gamma_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(1)}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); TensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); - InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + InitializerAttrs beta_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(0)}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } @@ -688,8 +689,8 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); // initializer chosen based on // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 - InitializerAttrs input_bias_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + InitializerAttrs input_bias_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(0)}}; weights.push_back( make_weight_attrs(input_bias_shape, input_bias_initializer)); @@ -698,8 +699,8 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); // initializer chosen based on // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 - InitializerAttrs output_bias_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + InitializerAttrs output_bias_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(0)}}; weights.push_back( make_weight_attrs(output_bias_shape, output_bias_initializer)); @@ -870,14 +871,14 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( TensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); - InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + InitializerAttrs gamma_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(1)}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); TensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); - InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + InitializerAttrs beta_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(0)}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index f33b4dcd17..79ac43ae66 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,5 +1,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/make_datatype_value.h" #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_matmul.h" #include "op-attrs/ops/batch_norm.h" @@ -385,14 +386,14 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( ParallelTensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); - InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + InitializerAttrs gamma_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(1)}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); ParallelTensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); - InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + InitializerAttrs beta_initializer = InitializerAttrs{ + ConstantInitializerAttrs{make_float_data_type_value(0)}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } diff --git a/lib/runtime/src/ops/embedding.cc b/lib/runtime/src/ops/embedding.cc index 2370739d58..296b9f443b 100644 --- a/lib/runtime/src/ops/embedding.cc +++ b/lib/runtime/src/ops/embedding.cc @@ -85,7 +85,7 @@ static std::optional attrs.aggr, input.shape.get_dim(), output.shape.get_dim(), - input.shape.at(ff_dim_t(0))); + input.shape.at(ff_dim_t{0})); } TaskImplFunction get_embedding_fwd_task_impl() {