diff --git a/.github/workflows/fbgemm_docs.yml b/.github/workflows/fbgemm_docs.yml new file mode 100644 index 000000000..327e5dd6b --- /dev/null +++ b/.github/workflows/fbgemm_docs.yml @@ -0,0 +1,84 @@ +# This workflow builds the fbgemm_gpu docs and deploys them to gh-pages. +name: Generate documentation +on: + push: + branches: + - main +jobs: + build_docs_job: + runs-on: linux.2xlarge + steps: + # Checkout the repository to the GitHub Actions runner + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: true + # Update references + # TODO: update the git submodule sync after we fixed the auto-sync part + - name: Git Sumbodule Update + run: | + git submodule init + git submodule update --remote --recursive + git log + - name: Update pip + run: | + sudo yum update -y + sudo yum -y install git python3-pip + sudo pip3 install --upgrade pip + - name: Setup conda + run: | + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh + bash ~/miniconda.sh -b -p $HOME/miniconda + - name: setup Path + run: | + echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH + echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH + - name: create conda env + run: | + conda create --name build_binary python=3.9 + conda info + - name: check python version + run: | + conda run -n build_binary python --version + - name: Install gcc + shell: bash + run: | + sudo yum group install -y "Development Tools" + - name: Setup Path + run: | + echo /usr/local/bin >> $GITHUB_PATH + - name: Install PyTorch + shell: bash + run: | + conda run -n build_binary python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Test PyTorch Installation + run: | + conda run -n build_binary python -c "import torch.distributed" + echo "torch.distributed succeeded" + - name: Install fbgemm_gpu nightly + run: | + conda run -n build_binary python -m pip install fbgemm-gpu-nightly-cpu + - name: Test fbgemm_gpu installation + shell: bash + run: | + conda run -n build_binary \ + python -c "import fbgemm_gpu" + - name: Install Doxygen + run: | + conda install -n build_binary -c conda-forge doxygen + which doxygen + - name: Build the docset + run: | + conda run -n build_binary python -m pip install -r fbgemm_gpu/docs/requirements.txt + cd ./fbgemm_gpu/docs + conda run -n build_binary doxygen Doxyfile.in + conda run -n build_binary make html + cd .. + - name: Get output time + run: echo "The time was ${{ steps.build.outputs.time }}" + - name: Deploy + uses: JamesIves/github-pages-deploy-action@releases/v3 + with: + ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BRANCH: gh-pages # The branch the action should deploy to. + FOLDER: fbgemm_gpu/docs/build/html # The folder the action should deploy. diff --git a/.github/workflows/fbgemm_nightly_build.yml b/.github/workflows/fbgemm_nightly_build.yml index 552b5eb33..f70e20dae 100644 --- a/.github/workflows/fbgemm_nightly_build.yml +++ b/.github/workflows/fbgemm_nightly_build.yml @@ -117,6 +117,7 @@ jobs: --python-tag=${{ matrix.python-tag }} \ -DTORCH_CUDA_ARCH_LIST="'7.0;8.0'" \ --plat-name=manylinux1_x86_64 + ls -lt dist/*.whl - name: Upload wheel as GHA artifact uses: actions/upload-artifact@v2 with: @@ -271,4 +272,5 @@ jobs: --username __token__ \ --password "$PYPI_TOKEN" \ --skip-existing \ + --verbose \ fbgemm_gpu_nightly-*.whl diff --git a/.github/workflows/fbgemm_nightly_build_cpu.yml b/.github/workflows/fbgemm_nightly_build_cpu.yml index ebaeee2b1..8e6ba7291 100644 --- a/.github/workflows/fbgemm_nightly_build_cpu.yml +++ b/.github/workflows/fbgemm_nightly_build_cpu.yml @@ -105,6 +105,7 @@ jobs: --python-tag=${{ matrix.python-tag }} \ --cpu_only \ --plat-name=manylinux1_x86_64 + ls -lt dist/*.whl - name: Upload wheel as GHA artifact uses: actions/upload-artifact@v2 with: @@ -156,4 +157,5 @@ jobs: --username __token__ \ --password "$PYPI_TOKEN" \ --skip-existing \ + --verbose \ fbgemm_gpu/dist/fbgemm_gpu_nightly_cpu-*.whl diff --git a/.github/workflows/fbgemm_release_build.yml b/.github/workflows/fbgemm_release_build.yml index 1d8acca06..c39490c32 100644 --- a/.github/workflows/fbgemm_release_build.yml +++ b/.github/workflows/fbgemm_release_build.yml @@ -119,6 +119,7 @@ jobs: --python-tag=${{ matrix.python-tag }} \ -DTORCH_CUDA_ARCH_LIST="'7.0;8.0'" \ --plat-name=manylinux1_x86_64 + ls -lt dist/*.whl - name: Upload wheel as GHA artifact uses: actions/upload-artifact@v2 with: @@ -273,4 +274,5 @@ jobs: --username __token__ \ --password "$PYPI_TOKEN" \ --skip-existing \ + --verbose \ fbgemm_gpu-*.whl diff --git a/.github/workflows/fbgemm_release_build_cpu.yml b/.github/workflows/fbgemm_release_build_cpu.yml index 27c0c5888..99664baac 100644 --- a/.github/workflows/fbgemm_release_build_cpu.yml +++ b/.github/workflows/fbgemm_release_build_cpu.yml @@ -107,6 +107,7 @@ jobs: --python-tag=${{ matrix.python-tag }} \ --cpu_only \ --plat-name=manylinux1_x86_64 + ls -lt dist/*.whl - name: Upload wheel as GHA artifact uses: actions/upload-artifact@v2 with: @@ -158,4 +159,5 @@ jobs: --username __token__ \ --password "$PYPI_TOKEN" \ --skip-existing \ + --verbose \ fbgemm_gpu/dist/fbgemm_gpu_cpu-*.whl diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 092226d19..250544bed 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -248,6 +248,113 @@ jobs: python -c "import fbgemm_gpu" python -c "import fbgemm_gpu.split_embedding_codegen_lookup_invokers" + build_amd_gpu: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + + steps: + - name: Free space + run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android + + - uses: actions/checkout@v2 + + - name: Install ROCm 5.1.1 + shell: bash + run: | + sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 10 + wget https://repo.radeon.com/amdgpu-install/22.10.1/ubuntu/focal/amdgpu-install_22.10.1.50101-1_all.deb + export DEBIAN_FRONTEND=noninteractive + sudo apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb + amdgpu-install -y --usecase=hiplibsdk,rocm --no-dkms + sudo rm amdgpu-install_22.10.1.50101-1_all.deb + + - name: Install dependencies + shell: bash + run: | + sudo apt-get update + sudo apt-get -y install git pip python3-dev mesa-common-dev clang comgr libopenblas-dev jp intel-mkl-full locales libnuma-dev + sudo apt-get install -y hipify-clang || true + sudo pip install cmake scikit-build ninja jinja2 numpy hypothesis --no-input + sudo apt-get clean + # Install pytorch 1.11 as required by fbgemm_gpu + sudo pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ + + - name: Checkout submodules + shell: bash + run: | + cd fbgemm_gpu + git submodule sync + git submodule update --init --recursive + + - name: Build fbgemm_gpu + shell: bash + run: | + sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 10 + cd fbgemm_gpu + # build for MI250 only to save time. + sudo PYTORCH_ROCM_ARCH=gfx90a python3 setup.py build develop + + - name: Test fbgemm_gpu installation + shell: bash + run: | + cd fbgemm_gpu + cd test + python3 input_combine_test.py + python3 quantize_ops_test.py + python3 sparse_ops_test.py + python3 -c "import fbgemm_gpu" + python3 -c "import fbgemm_gpu.split_embedding_codegen_lookup_invokers" + + test_amd_gpu: + runs-on: rocm + strategy: + matrix: + os: [ubuntu-latest] + + steps: + - name: pre-checkout + shell: bash + run: | + if [ -d ${{ github.workspace }} ] + then + sudo chown -R $USER:$USER ${{ github.workspace }} + fi + sudo add-apt-repository ppa:git-core/ppa + sudo apt update + sudo apt -y install --only-upgrade git + + - uses: actions/checkout@v2 + with: + ref: ${{ github.ref }} + submodules: 'true' + + - name: build fbgemm_gpu and test + shell: bash + run: | + set -eux + env + ls -l + DOCKER_IMAGE=rocm/pytorch:rocm5.1.1_ubuntu20.04_py3.7_pytorch_staging_base + docker pull $DOCKER_IMAGE + JENKINS_REPO_DIR=fbgemm-private-jenkins + JENKINS_REPO_DIR_BAREMETAL=$PWD + JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR + DOCKER_OPTIONS="\ + --user 0 \ + --network=host \ + --ipc=host \ + --shm-size 16G \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device=/dev/kfd \ + --device=/dev/dri \ + -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER + " + docker run $DOCKER_OPTIONS $DOCKER_IMAGE $JENKINS_REPO_DIR_DOCKER/.jenkins/rocm/build_and_test.sh $JENKINS_REPO_DIR_DOCKER + build_cpu_only: runs-on: ${{ matrix.os }} strategy: diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh new file mode 100755 index 000000000..dadd1342c --- /dev/null +++ b/.jenkins/rocm/build_and_test.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# exit immediately on failure, or if an undefined variable is used +set -eux + +FBGEMM_REPO_DIR=${1:-/workspace/FBGEMM} + +git config --global --add safe.directory "$FBGEMM_REPO_DIR" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/asmjit" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/cpuinfo" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/googletest" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/hipify_torch" + +# Install dependencies +apt-get update --allow-insecure-repositories && \ + apt-get install -y --allow-unauthenticated \ + git \ + jq \ + sshfs \ + sshpass \ + unzip + +apt-get install -y locales +locale-gen en_US.UTF-8 + +pip3 install click +pip3 install jinja2 +pip3 install ninja +pip3 install scikit-build +pip3 install --upgrade hypothesis +pip3 install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ + +pip3 list + +# Build fbgemm_gpu +cd "$FBGEMM_REPO_DIR/fbgemm_gpu" +MAX_JOBS="$(nproc)" +export MAX_JOBS +export PYTORCH_ROCM_ARCH="gfx908" +python setup.py build develop + +export FBGEMM_TEST_WITH_ROCM=1 + +# Test fbgemm_gpu +cd test + +python batched_unary_embeddings_test.py --verbose +python input_combine_test.py --verbose +python jagged_tensor_ops_test.py --verbose +python layout_transform_ops_test.py --verbose +python merge_pooled_embeddings_test.py --verbose +python metric_ops_test.py --verbose +python permute_pooled_embedding_modules_test.py --verbose +python quantize_ops_test.py --verbose +python sparse_ops_test.py --verbose +python split_embedding_inference_converter_test.py --verbose +python split_table_batched_embeddings_test.py --verbose +python uvm_test.py --verbose diff --git a/CMakeLists.txt b/CMakeLists.txt index 6be603dfa..d8b6b989c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,6 +182,10 @@ if(NOT TARGET asmjit) add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit") set_property(TARGET asmjit PROPERTY POSITION_INDEPENDENT_CODE ON) + # add a flag required for mac build + if(NOT MSVC) + target_compile_options(asmjit PRIVATE "-Wno-sign-conversion") + endif() endif() if(NOT TARGET cpuinfo) @@ -293,6 +297,10 @@ endif() if(FBGEMM_BUILD_BENCHMARKS) add_subdirectory(bench) + # add a flag to enable Clang 14 + set_source_files_properties( + bench/GEMMsBenchmark.cc + PROPERTIES COMPILE_FLAGS "-Wno-unused-variable") endif() if(FBGEMM_BUILD_DOCS) diff --git a/README.md b/README.md index 48de40e7c..195f5e011 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,8 @@ cd FBGEMM # if you are updating an existing checkout git submodule sync git submodule update --init --recursive -mkdir build && cd build -cmake .. -make +cmake -B build +make -C build ``` To run the tests after building FBGEMM (if tests are built), use the following diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 09c1669ac..2a1babf90 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -10,17 +10,14 @@ if(SKBUILD) message("The project is built using scikit-build") endif() -if(EXISTS "/usr/bin/nvidia-smi") - message("NVIDIA GPU detected.") - option(USE_CUDA "Use CUDA" ON) - option(USE_ROCM "Use ROCm" OFF) -elseif(EXISTS "/opt/rocm/bin/rocm-smi") +option(USE_CUDA "Use CUDA" ON) +option(USE_ROCM "Use ROCm" OFF) + +if(((EXISTS "/opt/rocm/") OR (EXISTS $ENV{ROCM_PATH})) + AND NOT (EXISTS "/bin/nvcc")) message("AMD GPU detected.") - option(USE_CUDA "Use CUDA" OFF) - option(USE_ROCM "Use ROCm" ON) -else() - message("Unable to detect GPU vendor") - message(FATAL_ERROR "") + set(USE_CUDA OFF) + set(USE_ROCM ON) endif() if(FBGEMM_CPU_ONLY) @@ -28,6 +25,7 @@ if(FBGEMM_CPU_ONLY) endif() message("${message_line}") +message(STATUS "USE_ROCM ${USE_ROCM}") if(FBGEMM_CPU_ONLY OR USE_ROCM) project( @@ -41,25 +39,20 @@ else() LANGUAGES CXX C CUDA) endif() -if(USE_CUDA) - set(default_cuda_architectures 60 61 70 75 80) - set(cuda_architectures_doc - "CUDA architectures to build for. Default is ${default_cuda_architectures}") - set(cuda_architectures - "${default_cuda_architectures}" - CACHE STRING "${cuda_architectures_doc}") +find_package(Torch REQUIRED) +find_package(PythonExtensions REQUIRED) - message("${message_line}") - message("fbgemm_gpu:") - message("Building for cuda_architectures = \"${cuda_architectures}\"") - message("${message_line}") +set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(THIRDPARTY ${FBGEMM}/third_party) if(DEFINED GLIBCXX_USE_CXX11_ABI) if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) set(CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") - message("${CMAKE_CXX_FLAGS}") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() + message("${CMAKE_CXX_FLAGS}") endif() # @@ -68,49 +61,22 @@ endif() # constructor exists to convert from "int" to "__half" errors in # gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu # - set(TORCH_CUDA_OPTIONS - --expt-relaxed-constexpr - -D__CUDA_NO_HALF_OPERATORS__ - # -D__CUDA_NO_HALF_CONVERSIONS__ - -D__CUDA_NO_BFLOAT16_CONVERSIONS__ - -D__CUDA_NO_HALF2_OPERATORS__) -endif() -find_package(Torch REQUIRED) -find_package(PythonExtensions REQUIRED) - -set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(THIRDPARTY ${FBGEMM}/third_party) +set(TORCH_CUDA_OPTIONS + --expt-relaxed-constexpr -D__CUDA_NO_HALF_OPERATORS__ + # -D__CUDA_NO_HALF_CONVERSIONS__ + -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__) if(USE_ROCM) - if(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) - SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) - else() - SET(FBGEMM_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) - endif() - - list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" "${THIRDPARTY}/hipify_torch/cmake") + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" + "${THIRDPARTY}/hipify_torch/cmake") include(Hip) - if(NOT FBGEMM_HAVE_HIP) - message(FATAL_ERROR "Not able to find HIP installation.") - endif() include(Hipify) - list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) - set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) - - find_package(rocBLAS REQUIRED) - find_package(hipFFT REQUIRED) - find_package(hipRAND REQUIRED) - find_package(rocRAND REQUIRED) - find_package(hipSPARSE REQUIRED) - find_package(OpenMP REQUIRED) - find_package(rocPRIM REQUIRED) - + message("${message_line}") - message(STATUS "hip found ${ROCM_FOUND}") + message(STATUS "hip found ${HIP_FOUND}") endif() - # # GENERATED CUDA, CPP and Python code # @@ -197,49 +163,41 @@ set(codegen_dependencies ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh - ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_gpu.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_utils.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh - ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h -) + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h) -if(USE_CUDA) - add_custom_command( - OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} - ${gen_gpu_host_source_files} ${gen_python_files} +if(USE_ROCM) + execute_process( COMMAND "${PYTHON_EXECUTABLE}" "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" - "--opensource" - DEPENDS "${codegen_dependencies}") - - set_source_files_properties( - ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-fopenmp") -elseif(USE_ROCM) - execute_process( - COMMAND - "${PYTHON_EXECUTABLE}" - "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" - "--opensource") + "--opensource" DEPENDS "${codegen_dependencies}") set(header_include_dir - ${CMAKE_CURRENT_SOURCE_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR}/src - ${CMAKE_CURRENT_SOURCE_DIR} - ) - hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR ${header_include_dir}) + ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR}) - set_source_files_properties( - ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma") + hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR + ${header_include_dir}) +else() + add_custom_command( + OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} + ${gen_gpu_host_source_files} ${gen_python_files} + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" "--opensource" + DEPENDS "${codegen_dependencies}") endif() +set_source_files_properties( + ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma;-fopenmp") set_source_files_properties( ${gen_cpu_source_files} PROPERTIES INCLUDE_DIRECTORIES - "${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include" + "${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include;${THIRDPARTY}/asmjit/src" ) set_source_files_properties( @@ -257,8 +215,8 @@ set_source_files_properties(${gen_gpu_source_files} PROPERTIES COMPILE_OPTIONS "${TORCH_CUDA_OPTIONS}") if(NOT FBGEMM_CPU_ONLY) - set(gen_source_files ${gen_gpu_source_files} - ${gen_gpu_host_source_files} ${gen_cpu_source_files}) + set(gen_source_files ${gen_gpu_source_files} ${gen_gpu_host_source_files} + ${gen_cpu_source_files}) else() set(gen_source_files ${gen_cpu_source_files}) endif() @@ -285,14 +243,18 @@ set(cpp_fbgemm_files_avx2 "../src/EmbeddingSpMDMAvx2.cc" set_source_files_properties(${cpp_fbgemm_files_avx2} PROPERTIES COMPILE_OPTIONS "-mavx2;-mf16c;-mfma") -set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}) set(cpp_fbgemm_files_avx512 "../src/EmbeddingSpMDMAvx512.cc") -if(USE_CUDA) - set_source_files_properties( - ${cpp_fbgemm_files_avx512} - PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") - list(APPEND cpp_fbgemm_files ${cpp_fbgemm_files_avx512}) + +set_source_files_properties( + ${cpp_fbgemm_files_avx512} + PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") + +if(USE_ROCM) + set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}) +else() + set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2} + ${cpp_fbgemm_files_avx512}) endif() set(cpp_fbgemm_files_include_directories @@ -307,18 +269,15 @@ set_source_files_properties( # Actual static SOURCES # -# Ensure NVML_LIB_PATH is empty if it wasn't set and if the -# default lib path doesn't exist. +# Ensure NVML_LIB_PATH is empty if it wasn't set and if the default lib path +# doesn't exist. if(NOT NVML_LIB_PATH) set(DEFAULT_NVML_LIB_PATH - "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") + "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") if(EXISTS ${DEFAULT_NVML_LIB_PATH}) - message( - STATUS - "Setting NVML_LIB_PATH: \ - ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so" - ) + message(STATUS "Setting NVML_LIB_PATH: \ + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") set(NVML_LIB_PATH "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") endif() endif() @@ -336,7 +295,9 @@ set(fbgemm_gpu_sources_cpu src/sparse_ops_cpu.cpp) if(NOT FBGEMM_CPU_ONLY) - list(APPEND fbgemm_gpu_sources_cpu + list( + APPEND + fbgemm_gpu_sources_cpu codegen/embedding_forward_quantized_host.cpp codegen/embedding_backward_dense_host.cpp codegen/embedding_bounds_check_host.cpp @@ -347,33 +308,37 @@ if(NOT FBGEMM_CPU_ONLY) src/permute_pooled_embedding_ops_split_cpu.cpp src/quantize_ops_gpu.cpp src/sparse_ops_gpu.cpp - src/split_table_batched_embeddings.cpp) + src/split_table_batched_embeddings.cpp + src/metric_ops_host.cpp) - if(NVML_LIB_PATH) - list(APPEND fbgemm_gpu_sources_cpu - src/merge_pooled_embeddings_cpu.cpp - src/merge_pooled_embeddings_gpu.cpp) - endif() + if(NVML_LIB_PATH) + list(APPEND fbgemm_gpu_sources_cpu src/merge_pooled_embeddings_cpu.cpp + src/merge_pooled_embeddings_gpu.cpp) + endif() endif() -set(fbgemm_gpu_sources_cpu_option "-mavx;-mf16c;-mfma;-mavx2") -if(USE_CUDA) - set_source_files_properties( - ${fbgemm_gpu_sources_cpu} PROPERTIES COMPILE_OPTIONS - "${fbgemm_gpu_sources_cpu_option};-fopenmp") -endif() +set_source_files_properties( + ${fbgemm_gpu_sources_cpu} PROPERTIES COMPILE_OPTIONS + "-mavx;-mf16c;-mfma;-mavx2;-fopenmp") if(NOT FBGEMM_CPU_ONLY) set(fbgemm_gpu_sources_gpu - codegen/embedding_bounds_check.cu src/cumem_utils.cu - src/histogram_binning_calibration_ops.cu src/jagged_tensor_ops.cu - src/layout_transform_ops.cu src/permute_pooled_embedding_ops.cu - src/permute_pooled_embedding_ops_split.cu - src/quantize_ops.cu src/sparse_ops.cu src/split_embeddings_cache_cuda.cu - src/split_embeddings_utils.cu) - - set_source_files_properties(${fbgemm_gpu_sources_gpu} - PROPERTIES COMPILE_OPTIONS "${TORCH_CUDA_OPTIONS}") + codegen/embedding_bounds_check.cu + src/cumem_utils.cu + src/histogram_binning_calibration_ops.cu + src/jagged_tensor_ops.cu + src/layout_transform_ops.cu + src/permute_pooled_embedding_ops.cu + src/permute_pooled_embedding_ops_split.cu + src/quantize_ops.cu + src/sparse_ops.cu + src/split_embeddings_cache_cuda.cu + src/split_embeddings_utils.cu + src/metric_ops.cu) + + set_source_files_properties( + ${fbgemm_gpu_sources_gpu} PROPERTIES COMPILE_OPTIONS + "${TORCH_CUDA_OPTIONS}") # XXXUPS!!! Replace with real set_source_files_properties( @@ -394,38 +359,47 @@ endif() if(USE_ROCM) set(abspath_gen_source_files) foreach(filename_gen_source_file ${gen_source_files}) - list(APPEND abspath_gen_source_files "${CMAKE_BINARY_DIR}/${filename_gen_source_file}") + list(APPEND abspath_gen_source_files + "${CMAKE_BINARY_DIR}/${filename_gen_source_file}") endforeach() endif() -if(USE_CUDA) - add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} - ${cpp_asmjit_files} ${cpp_fbgemm_files}) - set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES - "${cuda_architectures}") - if(NOT FBGEMM_CPU_ONLY) - target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE) - endif() - set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) -elseif(USE_ROCM) +# +# MODULE +# + +if(USE_ROCM) get_hipified_list("${fbgemm_gpu_sources}" fbgemm_gpu_sources) get_hipified_list("${abspath_gen_source_files}" abspath_gen_source_files) get_hipified_list("${cpp_fbgemm_files}" cpp_fbgemm_files) - set(FBGEMM_ALL_HIP_FILES ${fbgemm_gpu_sources} ${abspath_gen_source_files} ${cpp_fbgemm_files}) - set_source_files_properties(${FBGEMM_ALL_HIP_FILES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + set(FBGEMM_ALL_HIP_FILES ${fbgemm_gpu_sources} ${abspath_gen_source_files} + ${cpp_fbgemm_files}) + set_source_files_properties(${FBGEMM_ALL_HIP_FILES} + PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) hip_include_directories("${cpp_fbgemm_files_include_directories}") - - hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} - HIPCC_OPTIONS ${HIP_HCC_FLAGS}) - target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) -endif() -list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) -if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) -endif() -if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) + + hip_add_library( + fbgemm_gpu_py + SHARED + ${cpp_asmjit_files} + ${FBGEMM_ALL_HIP_FILES} + ${FBGEMM_HIP_HCC_LIBRARIES} + HIPCC_OPTIONS + ${HIP_HCC_FLAGS}) + target_include_directories( + fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} + ${ROCM_SMI_INCLUDE}) + list(GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) +else() + add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} + ${cpp_asmjit_files} ${cpp_fbgemm_files}) + set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES + "${cuda_architectures}") + + if(NOT FBGEMM_CPU_ONLY) + target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE) + endif() endif() set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "") @@ -435,9 +409,7 @@ if(NVML_LIB_PATH) target_link_libraries(fbgemm_gpu_py ${NVML_LIB_PATH}) endif() target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS}) -if(USE_CUDA) - set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) -endif() +set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) install(TARGETS fbgemm_gpu_py DESTINATION fbgemm_gpu) diff --git a/fbgemm_gpu/README.md b/fbgemm_gpu/README.md index 680b55f14..5f95d8c93 100644 --- a/fbgemm_gpu/README.md +++ b/fbgemm_gpu/README.md @@ -143,6 +143,11 @@ cd ../bench python split_table_batched_embeddings_benchmark.py ``` +To run the tests and benchmarks on a GPU-capable device in CPU-only mode use CUDA_VISIBLE_DEVICES=-1 +``` +CUDA_VISIBLE_DEVICES=-1 python split_table_batched_embeddings_test.py +``` + ## How FBGEMM_GPU works For a high-level overview, design philosophy and brief descriptions of various parts of FBGEMM_GPU please see our Wiki (work in progress). @@ -151,6 +156,10 @@ parts of FBGEMM_GPU please see our Wiki (work in progress). We have extensively used comments in our source files. The best and up-do-date documentation is available in the source files. +# Building API Documentation + +See [docs/README.md](docs/README.md). + ## Join the FBGEMM community See the [`CONTRIBUTING`](../CONTRIBUTING.md) file for how to help out. diff --git a/fbgemm_gpu/bench/README.md b/fbgemm_gpu/bench/README.md new file mode 100644 index 000000000..4588ff4a9 --- /dev/null +++ b/fbgemm_gpu/bench/README.md @@ -0,0 +1,7 @@ +### Benchmarks + +## TorchRec FusedTableBatchedEmbeddingBags + +[Torchrec](https://pytorch.org/torchrec/) uses fbgemm_gpu embedding and embedding bag implementations for Fused, Batched, Quantized versions of embedding and embeddingbag (in addition to other kernels). +They have run benchmarks on FusedEmbeddingBagCollection, which is implemented with fbgemm_gpu's [`SplitTableBatchedEmbeddingBagsCodegen`](https://github.com/pytorch/FBGEMM/blob/253b8842eeb2b33e65f7e2a7cfb79923b0e46bd7/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py#L171). They benchmark utilizing UVM and UVM-caching. +The [results](https://github.com/pytorch/torchrec/tree/main/benchmarks) show between 13x and 23x usecase in DLRM embedding sizes. diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 14dc9c359..08ca3f554 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -121,7 +121,7 @@ def main(batch_size, num_tables, num_tasks, repeats) -> None: param.detach().copy_(ref_emb.emb_modules[i].weight) output_ref = ref_emb(offsets, indices) output = unary_emb(offsets_tensor, indices_tensor) - torch.testing.assert_allclose(output_ref, output) + torch.testing.assert_close(output_ref, output) # backward d_output = torch.randn([num_tasks, batch_size, len(hash_sizes)]).to(device) * 0.1 output_ref.backward(d_output) @@ -131,7 +131,8 @@ def main(batch_size, num_tables, num_tasks, repeats) -> None: d_weight_ref.append(emb.weight.grad) d_weight_ref = torch.cat(d_weight_ref).view(num_tasks, -1) d_weight = unary_emb.weight.grad - torch.testing.assert_allclose(d_weight_ref, d_weight.squeeze()) + # pyre-fixme[16]: Optional type has no attribute `squeeze`. + torch.testing.assert_close(d_weight_ref, d_weight.squeeze()) # A100 40MB L2 cache elapse, _ = benchmark_torch_function(ref_emb, (offsets, indices), iters=repeats) diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index d9abf3626..6edf200bb 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -3,12 +3,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools +import logging +import statistics import time -from typing import Tuple +from typing import Callable, List, Optional, Tuple +import numpy as np import torch +from fbgemm_gpu.split_table_batched_embeddings_ops import SparseType + +# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed). +from numpy.random import default_rng from torch import Tensor +logging.basicConfig(level=logging.DEBUG) + def benchmark_torch_function( # pyre-fixme[2]: Parameter must be annotated. @@ -46,3 +56,365 @@ def benchmark_torch_function( # pyre-fixme[61]: `output` is undefined, or not always defined. return float(elapsed_time) / iters, output + + +def round_up(a: int, b: int) -> int: + return int((a + b - 1) // b) * b + + +def get_device() -> torch.device: + # pyre-fixme[7]: Expected `device` but got `Union[int, device]`. + return ( + torch.cuda.current_device() + if torch.cuda.is_available() + else torch.device("cpu") + ) + + +# Merged indices with shape (T, B, L) -> (flattened indices with shape +# (T * B * L), offsets with shape (T * B + 1)) +def get_table_batched_offsets_from_dense( + merged_indices: Tensor, +) -> Tuple[Tensor, Tensor]: + (T, B, L) = merged_indices.size() + lengths = np.ones((T, B)) * L + flat_lengths = lengths.flatten() + return ( + merged_indices.long().contiguous().view(-1).to(get_device()), + torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long().to(get_device()), + ) + + +def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + (B, L) = indices.size() + return ( + indices.contiguous().view(-1), + torch.tensor( + np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64) + ), + ) + + +def b_indices( + b: Callable[..., torch.Tensor], + x: torch.Tensor, + per_sample_weights: Optional[torch.Tensor] = None, + use_cpu: bool = False, + do_pooling: bool = True, +) -> torch.Tensor: + (indices, offsets) = get_offsets_from_dense(x) + if do_pooling: + return b( + indices.cuda(), + offsets.cuda(), + per_sample_weights=per_sample_weights, + ) + else: + return b(indices.cuda()) + + +def generate_requests( + iters: int, + B: int, + T: int, + L: int, + E: int, + # inter-batch indices reuse rate + reuse: float = 0.0, + # alpha <= 1.0: use uniform distribution + # alpha > 1.0: use zipf distribution + alpha: float = 1.0, + weights_precision: SparseType = SparseType.FP32, + weighted: bool = False, + requests_data_file: Optional[str] = None, + # Comma-separated list of table numbers + tables: Optional[str] = None, +) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]]: + if requests_data_file is not None: + indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file) + + average_L = 0 + if tables is not None: + emb_tables = tuple(int(x) for x in tables.split(",")) + indices = torch.zeros(0, dtype=indices_tensor.dtype) + offsets = torch.zeros(1, dtype=offsets_tensor.dtype) + total_L = 0 + for t in emb_tables: + t_offsets = offsets_tensor[B * t : B * (t + 1) + 1] + total_L += t_offsets[-1] - t_offsets[0] + indices = torch.cat( + (indices, indices_tensor[t_offsets[0] : t_offsets[-1]]) + ) + offsets = torch.cat( + ( + offsets, + t_offsets[1:] - t_offsets[0] + offsets[-1], + ) + ) + indices_tensor = indices + offsets_tensor = offsets + average_L = int(total_L / B) + + assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), ( + f"Requested tables: {emb_tables} " + f"does not conform to inputs (T, B) = ({T}, {B})." + ) + logging.warning( + f"Using (indices = {indices_tensor.size()}, offsets = {offsets_tensor.size()}) based " + f"on tables: {emb_tables}" + ) + else: + average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B) + assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), ( + f"Data file (indices = {indices_tensor.size()}, " + f"offsets = {offsets_tensor.size()}, lengths = {lengths_tensor.size()}) " + f"does not conform to inputs (T, B) = ({T}, {B})." + ) + + assert ( + L == average_L + ), f"Requested L does not align with provided data file ({L} vs. {average_L})" + assert E > max(indices_tensor), ( + f"Number of embeddings is not enough to support maximum index " + f"provided by data file {E} vs. {max(indices_tensor)}" + ) + + weights_tensor = ( + None + if not weighted + else torch.randn(indices_tensor.size(), device=get_device()) + ) + rs = [] + for _ in range(iters): + rs.append( + ( + indices_tensor.to(get_device()), + offsets_tensor.to(get_device()), + weights_tensor, + ) + ) + return rs + + if alpha <= 1.0: + all_indices = torch.randint( + low=0, + high=E, + size=(iters, T, B, L), + device=get_device(), + dtype=torch.int32, + ) + # each bag is usually sorted + (all_indices, _) = torch.sort(all_indices) + all_indices = all_indices.reshape(iters, T, B * L) + else: + assert E >= L, "num-embeddings must be greater than equal to bag-size" + # oversample and then remove duplicates to obtain sampling without + # replacement + all_indices = (np.random.zipf(a=alpha, size=(iters, T, B, 3 * L)) - 1) % E + all_indices = torch.ops.fbgemm.bottom_unique_k_per_row( + torch.as_tensor(all_indices), L + ) + rng = default_rng() + permutation = torch.as_tensor( + rng.choice(E, size=all_indices.max().item() + 1, replace=False) + ) + all_indices = permutation.gather(0, all_indices.flatten()) + all_indices = all_indices.to(get_device()).int().reshape(iters, T, B * L) + for it in range(iters - 1): + for t in range(T): + reused_indices = torch.randperm(B * L, device=get_device())[ + : int(B * L * reuse) + ] + all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices] + + rs = [] + for it in range(iters): + weights_tensor = ( + None if not weighted else torch.randn(T * B * L, device=get_device()) + ) + rs.append( + get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L)) + + (weights_tensor,) + ) + return rs + + +def benchmark_requests( + requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], + func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], + flush_gpu_cache_size_mb: int = 0, + check_median: bool = False, + num_warmups: int = 0, +) -> float: + times = [] + + if num_warmups > 0: + indices, offsets, weights = requests[0] + for _ in range(num_warmups): + func(indices, offsets, weights) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for (indices, offsets, weights) in requests: + start_time = time.time() + if torch.cuda.is_available(): + if flush_gpu_cache_size_mb: + _ = torch.rand( + flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float + ) + torch.cuda.synchronize() + start_event.record() + func(indices, offsets, weights) + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + it_time = start_event.elapsed_time(end_event) * 1.0e-3 + times.append(it_time) + else: + it_time = time.time() - start_time + times.append(it_time) + avg_time = sum(times) / len(requests) + median_time = statistics.median(times) + return median_time if check_median else avg_time + + +def benchmark_requests_refer( + requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], + T: int, + B: int, + L: int, + E: int, + D: int, + pooling_mode: str, + weighted: bool, + flush_gpu_cache_size_mb: int = 0, + check_median: bool = False, +) -> float: + do_pooling = pooling_mode in ["sum", "mean"] + if do_pooling: + nn_embedding_list = [ + torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda() + ] * T + else: + nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T + + times = [] + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for (indices, _, weights) in requests: + indices_list = indices.view(T, B, L).split(1) + + if weighted: + assert weights is not None + weights_list = weights.view(T, B, L).split(1) + + start_time = time.time() + if torch.cuda.is_available(): + if flush_gpu_cache_size_mb: + _ = torch.rand( + flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float + ) + torch.cuda.synchronize() + start_event.record() + + nn_embedding_output = ( + [ + b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling) + for (nn_embedding, x) in zip(nn_embedding_list, indices_list) + ] + if not weighted + else [ + b_indices( + nn_embedding, + x, + per_sample_weights=xw.view(-1), + use_cpu=False, + do_pooling=do_pooling, + ) + for (nn_embedding, x, xw) in zip( + nn_embedding_list, + indices_list, + # pyre-fixme[61]: `weights_list` is undefined, or not always + # defined. + weights_list, + ) + ] + ) + if do_pooling: + final_output = torch.cat( + [f.view(B, -1) for f in nn_embedding_output], dim=1 + ) + else: + final_output = torch.cat(nn_embedding_output, dim=0).view(-1, D) + + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + it_time = start_event.elapsed_time(end_event) * 1.0e-3 + times.append(it_time) + else: + it_time = time.time() - start_time + times.append(it_time) + avg_time = sum(times) / len(requests) + median_time = statistics.median(times) + return median_time if check_median else avg_time + + +def benchmark_pipelined_requests( + requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], + func1: Callable[[Tensor, Tensor, Optional[Tensor]], None], + func2: Callable[[Tensor, Tensor, Optional[Tensor]], None], + flush_gpu_cache_size_mb: int = 0, + check_median: bool = False, +) -> Tuple[float, float]: + torch.cuda.synchronize() + start_events = [ + (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) + for _ in requests + ] + end_events = [ + (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) + for _ in requests + ] + for ((indices, offsets, indices_weights), start_event, end_event) in zip( + requests, start_events, end_events + ): + if flush_gpu_cache_size_mb: + _ = torch.rand( + flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float + ) + torch.cuda.synchronize() + start_event[0].record() + func1(indices, offsets, indices_weights) + end_event[0].record() + start_event[1].record() + func2(indices, offsets, indices_weights) + end_event[1].record() + torch.cuda.synchronize() + avg_time = ( + sum( + start_event[0].elapsed_time(end_event[0]) * 1.0e-3 + for start_event, end_event in zip(start_events, end_events) + ) + / len(requests), + sum( + start_event[1].elapsed_time(end_event[1]) * 1.0e-3 + for start_event, end_event in zip(start_events, end_events) + ) + / len(requests), + ) + median_time = ( + statistics.median( + start_event[0].elapsed_time(end_event[0]) * 1.0e-3 + for start_event, end_event in zip(start_events, end_events) + ), + statistics.median( + start_event[1].elapsed_time(end_event[1]) * 1.0e-3 + for start_event, end_event in zip(start_events, end_events) + ), + ) + return median_time if check_median else avg_time diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 95075597f..64e7177f3 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -50,6 +50,7 @@ def device( else torch.float32 ) + # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, float, int]`. values_2d = torch.rand(total_lengths, embedding_dim, dtype=dtype) if torch.cuda.is_available(): @@ -57,8 +58,7 @@ def device( values_2d = values_2d.cuda() time, output = benchmark_torch_function( - torch.ops.fbgemm.jagged_2d_to_dense, - (values_2d, offsets, max_len), + torch.ops.fbgemm.jagged_2d_to_dense, (values_2d, offsets, max_len), iters=1000 ) num_bytes = ( @@ -68,6 +68,19 @@ def device( ) logging.info(f"jagged_2d_to_dense {time} sec {num_bytes / time / 1e9} GB/s") + total_L = values_2d.size(0) + time, jagged_output = benchmark_torch_function( + torch.ops.fbgemm.dense_to_jagged, (output, [offsets], total_L), iters=1000 + ) + + # Recompute num_bytes to disinclude entire dense tensor + num_bytes = offsets.numel() * offsets.element_size() + 2 * ( + values_2d.numel() * values_2d.element_size() + ) + logging.info(f"dense_to_jagged (2d) {time} sec {num_bytes / time / 1e9} GB/s") + + # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, + # typing.Tuple[int, ...]]` but got `Union[bool, float, int]`. values_1d = torch.rand(total_lengths) if torch.cuda.is_available(): values_1d = values_1d.cuda() @@ -77,6 +90,7 @@ def device( values_1d, offsets, max_len, padding_value=0 ), (), + iters=1000, ) num_bytes = ( @@ -86,6 +100,18 @@ def device( ) logging.info(f"jagged_1d_to_dense {time} sec {num_bytes / time / 1e9} GB/s") + total_L = values_1d.size(0) + output_1d = torch.unsqueeze(output, -1) + time, jagged_output = benchmark_torch_function( + torch.ops.fbgemm.dense_to_jagged, (output_1d, [offsets], total_L), iters=1000 + ) + + # Recompute num_bytes to disinclude entire dense tensor + num_bytes = offsets.numel() * offsets.element_size() + 2 * ( + values_1d.numel() * values_1d.element_size() + ) + logging.info(f"dense_to_jagged (1d) {time} sec {num_bytes / time / 1e9} GB/s") + if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 0a820da14..8b56af31a 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -9,7 +9,7 @@ import logging import signal -from typing import Tuple, List +from typing import List, Tuple import click import fbgemm_gpu @@ -26,14 +26,19 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" + ) + from fbgemm_gpu.split_table_batched_embeddings_ops import ( - SparseType, BoundsCheckMode, - IntNBitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, + IntNBitTableBatchedEmbeddingBagsCodegen, + SparseType, ) -from torch.profiler import ProfilerActivity, profile +from torch.profiler import profile, ProfilerActivity def get_gpu_device(gpu_num) -> torch.device: diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index f15e1417c..e4f01ee7f 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -33,17 +33,7 @@ def cli() -> None: pass -@cli.command() -@click.option("--flush-gpu-cache-size-mb", default=0) -@click.option("--iters", default=100) -@click.option("--warmup-runs", default=2) -@settings(max_examples=10, deadline=None) -# pyre-ignore -@given( - num_columns=st.sampled_from([2 ** n for n in range(4, 10)]), - num_rows=st.sampled_from([2 ** n for n in range(4, 10)]), -) -def bench( +def bench_impl( flush_gpu_cache_size_mb: int, iters: int, num_columns: int, @@ -57,11 +47,17 @@ def bench( "int2_quant": 0.0, "fp8_143_quant": 0.0, "fp8_152_quant": 0.0, + "fp16_quant": 0.0, + "bf16_quant_fbgemm": 0.0, + "bf16_quant_pytorch": 0.0, "int8_dequant": 0.0, "int4_dequant": 0.0, "int2_dequant": 0.0, "fp8_143_dequant": 0.0, "fp8_152_dequant": 0.0, + "fp16_dequant": 0.0, + "bf16_dequant_fbgemm": 0.0, + "bf16_dequant_pytorch": 0.0, } benchmark = functools.partial( @@ -72,6 +68,8 @@ def bench( ) input_data = torch.rand(num_rows, num_columns).float() + if torch.cuda.is_available(): + input_data = input_data.cuda() quant_data_8bit = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_data) quant_data_4bit = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf( @@ -86,9 +84,11 @@ def bench( quant_data_fp8_152 = torch.ops.fbgemm.FloatToHFP8Quantized( input_data, 5, 30, (2 - 2 ** (-2)) ) - - if torch.cuda.is_available(): - input_data = input_data.cuda() + quant_data_fp16 = input_data.half() + quant_data_bf16_fbgemm = torch.ops.fbgemm.FloatToBfloat16Quantized( + input_data.contiguous() + ) + quant_data_bf16_pytorch = input_data.bfloat16().view(torch.half) average_time["int8_quant"], _ = benchmark( torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized, @@ -98,7 +98,6 @@ def bench( torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf, (input_data, 4), ) - average_time["int2_quant"], _ = benchmark( torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf, (input_data, 2), @@ -111,12 +110,23 @@ def bench( torch.ops.fbgemm.FloatToHFP8Quantized, (input_data, 5, 30, (2 - 2 ** (-2))), ) + average_time["fp16_quant"], _ = benchmark( + lambda tensor: tensor.half(), + (input_data,), + ) + average_time["bf16_quant_fbgemm"], _ = benchmark( + torch.ops.fbgemm.FloatToBfloat16Quantized, + (input_data,), + ) + average_time["bf16_quant_pytorch"], _ = benchmark( + lambda tensor: tensor.bfloat16().view(torch.half), + (input_data,), + ) average_time["int8_dequant"], _ = benchmark( torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat, (quant_data_8bit,), ) - average_time["int4_dequant"], _ = benchmark( torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat, (quant_data_4bit, 4), @@ -133,11 +143,75 @@ def bench( torch.ops.fbgemm.HFP8QuantizedToFloat, (quant_data_fp8_152, 5, 30), ) + average_time["fp16_dequant"], _ = benchmark( + lambda tensor: tensor.float(), + (quant_data_fp16,), + ) + average_time["bf16_dequant_fbgemm"], _ = benchmark( + torch.ops.fbgemm.Bfloat16QuantizedToFloat, + (quant_data_bf16_fbgemm,), + ) + average_time["bf16_dequant_pytorch"], _ = benchmark( + lambda tensor: tensor.view(torch.bfloat16).float(), + (quant_data_bf16_pytorch,), + ) + logging.info(f"-------------- ncols={num_columns}, nrows={num_rows}-------------") for k, t_time in average_time.items(): logging.info(f"{k} time per iter: {t_time * 1.0e6:.0f}us") +@settings(max_examples=10, deadline=None) +# pyre-ignore +@given( + num_columns=st.sampled_from([2**n for n in range(4, 10)]), + num_rows=st.sampled_from([2**n for n in range(4, 10)]), +) +def bench_spectrum( + flush_gpu_cache_size_mb: int, + iters: int, + num_columns: int, + num_rows: int, + warmup_runs: int, +) -> None: + bench_impl( + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + iters=iters, + num_columns=num_columns, + num_rows=num_rows, + warmup_runs=warmup_runs, + ) + + +@cli.command() +@click.option("--flush-gpu-cache-size-mb", default=0) +@click.option("--iters", default=100) +@click.option("--num-columns", default=-1) +@click.option("--num-rows", default=-1) +@click.option("--warmup-runs", default=2) +def bench( + flush_gpu_cache_size_mb: int, + iters: int, + num_columns: int, + num_rows: int, + warmup_runs: int, +) -> None: + if num_columns == -1 or num_rows == -1: + bench_spectrum( + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + iters=iters, + warmup_runs=warmup_runs, + ) + else: + bench_impl( + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + iters=iters, + num_columns=num_columns, + num_rows=num_rows, + warmup_runs=warmup_runs, + ) + + @cli.command() @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--iters", default=100) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 8ab67381e..9d4e53a8b 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -3,11 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools import logging import random import click import fbgemm_gpu +import numpy as np import torch logging.basicConfig(level=logging.DEBUG) @@ -69,5 +71,158 @@ def device( logging.info(f"expand_into_jagged_permute {time} sec {num_bytes / time / 1e9} GB/s") +@cli.command() +@click.option("--row-size", default=25600) +@click.option("--batch-size", default=4096) +@click.option("--unique-batch-size", default=1024) +@click.option("--input-precision", type=str, default="fp32") +def batch_reuse_index_select_device( + row_size: int, batch_size: int, unique_batch_size: int, input_precision: str +) -> None: + # A function for generating indices in batch_reuse + # pyre-fixme[11]: Annotation `array` is not defined as a type. + def gen_inverse_index(curr_size: int, final_size: int) -> np.array: + inverse_index = list(range(curr_size)) + np_arr = np.array(inverse_index) + for _ in range(final_size - curr_size): + inverse_index.append(np.random.randint(0, curr_size)) + np_arr = np.array(inverse_index) + np.random.shuffle(np_arr) + return np_arr + + dtype = torch.float + if input_precision == "fp32": + dtype = torch.float + elif input_precision == "fp16": + dtype = torch.half + else: + raise RuntimeError(f"Does not support data type {input_precision}") + + # pyre-fixme[16]: Module `cuda` has no attribute `IntTensor`. + indices = torch.cuda.IntTensor(gen_inverse_index(unique_batch_size, batch_size)) + + input = torch.rand(unique_batch_size, row_size, dtype=dtype, device="cuda") + input.requires_grad = True + num_bytes = 2 * batch_size * row_size * input.element_size() + time, output = benchmark_torch_function( + torch.ops.fbgemm.index_select_dim0, (input, indices, 0, unique_batch_size) + ) + logging.info( + f"index_select_dim0 forward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s" + ) + + grad = torch.rand_like(output, dtype=dtype, device="cuda") + num_bytes = (input.numel() + output.numel()) * input.element_size() + time, _ = benchmark_torch_function( + functools.partial(output.backward, retain_graph=True), (grad,) + ) + logging.info( + f"index_select_dim0 backward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s" + ) + + +@cli.command() +@click.option("--max-seq-length", default=500) +@click.option("--batch-size", default=4096) +@click.option("--num-cols", default=256) +@click.option("--num-jagged-tensor-rows", default=4096) +@click.option("--num-zero-padding", default=1024) +@click.option("--index-dtype", type=click.Choice(["int", "long"]), default="int") +@click.option( + "--jagged-tensor-dtype", type=click.Choice(["float", "half"]), default="float" +) +def jagged_index_select_2d_bench( + max_seq_length: int, + batch_size: int, + num_cols: int, + num_jagged_tensor_rows: int, + num_zero_padding: int, + index_dtype: str, + jagged_tensor_dtype: str, +) -> None: + def jagged_index_select_2d_ref( + values: torch.Tensor, lengths: torch.Tensor, inverse_lookup: torch.Tensor + ) -> torch.Tensor: + offsets = torch.ops.fbgemm.asynchronous_exclusive_cumsum(lengths) + end_offsets = offsets + lengths + full_start_offset = torch.index_select(offsets, 0, inverse_lookup) + full_end_offset = torch.index_select(end_offsets, 0, inverse_lookup) + index_ranges = torch.stack( + (full_start_offset, full_end_offset), dim=0 + ).transpose(0, 1) + + to_be_merged_tensors = [] + for row in index_ranges: + to_be_merged_tensors.append(torch.arange(row[0], row[1], device="cuda")) + all_indices = torch.cat(to_be_merged_tensors, dim=0) + new_embeddings = torch.index_select(values, 0, all_indices) + return new_embeddings + + index_t = {"int": torch.int, "long": torch.long}[index_dtype] + scalar_t = {"float": torch.float, "half": torch.half}[jagged_tensor_dtype] + + lengths = torch.randint( + low=0, + high=max_seq_length, + size=(num_jagged_tensor_rows,), + dtype=index_t, + device="cuda", + ) + indices, _ = torch.sort( + torch.randint( + low=0, + high=num_jagged_tensor_rows, + size=(batch_size,), + dtype=index_t, + device="cuda", + ) + ) + values = torch.rand( + int(lengths.sum().item()), num_cols, dtype=scalar_t, device="cuda" + ) + values.requires_grad = True + + indices[batch_size - num_zero_padding :] = 0 + + time, (output, _) = benchmark_torch_function( + torch.ops.fbgemm.jagged_index_select, + (values, lengths, indices), + num_warmups=10, + iters=100, + ) + time_ref, output_ref = benchmark_torch_function( + jagged_index_select_2d_ref, + (values, lengths, indices), + num_warmups=10, + iters=100, + ) + logging.info( + f"jagged_index_select_2d_bench " + f"(max_seq_length={max_seq_length}, " + f"batch_size={batch_size}, " + f"num_cols={num_cols}, " + f"num_jagged_tensor_rows={num_jagged_tensor_rows}, " + f"num_zero_padding={num_zero_padding}, " + f"index_dtype={index_dtype}, " + f"jagged_tensor_dtype={jagged_tensor_dtype})" + ) + logging.info(f"forward: fbgemm {time * 1e3:.3f} ms, ref {time_ref * 1e3:.3f} ms") + + grad = torch.rand_like(output) + time, _ = benchmark_torch_function( + functools.partial(output.backward, retain_graph=True), + (grad,), + num_warmups=10, + iters=100, + ) + time_ref, _ = benchmark_torch_function( + functools.partial(output_ref.backward, retain_graph=True), + (grad,), + num_warmups=10, + iters=100, + ) + logging.info(f"backward: fbgemm {time * 1e3:.3f} ms, ref {time_ref * 1e3:.3f} ms") + + if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index 8beb07053..7f7f9a3c1 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -17,7 +17,7 @@ EmbeddingLocation, IntNBitTableBatchedEmbeddingBagsCodegen, ) -from torch import Tensor, nn +from torch import nn, Tensor logging.basicConfig(level=logging.DEBUG) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index eb5f8845d..7b550c21f 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -5,12 +5,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import itertools + import logging import math import random import statistics -import time from typing import Callable, List, Optional, Tuple import click @@ -32,14 +31,14 @@ ComputeDevice, DenseTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, - OptimType, - SparseType, - SplitTableBatchedEmbeddingBagsCodegen, IntNBitTableBatchedEmbeddingBagsCodegen, + OptimType, PoolingMode, + RecordCacheMetrics, rounded_row_size_in_bytes, + SparseType, + SplitTableBatchedEmbeddingBagsCodegen, ) -from numpy.random import default_rng from torch import Tensor # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. @@ -47,364 +46,28 @@ if open_source: # pyre-ignore[21] - from bench_utils import benchmark_torch_function -else: - from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - - -logging.basicConfig(level=logging.DEBUG) - - -def round_up(a: int, b: int) -> int: - return int((a + b - 1) // b) * b - - -def get_device() -> torch.device: - return ( - torch.cuda.current_device() - if torch.cuda.is_available() - else torch.device("cpu") - ) - - -# Merged indices with shape (T, B, L) -> (flattened indices with shape -# (T * B * L), offsets with shape (T * B + 1)) -def get_table_batched_offsets_from_dense( - merged_indices: Tensor, -) -> Tuple[Tensor, Tensor]: - (T, B, L) = merged_indices.size() - lengths = np.ones((T, B)) * L - flat_lengths = lengths.flatten() - return ( - merged_indices.long().contiguous().view(-1).to(get_device()), - torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long().to(get_device()), + from bench_utils import ( + benchmark_pipelined_requests, + benchmark_requests, + benchmark_requests_refer, + benchmark_torch_function, + generate_requests, + get_device, + round_up, ) - - -def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - (B, L) = indices.size() - return ( - indices.contiguous().view(-1), - torch.tensor( - np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64) - ), +else: + from fbgemm_gpu.bench.bench_utils import ( + benchmark_pipelined_requests, + benchmark_requests, + benchmark_requests_refer, + benchmark_torch_function, + generate_requests, + get_device, + round_up, ) -def b_indices( - b: Callable[..., torch.Tensor], - x: torch.Tensor, - per_sample_weights: Optional[torch.Tensor] = None, - use_cpu: bool = False, - do_pooling: bool = True, -) -> torch.Tensor: - (indices, offsets) = get_offsets_from_dense(x) - if do_pooling: - return b( - indices.cuda(), - offsets.cuda(), - per_sample_weights=per_sample_weights, - ) - else: - return b(indices.cuda()) - - -def generate_requests( - iters: int, - B: int, - T: int, - L: int, - E: int, - # inter-batch indices reuse rate - reuse: float = 0.0, - # alpha <= 1.0: use uniform distribution - # alpha > 1.0: use zipf distribution - alpha: float = 1.0, - weights_precision: SparseType = SparseType.FP32, - weighted: bool = False, - requests_data_file: Optional[str] = None, - # Comma-separated list of table numbers - tables: Optional[str] = None, -) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]]: - if requests_data_file is not None: - indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file) - - average_L = 0 - if tables is not None: - emb_tables = tuple(int(x) for x in tables.split(",")) - indices = torch.zeros(0, dtype=indices_tensor.dtype) - offsets = torch.zeros(1, dtype=offsets_tensor.dtype) - total_L = 0 - for t in emb_tables: - t_offsets = offsets_tensor[B * t : B * (t + 1) + 1] - total_L += t_offsets[-1] - t_offsets[0] - indices = torch.cat( - (indices, indices_tensor[t_offsets[0] : t_offsets[-1]]) - ) - offsets = torch.cat( - ( - offsets, - t_offsets[1:] - t_offsets[0] + offsets[-1], - ) - ) - indices_tensor = indices - offsets_tensor = offsets - average_L = int(total_L / B) - - assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), ( - f"Requested tables: {emb_tables} " - f"does not conform to inputs (T, B) = ({T}, {B})." - ) - logging.warning( - f"Using (indices = {indices_tensor.size()}, offsets = {offsets_tensor.size()}) based " - f"on tables: {emb_tables}" - ) - else: - average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B) - assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), ( - f"Data file (indices = {indices_tensor.size()}, " - f"offsets = {offsets_tensor.size()}, lengths = {lengths_tensor.size()}) " - f"does not conform to inputs (T, B) = ({T}, {B})." - ) - - assert ( - L == average_L - ), f"Requested L does not align with provided data file ({L} vs. {average_L})" - assert E > max(indices_tensor), ( - f"Number of embeddings is not enough to support maximum index " - f"provided by data file {E} vs. {max(indices_tensor)}" - ) - - weights_tensor = ( - None - if not weighted - else torch.randn(indices_tensor.size(), device=get_device()) - ) - rs = [] - for _ in range(iters): - rs.append( - ( - indices_tensor.to(get_device()), - offsets_tensor.to(get_device()), - weights_tensor, - ) - ) - return rs - - if alpha <= 1.0: - all_indices = torch.randint( - low=0, - high=E, - size=(iters, T, B, L), - device=get_device(), - dtype=torch.int32, - ) - # each bag is usually sorted - (all_indices, _) = torch.sort(all_indices) - all_indices = all_indices.reshape(iters, T, B * L) - else: - assert E >= L, "num-embeddings must be greater than equal to bag-size" - # oversample and then remove duplicates to obtain sampling without - # replacement - all_indices = (np.random.zipf(a=alpha, size=(iters, T, B, 3 * L)) - 1) % E - for index_tuple in itertools.product(range(iters), range(T), range(B)): - # sample without replacement from - # https://stats.stackexchange.com/questions/20590/how-do-i-sample-without-replacement-using-a-sampling-with-replacement-function - r = set() - for x in all_indices[index_tuple]: - if x not in r: - r.add(x) - if len(r) == L: - break - assert (len(r)) == L, "too skewed distribution (alpha too big)" - all_indices[index_tuple][:L] = list(r) - # shuffle indices so we don't have unintended spatial locality - all_indices = torch.as_tensor(all_indices[:, :, :, :L]) - rng = default_rng() - permutation = torch.as_tensor( - rng.choice(E, size=all_indices.max().item() + 1, replace=False) - ) - all_indices = permutation.gather(0, all_indices.flatten()) - all_indices = all_indices.to(get_device()).int().reshape(iters, T, B * L) - for it in range(iters - 1): - for t in range(T): - reused_indices = torch.randperm(B * L, device=get_device())[ - : int(B * L * reuse) - ] - all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices] - - rs = [] - for it in range(iters): - weights_tensor = ( - None if not weighted else torch.randn(T * B * L, device=get_device()) - ) - rs.append( - get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L)) - + (weights_tensor,) - ) - return rs - - -def benchmark_requests( - requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], - func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], - flush_gpu_cache_size_mb: int = 0, - check_median: bool = False, -) -> float: - times = [] - if torch.cuda.is_available(): - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - for (indices, offsets, weights) in requests: - start_time = time.time() - if torch.cuda.is_available(): - if flush_gpu_cache_size_mb: - _ = torch.rand( - flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float - ) - torch.cuda.synchronize() - start_event.record() - func(indices, offsets, weights) - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - it_time = start_event.elapsed_time(end_event) * 1.0e-3 - times.append(it_time) - else: - it_time = time.time() - start_time - times.append(it_time) - avg_time = sum(times) / len(requests) - median_time = statistics.median(times) - return median_time if check_median else avg_time - - -def benchmark_requests_refer( - requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], - T: int, - B: int, - L: int, - E: int, - D: int, - pooling_mode: str, - weighted: bool, - flush_gpu_cache_size_mb: int = 0, - check_median: bool = False, -) -> float: - do_pooling = pooling_mode in ["sum", "mean"] - if do_pooling: - nn_embedding_list = [ - torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda() - ] * T - else: - nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T - - times = [] - if torch.cuda.is_available(): - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - for (indices, _, weights) in requests: - indices_list = indices.view(T, B, L).split(1) - - if weighted: - assert weights is not None - weights_list = weights.view(T, B, L).split(1) - - start_time = time.time() - if torch.cuda.is_available(): - if flush_gpu_cache_size_mb: - _ = torch.rand( - flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float - ) - torch.cuda.synchronize() - start_event.record() - - nn_embedding_output = ( - [ - b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling) - for (nn_embedding, x) in zip(nn_embedding_list, indices_list) - ] - if not weighted - else [ - b_indices( - nn_embedding, - x, - per_sample_weights=xw.view(-1), - use_cpu=False, - do_pooling=do_pooling, - ) - for (nn_embedding, x, xw) in zip( - nn_embedding_list, - indices_list, - # pyre-fixme[61]: `weights_list` is undefined, or not always - # defined. - weights_list, - ) - ] - ) - if do_pooling: - final_output = torch.cat( - [f.view(B, -1) for f in nn_embedding_output], dim=1 - ) - else: - final_output = torch.cat(nn_embedding_output, dim=0).view(-1, D) - - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - it_time = start_event.elapsed_time(end_event) * 1.0e-3 - times.append(it_time) - else: - it_time = time.time() - start_time - times.append(it_time) - avg_time = sum(times) / len(requests) - median_time = statistics.median(times) - return median_time if check_median else avg_time - - -def benchmark_pipelined_requests( - requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], - func1: Callable[[Tensor, Tensor, Optional[Tensor]], None], - func2: Callable[[Tensor, Tensor, Optional[Tensor]], None], - flush_gpu_cache_size_mb: int = 0, -) -> Tuple[float, float]: - torch.cuda.synchronize() - start_events = [ - (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) - for _ in requests - ] - end_events = [ - (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) - for _ in requests - ] - for ((indices, offsets, indices_weights), start_event, end_event) in zip( - requests, start_events, end_events - ): - if flush_gpu_cache_size_mb: - _ = torch.rand( - flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float - ) - torch.cuda.synchronize() - start_event[0].record() - func1(indices, offsets, indices_weights) - end_event[0].record() - start_event[1].record() - func2(indices, offsets, indices_weights) - end_event[1].record() - torch.cuda.synchronize() - return ( - sum( - start_event[0].elapsed_time(end_event[0]) * 1.0e-3 - for start_event, end_event in zip(start_events, end_events) - ) - / len(requests), - sum( - start_event[1].elapsed_time(end_event[1]) * 1.0e-3 - for start_event, end_event in zip(start_events, end_events) - ) - / len(requests), - ) +logging.basicConfig(level=logging.DEBUG) @click.group() @@ -428,7 +91,9 @@ def cli() -> None: @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") @click.option("--weighted-num-requires-grad", type=int, default=None) +@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--dense", is_flag=True, default=False) @click.option("--output-dtype", type=SparseType, default=SparseType.FP32) @@ -449,7 +114,9 @@ def device( # noqa C901 reuse: float, row_wise: bool, weighted: bool, + pooling: str, weighted_num_requires_grad: Optional[int], + bounds_check_mode: int, flush_gpu_cache_size_mb: int, dense: bool, output_dtype: SparseType, @@ -496,6 +163,17 @@ def device( # noqa C901 else: managed_option = EmbeddingLocation.MANAGED + if pooling is None or pooling == "sum": + pooling = "sum" + pooling_mode = PoolingMode.SUM + do_pooling = True + elif pooling == "mean": + pooling_mode = PoolingMode.MEAN + do_pooling = True + else: # "none" + pooling_mode = PoolingMode.NONE + do_pooling = False + if dense: emb = DenseTableBatchedEmbeddingBagsCodegen( [ @@ -505,6 +183,7 @@ def device( # noqa C901 ) for d in Ds ], + pooling_mode=pooling_mode, use_cpu=not torch.cuda.is_available(), ) else: @@ -526,6 +205,8 @@ def device( # noqa C901 weights_precision=weights_precision, stochastic_rounding=stoc, output_dtype=output_dtype, + pooling_mode=pooling_mode, + bounds_check_mode=BoundsCheckMode(bounds_check_mode), ) emb = emb.to(get_device()) @@ -534,6 +215,18 @@ def device( # noqa C901 nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 + output_size_multiplier = output_dtype.bit_rate() / 8.0 + if do_pooling: + read_write_bytes = ( + output_size_multiplier * B * sum(Ds) + + param_size_multiplier * B * sum(Ds) * L + ) + else: + read_write_bytes = ( + output_size_multiplier * B * sum(Ds) * L + + param_size_multiplier * B * sum(Ds) * L + ) + logging.info( f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, " f"{nparams * param_size_multiplier / 1.0e9: .2f} GB" @@ -570,7 +263,7 @@ def device( # noqa C901 logging.info( f"Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " - f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) @@ -578,7 +271,10 @@ def device( # noqa C901 # backward bench not representative return - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + grad_output = torch.randn(B * T * L, D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -592,7 +288,7 @@ def device( # noqa C901 ) logging.info( f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " - f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " + f"BW: {3 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " f"T: {time_per_iter * 1.0e6:.0f}us" ) @@ -1068,6 +764,8 @@ def benchmark_cpu_requests( @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) +@click.option("--fp8-exponent-bits", type=int, default=None) +@click.option("--fp8-exponent-bias", type=int, default=None) def nbit_cpu( # noqa C901 alpha: float, bag_size: int, @@ -1087,6 +785,8 @@ def nbit_cpu( # noqa C901 requests_data_file: Optional[str], tables: Optional[str], output_dtype: SparseType, + fp8_exponent_bits: Optional[int], + fp8_exponent_bias: Optional[int], ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -1110,6 +810,8 @@ def nbit_cpu( # noqa C901 device="cpu", index_remapping=[torch.arange(E) for _ in Ds] if index_remapping else None, output_dtype=output_dtype, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cpu() emb.fill_random_weights() @@ -1177,8 +879,7 @@ def nbit_cpu( # noqa C901 @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) @click.option("--pooling", type=str, default="sum") -@click.option("--weighted-num-requires-grad", type=int, default=None) -@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value) +@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) @click.option("--pruning-ratio", type=float, default=None) @click.option("--load-factor", default=0.75) @click.option("--use-array-for-index-remapping", is_flag=True, default=True) @@ -1191,6 +892,8 @@ def nbit_cpu( # noqa C901 @click.option("--run-reference", is_flag=True, default=False) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) +@click.option("--fp8-exponent-bits", type=int, default=None) +@click.option("--fp8-exponent-bias", type=int, default=None) def nbit_device( # noqa C901 alpha: float, bag_size: int, @@ -1206,7 +909,6 @@ def nbit_device( # noqa C901 row_wise: bool, weighted: bool, pooling: str, - weighted_num_requires_grad: Optional[int], bounds_check_mode: int, pruning_ratio: Optional[float], load_factor: float, @@ -1220,6 +922,8 @@ def nbit_device( # noqa C901 run_reference: bool, requests_data_file: Optional[str], tables: Optional[str], + fp8_exponent_bits: Optional[int], + fp8_exponent_bias: Optional[int], ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -1280,6 +984,8 @@ def nbit_device( # noqa C901 use_array_for_index_remapping=use_array_for_index_remapping, output_dtype=output_dtype, pooling_mode=pooling_mode, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cuda() emb.fill_random_weights() @@ -1455,6 +1161,8 @@ def nbit_device( # noqa C901 @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) @click.option("--enforce-hbm", is_flag=True, default=False) +@click.option("--fp8-exponent-bits", type=int, default=None) +@click.option("--fp8-exponent-bias", type=int, default=None) def nbit_uvm( alpha: bool, bag_size: int, @@ -1476,6 +1184,8 @@ def nbit_uvm( cache_algorithm: str, cache_load_factor: float, enforce_hbm: bool, + fp8_exponent_bits: Optional[int], + fp8_exponent_bias: Optional[int], ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -1522,6 +1232,8 @@ def nbit_uvm( cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, enforce_hbm=enforce_hbm, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cuda() emb_uvm.fill_random_weights() @@ -1538,6 +1250,8 @@ def nbit_uvm( for d in Ds[T_uvm:] ], output_dtype=output_dtype, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cuda() emb_gpu.fill_random_weights() @@ -1560,6 +1274,8 @@ def nbit_uvm( cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, enforce_hbm=enforce_hbm, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cuda() emb_mixed.fill_random_weights() @@ -1603,13 +1319,15 @@ def nbit_uvm( if T_gpu > 0: nparams_byte = sum(w.numel() for (w, _) in emb_mixed.split_embedding_weights()) logging.info( - f"{weights_precision} Embedding tables: {E * T + E_uvm * T_uvm} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " + f"{weights_precision} Embedding tables: {E * T_gpu + E_uvm * T_uvm} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage ) logging.info( - f"Accessed weights per batch: {B * (T * L + T_uvm * L_uvm)} rows, " - f"{B * (T * L * sum(Ds[T_uvm:]) + T_uvm * L_uvm * sum(Ds[:T_uvm])) * param_size_multiplier / 1.0e9: .2f} GB" + f"Accessed weights per batch: {B * (T_gpu * L + T_uvm * L_uvm)} rows, " + f"{B * (L * sum(Ds[T_uvm:]) + L_uvm * sum(Ds[:T_uvm])) * param_size_multiplier / 1.0e9: .2f} GB" ) + torch.cuda.cudart().cudaProfilerStart() + torch.cuda.nvtx.range_push("uvm forward") time_per_iter = benchmark_requests( requests_uvm, @@ -1626,7 +1344,8 @@ def nbit_uvm( f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter * 1.0e6:.0f}us" ) - + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() if T_gpu > 0: requests = [] assert requests_gpu is not None @@ -1694,9 +1413,6 @@ def nbit_uvm( indices, offsets, ), - # pyre-fixme[6]: Expected `(Tensor, Tensor, Optional[Tensor]) -> None` for - # 3rd param but got `(indices: Any, offsets: Any, indices_weights: Any) -> - # Tensor`. lambda indices, offsets, indices_weights: emb_mixed.forward( indices, offsets, @@ -1734,6 +1450,10 @@ def nbit_uvm( @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--enforce-hbm", is_flag=True, default=False) +@click.option("--record-cache-miss-counter", is_flag=True, default=False) +@click.option("--record-tablewise-cache-miss", is_flag=True, default=False) +@click.option("--fp8-exponent-bits", type=int, default=None) +@click.option("--fp8-exponent-bias", type=int, default=None) def nbit_cache( # noqa C901 alpha: float, bag_size: int, @@ -1751,6 +1471,10 @@ def nbit_cache( # noqa C901 flush_gpu_cache_size_mb: int, output_dtype: SparseType, enforce_hbm: bool, + record_cache_miss_counter: bool, + record_tablewise_cache_miss: bool, + fp8_exponent_bits: Optional[int], + fp8_exponent_bias: Optional[int], ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -1782,6 +1506,8 @@ def nbit_cache( # noqa C901 ], output_dtype=output_dtype, enforce_hbm=enforce_hbm, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cuda() emb_nc.fill_random_weights() @@ -1796,10 +1522,15 @@ def nbit_cache( # noqa C901 ) for d in Ds ], + record_cache_metrics=RecordCacheMetrics( + record_cache_miss_counter, record_tablewise_cache_miss + ), cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, output_dtype=output_dtype, enforce_hbm=enforce_hbm, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ).cuda() emb.fill_random_weights() @@ -1807,7 +1538,11 @@ def nbit_cache( # noqa C901 param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 read_write_bytes = ( - output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L + param_size_multiplier + * B + * sum(Ds) + * L + # output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L ) logging.info( f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " @@ -1837,14 +1572,20 @@ def nbit_cache( # noqa C901 f"T: {time_per_iter * 1.0e6:.0f}us" ) - # exchanged_cache_lines = [100] # warm up for indices, offsets, _ in warmup_requests: emb.forward(indices.int(), offsets.int()) + # get cache miss rate (forward only) and exchanged cache lines (prefetch) cache_misses = [] exchanged_cache_lines = [] + unique_indices = [] + input_indices = [] NOT_FOUND = -1 + # reset the cache miss counters after warmup + if record_cache_miss_counter or record_tablewise_cache_miss: + emb.reset_cache_miss_counter() + for indices, offsets, _ in requests: # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self, @@ -1862,6 +1603,14 @@ def nbit_cache( # noqa C901 (emb.lxu_cache_locations_list.top() == NOT_FOUND).sum().item() ) emb.forward(indices, offsets) + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + emb.cache_hash_size_cumsum, + indices, + offsets, + ) + unique_indices.append(len(torch.unique(linear_cache_indices, sorted=False))) + input_indices.append(len(indices)) + logging.info( f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, " f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}" @@ -1870,20 +1619,35 @@ def nbit_cache( # noqa C901 f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, " f"max: {max(cache_misses)}, min: {min(cache_misses)}" ) - + logging.info( + f"input_indices -- mean: {sum(input_indices)/len(requests)}, " + f"max: {max(input_indices)}, min: {min(input_indices)}" + ) + logging.info( + f"unique_indices -- mean: {sum(unique_indices)/len(requests)}, " + f"max: {max(unique_indices)}, min: {min(unique_indices)}" + ) + unique_miss_rate = [a / b for (a, b) in zip(exchanged_cache_lines, unique_indices)] + logging.info( + f"unique_miss_rate -- mean: {sum(unique_miss_rate)/len(requests)}, " + f"max: {max(unique_miss_rate)}, min: {min(unique_miss_rate)}" + ) + if record_cache_miss_counter or record_tablewise_cache_miss: + emb.print_cache_miss_counter() # benchmark prefetch - emb.reset_cache_states() + if record_cache_miss_counter or record_tablewise_cache_miss: + emb.reset_cache_states() for indices, offsets, _ in warmup_requests: emb.forward(indices, offsets) + + torch.cuda.cudart().cudaProfilerStart() + torch.cuda.nvtx.range_push("pipeline") prefetch_time, forward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb.prefetch( indices, offsets, ), - # pyre-fixme[6]: Expected `(Tensor, Tensor, Optional[Tensor]) -> None` for - # 3rd param but got `(indices: Any, offsets: Any, indices_weights: Any) -> - # Tensor`. lambda indices, offsets, indices_weights: emb.forward( indices, offsets, @@ -1892,6 +1656,7 @@ def nbit_cache( # noqa C901 flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) e2e_time = prefetch_time + forward_time + torch.cuda.nvtx.range_pop() logging.info( f"Forward(LXU) {weights_precision}, reuse: {reuse}, alpha: {alpha}, B: {B}, " @@ -1903,6 +1668,7 @@ def nbit_cache( # noqa C901 f"TfwdTime: {forward_time * 1.0e6:.0f}us, " f"{read_write_bytes / forward_time / 1.0e9: .2f} GB/s" ) + torch.cuda.cudart().cudaProfilerStop() @cli.command() @@ -1956,7 +1722,7 @@ def hashtable( # noqa C901 ) hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long() - assert hash_table.numel() * 4 < 2 ** 32 + assert hash_table.numel() * 4 < 2**32 # initialize hash_table[:, :] = -1 torch.ops.fbgemm.pruned_hashmap_insert( @@ -2154,6 +1920,8 @@ def bounds_check_indices( # noqa C901 @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--iters", type=int, default=100) +@click.option("--fp8-exponent-bits", type=int, default=None) +@click.option("--fp8-exponent-bias", type=int, default=None) def emb_inplace_update( # noqa C901 num_tables: int, embedding_dim: int, @@ -2162,6 +1930,8 @@ def emb_inplace_update( # noqa C901 weights_precision: SparseType, output_dtype: SparseType, iters: int, + fp8_exponent_bits: Optional[int], + fp8_exponent_bias: Optional[int], ) -> None: if open_source: logging.warning( @@ -2206,6 +1976,8 @@ def emb_inplace_update( # noqa C901 embedding_specs=embedding_specs, output_dtype=output_dtype, device=torch.cuda.current_device(), + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, ) # Initilize the random weights for int nbit table split embedding bag op.fill_random_weights() @@ -2241,6 +2013,8 @@ def emb_inplace_update( # noqa C901 high=255, size=(update_weight_size,), dtype=torch.uint8, + # pyre-fixme[6]: For 5th param expected `Union[None, str, device]` but got + # `int`. device=torch.cuda.current_device(), ) @@ -2282,16 +2056,22 @@ def emb_inplace_update( # noqa C901 update_table_idx = torch.tensor( update_table_idx, + # pyre-fixme[6]: For 2nd param expected `Union[None, str, device]` but got + # `int`. device=torch.cuda.current_device(), dtype=torch.int32, ) update_row_idx = torch.tensor( update_row_idx, + # pyre-fixme[6]: For 2nd param expected `Union[None, str, device]` but got + # `int`. device=torch.cuda.current_device(), dtype=torch.int32, ) update_offsets = torch.tensor( update_offsets, + # pyre-fixme[6]: For 2nd param expected `Union[None, str, device]` but got + # `int`. device=torch.cuda.current_device(), dtype=torch.int64, ) diff --git a/fbgemm_gpu/build.sh b/fbgemm_gpu/build.sh deleted file mode 100755 index f181dd6a9..000000000 --- a/fbgemm_gpu/build.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -export MAX_JOBS=32 -python3.6 setup.py build develop 2>&1 | tee build.log diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index cdc225e9d..7db59ea9f 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -5,7 +5,11 @@ IF(NOT DEFINED ENV{ROCM_PATH}) ELSE() SET(ROCM_PATH $ENV{ROCM_PATH}) ENDIF() - +if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS}) + set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include) +else() + set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS}) +endif() # HIP_PATH IF(NOT DEFINED ENV{HIP_PATH}) SET(HIP_PATH ${ROCM_PATH}/hip) @@ -60,10 +64,10 @@ ELSE() ENDIF() # THRUST_PATH -IF(DEFINED ENV{THRUST_PATH}) - SET(THRUST_PATH $ENV{THRUST_PATH}) -ELSE() +IF(NOT DEFINED ENV{THRUST_PATH}) SET(THRUST_PATH ${ROCM_PATH}/include) +ELSE() + SET(THRUST_PATH $ENV{THRUST_PATH}) ENDIF() # HIPRAND_PATH @@ -94,12 +98,117 @@ set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) ADD_DEFINITIONS(-DNDEBUG) ADD_DEFINITIONS(-DUSE_ROCM) +IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) + SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) +ELSE() + SET(FBGEMM_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) +ENDIF() + # Find the HIP Package find_package(HIP) IF(HIP_FOUND) set(FBGEMM_HAVE_HIP TRUE) + # Find ROCM version for checks + # ROCM 5.0 and later will have header api for version management + if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h) + + set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") + set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc") + file(WRITE ${file} "" + "#include \n" + "#include \n" + + "#ifndef ROCM_VERSION_PATCH\n" + "#define ROCM_VERSION_PATCH 0\n" + "#endif\n" + "#define STRINGIFYHELPER(x) #x\n" + "#define STRINGIFY(x) STRINGIFYHELPER(x)\n" + "int main() {\n" + " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n" + " return 0;\n" + "}\n" + ) + + try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + RUN_OUTPUT_VARIABLE rocm_version_from_header + COMPILE_OUTPUT_VARIABLE output_var + ) + # We expect the compile to be successful if the include directory exists. + if(NOT compile_result) + message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var}) + endif() + message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header}) + set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header}) + message("\n***** ROCm version from rocm_version.h ****\n") + endif() + + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) + + if(ROCM_VERSION_DEV_MATCH) + set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) + set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) + set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) + set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + endif() + + message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") + message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") + message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") + message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") + message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") + + math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}") + message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}") + message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}") + message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}") + + message("\n***** Library versions from dpkg *****\n") + execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hip-base COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}") + + message("\n***** Library versions from cmake find_package *****\n") + + # As of ROCm 5.1.x, all *.cmake files are under /opt/rocm/lib/cmake/ + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.1.0") + set(hip_DIR ${HIP_PATH}/lib/cmake/hip) + set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) + set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) + set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) + set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand) + set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand) + set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas) + set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen) + set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft) + set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft) + set(hipsparse_DIR ${ROCM_PATH}/lib/cmake/hipsparse) + set(rccl_DIR ${ROCM_PATH}/lib/cmake/rccl) + set(rocprim_DIR ${ROCM_PATH}/lib/cmake/rocprim) + set(hipcub_DIR ${ROCM_PATH}/lib/cmake/hipcub) + set(rocthrust_DIR ${ROCM_PATH}/lib/cmake/rocthrust) + set(ROCclr_DIR ${ROCM_PATH}/rocclr/lib/cmake/rocclr) + set(ROCRAND_INCLUDE ${ROCM_PATH}/include) + set(ROCM_SMI_INCLUDE ${ROCM_PATH}/rocm_smi/include) + else() + message(FATAL_ERROR "\n***** The minimal ROCm version is 5.1.0 but have ${ROCM_VERSION_DEV} installed *****\n") + endif() + + find_package(hip REQUIRED) + find_package(rocblas REQUIRED) + find_package(hipfft REQUIRED) + find_package(hiprand REQUIRED) + find_package(rocrand REQUIRED) + find_package(hipsparse REQUIRED) + find_package(rocprim REQUIRED) + if(HIP_COMPILER STREQUAL clang) set(hip_library_name amdhip64) else() @@ -115,6 +224,7 @@ IF(HIP_FOUND) # list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_NO_BFLOAT16_CONVERSIONS__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF2_OPERATORS__=1) + list(APPEND HIP_CXX_FLAGS "${CMAKE_CXX_FLAGS}") list(APPEND HIP_CXX_FLAGS -mavx2) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) @@ -126,34 +236,17 @@ IF(HIP_FOUND) list(APPEND HIP_HCC_FLAGS -fno-gpu-rdc) list(APPEND HIP_HCC_FLAGS -Wno-defaulted-function-deleted) foreach(fbgemm_rocm_arch ${FBGEMM_ROCM_ARCH}) - list(APPEND HIP_HCC_FLAGS --amdgpu-target=${fbgemm_rocm_arch}) + list(APPEND HIP_HCC_FLAGS --offload-arch=${fbgemm_rocm_arch}) endforeach() - set(hip_DIR ${HIP_PATH}/lib/cmake/hip) - set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) - set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) - set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) - set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand) - set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand) - set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas) - set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen) - set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft) - set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft) - set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse) - set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl) - set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim) - set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub) - set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) - set(ROCclr_DIR ${ROCM_PATH}/rocclr/lib/cmake/rocclr) - - find_package(hip REQUIRED) - - set(ROCRAND_INCLUDE ${ROCRAND_PATH}/include) - set(ROCM_SMI_INCLUDE ${ROCM_PATH}/rocm_smi/include) - set(FBGEMM_HIP_INCLUDE ${ROCM_PATH}/include ${FBGEMM_HIP_INCLUDE}) set(FBGEMM_HIP_INCLUDE ${hip_INCLUDE_DIRS} $ $ ${FBGEMM_HIP_INCLUDE}) hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) + list (APPEND CMAKE_PREFIX_PATH ${HIP_PATH} ${ROCM_PATH}) + set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) + +ELSE() + message("Not able to find HIP installation.") ENDIF() diff --git a/fbgemm_gpu/codegen/__init__.template b/fbgemm_gpu/codegen/__init__.template index 6f361ae93..de8bf21dd 100644 --- a/fbgemm_gpu/codegen/__init__.template +++ b/fbgemm_gpu/codegen/__init__.template @@ -13,9 +13,7 @@ import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_lars_sgd as loo import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam as lookup_partial_rowwise_adam # noqa: F401 import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_lamb as lookup_partial_rowwise_lamb # noqa: F401 import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad as lookup_rowwise_adagrad # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad_with_weight_decay as lookup_rowwise_adagrad_with_weight_decay # noqa: F401 import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_sgd as lookup_sgd # noqa: F401 import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_approx_sgd as lookup_approx_sgd # noqa: F401 import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_approx_rowwise_adagrad as lookup_approx_rowwise_adagrad # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_approx_rowwise_adagrad_with_weight_decay as lookup_approx_rowwise_adagrad_with_weight_decay # noqa: F401 import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_weighted_adagrad as lookup_rowwise_weighted_adagrad # noqa: F401 diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 7b7b015d2..d06a0f0ca 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -434,6 +434,9 @@ def rowwise_adagrad() -> None: } else if (weight_decay_mode == 2) { // Decoupled weight decay correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; } } multiplier = shfl_sync(multiplier, 0); @@ -461,6 +464,9 @@ def rowwise_adagrad() -> None: } else if (weight_decay_mode == 2) { // Decoupled weight decay correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; } for (int64_t d = 0; d < D; ++d) { host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier; @@ -549,6 +555,9 @@ def rowwise_adagrad_with_weight_decay() -> None: } else if (weight_decay_mode == 2) { // Decoupled weight decay correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; } } multiplier = shfl_sync(multiplier, 0); @@ -576,6 +585,9 @@ def rowwise_adagrad_with_weight_decay() -> None: } else if (weight_decay_mode == 2) { // Decoupled weight decay correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; } for (int64_t d = 0; d < D; ++d) { host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier; @@ -1236,13 +1248,16 @@ def forward_quantized() -> None: class elem_type: enum_name: str cpp_type_name: str + primitive_type: str + bit_width: int type_map = { - 32: elem_type("FP32", "float"), - 16: elem_type("FP16", "__half2"), - 8: elem_type("INT8", "uint32_t"), - 4: elem_type("INT4", "uint32_t"), - 2: elem_type("INT2", "uint32_t"), + "FP32": elem_type("FP32", "float", "FP", 32), + "FP16": elem_type("FP16", "__half2", "FP", 16), + "FP8": elem_type("FP8", "uint32_t", "FP", 8), + "INT8": elem_type("INT8", "uint32_t", "INT", 8), + "INT4": elem_type("INT4", "uint32_t", "INT", 4), + "INT2": elem_type("INT2", "uint32_t", "INT", 2), } template = env.get_template("embedding_forward_quantized_split_template.cu") @@ -1316,4 +1331,3 @@ def main() -> None: if __name__ == "__main__": main() - # hipify_gen() diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index 0fba82286..7117db750 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -13,6 +13,7 @@ #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; Tensor dense_embedding_codegen_forward_unweighted_cuda( Tensor dev_weights, @@ -173,8 +174,12 @@ class SplitLookupFunction_Dense_Op using torch::autograd::Variable; auto grad_output = grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { + + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { grad_output = grad_output.contiguous(); } @@ -323,8 +328,11 @@ class SplitNoBagLookupFunction_Dense_Op using torch::autograd::Variable; auto grad_output = grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { grad_output = grad_output.contiguous(); } diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp index ae01d965b..053c123e9 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp @@ -13,6 +13,7 @@ #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; Tensor split_embedding_backward_codegen_dense_cpu( Tensor grad_output, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp index 28e50399b..024765225 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp @@ -17,6 +17,7 @@ #include "fbgemm_gpu/embedding_common.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; namespace { template diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp index bc994ccb9..2fedcc39f 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp @@ -19,6 +19,7 @@ #include "fbgemm_gpu/cpu_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; namespace internal { template @@ -61,8 +62,7 @@ void split_embedding_backward_exact_cpu_kernel( const bool has_weights = indice_weights.defined(); auto grad_stride = grad_output.size(1); - std::vector<::internal::BatchedHyperCompressedSparseColumn> batched_cscs( - num_tables); + std::vector<::internal::HyperCompressedSparseColumn> cscs(num_tables); auto get_hash_size = [&hash_size_cumsum_data](int feature_begin) { int64_t hash_size; @@ -83,8 +83,8 @@ void split_embedding_backward_exact_cpu_kernel( int feature_begin = table_to_feature_offset[t]; int64_t hash_size = get_hash_size(feature_begin); - ::internal::batched_csr2csc( - batched_cscs[t], + ::internal::csr2csc( + cscs[t], B, offsets.accessor(), indices.accessor(), @@ -95,16 +95,13 @@ void split_embedding_backward_exact_cpu_kernel( table_to_feature_offset + t, hash_size); } - // sort based csr2csc handles segment_ids differently - bool is_csr2csc_sort = batched_cscs[0].weights == nullptr; for (int t = 0; t < num_tables; ++t) { int feature_begin = table_to_feature_offset[t]; - int c_begin = batched_cscs[t].table_ptr[0]; - int c_end = batched_cscs[t].table_ptr[1]; - int* col_segment_ptr = batched_cscs[t].column_segment_ptr; - int* col_segment_indices = batched_cscs[t].column_segment_indices; + int num_non_zero_columns = cscs[t].num_non_zero_columns; + int* col_segment_ptr = cscs[t].column_segment_ptr; + int* col_segment_indices = cscs[t].column_segment_indices; auto hash_size = get_hash_size(feature_begin); @@ -127,7 +124,7 @@ void split_embedding_backward_exact_cpu_kernel( /*IndexType=*/int32_t, /*OffsetType=*/int32_t>( D, - batched_cscs[t].weights != nullptr, + cscs[t].weights != nullptr, /*normalize_by_lengths=*/false, /*prefetch=*/16, /*is_weight_positional=*/false, @@ -138,7 +135,7 @@ void split_embedding_backward_exact_cpu_kernel( fbgemm::GenerateSparseAdaGrad(D, /*rowwise=*/true); constexpr int C_BLOCK = 64; - at::parallel_for(c_begin, c_end, C_BLOCK, [&](int64_t c0, int64_t c1) { + at::parallel_for(0, num_non_zero_columns, C_BLOCK, [&](int64_t c0, int64_t c1) { grad_t grad_blocked_buffer[C_BLOCK * D]; for (int64_t c = c0; c < c1; c += C_BLOCK) { const int* offsets_begin_ptr = col_segment_ptr + c; @@ -149,11 +146,11 @@ void split_embedding_backward_exact_cpu_kernel( B, reinterpret_cast( grad_output_data + D_begin), - batched_cscs[t].row_indices + *offsets_begin_ptr, + cscs[t].row_indices + *offsets_begin_ptr, offsets_begin_ptr, - batched_cscs[t].weights == nullptr + cscs[t].weights == nullptr ? nullptr - : batched_cscs[t].weights + *offsets_begin_ptr, + : cscs[t].weights + *offsets_begin_ptr, reinterpret_cast(grad_blocked_buffer)); if (!success) { @@ -163,7 +160,7 @@ void split_embedding_backward_exact_cpu_kernel( c, c_block_end, col_segment_ptr, - batched_cscs[t].row_indices, + cscs[t].row_indices, hash_size, /*allow_minus_one=*/false); } @@ -195,29 +192,28 @@ void split_embedding_backward_exact_cpu_kernel( // TODO: to parallelize, we should easily identify segments belong to // the same column. at::acc_type grad_buffer[D]; - for (int c = c_begin; c < c_end; ++c) { + for (int c = 0; c < num_non_zero_columns; ++c) { int64_t idx = col_segment_indices[c]; - if (c == c_begin || col_segment_indices[c - 1] != idx) { + if (c == 0 || col_segment_indices[c - 1] != idx) { memset(grad_buffer, 0, D * sizeof(at::acc_type)); } const int64_t embedding_begin = table_begin + idx * D; for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) { int D_offset = D_begin; if (is_shared_table) { - D_offset += - batched_cscs[t].column_segment_ids[is_csr2csc_sort ? r : c] * D; + D_offset += cscs[t].column_segment_ids[r] * D; } - int b = batched_cscs[t].row_indices[r]; + int b = cscs[t].row_indices[r]; for (int64_t d = 0; d < D; ++d) { - if (batched_cscs[t].weights != nullptr) { + if (cscs[t].weights != nullptr) { grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] * - batched_cscs[t].weights[r]; + cscs[t].weights[r]; } else { grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d]; } } } - if (c == c_end - 1 || col_segment_indices[c + 1] != idx) { + if (c == num_non_zero_columns - 1 || col_segment_indices[c + 1] != idx) { {{ split_weight_update_cpu }} } } // for each c diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp index 7a6c4698e..e16d103d7 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp @@ -14,6 +14,9 @@ #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +/// @defgroup embedding-cpu Embedding CPU Operators void split_embedding_backward_codegen_{{ optimizer }}_cpu( Tensor grad_output, @@ -176,6 +179,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< } }; +///@ingroup embedding-cpu Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( Tensor host_weights, Tensor weights_placements, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index bf2c3dee9..70fcd1f54 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -14,6 +14,9 @@ #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +/// @defgroup embedding-cuda Embedding CUDA Operators Tensor split_embedding_codegen_forward_unweighted_cuda( Tensor dev_weights, @@ -186,6 +189,7 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : {% for (var, _) in args.saved_data %} ctx->saved_data["{{ var }}"] = {{ var }}; {% endfor %} + {% if not nobag %} #ifdef __HIP_PLATFORM_HCC__ constexpr int32_t BT_block_size = 64; @@ -270,9 +274,11 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : using torch::autograd::Variable; auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || - grad_output.stride(0) % 4 != 0) { + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { grad_output = grad_output.contiguous(); } @@ -423,6 +429,7 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : }; {% endfor %} +///@ingroup embedding-cuda Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( Tensor placeholder_autograd_tensor, Tensor dev_weights, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index a18ae700e..a8d9c4b70 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -968,6 +968,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ // over 48 KB per block are architecture-specific, as such they // must use dynamic shared memory (rather than statically sized // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". + #ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< diff --git a/fbgemm_gpu/codegen/embedding_bounds_check.cu b/fbgemm_gpu/codegen/embedding_bounds_check.cu index 8d7c5d196..4d77d2b50 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/embedding_bounds_check.cu @@ -13,9 +13,9 @@ template __device__ void adjust_offset_kernel( index_t& indices_start, index_t& indices_end, - index_t num_indices, - index_t* offset_acc_start, - index_t* offset_acc_end) { + const index_t num_indices, + index_t* const offset_acc_start, + index_t* const offset_acc_end) { indices_start = std::max(static_cast(0), std::min(indices_start, num_indices)); indices_end = std::max(indices_start, std::min(indices_end, num_indices)); @@ -29,7 +29,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel( rows_per_table, at::PackedTensorAccessor32 indices, at::PackedTensorAccessor32 offsets, - int64_t bounds_check_mode_, + const int64_t bounds_check_mode_, at::PackedTensorAccessor32 warning, FixedDivisor fd) { int32_t T = rows_per_table.size(0); @@ -84,10 +84,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel( &offsets[t * B + b + 1]); } - auto L = indices_end - indices_start; + const auto L = indices_end - indices_start; for (index_t i = static_cast(threadIdx.x); i < L; i += static_cast(fbgemm_gpu::kWarpSize)) { - auto idx = indices[indices_start + i]; + const auto idx = indices[indices_start + i]; if (idx == -1) { // -1 indicates pruned rows. continue; @@ -161,16 +161,17 @@ void bounds_check_indices_cuda( at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(rows_per_table.get_device()); - int32_t T = rows_per_table.size(0); - int32_t B = (offsets.size(0) - 1) / T; + const int32_t T = rows_per_table.size(0); + const int32_t B = (offsets.size(0) - 1) / T; if (B == 0 || T == 0) { return; } - auto bounds_check_mode = static_cast(bounds_check_mode_); + const auto bounds_check_mode = + static_cast(bounds_check_mode_); if (bounds_check_mode == BoundsCheckMode::WARNING) { warning.zero_(); } - int64_t num_indices = indices.size(0); + const int64_t num_indices = indices.size(0); TORCH_CHECK( offsets.size(0) == B * T + 1, diff --git a/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp b/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp index 8d2ead3e5..84575a336 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp +++ b/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp @@ -14,6 +14,9 @@ using Tensor = at::Tensor; +///@defgroup embedding-cuda Embedding CUDA Operators + +///@ingroup embedding-cuda void bounds_check_indices_cuda( Tensor& rows_per_table, Tensor& indices, diff --git a/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp index 88d5893d9..a2dd19a75 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp @@ -12,6 +12,10 @@ #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +///@defgroup embedding-cpu Embedding CPU Operators +/// namespace { @@ -31,6 +35,7 @@ void adjust_offset_cpu( *offsets_acc_end = indices_end; } +///@addtogroup embedding-cpu void bounds_check_indices_cpu( Tensor& rows_per_table, Tensor& indices, diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp index 03db60ac1..5d2f10c51 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp @@ -19,6 +19,8 @@ #include #include +using namespace fbgemm_gpu; + namespace { using Tensor = at::Tensor; @@ -116,7 +118,8 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cpu( Tensor indice_weights, {% endif %} int64_t output_dtype, - int64_t unused + int64_t fp8_exponent_bits, + int64_t fp8_exponent_bias ) { TENSOR_ON_CPU(dev_weights); TENSOR_ON_CPU(uvm_weights); @@ -160,7 +163,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cpu( const auto* weights_tys_acc = weights_tys.data_ptr(); - DISPATCH_OUTPUT_TYPES(output.type(), "intn_split_embedding_codegen_forward_kernel", [&] { + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "intn_split_embedding_codegen_forward_kernel", [&] { auto* output_acc = output.data_ptr(); {% if weighted %} const float* indice_weights_acc = indice_weights.data_ptr(); @@ -246,6 +249,26 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cpu( offsets_begin_ptr, indice_weights_ptr, reinterpret_cast(output_acc + D_start)); + } else if (weight_ty == SparseType::FP8) { + assert(fp8_exponent_bits > 0 && fp8_exponent_bias > 0); + auto kernel = fbgemm::GenerateEmbeddingSpMDMFP8WithStrides( + D, + normalize_by_lengths, + /*is_weight_positional=*/false, + /*use_offsets=*/true, + /*output_stride=*/total_D, + /*input_stride=*/D_bytes / sizeof(uint8_t), + /*exponent_bits=*/fp8_exponent_bits, + /*exponent_bias=*/fp8_exponent_bias); + success = kernel( + B, + index_size, + num_rows, + weights, + indices_acc + *offsets_begin_ptr, + offsets_begin_ptr, + indice_weights_ptr, + reinterpret_cast(output_acc + D_start)); } else if (weight_ty == SparseType::INT8) { auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides( D, diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp index 13cacd48a..be9d1b476 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp @@ -14,6 +14,10 @@ #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +///@defgroup embedding-cuda Embedding CUDA Operators +/// Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda( Tensor dev_weights, @@ -35,7 +39,9 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda( int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, - int64_t unused); + int64_t max_float8_D, + int64_t fp8_exponent_bits, + int64_t fp8_exponent_bias); Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( Tensor dev_weights, @@ -58,7 +64,9 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, - int64_t unused); + int64_t max_float8_D, + int64_t fp8_exponent_bits, + int64_t fp8_exponent_bias); Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( Tensor dev_weights, @@ -78,8 +86,11 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, - int64_t unused); + int64_t max_float8_D, + int64_t fp8_exponent_bits, + int64_t fp8_exponent_bias); +///@ingroup embedding-cuda Tensor int_nbit_split_embedding_codegen_lookup_function( Tensor dev_weights, Tensor uvm_weights, @@ -100,10 +111,18 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( int64_t output_dtype, c10::optional lxu_cache_weights, c10::optional lxu_cache_locations, - c10::optional row_alignment) { + c10::optional row_alignment, + c10::optional max_float8_D, + c10::optional fp8_exponent_bits, + c10::optional fp8_exponent_bias) { if (static_cast(pooling_mode) == PoolingMode::NONE) { std::vector max_D_list{ - max_int2_D, max_int4_D, max_int8_D, max_float16_D, max_float32_D}; + max_int2_D, + max_int4_D, + max_int8_D, + max_float8_D ? *max_float8_D : 0, + max_float16_D, + max_float32_D}; int64_t max_D = *std::max_element(max_D_list.begin(), max_D_list.end()); return int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( dev_weights, @@ -123,7 +142,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( output_dtype, lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), lxu_cache_locations.value_or(at::empty({0}, at::kInt)), - 0); + max_float8_D ? *max_float8_D : 0, + fp8_exponent_bits ? *fp8_exponent_bits : -1, + fp8_exponent_bias ? *fp8_exponent_bias : -1); } if (!indice_weights) { return int_nbit_split_embedding_codegen_forward_unweighted_cuda( @@ -146,7 +167,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( output_dtype, lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), lxu_cache_locations.value_or(at::empty({0}, at::kInt)), - 0); + max_float8_D ? *max_float8_D : 0, + fp8_exponent_bits ? *fp8_exponent_bits : -1, + fp8_exponent_bias ? *fp8_exponent_bias : -1); } return int_nbit_split_embedding_codegen_forward_weighted_cuda( dev_weights, @@ -169,15 +192,19 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( output_dtype, lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), lxu_cache_locations.value_or(at::empty({0}, at::kInt)), - 0); + max_float8_D ? *max_float8_D : 0, + fp8_exponent_bits ? *fp8_exponent_bits : -1, + fp8_exponent_bias ? *fp8_exponent_bias : -1); } +///@ingroup embedding-cuda Tensor pruned_hashmap_lookup_unweighted_cuda( Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets); +///@ingroup embedding-cuda Tensor pruned_array_lookup_cuda( Tensor indices, Tensor offsets, diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp index 3100db743..82876fa70 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp @@ -20,6 +20,9 @@ using Tensor = at::Tensor; +///@defgroup embedding-cpu Embedding CPU Operators +/// + Tensor int_nbit_split_embedding_codegen_forward_unweighted_cpu( Tensor dev_weights, Tensor uvm_weights, @@ -33,7 +36,8 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cpu( int64_t pooling_mode, int64_t row_alignment, int64_t output_dtype, - int64_t unused); + int64_t fp8_exponent_bits, + int64_t fp8_exponent_bias); Tensor int_nbit_split_embedding_codegen_forward_weighted_cpu( Tensor dev_weights, @@ -49,8 +53,10 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cpu( int64_t row_alignment, Tensor indice_weights, int64_t output_dtype, - int64_t unused); + int64_t fp8_exponent_bits, + int64_t fp8_exponent_bias); +///@ingroup embedding-cpu Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( Tensor dev_weights, Tensor uvm_weights, // to match the interface of CUDA op using UVM @@ -73,7 +79,10 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( lxu_cache_weights, // Not used, to match cache interface for CUDA op c10::optional lxu_cache_locations, // Not used, to match cache interface for CUDA op - c10::optional row_alignment) { + c10::optional row_alignment, + c10::optional max_float8_D, + c10::optional fp8_exponent_bits, + c10::optional fp8_exponent_bias) { if (!indice_weights) { return int_nbit_split_embedding_codegen_forward_unweighted_cpu( dev_weights, @@ -88,7 +97,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( pooling_mode, row_alignment ? *row_alignment : 1, output_dtype, - 0); + fp8_exponent_bits ? *fp8_exponent_bits : -1, + fp8_exponent_bias ? *fp8_exponent_bias : -1); } return int_nbit_split_embedding_codegen_forward_weighted_cpu( dev_weights, @@ -104,9 +114,11 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( row_alignment ? *row_alignment : 1, *indice_weights, output_dtype, - 0); + fp8_exponent_bits ? *fp8_exponent_bits : -1, + fp8_exponent_bias ? *fp8_exponent_bias : -1); } +///@ingroup embedding-cpu void pruned_hashmap_insert_unweighted_cpu( Tensor indices, Tensor dense_indices, @@ -114,12 +126,14 @@ void pruned_hashmap_insert_unweighted_cpu( Tensor hash_table, Tensor hash_table_offsets); +///@ingroup embedding-cpu Tensor pruned_hashmap_lookup_unweighted_cpu( Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets); +///@ingroup embedding-cpu Tensor pruned_array_lookup_cpu( Tensor indices, Tensor offsets, @@ -128,7 +142,7 @@ Tensor pruned_array_lookup_cpu( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None) -> Tensor"); + "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor"); DISPATCH_TO_CPU( "int_nbit_split_embedding_codegen_lookup_function", int_nbit_split_embedding_codegen_lookup_function_cpu); diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu index eea68c958..ab56c2028 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu @@ -20,6 +20,7 @@ constexpr int32_t kCacheLocationMissing = -1; __device__ inline int32_t padded_D(int32_t dim, SparseType weight_ty) { if (weight_ty == SparseType::FP32) { return dim; } if (weight_ty == SparseType::FP16) { return dim; } + if (weight_ty == SparseType::FP8) { return dim; } if (weight_ty == SparseType::INT8) { return dim + 4; } if (weight_ty == SparseType::INT4) { return dim + 8; } if (weight_ty == SparseType::INT2) { return dim + 16; } @@ -126,7 +127,7 @@ void cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { "Size is not supported"); unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - int src_in_bytes = (pred_guard ? SizeInBytes : 0); + const int src_in_bytes = pred_guard ? SizeInBytes : 0; asm volatile( "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), @@ -150,10 +151,10 @@ void cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { {% for nobag in [True, False] %} {% if not nobag or not weighted %} // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) -{% for bit_width in [32, 16, 8, 4, 2] %} +{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) -__global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( +__global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const at::PackedTensorAccessor64 dev_weights, const at::PackedTensorAccessor64 uvm_weights, const at::PackedTensorAccessor32 weights_placements, @@ -162,30 +163,34 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i {% if not nobag %} const at::PackedTensorAccessor32 D_offsets, {% else %} - int64_t D, + const int64_t D, {% endif %} const at::PackedTensorAccessor32 indices, const at::PackedTensorAccessor32 offsets, {% if not nobag %} - int64_t pooling_mode, + const int64_t pooling_mode, {% endif %} - int64_t row_alignment, + const int64_t row_alignment, {% if weighted %} at::PackedTensorAccessor32 indice_weights, {% endif %} + {% if type_map[emb_weight_type].enum_name == "FP8" %} + const int exponent_bits, + const int exponent_bias, + {% endif %} at::PackedTensorAccessor32 output, // [B][total_D], const at::PackedTensorAccessor64 lxu_cache_weights, const at::PackedTensorAccessor32 lxu_cache_locations ) { - int32_t T = weights_offsets.size(0); + const int32_t T = weights_offsets.size(0); {% if not nobag %} - int32_t B = output.size(0); + const int32_t B = output.size(0); {% else %} - int32_t B = (offsets.size(0) - 1) / T; + const int32_t B = (offsets.size(0) - 1) / T; {% endif %} - int32_t bb_t = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t bb_t = blockIdx.x * blockDim.y + threadIdx.y; if (bb_t >= div_round_up(B, OutputRowsPerThread) * T) { return; } @@ -202,7 +207,7 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i const int32_t D = D_end - D_start; {% endif %} SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::{{ type_map[bit_width].enum_name }}) { + if (weight_ty != SparseType::{{ type_map[emb_weight_type].enum_name }}) { return; } @@ -213,9 +218,9 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i return; } - uint32_t bb = bb_t % div_round_up(B, OutputRowsPerThread); + const uint32_t bb = bb_t % div_round_up(B, OutputRowsPerThread); - int64_t weights_offset = weights_offsets[t]; + const int64_t weights_offset = weights_offsets[t]; const int32_t D_total = padded_D(D, weight_ty); const int32_t D_padding = D_total - D; @@ -240,13 +245,13 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i } else { weights = &uvm_weights[weights_offset]; } - constexpr size_t kOutputsPerThread = {{ (32 // bit_width) }}; + constexpr size_t kOutputsPerThread = {{ (32 // type_map[emb_weight_type].bit_width) }}; constexpr uint32_t NumUint4PerRow = MaxNum128BRows * 128 / sizeof(uint4); const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); {% if not nobag %} - VecNT<{{ (32 // bit_width) }}> accumulators[OutputRowsPerThread][MaxNum128BRows]; + VecNT<{{ (32 // type_map[emb_weight_type].bit_width) }}, PrimitiveType::{{ type_map[emb_weight_type].primitive_type }}> accumulators[OutputRowsPerThread][MaxNum128BRows]; {% endif %} for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { @@ -304,7 +309,7 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i // scale and bias are at the beginning of each row. // rationale: have scale/shift at start since these get loaded first // and then broadcasted around so it might speed up the first cache miss. - {% if bit_width in [8, 4, 2] %} + {% if type_map[emb_weight_type].primitive_type == "INT" %} half2 shift_scale = reinterpret_cast(row)[0]; {% endif %} @@ -312,16 +317,16 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; {% endif %} - using scalar_t = {{ type_map[bit_width].cpp_type_name }}; + using scalar_t = {{ type_map[emb_weight_type].cpp_type_name }}; {% if not nobag %} #pragma unroll MaxNum128BRows for (uint32_t j = 0; j < MaxNum128BRows; ++j) { scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; {% if weighted %} - accumulators[i][j].fma(v, {% if bit_width in [8, 4, 2] %} shift_scale, {% endif %} row_weight); + accumulators[i][j].fma(v, {% if type_map[emb_weight_type].primitive_type == "INT" %} shift_scale, {% elif type_map[emb_weight_type].enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); {% else %} - accumulators[i][j].add(v{% if bit_width in [8, 4, 2] %}, shift_scale {% endif %}); + accumulators[i][j].add(v{% if type_map[emb_weight_type].primitive_type == "INT" %}, shift_scale {% elif type_map[emb_weight_type].enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); {% endif %} } {% else %} @@ -336,8 +341,8 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // bit_width) }})); - VecNT<{{ (32 // bit_width) }}> acc(v{% if bit_width in [8, 4, 2] %}, shift_scale {% endif %}); + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // type_map[emb_weight_type].bit_width) }})); + VecNT<{{ (32 // type_map[emb_weight_type].bit_width) }}, PrimitiveType::{{ type_map[emb_weight_type].primitive_type }}> acc(v{% if type_map[emb_weight_type].primitive_type == "INT" %}, shift_scale {% elif type_map[emb_weight_type].enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); acc.store(&output[output_j][output_d], num_valid_outputs); } } @@ -351,10 +356,10 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i for (uint32_t j = 0; j < MaxNum128BRows; ++j) { int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - VecNT<{{ (32 // bit_width) }}> acc(v{% if bit_width in [8, 4, 2] %}, shift_scale {% endif %}); + VecNT<{{ (32 // type_map[emb_weight_type].bit_width) }}, PrimitiveType::{{ type_map[emb_weight_type].primitive_type }}> acc(v{% if type_map[emb_weight_type].primitive_type == "INT" %}, shift_scale {% elif type_map[emb_weight_type].enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // bit_width) }}_max(acc.acc)); - thread_local_min = min(thread_local_min, float{{ (32 // bit_width) }}_min(acc.acc)); + thread_local_max = max(thread_local_max, float{{ (32 // type_map[emb_weight_type].bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // type_map[emb_weight_type].bit_width) }}_min(acc.acc)); } } qparams = warp_find_qparams(thread_local_min, thread_local_max); @@ -363,8 +368,8 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // bit_width) }})); - VecNT<{{ (32 // bit_width) }}> acc(v{% if bit_width in [8, 4, 2] %}, shift_scale {% endif %}); + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // type_map[emb_weight_type].bit_width) }})); + VecNT<{{ (32 // type_map[emb_weight_type].bit_width) }}, PrimitiveType::{{ type_map[emb_weight_type].primitive_type }}> acc(v{% if type_map[emb_weight_type].primitive_type == "INT" %}, shift_scale {% elif type_map[emb_weight_type].enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); acc.store(&output[output_j][output_d], qparams, num_valid_outputs); } } @@ -393,7 +398,7 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i } if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // bit_width) }})); + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // type_map[emb_weight_type].bit_width) }})); accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); } @@ -411,19 +416,19 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i accumulators[i][j].mul(inv_L); } if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // bit_width) }}_max(accumulators[i][j].acc)); - thread_local_min = min(thread_local_min, float{{ (32 // bit_width) }}_min(accumulators[i][j].acc)); + thread_local_max = max(thread_local_max, float{{ (32 // type_map[emb_weight_type].bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // type_map[emb_weight_type].bit_width) }}_min(accumulators[i][j].acc)); } } qparams = warp_find_qparams(thread_local_min, thread_local_max); - int output_D_start = D_start + t * 8; - int output_D_end = output_D_start + D; + const int output_D_start = D_start + t * 8; + const int output_D_end = output_D_start + D; #pragma unroll MaxNum128BRows for (uint32_t j = 0; j < MaxNum128BRows; ++j) { const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // bit_width) }})); + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // type_map[emb_weight_type].bit_width) }})); accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); } } @@ -436,7 +441,7 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i } {% endif %} } -{% endfor %} // for bit_width in [32, 16, 8, 4, 2] +{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] {% endif %} // if not nobag or not weighted {% endfor %} // for nobag in [True, False] @@ -455,23 +460,23 @@ __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_ const at::PackedTensorAccessor32 offsets, const at::PackedTensorAccessor64 hash_table, const at::PackedTensorAccessor32 hash_table_offsets, - int32_t B, - int32_t T, + const int32_t B, + const int32_t T, at::PackedTensorAccessor32 dense_indices) { // uint32_t capacity = hash_table.size(0); - int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; - int32_t t = b_t / B; - int32_t b = b_t % B; + const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t t = b_t / B; + const int32_t b = b_t % B; if (b_t >= B * T) { return; } - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - int32_t L = indices_end - indices_start; + const int32_t indices_start = offsets[t * B + b]; + const int32_t indices_end = offsets[t * B + b + 1]; + const int32_t L = indices_end - indices_start; - int64_t table_start = hash_table_offsets[t]; - int64_t table_end = hash_table_offsets[t + 1]; - int64_t capacity = table_end - table_start; + const int64_t table_start = hash_table_offsets[t]; + const int64_t table_end = hash_table_offsets[t + 1]; + const int64_t capacity = table_end - table_start; if (capacity == 0) { // No pruning applied on the indices associated with this table. @@ -481,21 +486,21 @@ __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_ return; } - uint32_t subwarp_id = threadIdx.x / 4; - uint32_t subwarp_tid = threadIdx.x % 4; + const uint32_t subwarp_id = threadIdx.x / 4; + const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef __HIP_PLATFORM_HCC__ - uint64_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); + const uint64_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #else - uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); + const uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #endif for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { - int32_t idx = indices[indices_start + l_start + subwarp_id]; + const int32_t idx = indices[indices_start + l_start + subwarp_id]; uint32_t slot_start = pruned_hash_function(static_cast(idx)) % capacity; while (true) { - uint32_t slot = (slot_start + subwarp_tid) % capacity; - int2 val = *reinterpret_cast(&hash_table[table_start + static_cast(slot)][0]); - int32_t slot_sparse_idx = val.x; - int32_t slot_dense_idx = val.y; + const uint32_t slot = (slot_start + subwarp_tid) % capacity; + const int2 val = *reinterpret_cast(&hash_table[table_start + static_cast(slot)][0]); + const int32_t slot_sparse_idx = val.x; + const int32_t slot_dense_idx = val.y; bool found = false; bool empty = false; @@ -505,20 +510,9 @@ __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_ found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } -#ifdef __HIP_PLATFORM_HCC__ - // FIXME: __any_sync with mask isn't supported by HIP yet. - // See https://fburl.com/fvy7j0lq for the similar context. - // assert false here with https://fburl.com/pfm7enw2 - if (__any_sync(subwarp_mask, found)) { -#else if (__any_sync(subwarp_mask, found)) { -#endif break; -#ifdef __HIP_PLATFORM_HCC__ - } else if (__any_sync(subwarp_mask, empty)) { -#else } else if (__any_sync(subwarp_mask, empty)) { -#endif dense_indices[indices_start + l_start + subwarp_id] = -1; break; } @@ -533,22 +527,22 @@ __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_ const at::PackedTensorAccessor32 offsets, const at::PackedTensorAccessor32 index_remappings, const at::PackedTensorAccessor32 index_remappings_offsets, - int32_t B, - int32_t T, + const int32_t B, + const int32_t T, at::PackedTensorAccessor32 dense_indices) { - int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; - int32_t t = b_t / B; - int32_t b = b_t % B; + const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t t = b_t / B; + const int32_t b = b_t % B; if (b_t >= B * T) { return; } - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - int32_t L = indices_end - indices_start; + const int32_t indices_start = offsets[t * B + b]; + const int32_t indices_end = offsets[t * B + b + 1]; + const int32_t L = indices_end - indices_start; - int64_t index_remappings_start = index_remappings_offsets[t]; - int64_t index_remappings_end = index_remappings_offsets[t + 1]; - int64_t capacity = index_remappings_end - index_remappings_start; + const int64_t index_remappings_start = index_remappings_offsets[t]; + const int64_t index_remappings_end = index_remappings_offsets[t + 1]; + const int64_t capacity = index_remappings_end - index_remappings_start; for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { int32_t idx = indices[indices_start + l]; @@ -569,28 +563,30 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Tensor weights_tys, {% if not nobag %} Tensor D_offsets, - int64_t total_D, + const int64_t total_D, {% else %} - int64_t D, + const int64_t D, {% endif %} - int64_t max_int2_D, - int64_t max_int4_D, - int64_t max_int8_D, - int64_t max_float16_D, - int64_t max_float32_D, + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, Tensor indices, Tensor offsets, {% if not nobag %} - int64_t pooling_mode, + const int64_t pooling_mode, {% endif %} - int64_t row_alignment, + const int64_t row_alignment, {% if weighted %} Tensor indice_weights, {% endif %} - int64_t output_dtype, + const int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, - int64_t unused + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias ) { TENSOR_ON_CUDA_GPU(dev_weights); TENSOR_ON_CUDA_GPU(uvm_weights); @@ -612,14 +608,14 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ device_guard.set_index(dev_weights.get_device()); {% if not nobag %} - int32_t T = D_offsets.numel() - 1; + const int32_t T = D_offsets.numel() - 1; {% else %} - int32_t total_L = indices.numel(); - int32_t T = weights_offsets.numel(); + const int32_t total_L = indices.numel(); + const int32_t T = weights_offsets.numel(); {% endif %} TORCH_CHECK(T > 0); // offsets = [B x T + 1] - int32_t B = (offsets.size(0) - 1) / T; + const int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); {% if not nobag %} @@ -685,7 +681,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "int2_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int2_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int2_D > 0) { auto max_int2_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int2_D, SparseType::INT2, row_alignment), 128); TORCH_CHECK(max_int2_128b_rows <= 2); @@ -729,7 +725,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "int4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int4_D > 0) { auto max_int4_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int4_D, SparseType::INT4, row_alignment), 128); TORCH_CHECK(max_int4_128b_rows <= 4); @@ -746,7 +742,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X - // launch 8-bit kernel + // launch 8-bit int kernel #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ nbit::INT8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ @@ -776,7 +772,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "int8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int8_D > 0) { auto max_int8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int8_D, SparseType::INT8, row_alignment), 128); TORCH_CHECK(max_int8_128b_rows <= 8); @@ -796,6 +792,58 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X + // launch 8-bit float kernel + #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + nbit::FP8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ + nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ + dim3(kWarpSize, kWarpsPerBlock), \ + 0, \ + at::cuda::getCurrentCUDAStream()>>>( \ + dev_weights.packed_accessor64(), \ + uvm_weights.packed_accessor64(), \ + weights_placements.packed_accessor32(), \ + weights_offsets.packed_accessor32(), \ + weights_tys.packed_accessor32(), \ + {% if not nobag %} \ + D_offsets.packed_accessor32(), \ + {% else %} \ + D, \ + {% endif %} \ + indices.packed_accessor32(), \ + offsets.packed_accessor32(), \ + {% if not nobag %} \ + pooling_mode, \ + {% endif %} \ + row_alignment, \ + {% if weighted %} indice_weights.packed_accessor32(), {% endif %} \ + fp8_exponent_bits, \ + fp8_exponent_bias, \ + output.packed_accessor32(), \ + lxu_cache_weights.packed_accessor64(), \ + lxu_cache_locations.packed_accessor32() \ + ); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { + if (max_float8_D > 0) { + auto max_fp8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float8_D, SparseType::FP8, row_alignment), 128); + TORCH_CHECK(max_fp8_128b_rows <= 8); + if (max_fp8_128b_rows > 0) { + X(2, 8, 0, 1); + } + if (max_fp8_128b_rows > 1) { + X(2, 4, 1, 2); + } + if (max_fp8_128b_rows > 2) { + X(2, 4, 2, 4); + } + if (max_fp8_128b_rows > 4) { + X(2, 4, 4, 8); + } + } + })); + #undef X + // launch 16-bit kernel #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ nbit::FP16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ @@ -826,7 +874,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "fp16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float16_D > 0) { auto max_fp16_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float16_D, SparseType::FP16, row_alignment), 128); TORCH_CHECK(max_fp16_128b_rows <= 16); @@ -876,7 +924,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "fp32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float32_D > 0) { auto max_fp32_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float32_D, SparseType::FP32, row_alignment), 128); TORCH_CHECK(max_fp32_128b_rows <= 32); @@ -906,8 +954,8 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cuda( at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(indices.get_device()); auto dense_indices = at::empty_like(indices); - int32_t T = hash_table_offsets.size(0) - 1; - int32_t B = (offsets.size(0) - 1) / T; + const int32_t T = hash_table_offsets.size(0) - 1; + const int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; @@ -943,7 +991,7 @@ Tensor pruned_array_lookup_cuda( at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(indices.get_device()); auto dense_indices = at::empty_like(indices); - int32_t T = index_remappings_offsets.size(0) - 1; + const int32_t T = index_remappings_offsets.size(0) - 1; TORCH_CHECK( (offsets.size(0) - 1) % T == 0, "offsets.size() - 1 is not divisible by T! offsets.size: ", @@ -951,7 +999,7 @@ Tensor pruned_array_lookup_cuda( "T: ", T ); - int32_t B = (offsets.size(0) - 1) / T; + const int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0, "offsets.size(): ", offsets.size(0), ", T: ", T, ", B: ", B); TORCH_CHECK(index_remappings.size(0) < std::numeric_limits::max()); TORCH_CHECK(indices.dim() == 1, "Tensor dim: ", indices.dim()); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp index c2eca730d..89c43fc4e 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp @@ -20,6 +20,7 @@ #include using Tensor = at::Tensor; +using namespace fbgemm_gpu; template void split_embedding_forward_cpu_kernel( @@ -318,267 +319,287 @@ Tensor split_embedding_codegen_grad_indice_weights_cpu( namespace internal { -template -void batched_csr2csc( - BatchedHyperCompressedSparseColumn& batched_csc, +namespace { + +template +void csr2csc_template_( + HyperCompressedSparseColumn& csc, int B, - // TODO: use accessor for the following 3 parameters - const at::TensorAccessor& batched_csr_offsets, - const at::TensorAccessor& batched_csr_indices, - const at::TensorAccessor& batched_csr_weights, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings) { - int num_tables = 1; - batched_csc.num_tables = num_tables; - batched_csc.table_ptr = static_cast( - fbgemm::fbgemmAlignedAlloc(64, (num_tables + 1) * sizeof(int))); - batched_csc.table_ptr[0] = 0; - int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B] - - batched_csr_offsets[table_to_feature_offset[0] * B]; + csc.num_non_zero_columns = 0; + int64_t nnz = csr_offsets[table_to_feature_offset[1] * B] - + csr_offsets[table_to_feature_offset[0] * B]; if (nnz == 0) { - batched_csc.table_ptr[1] = 0; return; } - batched_csc.row_indices = + csc.row_indices = static_cast(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int))); - bool has_weights = batched_csr_weights.data() != nullptr; - if (has_weights || - static_cast(pooling_mode) == PoolingMode::MEAN) { - batched_csc.weights = static_cast( + bool has_weights = csr_weights.data() != nullptr; + if (IS_VALUE_PAIR) { + csc.weights = static_cast( fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(float))); } int column_ptr_curr = 0; - int t = 0; bool is_shared_table = - table_to_feature_offset[t + 1] > table_to_feature_offset[t] + 1; - auto NS = batched_csr_offsets[table_to_feature_offset[t + 1] * B] - - batched_csr_offsets[table_to_feature_offset[t] * B]; + table_to_feature_offset[1] > table_to_feature_offset[0] + 1; + auto NS = csr_offsets[table_to_feature_offset[1] * B] - + csr_offsets[table_to_feature_offset[0] * B]; int num_non_empty_segments = 0; - if (!batched_csc.weights) { - batched_csc.column_segment_ids = - static_cast(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int))); - - int* tmpBufKeys = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); - int* tmpBufValues = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); - int* tmpBuf1Keys = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); - int* tmpBuf1Values = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); - const auto FBo = batched_csr_offsets[table_to_feature_offset[t] * B]; - for (int feature = table_to_feature_offset[t]; - feature < table_to_feature_offset[t + 1]; - ++feature) { - const auto FBs = (feature - table_to_feature_offset[t]) * B; + + using pair_t = std::pair; + using value_t = typename std::conditional::type; + + csc.column_segment_ids = + static_cast(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int))); + int* tmpBufKeys = + static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); + value_t* tmpBufValues = static_cast( + fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(value_t))); + int* tmpBuf1Keys = + static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); + value_t* tmpBuf1Values = static_cast( + fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(value_t))); + + const auto FBo = csr_offsets[table_to_feature_offset[0] * B]; + for (int feature = table_to_feature_offset[0]; + feature < table_to_feature_offset[1]; + ++feature) { + const auto FBs = (feature - table_to_feature_offset[0]) * B; #pragma omp parallel for - for (int b = 0; b < B; ++b) { - const auto FBb = feature * B + b; - int64_t pool_begin = batched_csr_offsets[FBb]; - int64_t pool_end = batched_csr_offsets[FBb + 1]; - for (int64_t p = pool_begin; p < pool_end; ++p) { - tmpBufKeys[p - FBo] = batched_csr_indices[p]; - tmpBufValues[p - FBo] = FBs + b; + for (int b = 0; b < B; ++b) { + const auto FBb = feature * B + b; + int64_t pool_begin = csr_offsets[FBb]; + int64_t pool_end = csr_offsets[FBb + 1]; + int64_t L = pool_end - pool_begin; + // MEAN pooling will not work with indice_weights! + double scale_factor = + (static_cast(pooling_mode) == PoolingMode::MEAN && + !has_weights && L > 0) + ? 1.0 / L + : 1.0; + + for (int64_t p = pool_begin; p < pool_end; ++p) { + tmpBufKeys[p - FBo] = csr_indices[p]; + if (IS_VALUE_PAIR) { + reinterpret_cast(tmpBufValues)[p - FBo] = std::make_pair( + FBs + b, scale_factor * (has_weights ? csr_weights[p] : 1.0f)); + } else { + reinterpret_cast(tmpBufValues)[p - FBo] = FBs + b; } } } + } + + int* sorted_col_row_index_keys = nullptr; + value_t* sorted_col_row_index_values = nullptr; + + std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) = + fbgemm_gpu::radix_sort_parallel( + tmpBufKeys, + tmpBufValues, + tmpBuf1Keys, + tmpBuf1Values, + NS, + num_embeddings); + + int max_thds = omp_get_max_threads(); + int num_uniq[max_thds][64]; - int* sorted_col_row_index_keys = nullptr; - int* sorted_col_row_index_values = nullptr; - std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) = - fbgemm_gpu::radix_sort_parallel( - tmpBufKeys, - tmpBufValues, - tmpBuf1Keys, - tmpBuf1Values, - NS, - num_embeddings); - - int max_thds = omp_get_max_threads(); - int num_uniq[max_thds][64]; - int U = 0; - if (at::get_num_threads() > 1) { - // This block is not needed for single thread + int U = 0; + if (at::get_num_threads() > 1) { + // This block is not needed for single thread #pragma omp parallel - { - int tid = omp_get_thread_num(); - num_uniq[tid][0] = 0; + { + int tid = omp_get_thread_num(); + num_uniq[tid][0] = 0; #pragma omp for schedule(static) - for (int i = 1; i < NS; i++) { - if (sorted_col_row_index_keys[i] != - sorted_col_row_index_keys[i - 1]) { - num_uniq[tid][0]++; - } + for (int i = 1; i < NS; i++) { + if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) { + num_uniq[tid][0]++; } } - num_uniq[0][0] += 1; - for (int i = 1; i < max_thds; i++) - num_uniq[i][0] += num_uniq[i - 1][0]; - U = num_uniq[max_thds - 1][0]; } + num_uniq[0][0] += 1; + for (int i = 1; i < max_thds; i++) { + num_uniq[i][0] += num_uniq[i - 1][0]; + } + U = num_uniq[max_thds - 1][0]; + } - batched_csc.column_segment_ptr = static_cast( - fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int))); - batched_csc.column_segment_indices = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); + csc.column_segment_ptr = + static_cast(fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int))); + csc.column_segment_indices = + static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); + csc.column_segment_ptr[0] = 0; + const pair_t* sorted_col_row_index_values_pair = + reinterpret_cast(sorted_col_row_index_values); + const int* sorted_col_row_index_values_int = + reinterpret_cast(sorted_col_row_index_values); + if (IS_VALUE_PAIR) { + csc.row_indices[0] = sorted_col_row_index_values_pair[0].first % B; + csc.weights[0] = sorted_col_row_index_values_pair[0].second; + csc.column_segment_ids[0] = sorted_col_row_index_values_pair[0].first / B; + } else { + csc.row_indices[0] = sorted_col_row_index_values_int[0] % B; + csc.column_segment_ids[0] = sorted_col_row_index_values_int[0] / B; + } + csc.column_segment_indices[0] = sorted_col_row_index_keys[0]; - batched_csc.column_segment_ptr[0] = 0; - batched_csc.row_indices[0] = sorted_col_row_index_values[0] % B; - batched_csc.column_segment_indices[0] = sorted_col_row_index_keys[0]; - batched_csc.column_segment_ids[0] = sorted_col_row_index_values[0] / B; #pragma omp parallel - { - int tid = omp_get_thread_num(); - int* tstart = - (tid == 0 - ? batched_csc.column_segment_indices + 1 - : batched_csc.column_segment_indices + num_uniq[tid - 1][0]); - - int* t_offs = - (tid == 0 ? batched_csc.column_segment_ptr + 1 - : batched_csc.column_segment_ptr + num_uniq[tid - 1][0]); - - if (!is_shared_table) { - // For non shared table, no need for computing modulo. - // As an optimization, pointer swap instead of copying. + { + int tid = omp_get_thread_num(); + int* tstart = + (tid == 0 ? csc.column_segment_indices + 1 + : csc.column_segment_indices + num_uniq[tid - 1][0]); + + int* t_offs = + (tid == 0 ? csc.column_segment_ptr + 1 + : csc.column_segment_ptr + num_uniq[tid - 1][0]); + + if (!IS_VALUE_PAIR && !is_shared_table) { + // For non shared table, no need for computing modulo. + // As an optimization, pointer swap instead of copying. #pragma omp master - std::swap( - batched_csc.row_indices, - sorted_col_row_index_values == tmpBufValues ? tmpBufValues - : tmpBuf1Values); - } else { + std::swap( + csc.row_indices, + *reinterpret_cast( + sorted_col_row_index_values == tmpBufValues ? &tmpBufValues + : &tmpBuf1Values)); + } else { #ifdef FBCODE_CAFFE2 - libdivide::divider divisor(B); + libdivide::divider divisor(B); #endif #pragma omp for schedule(static) - for (int i = 1; i < NS; ++i) { - int v = sorted_col_row_index_values[i]; + for (int i = 1; i < NS; ++i) { + int v = IS_VALUE_PAIR ? sorted_col_row_index_values_pair[i].first + : sorted_col_row_index_values_int[i]; #ifdef FBCODE_CAFFE2 - int q = v / divisor; + int q = v / divisor; #else - int q = v / B; + int q = v / B; #endif - batched_csc.column_segment_ids[i] = q; - batched_csc.row_indices[i] = v - q * B; + csc.column_segment_ids[i] = q; + csc.row_indices[i] = v - q * B; + if (IS_VALUE_PAIR) { + csc.weights[i] = sorted_col_row_index_values_pair[i].second; } } + } #pragma omp for schedule(static) - for (int i = 1; i < NS; ++i) { - if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) { - *tstart = sorted_col_row_index_keys[i]; - *t_offs = i; - tstart++; - t_offs++; - } + for (int i = 1; i < NS; ++i) { + if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) { + *tstart = sorted_col_row_index_keys[i]; + *t_offs = i; + tstart++; + t_offs++; } + } - if (at::get_num_threads() == 1 && tid == 0) { - // Special handling of single thread case - U = t_offs - batched_csc.column_segment_ptr; - } - } // omp parallel - batched_csc.table_ptr[t + 1] = batched_csc.table_ptr[t] + U; - batched_csc.column_segment_ptr[U] = NS; - column_ptr_curr += NS; - fbgemm::fbgemmAlignedFree(tmpBufKeys); - fbgemm::fbgemmAlignedFree(tmpBufValues); - fbgemm::fbgemmAlignedFree(tmpBuf1Keys); - fbgemm::fbgemmAlignedFree(tmpBuf1Values); - } else { - // batched_csc.weights -#ifdef FBCODE_CAFFE2 - folly::F14FastMap< -#else - std::unordered_map< -#endif - int64_t, - std::vector>>> - non_empty_columns; - int f_begin = table_to_feature_offset[t]; - int f_end = table_to_feature_offset[t + 1]; - for (int feature = f_begin; feature < f_end; ++feature) { - for (int b = 0; b < B; ++b) { - int64_t pool_begin = batched_csr_offsets[feature * B + b]; - int64_t pool_end = batched_csr_offsets[feature * B + b + 1]; - int64_t L = pool_end - pool_begin; - // MEAN pooling will not work with indice_weights! - double scale_factor = - (static_cast(pooling_mode) == PoolingMode::MEAN && - !has_weights && L > 0) - ? 1.0 / L - : 1.0; - for (int64_t p = pool_begin; p < pool_end; ++p) { - auto itr = non_empty_columns.find(batched_csr_indices[p]); - if (itr == non_empty_columns.end()) { - itr = non_empty_columns - .emplace( - batched_csr_indices[p], - std::vector>>( - f_end - f_begin)) - .first; - } - if (itr->second[feature - f_begin].empty()) { - ++num_non_empty_segments; - } - itr->second[feature - f_begin].emplace_back( - b, scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f)); - } - } - } // for each feature - - batched_csc.table_ptr[t + 1] = - batched_csc.table_ptr[t] + num_non_empty_segments; - batched_csc.column_segment_ptr = static_cast( - fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int))); - batched_csc.column_segment_ptr[0] = 0; - batched_csc.column_segment_indices = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); - batched_csc.column_segment_ids = - static_cast(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); - int k = 1; - for (auto const& column : non_empty_columns) { - int feature = f_begin; - for (auto const& column_segment : column.second) { - if (!column_segment.empty()) { - batched_csc.column_segment_ptr[k] = - column_ptr_curr + column_segment.size(); - batched_csc.column_segment_indices[k - 1] = column.first; - batched_csc.column_segment_ids[k - 1] = feature - f_begin; - k++; - for (auto const& non_zero : column_segment) { - batched_csc.row_indices[column_ptr_curr] = non_zero.first; - batched_csc.weights[column_ptr_curr] = non_zero.second; - ++column_ptr_curr; - } - } - ++feature; - } // for each column segment - } // for each column - } // !batched_csc.weights.empty() + if (at::get_num_threads() == 1 && tid == 0) { + // Special handling of single thread case + U = t_offs - csc.column_segment_ptr; + } + + } // omp parallel + + csc.num_non_zero_columns = U; + csc.column_segment_ptr[U] = NS; + column_ptr_curr += NS; + + fbgemm::fbgemmAlignedFree(tmpBufKeys); + fbgemm::fbgemmAlignedFree(tmpBufValues); + fbgemm::fbgemmAlignedFree(tmpBuf1Keys); + fbgemm::fbgemmAlignedFree(tmpBuf1Values); assert(column_ptr_curr == nnz); } -template void batched_csr2csc( - BatchedHyperCompressedSparseColumn& batched_csc, +#define INSTANTIATE_BATCHED_CSR2CSC(SCALAR_T) \ + template void csr2csc_template_( \ + HyperCompressedSparseColumn & csc, \ + int B, \ + const at::TensorAccessor& csr_offsets, \ + const at::TensorAccessor& csr_indices, \ + const at::TensorAccessor& csr_weights, \ + int64_t pooling_mode, \ + const int* table_to_feature_offset, \ + int64_t num_embeddings); \ + \ + template void csr2csc_template_( \ + HyperCompressedSparseColumn & csc, \ + int B, \ + const at::TensorAccessor& csr_offsets, \ + const at::TensorAccessor& csr_indices, \ + const at::TensorAccessor& csr_weights, \ + int64_t pooling_mode, \ + const int* table_to_feature_offset, \ + int64_t num_embeddings); + +INSTANTIATE_BATCHED_CSR2CSC(float) +INSTANTIATE_BATCHED_CSR2CSC(double) +#undef INSTANTIATE_BATCHED_CSR2CSC + +} // namespace + +template +void csr2csc( + HyperCompressedSparseColumn& csc, + int B, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_weights, + int64_t pooling_mode, + const int* table_to_feature_offset, + int64_t num_embeddings) { + bool has_weights = csr_weights.data() != nullptr; + if (has_weights || + static_cast(pooling_mode) == PoolingMode::MEAN) { + csr2csc_template_( + csc, + B, + csr_offsets, + csr_indices, + csr_weights, + pooling_mode, + table_to_feature_offset, + num_embeddings); + } else { + csr2csc_template_( + csc, + B, + csr_offsets, + csr_indices, + csr_weights, + pooling_mode, + table_to_feature_offset, + num_embeddings); + } +} + +template void csr2csc( + HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& batched_csr_offsets, - const at::TensorAccessor& batched_csr_indices, - const at::TensorAccessor& batched_csr_weights, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings); -template void batched_csr2csc( - BatchedHyperCompressedSparseColumn& batched_csc, +template void csr2csc( + HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& batched_csr_offsets, - const at::TensorAccessor& batched_csr_indices, - const at::TensorAccessor& batched_csr_weights, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_cpu.h b/fbgemm_gpu/codegen/embedding_forward_split_cpu.h index ad2eaf02d..c8b7b25ca 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_cpu.h +++ b/fbgemm_gpu/codegen/embedding_forward_split_cpu.h @@ -32,26 +32,21 @@ at::Tensor split_embedding_codegen_grad_indice_weights_cpu( at::Tensor feature_requires_grad); namespace internal { -// A batch of compressed sparse row but each sparse matrix is hyper sparse +// A compressed sparse column but each sparse matrix is hyper sparse // meaning there can be many columns without any non-zeros. -struct BatchedHyperCompressedSparseColumn { - int num_tables; // # of matrices (or tables) - // pointers to the beginning of each table in column_ptr (length T + 1) - int* table_ptr = nullptr; +struct HyperCompressedSparseColumn { + int num_non_zero_columns; // pointers to the beginning of each column segment in row_indices - // (length table_ptr[T] + 1) + // (length num_non_zero_columns + 1) // For a shared table, a column can have multiple segments, each for a // feature sharing the table. In this case, the segments will have the // same column_segment_indices but different column_segment_ids. int* column_segment_ptr = nullptr; - int* column_segment_indices = nullptr; // length table_ptr[T] - int* column_segment_ids = nullptr; // length table_ptr[T] - int* row_indices = nullptr; // length column_ptr[table_ptr[T]] - float* weights = nullptr; // length column_ptr[table_ptr[T]] - ~BatchedHyperCompressedSparseColumn() { - if (table_ptr) { - fbgemm::fbgemmAlignedFree(table_ptr); - } + int* column_segment_indices = nullptr; // length num_non_zero_columns + int* column_segment_ids = nullptr; // length num_non_zero_columns + int* row_indices = nullptr; // length column_ptr[num_non_zero_columns] + float* weights = nullptr; // length column_ptr[num_non_zero_columns] + ~HyperCompressedSparseColumn() { if (column_segment_ptr) { fbgemm::fbgemmAlignedFree(column_segment_ptr); fbgemm::fbgemmAlignedFree(column_segment_indices); @@ -65,12 +60,12 @@ struct BatchedHyperCompressedSparseColumn { }; template -void batched_csr2csc( - BatchedHyperCompressedSparseColumn& batched_csc, +void csr2csc( + HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& batched_csr_offsets, - const at::TensorAccessor& batched_csr_indices, - const at::TensorAccessor& batched_csr_weights, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index bd2c7e93e..c6f59f71f 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -26,6 +26,126 @@ constexpr size_t kForwardMaxThreads = 512; using Tensor = at::Tensor; using namespace fbgemm_gpu; +{% if not weighted %} +template < + typename emb_t, + typename cache_t, + {% if not dense %} + typename output_t, + {% endif %} + typename index_t, + size_t kThreadGroupSize + > +__launch_bounds__(kForwardMaxThreads) +__global__ void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel( + const at::PackedTensorAccessor64 dev_weights, + {% if not dense %} + const at::PackedTensorAccessor64 uvm_weights, + const at::PackedTensorAccessor64 + lxu_cache_weights, + const at::PackedTensorAccessor32 + weights_placements, + {% endif %} + const at::PackedTensorAccessor32 weights_offsets, + int64_t D, + const at::PackedTensorAccessor32 indices, + const at::PackedTensorAccessor32 offsets, + {% if not dense %} + const at::PackedTensorAccessor32 + lxu_cache_locations, + at::PackedTensorAccessor32 + output // [B][total_D], + {% else %} + at::PackedTensorAccessor32, 2, at::RestrictPtrTraits> + output // [B][total_D], + {% endif %} + ) { + int32_t T = weights_offsets.size(0); + int32_t B = (offsets.size(0) - 1) / T; + int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; + int32_t t = b_t / B; + int32_t b = b_t % B; + + if (b_t >= B * T) { + return; + } + int64_t weights_offset = weights_offsets[t]; + index_t indices_start = offsets[t * B + b]; + index_t indices_end = offsets[t * B + b + 1]; + int32_t L = indices_end - indices_start; + const emb_t* __restrict__ weights; + {% if not dense %} + const auto placement = static_cast(weights_placements[t]); + if (placement == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset]; + } else { + weights = &uvm_weights[weights_offset]; + } + {% else %} + weights = &dev_weights[weights_offset]; + {% endif %} + + int32_t D_emb = D; + if (std::is_same::value) { + D_emb += kINT8QparamsBytes; + } + + constexpr int32_t kNumThreadGroup = kWarpSize / kThreadGroupSize; + const int32_t group_start = threadIdx.x / kThreadGroupSize * kThreadGroupSize; + const int32_t group_end = group_start + kThreadGroupSize; + const int32_t d = threadIdx.x % kThreadGroupSize * 4; + + for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { + int32_t l = l_start + threadIdx.x; + int64_t idx = l < L ? indices[indices_start + l] : 0; + {% if not dense %} + int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; + {% endif %} + for (auto j = group_start; j < group_end && l_start + j < L; ++j) { + int64_t idx_j = shfl_sync(idx, j); + int64_t output_j = indices_start + l_start + j; + {% if not dense %} + int32_t cache_idx_j = shfl_sync(cache_idx, j); + {% endif %} + + {% if not dense %} + auto weight_row_cache = WeightRow( + const_cast(&weights[idx_j * D_emb]), + const_cast(&lxu_cache_weights[cache_idx_j][0]), + D, + nullptr); + float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams + + {% endif %} + auto weight_row_emb = WeightRow( + const_cast(&weights[idx_j * D_emb]), + nullptr, + D, + nullptr); + float2 qparams_emb; + if (std::is_same::value) { + qparams_emb = weight_row_emb.load_qparams(); + } + + if (d < D) { + {% if not dense %} + if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { + Vec4T weight = weight_row_cache.load(d, qparams_cache); + weight.store(&output[output_j][d]); + } else { + Vec4T weight = weight_row_emb.load(d, qparams_emb); + weight.store(&output[output_j][d]); + } + {% else %} + Vec4T weight = weight_row_emb.load(d, qparams_emb); + weight.store(&output[output_j][d]); + {% endif %} + } + } + } +} +{% endif %} + {% for nobag in [True, False] %} {% if not nobag or not weighted %} template < @@ -361,7 +481,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" {% endif %} {% else %} {% if dense %} - if (dev_weights.type().scalarType() == at::kHalf || dev_weights.type().scalarType() == at::kByte) { + if (dev_weights.scalar_type() == at::kHalf || dev_weights.scalar_type() == at::kByte) { output = at::empty({B, total_D}, dev_weights.options().dtype(at::kFloat)); } else { output = at::empty({B, total_D}, dev_weights.options()); @@ -439,6 +559,45 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } {% endfor %} {% else %} + {% for kEmbeddingSize in [4, 8, 16, 32] %} + if (D <= {{ kEmbeddingSize }}) { + {% if not dense %} + split_embedding_nobag_codegen_forward_unweighted_small_kernel<<< + {% else %} + dense_embedding_nobag_codegen_forward_unweighted_small_kernel<<< + {% endif %} + div_round_up((B * T), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(), + {% if not dense %} + uvm_weights.packed_accessor64(), + lxu_cache_weights.packed_accessor64(), + weights_placements.packed_accessor32(), + {% endif %} + weights_offsets.packed_accessor32(), + D, + indices.packed_accessor32(), + offsets.packed_accessor32(), + {% if not dense %} + lxu_cache_locations.packed_accessor32(), + output.packed_accessor32< + output_t, + 2, + at::RestrictPtrTraits>() + ); + {% else %} + output.packed_accessor32< + at::acc_type, + 2, + at::RestrictPtrTraits>() + ); + {% endif %} + + return; + } + {% endfor %} {% if not dense %} split_embedding_nobag_codegen_forward_unweighted_kernel<<< {% else %} diff --git a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh index 4a7e517a6..24e76e8c7 100644 --- a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh @@ -10,11 +10,7 @@ #include #include #include -#if !defined(NEW_ATOMIC_PATH) -#include -#else #include -#endif // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" diff --git a/fbgemm_gpu/docs/Doxyfile.in b/fbgemm_gpu/docs/Doxyfile.in new file mode 100644 index 000000000..f36821d10 --- /dev/null +++ b/fbgemm_gpu/docs/Doxyfile.in @@ -0,0 +1,2678 @@ +# Doxyfile 1.9.4 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables: +# doxygen -x_noenv [configFile] + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = "fbgemm_gpu" + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + +PROJECT_LOGO = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = "build" + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# numer of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = NO + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = NO + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = "The $name class" \ + "The $name widget" \ + "The $name file" \ + is \ + provides \ + specifies \ + contains \ + represents \ + a \ + an \ + the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:^^" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) + +ALIASES = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. + +EXTENSION_MAPPING = .cu=C++ \ + .cuh=C++ \ + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See https://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = YES + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + +GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = NO + +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = NO + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = YES + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = YES + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# declarations. If set to NO, these declarations will be included in the +# documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = YES + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = YES + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# The default value is: system dependent. + +CASE_SENSE_NAMES = NO + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = YES + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + +HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USEDFILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = NO + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# Possible values are: NO, YES and FAIL_ON_WARNINGS. +# The default value is: NO. + +WARN_AS_ERROR = NO + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + +INPUT = "../include/fbgemm_gpu" \ + "../src/" \ + "../codegen/" \ + "../include/fbgemm_gpu/layout_transform_ops.cuh" \ + "../include/fbgemm_gpu/permute_pooled_embedding_ops_split.h" \ + "../include/fbgemm_gpu/merge_pooled_embeddings.h" \ + "../include/fbgemm_gpu/sparse_ops.h" \ + "../fbgemm_gpu/src/quantize_ops.cu" \ + "../src/quantize_ops_cpu.cpp" \ + "../src/split_table_batched_embeddings.cpp" \ + "../src/jagged_tensor_ops.cu" \ + "../src/jagged_tensor_ops_cpu.cpp" \ + "../src/cumem_utils.h" \ + "../include/fbgemm_gpu/input_combine.h" \ + "../src/layout_transform_ops.cu" \ + "../src/layout_transform_ops_cpu.cpp" \ + "../codegen/embedding_backward_split_host_template.cpp" \ + "../codegen/embedding_backward_split_host_cpu_template.cpp" \ + "../codegen/embedding_forward_quantized_host.cpp" \ + "../codegen/embedding_forward_quantized_host_cpu.cpp" \ + "../codegen/embedding_bounds_check_host.cpp" \ + "../codegen/embedding_bounds_check_host_cpu.cpp" \ + "../codegen/embedding_backward_dense_host.cpp" \ + "../codegen/embedding_forward_template_helpers.cuh " \ + "../src/permute_pooled_embedding_ops_gpu.cpp" \ + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. + +FILE_PATTERNS = *.c \ + *.cpp \ + *.cuh \ + *.cu \ + *.h \ + *.hpp \ + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = YES + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + +EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# ANamespace::AClass, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + +EXCLUDE_SYMBOLS = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + +INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + +USE_MDFILE_AS_MAINPAGE = + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# entity all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see https://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = NO + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use gray-scales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_MENUS = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the main .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = NO + +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANSPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_TRANSPARENT = YES + +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. + +FORMULA_MACROFILE = + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# https://www.mathjax.org) which uses client side JavaScript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = NO + +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see https://docs.mathjax.org/en/v2.7-latest/tex.html +# #tex-and-latex-extensions): +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /