Skip to content

Commit

Permalink
add primitives: net.batch_all_to_all_v(), net.batch_all_gather_v() (#221
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ghostplant authored Dec 29, 2023
1 parent 6638dfc commit c2b2271
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/test_tutel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(
new_env['NCCL_SHM_DISABLE'] = '1'
"""Run helloworld example"""
if helloworld_file == 'helloworld':
command = 'python3 -m torch.distributed.launch --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps) + ' --device ' + str(device) + ' --num_tokens 1024'
command = 'python3 -m torch.distributed.run --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps) + ' --device ' + str(device) + ' --num_tokens 1024'
if use_model_parallel:
command += ' --parallel_type model'
else:
Expand Down
98 changes: 95 additions & 3 deletions tutel/custom/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,13 @@ template<typename dtype> static void invoke_cpu(const std::vector<torch::Tensor>

#if defined(USE_NCCL)

static ncclComm_t g_nccl_comm;
static ncclComm_t g_nccl_comm, shared_nccl_comm;
static std::vector<at::cuda::CUDAEvent> g_cuda_events;
static int g_world_size = 0;
static int g_world_rank = 0;
static int g_world_size = 0, shared_world_size = 0;
static int g_world_rank = 0, shared_world_rank = 0;
static int g_local_size = 0;
static int g_local_rank = 0;
static int __dtype_size[256];

// jit
static int mem_stride_copy_char_fd = -1;
Expand All @@ -349,6 +350,33 @@ static void get_nccl_unique_id(torch::Tensor &nccl_unique_id_tensor) {
memcpy((void *)nccl_unique_id_tensor.data_ptr(), &nccl_unique_id, sizeof(ncclUniqueId));
}

static void init_shared_nccl(
const torch::Tensor &nccl_unique_id_tensor,
int world_size,
int world_rank) {
ncclUniqueId nccl_unique_id;

CHECK_CPU(nccl_unique_id_tensor);
CHECK_EQ(nccl_unique_id_tensor.nbytes(), sizeof(ncclUniqueId));
memcpy(&nccl_unique_id, (void *)nccl_unique_id_tensor.data_ptr(), sizeof(ncclUniqueId));
CHECK_EQ(0, ncclGroupStart());
CHECK_EQ(0, ncclCommInitRank(&shared_nccl_comm, world_size, nccl_unique_id, world_rank));
CHECK_EQ(0, ncclGroupEnd());

shared_world_size = world_size;
shared_world_rank = world_rank;

__dtype_size[(int)torch::kFloat64] = 8;
__dtype_size[(int)torch::kInt64] = 8;
__dtype_size[(int)torch::kFloat32] = 4;
__dtype_size[(int)torch::kInt32] = 4;
__dtype_size[(int)torch::kFloat16] = 2;
__dtype_size[(int)torch::kInt16] = 2;
__dtype_size[(int)torch::kInt8] = 1;
__dtype_size[(int)torch::kUInt8] = 1;
__dtype_size[(int)torch::kBool] = 1;
}

static void init_nccl(
const torch::Tensor &nccl_unique_id_tensor,
int world_size,
Expand Down Expand Up @@ -431,6 +459,63 @@ static torch::Tensor& nccl_stream_acquire(torch::Tensor &tensor, int idx) {
return tensor;
}

static void batch_all_to_all_v(const std::vector<torch::Tensor> &ins, const std::vector<torch::Tensor> &outs, const torch::Tensor &in_sizes_, const torch::Tensor &out_sizes_) {
AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL");

auto in_sizes_cpu = in_sizes_.to(torch::kCPU).to(torch::kInt32);
auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32);
auto* in_sizes = (unsigned int*)in_sizes_cpu.data_ptr();
auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream();

ncclGroupStart();
for (int k = 0; k < ins.size(); ++k) {
auto* in_buff = ins[k].data_ptr();
auto* out_buff = outs[k].data_ptr();
auto dtype = ins[k].dtype();
int size = __dtype_size[*(unsigned short*)&dtype];
AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_to_all_v are not recognized.");
AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_to_all_v are supposed to share same length.");

int in_offset = 0, out_offset = 0;
for (int i = 0; i < shared_world_size; ++i) {
ncclSend((char*)in_buff + in_offset, in_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
in_offset += in_sizes[i] * size;
out_offset += out_sizes[i] * size;
}
}
ncclGroupEnd();
}

static void batch_all_gather_v(const std::vector<torch::Tensor> &ins, const std::vector<torch::Tensor> &outs, const torch::Tensor &out_sizes_) {
AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL");

auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32);
auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream();

ncclGroupStart();
for (int k = 0; k < ins.size(); ++k) {
auto* in_buff = ins[k].data_ptr();
auto* out_buff = outs[k].data_ptr();
auto dtype = ins[k].dtype();
int size = __dtype_size[*(unsigned short*)&dtype];
AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_gather_v are not recognized.");
AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_gather_v are supposed to share same length.");

int out_offset = 0;
for (int i = 0; i < shared_world_size; ++i) {
if (out_sizes[shared_world_rank])
ncclSend((char*)in_buff, out_sizes[shared_world_rank] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
if (out_sizes[i])
ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
out_offset += out_sizes[i] * size;
}
}
ncclGroupEnd();
}

static std::vector<torch::Tensor> nccl_all_to_all_scatter_async(
const torch::Tensor &input,
torch::IntArrayRef output_size,
Expand Down Expand Up @@ -686,6 +771,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&get_nccl_unique_id,
"Get ncclUniqueId for NCCL initialization"
);
m.def("init_shared_nccl",
&init_shared_nccl,
"NCCL initialization used for global world"
);
m.def("init_nccl",
&init_nccl,
"NCCL initialization"
Expand Down Expand Up @@ -718,6 +807,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&nccl_all_to_all_2d_async,
"NCCL AllToAll (2D Async, In-place if 2DH A2A is enabled)"
);

m.def("batch_all_to_all_v", &batch_all_to_all_v, "NCCL AllToAllV Batched.");
m.def("batch_all_gather_v", &batch_all_gather_v, "NCCL AllGatherV Batched.");
#endif
}

Expand Down
50 changes: 50 additions & 0 deletions tutel/impls/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def barrier(group=None):


TUTEL_GROUPING_CACHE = {}
TUTEL_SHARED_NCCL = False
TUTEL_SKIP_A2A = int(os.environ.get('SKIP_A2A', 0)) > 0

def create_groups_from_world(group_count, include_init=None):
Expand Down Expand Up @@ -128,6 +129,25 @@ class ParallelPropStorage:
result.is_distributed = is_distributed
result.dist_print = dist_print

global TUTEL_SHARED_NCCL
if is_distributed and not TUTEL_SHARED_NCCL and backend == 'nccl':
try:
world_size = get_world_size()
world_rank = get_world_rank()
nccl_unique_id_size = tutel_custom_kernel.get_nccl_unique_id_size()
nccl_unique_id = torch.zeros([nccl_unique_id_size], dtype=torch.int8).cpu()
if world_rank == 0:
tutel_custom_kernel.get_nccl_unique_id(nccl_unique_id)
nccl_unique_id = nccl_unique_id.cuda()
dist.broadcast(nccl_unique_id, 0, None)
tutel_custom_kernel.init_shared_nccl(
nccl_unique_id.cpu(),
world_size,
world_rank)
TUTEL_SHARED_NCCL = True
except:
pass

TUTEL_GROUPING_CACHE[original_group_count] = result
return result

Expand Down Expand Up @@ -186,6 +206,36 @@ def simple_all_gather(input, group=None):
dist.all_gather(tensor_list=tensor_list, tensor=input.view(1, -1), group=group)
return output.view([-1,] + list(input.shape[1:]))

def batch_all_to_all_v(datas, partition_sizes, group=None):
assert group is None, "batched_all_to_all_v() with non-default group is not implemented in this version."
assert type(datas) in (tuple, list), "data type for batch_all_to_all_v() is not a list of tensors"
in_sizes = partition_sizes
if type(in_sizes) != torch.Tensor:
in_sizes = torch.tensor(in_sizes, dtype=torch.int32, device=datas[0].device)
world_size = get_world_size(group)
assert in_sizes.numel() == world_size
if world_size == 1:
return list(datas), in_sizes
out_sizes = simple_all_to_all(in_sizes, group=group)
datas = [data.contiguous().view(-1).cuda() for data in datas]
outputs = [torch.zeros([out_sizes.sum()], dtype=data.dtype, device=data.device) for data in datas]
tutel_custom_kernel.batch_all_to_all_v(datas, outputs, in_sizes, out_sizes)
return outputs, out_sizes

def batch_all_gather_v(datas, group=None):
assert group is None, "batch_all_gather_v() with non-default group is not implemented in this version."
assert type(datas) in (tuple, list), "data type for batch_all_gather_v() is not a list of tensors"
datas = [data.contiguous().view(-1).cuda() for data in datas]
input_size = torch.tensor([int(datas[0].numel())], dtype=torch.int64, device=datas[0].device)
world_size = get_world_size(group)
if world_size == 1:
return list(datas), input_size
output_sizes = simple_all_gather(input_size)
size_int = int(output_sizes.sum())
outputs = [torch.zeros([size_int], dtype=data.dtype, device=data.device) for i, data in enumerate(datas)]
tutel_custom_kernel.batch_all_gather_v(datas, outputs, output_sizes)
return outputs, output_sizes

class AllToAllStatus:
initialized = False
num_split = 0
Expand Down
2 changes: 2 additions & 0 deletions tutel/impls/fast_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,14 @@ def get_dispatch_count(critial_data):
return critial_data[-1]

def fast_encode(data, critial_data, is_postscore=True):
assert data.is_contiguous(), "Input tensor for encode/decode should be in contiguous memory format."
num_global_experts = critial_data[0]
dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype)
dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore)
return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1))

def fast_decode(data, critial_data, is_postscore=True):
assert data.is_contiguous(), "Input tensor for encode/decode should be in contiguous memory format."
num_global_experts = critial_data[0]
dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype)
dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore)
Expand Down
2 changes: 2 additions & 0 deletions tutel/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .impls.communicate import simple_all_reduce, simple_all_to_all,simple_split, simple_reduce_scatter, simple_all_gather
# Communication with Backward Compute
from .impls.communicate import all_to_all, all_to_all_single, all_gather, zero_gather, zero_scatter, spatial_split, reduce_scatter, allreduce_forward, allreduce_backward
# Communication with Batch-based Compute
from .impls.communicate import batch_all_to_all_v, batch_all_gather_v


class TutelDistributedOptimizer:
Expand Down

0 comments on commit c2b2271

Please sign in to comment.