From 24c79823a5461f64bfc71f3cd748168aa5e5f863 Mon Sep 17 00:00:00 2001 From: Dylan Lim Date: Thu, 21 Nov 2024 22:46:25 -0800 Subject: [PATCH] format check --- lib/kernels/include/kernels/pool_2d_kernels.h | 1 - lib/kernels/src/cuda/metrics_functions.cu | 5 +- lib/kernels/src/cuda/optimizer_kernels.cu | 76 +++++++++---------- lib/pcg/include/pcg/metric.h | 5 +- lib/pcg/src/pcg/metric.cc | 62 +++++++-------- 5 files changed, 69 insertions(+), 80 deletions(-) diff --git a/lib/kernels/include/kernels/pool_2d_kernels.h b/lib/kernels/include/kernels/pool_2d_kernels.h index c0e57e2c9a..ad0a52efb9 100644 --- a/lib/kernels/include/kernels/pool_2d_kernels.h +++ b/lib/kernels/include/kernels/pool_2d_kernels.h @@ -74,7 +74,6 @@ void backward_kernel(cudaStream_t stream, void const *input_ptr, void *input_grad_ptr); - } // namespace Kernels::Pool2D } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/metrics_functions.cu b/lib/kernels/src/cuda/metrics_functions.cu index 2901f1d374..0250f829ec 100644 --- a/lib/kernels/src/cuda/metrics_functions.cu +++ b/lib/kernels/src/cuda/metrics_functions.cu @@ -200,10 +200,7 @@ void update_metrics_label_kernel_wrapper(float const *logit_ptr, cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - update_metrics_label_kernel<<>>( + update_metrics_label_kernel<<>>( logit_ptr, label_ptr, perf_cuda, *me, num_samples, num_classes); checkCUDA(cudaStreamSynchronize(stream)); checkCUDA(cudaMemcpy( diff --git a/lib/kernels/src/cuda/optimizer_kernels.cu b/lib/kernels/src/cuda/optimizer_kernels.cu index 237a277b21..1c6954a0b0 100644 --- a/lib/kernels/src/cuda/optimizer_kernels.cu +++ b/lib/kernels/src/cuda/optimizer_kernels.cu @@ -83,26 +83,23 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - const auto& state = meta->raw_variant; - ncclComm_t comm = std::visit([](const auto& s) -> ncclComm_t { - using T = std::decay_t; - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) { - throw mk_runtime_error("State type does not support NCCL operations"); - } else { - return s.handle.ncclComm; - } - }, state); - - checkNCCL(ncclAllReduce(w_grad_ptr, - (float *)w_grad_ptr, - size, - ncclFloat, - ncclSum, - comm, - stream)); + auto const &state = meta->raw_variant; + ncclComm_t comm = std::visit( + [](auto const &s) -> ncclComm_t { + using T = std::decay_t; + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + throw mk_runtime_error("State type does not support NCCL operations"); + } else { + return s.handle.ncclComm; + } + }, + state); + + checkNCCL(ncclAllReduce( + w_grad_ptr, (float *)w_grad_ptr, size, ncclFloat, ncclSum, comm, stream)); // fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr); // print_tensor((float*)w_grad_ptr, 16, "[After ncclAllReduce]"); @@ -205,27 +202,24 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, // Use NCCL to sync gradients cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - - const auto& state = meta->raw_variant; - ncclComm_t comm = std::visit([](const auto& s) -> ncclComm_t { - using T = std::decay_t; - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) { - throw mk_runtime_error("State type does not support NCCL operations"); - } else { - return s.handle.ncclComm; - } - }, state); - - checkNCCL(ncclAllReduce(w_grad_ptr, - (float *)w_grad_ptr, - size, - ncclFloat, - ncclSum, - comm, - stream)); + + auto const &state = meta->raw_variant; + ncclComm_t comm = std::visit( + [](auto const &s) -> ncclComm_t { + using T = std::decay_t; + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + throw mk_runtime_error("State type does not support NCCL operations"); + } else { + return s.handle.ncclComm; + } + }, + state); + + checkNCCL(ncclAllReduce( + w_grad_ptr, (float *)w_grad_ptr, size, ncclFloat, ncclSum, comm, stream)); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", // op->alpha, op->alpha_t, op->weight_decay); // Step 2: Adam update diff --git a/lib/pcg/include/pcg/metric.h b/lib/pcg/include/pcg/metric.h index f56078772e..718919112f 100644 --- a/lib/pcg/include/pcg/metric.h +++ b/lib/pcg/include/pcg/metric.h @@ -1,9 +1,9 @@ #ifndef _FF_METRICS_H_ #define _FF_METRICS_H_ -#include -#include "utils/fmt.h" #include "op-attrs/ops/loss_functions/loss_functions.h" +#include "utils/fmt.h" +#include namespace FlexFlow { @@ -69,5 +69,4 @@ struct formatter<::FlexFlow::Metric> : formatter { } // namespace fmt - #endif diff --git a/lib/pcg/src/pcg/metric.cc b/lib/pcg/src/pcg/metric.cc index eb0d6bc5d0..69aba90d12 100644 --- a/lib/pcg/src/pcg/metric.cc +++ b/lib/pcg/src/pcg/metric.cc @@ -2,37 +2,37 @@ namespace FlexFlow { MetricsAttrs::MetricsAttrs(LossFunction _loss_type, - std::vector const &metrics) - : loss_type(_loss_type), measure_accuracy(false), - measure_categorical_crossentropy(false), - measure_sparse_categorical_crossentropy(false), - measure_mean_squared_error(false), measure_root_mean_squared_error(false), - measure_mean_absolute_error(false) { -for (Metric const &m : metrics) { - switch (m) { - case Metric::ACCURACY: - measure_accuracy = true; - continue; - case Metric::CATEGORICAL_CROSSENTROPY: - measure_categorical_crossentropy = true; - continue; - case Metric::SPARSE_CATEGORICAL_CROSSENTROPY: - measure_sparse_categorical_crossentropy = true; - continue; - case Metric::MEAN_SQUARED_ERROR: - measure_mean_squared_error = true; - continue; - case Metric::ROOT_MEAN_SQUARED_ERROR: - measure_root_mean_squared_error = true; - continue; - case Metric::MEAN_ABSOLUTE_ERROR: - measure_mean_absolute_error = true; - continue; - default: - throw mk_runtime_error("Initializing MetricsAttrs with unrecogonized metrics type"); + std::vector const &metrics) + : loss_type(_loss_type), measure_accuracy(false), + measure_categorical_crossentropy(false), + measure_sparse_categorical_crossentropy(false), + measure_mean_squared_error(false), measure_root_mean_squared_error(false), + measure_mean_absolute_error(false) { + for (Metric const &m : metrics) { + switch (m) { + case Metric::ACCURACY: + measure_accuracy = true; + continue; + case Metric::CATEGORICAL_CROSSENTROPY: + measure_categorical_crossentropy = true; + continue; + case Metric::SPARSE_CATEGORICAL_CROSSENTROPY: + measure_sparse_categorical_crossentropy = true; + continue; + case Metric::MEAN_SQUARED_ERROR: + measure_mean_squared_error = true; + continue; + case Metric::ROOT_MEAN_SQUARED_ERROR: + measure_root_mean_squared_error = true; + continue; + case Metric::MEAN_ABSOLUTE_ERROR: + measure_mean_absolute_error = true; + continue; + default: + throw mk_runtime_error( + "Initializing MetricsAttrs with unrecogonized metrics type"); + } } } -} - -} +} // namespace FlexFlow