diff --git a/tests/test_tutel.py b/tests/test_tutel.py index 6716c45a..392e02d8 100644 --- a/tests/test_tutel.py +++ b/tests/test_tutel.py @@ -39,7 +39,7 @@ def run( new_env['NCCL_SHM_DISABLE'] = '1' """Run helloworld example""" if helloworld_file == 'helloworld': - command = 'python3 -m torch.distributed.launch --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps) + ' --device ' + str(device) + ' --num_tokens 1024' + command = 'python3 -m torch.distributed.run --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps) + ' --device ' + str(device) + ' --num_tokens 1024' if use_model_parallel: command += ' --parallel_type model' else: diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index 418dacb7..93447797 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -323,12 +323,13 @@ template static void invoke_cpu(const std::vector #if defined(USE_NCCL) -static ncclComm_t g_nccl_comm; +static ncclComm_t g_nccl_comm, shared_nccl_comm; static std::vector g_cuda_events; -static int g_world_size = 0; -static int g_world_rank = 0; +static int g_world_size = 0, shared_world_size = 0; +static int g_world_rank = 0, shared_world_rank = 0; static int g_local_size = 0; static int g_local_rank = 0; +static int __dtype_size[256]; // jit static int mem_stride_copy_char_fd = -1; @@ -349,6 +350,33 @@ static void get_nccl_unique_id(torch::Tensor &nccl_unique_id_tensor) { memcpy((void *)nccl_unique_id_tensor.data_ptr(), &nccl_unique_id, sizeof(ncclUniqueId)); } +static void init_shared_nccl( + const torch::Tensor &nccl_unique_id_tensor, + int world_size, + int world_rank) { + ncclUniqueId nccl_unique_id; + + CHECK_CPU(nccl_unique_id_tensor); + CHECK_EQ(nccl_unique_id_tensor.nbytes(), sizeof(ncclUniqueId)); + memcpy(&nccl_unique_id, (void *)nccl_unique_id_tensor.data_ptr(), sizeof(ncclUniqueId)); + CHECK_EQ(0, ncclGroupStart()); + CHECK_EQ(0, ncclCommInitRank(&shared_nccl_comm, world_size, nccl_unique_id, world_rank)); + CHECK_EQ(0, ncclGroupEnd()); + + shared_world_size = world_size; + shared_world_rank = world_rank; + + __dtype_size[(int)torch::kFloat64] = 8; + __dtype_size[(int)torch::kInt64] = 8; + __dtype_size[(int)torch::kFloat32] = 4; + __dtype_size[(int)torch::kInt32] = 4; + __dtype_size[(int)torch::kFloat16] = 2; + __dtype_size[(int)torch::kInt16] = 2; + __dtype_size[(int)torch::kInt8] = 1; + __dtype_size[(int)torch::kUInt8] = 1; + __dtype_size[(int)torch::kBool] = 1; +} + static void init_nccl( const torch::Tensor &nccl_unique_id_tensor, int world_size, @@ -431,6 +459,63 @@ static torch::Tensor& nccl_stream_acquire(torch::Tensor &tensor, int idx) { return tensor; } +static void batch_all_to_all_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &in_sizes_, const torch::Tensor &out_sizes_) { + AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL"); + + auto in_sizes_cpu = in_sizes_.to(torch::kCPU).to(torch::kInt32); + auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32); + auto* in_sizes = (unsigned int*)in_sizes_cpu.data_ptr(); + auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); + auto stream = at::cuda::getCurrentCUDAStream(); + + ncclGroupStart(); + for (int k = 0; k < ins.size(); ++k) { + auto* in_buff = ins[k].data_ptr(); + auto* out_buff = outs[k].data_ptr(); + auto dtype = ins[k].dtype(); + int size = __dtype_size[*(unsigned short*)&dtype]; + AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_to_all_v are not recognized."); + AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_to_all_v are supposed to share same length."); + + int in_offset = 0, out_offset = 0; + for (int i = 0; i < shared_world_size; ++i) { + ncclSend((char*)in_buff + in_offset, in_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); + ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); + in_offset += in_sizes[i] * size; + out_offset += out_sizes[i] * size; + } + } + ncclGroupEnd(); +} + +static void batch_all_gather_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &out_sizes_) { + AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL"); + + auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32); + auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); + auto stream = at::cuda::getCurrentCUDAStream(); + + ncclGroupStart(); + for (int k = 0; k < ins.size(); ++k) { + auto* in_buff = ins[k].data_ptr(); + auto* out_buff = outs[k].data_ptr(); + auto dtype = ins[k].dtype(); + int size = __dtype_size[*(unsigned short*)&dtype]; + AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_gather_v are not recognized."); + AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_gather_v are supposed to share same length."); + + int out_offset = 0; + for (int i = 0; i < shared_world_size; ++i) { + if (out_sizes[shared_world_rank]) + ncclSend((char*)in_buff, out_sizes[shared_world_rank] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); + if (out_sizes[i]) + ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); + out_offset += out_sizes[i] * size; + } + } + ncclGroupEnd(); +} + static std::vector nccl_all_to_all_scatter_async( const torch::Tensor &input, torch::IntArrayRef output_size, @@ -686,6 +771,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &get_nccl_unique_id, "Get ncclUniqueId for NCCL initialization" ); + m.def("init_shared_nccl", + &init_shared_nccl, + "NCCL initialization used for global world" + ); m.def("init_nccl", &init_nccl, "NCCL initialization" @@ -718,6 +807,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &nccl_all_to_all_2d_async, "NCCL AllToAll (2D Async, In-place if 2DH A2A is enabled)" ); + + m.def("batch_all_to_all_v", &batch_all_to_all_v, "NCCL AllToAllV Batched."); + m.def("batch_all_gather_v", &batch_all_gather_v, "NCCL AllGatherV Batched."); #endif } diff --git a/tutel/impls/communicate.py b/tutel/impls/communicate.py index 2fe8b89e..4c75d0da 100644 --- a/tutel/impls/communicate.py +++ b/tutel/impls/communicate.py @@ -33,6 +33,7 @@ def barrier(group=None): TUTEL_GROUPING_CACHE = {} +TUTEL_SHARED_NCCL = False TUTEL_SKIP_A2A = int(os.environ.get('SKIP_A2A', 0)) > 0 def create_groups_from_world(group_count, include_init=None): @@ -128,6 +129,25 @@ class ParallelPropStorage: result.is_distributed = is_distributed result.dist_print = dist_print + global TUTEL_SHARED_NCCL + if is_distributed and not TUTEL_SHARED_NCCL and backend == 'nccl': + try: + world_size = get_world_size() + world_rank = get_world_rank() + nccl_unique_id_size = tutel_custom_kernel.get_nccl_unique_id_size() + nccl_unique_id = torch.zeros([nccl_unique_id_size], dtype=torch.int8).cpu() + if world_rank == 0: + tutel_custom_kernel.get_nccl_unique_id(nccl_unique_id) + nccl_unique_id = nccl_unique_id.cuda() + dist.broadcast(nccl_unique_id, 0, None) + tutel_custom_kernel.init_shared_nccl( + nccl_unique_id.cpu(), + world_size, + world_rank) + TUTEL_SHARED_NCCL = True + except: + pass + TUTEL_GROUPING_CACHE[original_group_count] = result return result @@ -186,6 +206,36 @@ def simple_all_gather(input, group=None): dist.all_gather(tensor_list=tensor_list, tensor=input.view(1, -1), group=group) return output.view([-1,] + list(input.shape[1:])) +def batch_all_to_all_v(datas, partition_sizes, group=None): + assert group is None, "batched_all_to_all_v() with non-default group is not implemented in this version." + assert type(datas) in (tuple, list), "data type for batch_all_to_all_v() is not a list of tensors" + in_sizes = partition_sizes + if type(in_sizes) != torch.Tensor: + in_sizes = torch.tensor(in_sizes, dtype=torch.int32, device=datas[0].device) + world_size = get_world_size(group) + assert in_sizes.numel() == world_size + if world_size == 1: + return list(datas), in_sizes + out_sizes = simple_all_to_all(in_sizes, group=group) + datas = [data.contiguous().view(-1).cuda() for data in datas] + outputs = [torch.zeros([out_sizes.sum()], dtype=data.dtype, device=data.device) for data in datas] + tutel_custom_kernel.batch_all_to_all_v(datas, outputs, in_sizes, out_sizes) + return outputs, out_sizes + +def batch_all_gather_v(datas, group=None): + assert group is None, "batch_all_gather_v() with non-default group is not implemented in this version." + assert type(datas) in (tuple, list), "data type for batch_all_gather_v() is not a list of tensors" + datas = [data.contiguous().view(-1).cuda() for data in datas] + input_size = torch.tensor([int(datas[0].numel())], dtype=torch.int64, device=datas[0].device) + world_size = get_world_size(group) + if world_size == 1: + return list(datas), input_size + output_sizes = simple_all_gather(input_size) + size_int = int(output_sizes.sum()) + outputs = [torch.zeros([size_int], dtype=data.dtype, device=data.device) for i, data in enumerate(datas)] + tutel_custom_kernel.batch_all_gather_v(datas, outputs, output_sizes) + return outputs, output_sizes + class AllToAllStatus: initialized = False num_split = 0 diff --git a/tutel/impls/fast_dispatch.py b/tutel/impls/fast_dispatch.py index 3fca138d..195a06fb 100644 --- a/tutel/impls/fast_dispatch.py +++ b/tutel/impls/fast_dispatch.py @@ -204,12 +204,14 @@ def get_dispatch_count(critial_data): return critial_data[-1] def fast_encode(data, critial_data, is_postscore=True): + assert data.is_contiguous(), "Input tensor for encode/decode should be in contiguous memory format." num_global_experts = critial_data[0] dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore) return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1)) def fast_decode(data, critial_data, is_postscore=True): + assert data.is_contiguous(), "Input tensor for encode/decode should be in contiguous memory format." num_global_experts = critial_data[0] dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore) diff --git a/tutel/net.py b/tutel/net.py index e472c169..d5994fc6 100644 --- a/tutel/net.py +++ b/tutel/net.py @@ -8,6 +8,8 @@ from .impls.communicate import simple_all_reduce, simple_all_to_all,simple_split, simple_reduce_scatter, simple_all_gather # Communication with Backward Compute from .impls.communicate import all_to_all, all_to_all_single, all_gather, zero_gather, zero_scatter, spatial_split, reduce_scatter, allreduce_forward, allreduce_backward +# Communication with Batch-based Compute +from .impls.communicate import batch_all_to_all_v, batch_all_gather_v class TutelDistributedOptimizer: