-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Fp gemm and Softmax for Snitch platform (#31)
* add fp gemm and softmax * add new test for gemm_fp32_transb * add new test in CI.yml * increase size of the error between expected and actual * update CHANGELOG
- Loading branch information
1 parent
6758edc
commit feff1ef
Showing
17 changed files
with
477 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from Deeploy.DeeployTypes import NodeTemplate | ||
|
||
referenceTemplate = NodeTemplate(""" | ||
uint32_t compute_num = snrt_cluster_compute_core_num(); | ||
% if transB: | ||
gemm_fp32_transB_opt(${M} / compute_num, ${O}, ${N}, ${A}, ${N} * compute_num, ${B}, ${N}, ${C}, ${O} * compute_num, ${data_out}, 1, 1 ); | ||
% else: | ||
gemm_fp32_opt(${M} / compute_num, ${O}, ${N}, ${A}, ${N} * compute_num, ${B}, ${O}, ${C}, ${O} * compute_num, ${data_out}, 1, 1 ); | ||
%endif | ||
""") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# ---------------------------------------------------------------------- | ||
# | ||
# File: iSoftmaxTemplate.py | ||
# | ||
# Last edited: 30.05.2024 | ||
# | ||
# Copyright (C) 2024, ETH Zurich and University of Bologna. | ||
# | ||
# Author: | ||
# - Victor Jung, [email protected], ETH Zurich | ||
# | ||
# ---------------------------------------------------------------------- | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the License); you may | ||
# not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an AS IS BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation | ||
|
||
|
||
class FloatSoftmaxTemplate(NodeTemplate): | ||
|
||
def __init__(self, templateStr): | ||
super().__init__(templateStr) | ||
|
||
def alignToContext(self, ctxt: NetworkContext, | ||
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: | ||
|
||
data_in = ctxt.lookup(operatorRepresentation["data_in"]) | ||
operatorRepresentation["seq_len"] = data_in.shape[2] | ||
operatorRepresentation["input_samples"] = data_in.shape[-1] | ||
|
||
operatorRepresentation["kernelName"] = "Softmax_fp32" | ||
|
||
return ctxt, operatorRepresentation, [] | ||
|
||
|
||
FloatSoftmaxTemplateStr = r""" | ||
uint32_t batch_size = ${size} / ${lastDimLength}; | ||
uint32_t compute_num = snrt_cluster_compute_core_num(); | ||
int32_t ldI = compute_num * ${input_samples}; | ||
int32_t batch_offset = ${seq_len} * ${input_samples}; | ||
${kernelName}(${data_in}, ${data_out}, ldI, batch_offset, batch_size, ${seq_len}, ${input_samples}); | ||
""" | ||
|
||
FloatSoftmax_Template = FloatSoftmaxTemplate(FloatSoftmaxTemplateStr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#ifndef __DEEPLOY_MATH_GEMM_KERNEL_HEADER_ | ||
#define __DEEPLOY_MATH_GEMM_KERNEL_HEADER_ | ||
|
||
#include "DeeploySnitchMath.h" | ||
|
||
/* | ||
* TILING ONLY due to ssr loop | ||
* | ||
* | ||
* | ||
* FP32 GEMM with the following format: | ||
* A is an M x K matrix, B is a K x N matrix, and C is a M x N matrix | ||
* | ||
* A' = transpose(A) if transA else A | ||
* B' = transpose(B) if transB else B | ||
* | ||
* Y = A' * B' + C | ||
* | ||
*/ | ||
|
||
/* | ||
* | ||
* transposed A = no | ||
* transposed B = yes | ||
* multi-core = yes | ||
* unrolling = yes | ||
* simd = yes | ||
* parallelization = row-wise | ||
*/ | ||
|
||
void gemm_fp32_transB_opt(uint32_t M, uint32_t N, uint32_t K, float32_t *A, | ||
uint32_t ldA, float32_t *B, uint32_t ldB, | ||
float32_t *C, uint32_t ldC, float32_t *Y, | ||
uint32_t BETA, uint32_t setup_SSR); | ||
|
||
/* | ||
* | ||
* transposed A = no | ||
* transposed B = no | ||
* multi-core = yes | ||
* unrolling = yes | ||
* simd = yes | ||
* parallelization = row-wise | ||
*/ | ||
|
||
void gemm_fp32_opt(uint32_t M, uint32_t N, uint32_t K, float32_t *A, | ||
uint32_t ldA, float32_t *B, uint32_t ldB, float32_t *C, | ||
uint32_t ldC, float32_t *Y, uint32_t BETA, | ||
uint32_t setup_SSR); | ||
|
||
#endif //__DEEPLOY_MATH_GEMM_KERNEL_HEADER_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include "DeeploySnitchMath.h" | ||
|
||
void softmax_fp32(float *input, float *output, int32_t ldI, | ||
int32_t batch_offset, int32_t batch_size, int32_t seq_len, | ||
int32_t input_samples); |
Oops, something went wrong.