Skip to content

Commit

Permalink
add tutel support for top2
Browse files Browse the repository at this point in the history
Signed-off-by: Xueshen Liu <[email protected]>
  • Loading branch information
xenshinu committed Feb 10, 2025
1 parent 66d3d3e commit fc15332
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 31 deletions.
5 changes: 3 additions & 2 deletions csrc/aio/common/deepspeed_aio_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ int regular_read(const char* filename, std::vector<char>& buffer)
} while (r > 0);

if (read_bytes != num_bytes) {
std::cerr << "read error " << " read_bytes (read) = " << read_bytes
<< " num_bytes (fstat) = " << num_bytes << std::endl;
std::cerr << "read error "
<< " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes
<< std::endl;
}
assert(read_bytes == num_bytes);
close(fd);
Expand Down
10 changes: 6 additions & 4 deletions csrc/aio/py_lib/deepspeed_py_aio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer,

const std::chrono::duration<double> fn_time =
std::chrono::high_resolution_clock::now() - start_time;
std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
<< " call = " << fn_time.count() * 1e6 << std::endl;
std::cout << "Elapsed time(usec): "
<< "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
<< std::endl;
return 0;
}

Expand Down Expand Up @@ -117,7 +118,8 @@ int deepspeed_py_aio_read(torch::Tensor& buffer,

const std::chrono::duration<double> fn_time =
std::chrono::high_resolution_clock::now() - start_time;
std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
<< " call = " << fn_time.count() * 1e6 << std::endl;
std::cout << "Elapsed time(usec): "
<< "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
<< std::endl;
return 0;
}
2 changes: 1 addition & 1 deletion csrc/aio/py_lib/deepspeed_py_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Functionality for swapping tensors to/from (NVMe) storage devices.
#include "deepspeed_py_copy.h"
#include <omp.h>

#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))

#if defined(__AVX512__) or defined(__AVX256__)
union AVX_Data {
Expand Down
10 changes: 6 additions & 4 deletions csrc/aio/py_lib/deepspeed_py_io_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ int deepspeed_io_handle_t::read(torch::Tensor& buffer,
if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); }
const std::chrono::duration<double> fn_time =
std::chrono::high_resolution_clock::now() - start_time;
std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
<< " call = " << fn_time.count() * 1e6 << std::endl;
std::cout << "Elapsed time(usec): "
<< "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
<< std::endl;
return 0;
}

Expand Down Expand Up @@ -131,8 +132,9 @@ int deepspeed_io_handle_t::write(const torch::Tensor& buffer,

const std::chrono::duration<double> fn_time =
std::chrono::high_resolution_clock::now() - start_time;
std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
<< " call = " << fn_time.count() * 1e6 << std::endl;
std::cout << "Elapsed time(usec): "
<< "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
<< std::endl;
return 0;
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,8 @@ template <typename WarpShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
struct DefaultWarpIteratorAFromSharedMemory {
};

// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy>
Expand Down
9 changes: 5 additions & 4 deletions csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ struct CheckArch {
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define EVOFORMER_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << "[Evoformer Attention]" << "'" #COND "' failed: " << ERR << "\n"; \
return false; \
#define EVOFORMER_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << "[Evoformer Attention]" \
<< "'" #COND "' failed: " << ERR << "\n"; \
return false; \
}
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace epilogue {
namespace threadblock {

template <class AccessType, class Enable = void>
struct atomic_store {};
struct atomic_store {
};

template <class AccessType>
struct atomic_store<AccessType,
Expand Down
2 changes: 1 addition & 1 deletion csrc/includes/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ inline void writeAs(void* dst, const T& val)
std::memcpy(dst, &val, sizeof(T));
}

#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))

#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
Expand Down
3 changes: 2 additions & 1 deletion csrc/xpu/adam/multi_tensor_apply.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class multi_tensor_apply_kernel {
// to make sure multi_tensor_apply_kernel can be used in sycl::buffer
namespace sycl {
template <typename T, typename U, typename... ArgTypes>
struct is_device_copyable<multi_tensor_apply_kernel<T, U, ArgTypes...>> : std::true_type {};
struct is_device_copyable<multi_tensor_apply_kernel<T, U, ArgTypes...>> : std::true_type {
};
} // namespace sycl

template <int depth, typename T, typename... ArgTypes>
Expand Down
3 changes: 2 additions & 1 deletion csrc/xpu/common/custom_cuda_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ inline void has_capability_or_fail(const sycl::device& dev,
break;
default:
#define __SYCL_ASPECT(ASPECT, ID) \
case sycl::aspect::ASPECT: return #ASPECT;
case sycl::aspect::ASPECT: \
return #ASPECT;
#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string {
Expand Down
2 changes: 1 addition & 1 deletion csrc/xpu/includes/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__)

#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))

#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
Expand Down
10 changes: 5 additions & 5 deletions csrc/xpu/includes/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@
}

template <typename T>
__inline__ __attribute__((always_inline)) T
reduce_block_into_lanes(T* x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
__inline__ __attribute__((always_inline)) T reduce_block_into_lanes(
T* x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2);
Expand Down
29 changes: 24 additions & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def top2gating(logits: Tensor,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
top2_2nd_expert_sampling: bool = True,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
Expand All @@ -313,8 +314,12 @@ def top2gating(logits: Tensor,
mask2 = F.one_hot(indices2_s, num_classes=num_experts)

# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
if not use_tutel:
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
else:
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
locations2 = tutel_moe.fast_cumsum_sub_one(mask2)
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)

Expand Down Expand Up @@ -358,6 +363,19 @@ def top2gating(logits: Tensor,
gates1_s /= denom_s
gates2_s /= denom_s

if use_tutel:
# return critical information for tutel
return l_aux, capacity, num_experts, [
indices1_s,
indices2_s,
], [
locations1_s,
locations2_s,
], [
gates1_s,
gates2_s,
], exp_counts

# Calculate combine_weights and dispatch_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
Expand Down Expand Up @@ -517,7 +535,8 @@ def forward(self,

elif self.k == 2:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling,
use_tutel)
else:
gate_output = topkgating(logits, self.k,
self.capacity_factor if self.training else self.eval_capacity_factor,
Expand Down Expand Up @@ -568,7 +587,7 @@ def __init__(self,
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False

self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
self.use_tutel = use_tutel and TUTEL_INSTALLED and (gate.k == 1 or gate.k == 2)

if self.use_tutel:
logger.info('Using Tutel optimizations.')
Expand Down

0 comments on commit fc15332

Please sign in to comment.