Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Windows pipeline & Sync with upstream code #5

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions .github/workflows/win_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
name: Windows_CI
on:
push:
branches:
- main
- rel-*
pull_request:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
Win32_debug_no_ort:
runs-on: windows-2022
permissions:
actions: read
contents: read
security-events: write
steps:
- uses: actions/checkout@v4
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
config-file: ./.github/codeql/codeql-config.yml
languages: 'cpp'
- run: |
cmake --workflow --preset windows_win32_debug_no_ort_workflow
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:cpp"
output: sarif-results
upload: failure-only

- name: filter-sarif
uses: advanced-security/filter-sarif@v1
with:
patterns: |
+**/*.cc
+**/*.h
-tests/**/*.*
-build/**/*.*
input: sarif-results/cpp.sarif
output: sarif-results/cpp.sarif

- name: Upload SARIF
uses: github/codeql-action/upload-sarif@v3
with:
sarif_file: sarif-results/cpp.sarif

Win32_release_no_ort:
runs-on: windows-2022
steps:
- uses: actions/checkout@v4
- run: |
cmake --workflow --preset windows_win32_release_no_ort_workflow

WinX64_debug_no_ort:
runs-on: windows-2022
permissions:
actions: read
contents: read
security-events: write
steps:
- uses: actions/checkout@v4
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
config-file: ./.github/codeql/codeql-config.yml
languages: 'cpp'
- run: |
cmake --workflow --preset windows_x64_debug_no_ort_workflow
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:cpp"
output: sarif-results
upload: failure-only

- name: filter-sarif
uses: advanced-security/filter-sarif@v1
with:
patterns: |
+**/*.cc
+**/*.h
-tests/**/*.*
-build/**/*.*
input: sarif-results/cpp.sarif
output: sarif-results/cpp.sarif

- name: Upload SARIF
uses: github/codeql-action/upload-sarif@v3
with:
sarif_file: sarif-results/cpp.sarif

WinX64_release_no_ort:
runs-on: windows-2022
steps:
- uses: actions/checkout@v4
- run: |
cmake --workflow --preset windows_x64_release_no_ort_workflow

WinX64_release:
runs-on: windows-2022
steps:
- uses: actions/checkout@v4
- run: |
cmake --workflow --preset windows_x64_release_workflow
82 changes: 41 additions & 41 deletions include/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,50 @@ Module Name:
* @brief Define compute types of block quantization, in order of decreasing accuracy.
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32, /*!< input fp32, accumulator fp32 */
CompFp16, /*!< input fp16, accumulator fp16 */
CompBf16, /*!< input bf16, accumulator fp32 */
CompInt8, /*!< input int8, accumulator int32 */

// special values that should be the first and last actual values

CompMostAccurate = CompUndef,
CompLeastAccurate = CompInt8,
} MLAS_SQNBIT_GEMM_COMPUTE_TYPE;
SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */
HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */
BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */
SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */
HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */
} MLAS_QNBIT_GEMM_COMPUTE_TYPE;

/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*
* @tparam T data type of input A
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
template <typename T>
struct MLAS_QNBIT_GEMM_DATA_PARAMS {
const T* A = nullptr; ///< address of A (float32/16 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const T* Bias = nullptr; ///< optional address of Bias, vector size N
T* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
MLAS_GEMM_POSTPROCESSOR<T>* PostProcessor = nullptr;
};

/**
* @brief Batched GEMM: C = A * B + Bias
* A must be a float32 matrix
* A must be a float32/16 matrix
* B must be a quantized and packed n-bit int matrix
*
* Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
* Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
*
* Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasSQNBitGemmPackQuantBData().
* Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasQNBitGemmPackQuantBData().
*
* Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* point to an intermediate workspace buffer.
*
* @tparam T data type of input A
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
Expand All @@ -81,36 +80,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] Workspace Address of intermediate workspace buffer.
If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
buffer with at least that many bytes. Otherwise, it may be nullptr.
* @param[in] ThreadPool optional thread pool to use
*/
template <typename T>
void MLASCALL
MlasSQNBitGemmBatch(
MlasQNBitGemmBatch(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_QNBIT_GEMM_DATA_PARAMS<T>* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool = nullptr
);

/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
* @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform.
*
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
MlasIsQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
Expand All @@ -126,22 +126,22 @@ MlasIsSQNBitGemmAvailable(
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmBatchWorkspaceSize(
MlasQNBitGemmBatchWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
* @brief Gets the size in bytes of the packed quantized B data.
* If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch().
* If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasSQNBitGemmBatch().
* If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch().
* If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasQNBitGemmBatch().
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
Expand All @@ -150,12 +150,12 @@ MlasSQNBitGemmBatchWorkspaceSize(
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
MlasQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
Expand Down Expand Up @@ -186,12 +186,12 @@ MlasSQNBitGemmPackQuantBDataSize(
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
MlasSQNBitGemmPackQuantBData(
MlasQNBitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBDataAndOrBlkSum,
const void* QuantBScale,
Expand Down
Loading
Loading