Skip to content

Commit

Permalink
branch merge and test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oOTigger committed Jan 28, 2025
1 parent 24c7982 commit 54647c6
Show file tree
Hide file tree
Showing 38 changed files with 263 additions and 436 deletions.
76 changes: 25 additions & 51 deletions lib/kernels/include/kernels/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@

namespace FlexFlow {

inline int calculate_accessor_offset(std::vector<int> const &indices,
ArrayShape const &shape) {
int offset = 0;
int multiplier = 1;

for (int i = 0; i < shape.num_dims(); i++) {
if (indices.at(i) >= shape.at(legion_dim_t{i})) {
throw mk_runtime_error(
fmt::format("In {} dimension, attempting to access index {} "
"when only {} indexes exist",
i,
indices.at(i),
shape.at(legion_dim_t{i})));
}

offset += indices.at(i) * multiplier;
multiplier *= shape.at(legion_dim_t{i});
}

return offset;
}

class GenericTensorAccessorR {
public:
template <DataType DT>
Expand Down Expand Up @@ -57,23 +79,7 @@ class GenericTensorAccessorR {

using T = real_type_t<DT>;
T const *data_ptr = static_cast<T const *>(this->ptr);

int offset = 0;
int multiplier = 1;
for (int i = 0; i < this->shape.num_dims(); i++) {
if (indices.at(i) >= this->shape.at(legion_dim_t{i})) {
throw mk_runtime_error(
fmt::format("In {} dimension, attempting to access index {} "
"when only {} indexes exist",
i,
indices.at(i),
this->shape.at(legion_dim_t{i})));
}

offset += indices.at(i) * multiplier;
multiplier *= this->shape.at(legion_dim_t{i});
}

int offset = calculate_accessor_offset(indices, this->shape);
return data_ptr[offset];
}

Expand Down Expand Up @@ -141,24 +147,8 @@ class GenericTensorAccessorW {
}

using T = real_type_t<DT>;

T *data_ptr = static_cast<T *>(this->ptr);
int offset = 0;
int multiplier = 1;
for (int i = 0; i < this->shape.num_dims(); i++) {
if (indices.at(i) >= this->shape.at(legion_dim_t{i})) {
throw mk_runtime_error(
fmt::format("In {} dimension, attempting to access index {} "
"when only {} indexes exist",
i,
indices.at(i),
this->shape.at(legion_dim_t{i})));
}

offset += indices.at(i) * multiplier;
multiplier *= this->shape.at(legion_dim_t{i});
}

int offset = calculate_accessor_offset(indices, this->shape);
return data_ptr[offset];
}

Expand All @@ -179,24 +169,8 @@ class GenericTensorAccessorW {
}

using T = real_type_t<DT>;

T const *data_ptr = static_cast<T const *>(this->ptr);
int offset = 0;
int multiplier = 1;
for (int i = 0; i < this->shape.num_dims(); i++) {
if (indices.at(i) >= this->shape.at(legion_dim_t{i})) {
throw mk_runtime_error(
fmt::format("In {} dimension, attempting to access index {} "
"when only {} indexes exist",
i,
indices.at(i),
this->shape.at(legion_dim_t{i})));
}

offset += indices.at(i) * multiplier;
multiplier *= this->shape.at(legion_dim_t{i});
}

int offset = calculate_accessor_offset(indices, this->shape);
return data_ptr[offset];
}

Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/include/kernels/flat_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ void forward_kernel(ffStream_t stream,
GenericTensorAccessorR input,
float *output_ptr);

