Skip to content

Commit

Permalink
format check
Browse files Browse the repository at this point in the history
  • Loading branch information
oOTigger committed Jan 28, 2025
1 parent b444ed4 commit 24c7982
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 80 deletions.
1 change: 0 additions & 1 deletion lib/kernels/include/kernels/pool_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ void backward_kernel(cudaStream_t stream,
void const *input_ptr,
void *input_grad_ptr);


} // namespace Kernels::Pool2D
} // namespace FlexFlow

Expand Down
5 changes: 1 addition & 4 deletions lib/kernels/src/cuda/metrics_functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,7 @@ void update_metrics_label_kernel_wrapper(float const *logit_ptr,

cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
update_metrics_label_kernel<<<GET_BLOCKS(num_samples),
256,
0,
stream>>>(
update_metrics_label_kernel<<<GET_BLOCKS(num_samples), 256, 0, stream>>>(
logit_ptr, label_ptr, perf_cuda, *me, num_samples, num_classes);
checkCUDA(cudaStreamSynchronize(stream));
checkCUDA(cudaMemcpy(
Expand Down
76 changes: 35 additions & 41 deletions lib/kernels/src/cuda/optimizer_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,23 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));

const auto& state = meta->raw_variant;
ncclComm_t comm = std::visit([](const auto& s) -> ncclComm_t {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, FlexFlow::ElementUnaryPerDeviceState> ||
std::is_same_v<T, FlexFlow::ReshapePerDeviceState> ||
std::is_same_v<T, FlexFlow::TopKPerDeviceState> ||
std::is_same_v<T, FlexFlow::TransposePerDeviceState>) {
throw mk_runtime_error("State type does not support NCCL operations");
} else {
return s.handle.ncclComm;
}
}, state);

checkNCCL(ncclAllReduce(w_grad_ptr,
(float *)w_grad_ptr,
size,
ncclFloat,
ncclSum,
comm,
stream));
auto const &state = meta->raw_variant;
ncclComm_t comm = std::visit(
[](auto const &s) -> ncclComm_t {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, FlexFlow::ElementUnaryPerDeviceState> ||
std::is_same_v<T, FlexFlow::ReshapePerDeviceState> ||
std::is_same_v<T, FlexFlow::TopKPerDeviceState> ||
std::is_same_v<T, FlexFlow::TransposePerDeviceState>) {
throw mk_runtime_error("State type does not support NCCL operations");
} else {
return s.handle.ncclComm;
}
},
state);

checkNCCL(ncclAllReduce(
w_grad_ptr, (float *)w_grad_ptr, size, ncclFloat, ncclSum, comm, stream));

// fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr);
// print_tensor<float>((float*)w_grad_ptr, 16, "[After ncclAllReduce]");
Expand Down Expand Up @@ -205,27 +202,24 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op,
// Use NCCL to sync gradients
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));

const auto& state = meta->raw_variant;
ncclComm_t comm = std::visit([](const auto& s) -> ncclComm_t {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, FlexFlow::ElementUnaryPerDeviceState> ||
std::is_same_v<T, FlexFlow::ReshapePerDeviceState> ||
std::is_same_v<T, FlexFlow::TopKPerDeviceState> ||
std::is_same_v<T, FlexFlow::TransposePerDeviceState>) {
throw mk_runtime_error("State type does not support NCCL operations");
} else {
return s.handle.ncclComm;
}
}, state);

checkNCCL(ncclAllReduce(w_grad_ptr,
(float *)w_grad_ptr,
size,
ncclFloat,
ncclSum,
comm,
stream));

auto const &state = meta->raw_variant;
ncclComm_t comm = std::visit(
[](auto const &s) -> ncclComm_t {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, FlexFlow::ElementUnaryPerDeviceState> ||
std::is_same_v<T, FlexFlow::ReshapePerDeviceState> ||
std::is_same_v<T, FlexFlow::TopKPerDeviceState> ||
std::is_same_v<T, FlexFlow::TransposePerDeviceState>) {
throw mk_runtime_error("State type does not support NCCL operations");
} else {
return s.handle.ncclComm;
}
},
state);

checkNCCL(ncclAllReduce(
w_grad_ptr, (float *)w_grad_ptr, size, ncclFloat, ncclSum, comm, stream));
// fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n",
// op->alpha, op->alpha_t, op->weight_decay);
// Step 2: Adam update
Expand Down
5 changes: 2 additions & 3 deletions lib/pcg/include/pcg/metric.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#ifndef _FF_METRICS_H_
#define _FF_METRICS_H_

#include <unordered_set>
#include "utils/fmt.h"
#include "op-attrs/ops/loss_functions/loss_functions.h"
#include "utils/fmt.h"
#include <unordered_set>

namespace FlexFlow {

Expand Down Expand Up @@ -69,5 +69,4 @@ struct formatter<::FlexFlow::Metric> : formatter<string_view> {

} // namespace fmt


#endif
62 changes: 31 additions & 31 deletions lib/pcg/src/pcg/metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,37 @@

namespace FlexFlow {
MetricsAttrs::MetricsAttrs(LossFunction _loss_type,
std::vector<Metric> const &metrics)
: loss_type(_loss_type), measure_accuracy(false),
measure_categorical_crossentropy(false),
measure_sparse_categorical_crossentropy(false),
measure_mean_squared_error(false), measure_root_mean_squared_error(false),
measure_mean_absolute_error(false) {
for (Metric const &m : metrics) {
switch (m) {
case Metric::ACCURACY:
measure_accuracy = true;
continue;
case Metric::CATEGORICAL_CROSSENTROPY:
measure_categorical_crossentropy = true;
continue;
case Metric::SPARSE_CATEGORICAL_CROSSENTROPY:
measure_sparse_categorical_crossentropy = true;
continue;
case Metric::MEAN_SQUARED_ERROR:
measure_mean_squared_error = true;
continue;
case Metric::ROOT_MEAN_SQUARED_ERROR:
measure_root_mean_squared_error = true;
continue;
case Metric::MEAN_ABSOLUTE_ERROR:
measure_mean_absolute_error = true;
continue;
default:
throw mk_runtime_error("Initializing MetricsAttrs with unrecogonized metrics type");
std::vector<Metric> const &metrics)
: loss_type(_loss_type), measure_accuracy(false),
measure_categorical_crossentropy(false),
measure_sparse_categorical_crossentropy(false),
measure_mean_squared_error(false), measure_root_mean_squared_error(false),
measure_mean_absolute_error(false) {
for (Metric const &m : metrics) {
switch (m) {
case Metric::ACCURACY:
measure_accuracy = true;
continue;
case Metric::CATEGORICAL_CROSSENTROPY:
measure_categorical_crossentropy = true;
continue;
case Metric::SPARSE_CATEGORICAL_CROSSENTROPY:
measure_sparse_categorical_crossentropy = true;
continue;
case Metric::MEAN_SQUARED_ERROR:
measure_mean_squared_error = true;
continue;
case Metric::ROOT_MEAN_SQUARED_ERROR:
measure_root_mean_squared_error = true;
continue;
case Metric::MEAN_ABSOLUTE_ERROR:
measure_mean_absolute_error = true;
continue;
default:
throw mk_runtime_error(
"Initializing MetricsAttrs with unrecogonized metrics type");
}
}
}
}


}
} // namespace FlexFlow

0 comments on commit 24c7982

Please sign in to comment.