void backward_kernel(cudaStream_t stream,
void backward_kernel(ffStream_t stream,
GenericTensorAccessorR input,
float const *output_grad_ptr,
float *input_grad_ptr);
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/include/kernels/loss_function_kernels.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LOSS_FUNCTION_KERNELS_H
#define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LOSS_FUNCTION_KERNELS_H

#include "device.h"
#include "kernels/device.h"

namespace FlexFlow {

Expand Down
1 change: 1 addition & 0 deletions lib/kernels/include/kernels/managed_ff_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct ManagedFFStream {

ffStream_t const &raw_stream() const;

private:
void cleanup();

private:
Expand Down
1 change: 1 addition & 0 deletions lib/kernels/include/kernels/managed_per_device_ff_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct ManagedPerDeviceFFHandle {

PerDeviceFFHandle const &raw_handle() const;

private:
void cleanup();

private:
Expand Down
6 changes: 3 additions & 3 deletions lib/kernels/include/kernels/metrics_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
#define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_METRICS_KERNELS_H

#include "kernels/perf_metrics.h"
#include "pcg/metric.h"
#include "pcg/metric_attrs.h"

namespace FlexFlow {

void update_metrics_sparse_label_kernel_wrapper(float const *logit_ptr,
int const *label_ptr,
MetricsAttrs const *me,
MetricsAttrs const &me,
int num_effective_samples,
int num_classes,
PerfMetrics &perf_zc);

void update_metrics_label_kernel_wrapper(float const *logit_ptr,
float const *label_ptr,
MetricsAttrs const *me,
MetricsAttrs const &me,
int num_samples,
int num_classes,
PerfMetrics &perf_zc);
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/include/kernels/pool_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void forward_kernel(ffStream_t stream,
void const *input_ptr,
void *output_ptr);

void backward_kernel(cudaStream_t stream,
void backward_kernel(ffStream_t stream,
Pool2DPerDeviceState const &m,
void const *output_ptr,
void const *output_grad_ptr,
Expand Down
10 changes: 5 additions & 5 deletions lib/kernels/src/cuda/metrics_functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "device.h"
#include "kernels/metrics_kernels.h"
#include "kernels/perf_metrics.h"
#include "pcg/metric.h"
#include "pcg/metric_attrs.h"

namespace FlexFlow {

Expand Down Expand Up @@ -163,7 +163,7 @@ __global__ void update_metrics_label_kernel(float const *logits,

void update_metrics_sparse_label_kernel_wrapper(float const *logit_ptr,
int const *label_ptr,
MetricsAttrs const *me,
MetricsAttrs const &me,
int num_effective_samples,
int num_classes,
PerfMetrics &perf_zc) {
Expand All @@ -179,7 +179,7 @@ void update_metrics_sparse_label_kernel_wrapper(float const *logit_ptr,
CUDA_NUM_THREADS,
0,
stream>>>(
logit_ptr, label_ptr, perf_cuda, *me, num_effective_samples, num_classes);
logit_ptr, label_ptr, perf_cuda, me, num_effective_samples, num_classes);
checkCUDA(cudaStreamSynchronize(stream));
checkCUDA(cudaMemcpy(
&perf, perf_cuda, sizeof(CUDAPerfMetrics), cudaMemcpyDeviceToHost));
Expand All @@ -188,7 +188,7 @@ void update_metrics_sparse_label_kernel_wrapper(float const *logit_ptr,

void update_metrics_label_kernel_wrapper(float const *logit_ptr,
float const *label_ptr,
MetricsAttrs const *me,
MetricsAttrs const &me,
int num_samples,
int num_classes,
PerfMetrics &perf_zc) {
Expand All @@ -201,7 +201,7 @@ void update_metrics_label_kernel_wrapper(float const *logit_ptr,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
update_metrics_label_kernel<<<GET_BLOCKS(num_samples), 256, 0, stream>>>(
logit_ptr, label_ptr, perf_cuda, *me, num_samples, num_classes);
logit_ptr, label_ptr, perf_cuda, me, num_samples, num_classes);
checkCUDA(cudaStreamSynchronize(stream));
checkCUDA(cudaMemcpy(
&perf, perf_cuda, sizeof(CUDAPerfMetrics), cudaMemcpyDeviceToHost));
Expand Down
30 changes: 15 additions & 15 deletions lib/kernels/src/hip/embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ struct ForwardKernel {
weight.data_type == DataType::FLOAT ||
weight.data_type == DataType::DOUBLE);

if (aggr == AggregateOp::NONE) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr<TI, TD>),
if (aggr == AggregateOp::AVG || aggr == AggregateOp::SUM) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr<TI, TD>),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
Expand All @@ -374,10 +374,11 @@ struct ForwardKernel {
output.get<TD>(),
weight.get<TD>(),
out_dim,
batch_size);
in_dim,
batch_size,
aggr);
} else {
assert(aggr == AggregateOp::AVG || aggr == AggregateOp::SUM);
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr<TI, TD>),
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr<TI, TD>),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
Expand All @@ -386,9 +387,7 @@ struct ForwardKernel {
output.get<TD>(),
weight.get<TD>(),
out_dim,
in_dim,
batch_size,
aggr);
batch_size);
}
}
}
Expand All @@ -408,8 +407,9 @@ struct BackwardKernel {
assert(output.data_type == DataType::HALF ||
output.data_type == DataType::FLOAT ||
output.data_type == DataType::DOUBLE);
if (aggr == AggregateOp::NONE) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr<TI, TD>),

if (aggr == AggregateOp::AVG || aggr == AggregateOp::SUM) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr<TI, TD>),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
Expand All @@ -418,9 +418,11 @@ struct BackwardKernel {
output.get<TD>(),
weight_grad.get<TD>(),
out_dim,
batch_size);
in_dim,
batch_size,
aggr);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr<TI, TD>),
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr<TI, TD>),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
Expand All @@ -429,9 +431,7 @@ struct BackwardKernel {
output.get<TD>(),
weight_grad.get<TD>(),
out_dim,
in_dim,
batch_size,
aggr);
batch_size);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_batch_norm_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/batch_norm_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_concat_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("Test concat kernel forward and backward") {
size_t num_inputs = 2;
size_t size_per_input = 10;
ff_dim_t concat_axis = ff_dim_t{1};
ff_dim_t concat_axis = ff_dim_t{nonnegative_int{1}};

ManagedPerDeviceFFHandle managed_handle{
/*workSpaceSize=*/1024 * 1024,
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_flat_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/flat_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_layer_norm_kernels.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/layer_norm_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
7 changes: 5 additions & 2 deletions lib/kernels/test/src/test_managed_per_device_ff_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("ManagedPerDeviceFFHandle") {
ManagedPerDeviceFFHandle base_handle{1024 * 1024, true};
ManagedPerDeviceFFHandle base_handle{/*workSpaceSize=*/1024 * 1024,
/*allowTensorOpMathConversion=*/true};
PerDeviceFFHandle const *base_handle_ptr = &base_handle.raw_handle();

SUBCASE("constructor") {
Expand All @@ -22,7 +23,9 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE("move assignment operator") {
SUBCASE("move assign to other") {
ManagedPerDeviceFFHandle new_handle{1024 * 1024, true};
ManagedPerDeviceFFHandle new_handle{
/*workSpaceSize=*/1024 * 1024,
/*allowTensorOpMathConversion=*/true};
new_handle = std::move(base_handle);

CHECK(&base_handle.raw_handle() == nullptr);
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_partition_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/partition_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_pool_2d_kernels.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/pool_2d_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_reduction_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/reduction_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_reverse_kernels.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "doctest/doctest.h"
#include "kernels/reverse_kernels.h"
#include "kernels/reverse_kernels_cpu.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"

using namespace ::FlexFlow;
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/test/src/test_split_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "doctest/doctest.h"
#include "kernels/split_kernels.h"
#include "op-attrs/make_datatype_value.h"
#include "op-attrs/datatype_value.h"
#include "test_utils.h"
#include "utils/containers/repeat.h"

Expand Down
3 changes: 2 additions & 1 deletion lib/kernels/test/src/test_transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("Test Transpose Kernel Operations") {
std::size_t num_dims = 2;

std::vector<ff_dim_t> perm = {ff_dim_t{0}, ff_dim_t{1}};
std::vector<ff_dim_t> perm = {ff_dim_t{nonnegative_int{0}},
ff_dim_t{nonnegative_int{1}}};

ManagedPerDeviceFFHandle managed_handle{
/*workSpaceSize=*/1024 * 1024,
Expand Down
Loading

0 comments on commit 54647c6

Please sign in to comment.