Skip to content

Commit

Permalink
Add CUDA impl
Browse files Browse the repository at this point in the history
  • Loading branch information
hdelan committed Jul 3, 2024
1 parent 22ccdc9 commit ae22a1d
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions source/adapters/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_ur_adapter(${TARGET_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/queue.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sampler.hpp
${CMAKE_CURRENT_SOURCE_DIR}/sampler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_map.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tracing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp
Expand Down
142 changes: 142 additions & 0 deletions source/adapters/cuda/tensor_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//===--------- tensor_map.cpp - CUDA Adapter ------------------------------===//
//
// Copyright (C) 2024 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <cuda.h>
#include <ur_api.h>

#include "context.hpp"

struct ur_exp_tensor_map_handle_t_ {
CUtensorMap Map;
};

#define CONVERT(URTYPE, CUTYPE) \
if (URTYPE & UrType) \
return CUTYPE;

inline CUtensorMapDataType
convertUrToCuDataType(ur_exp_tensor_map_data_type_flags_t UrType) {
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT8,
CU_TENSOR_MAP_DATA_TYPE_UINT8);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT16,
CU_TENSOR_MAP_DATA_TYPE_UINT16);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT32,
CU_TENSOR_MAP_DATA_TYPE_UINT32);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_INT32,
CU_TENSOR_MAP_DATA_TYPE_INT32);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT64,
CU_TENSOR_MAP_DATA_TYPE_UINT64);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_INT64,
CU_TENSOR_MAP_DATA_TYPE_INT64);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT16,
CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT32,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT64,
CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_BFLOAT16,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT32_FTZ,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_TFLOAT32,
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32);
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_TFLOAT32_FTZ,
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ);
throw "convertUrToCuDataType failed!";
}

CUtensorMapInterleave
convertUrToCuInterleave(ur_exp_tensor_map_interleave_flags_t UrType) {
CONVERT(UR_EXP_TENSOR_MAP_INTERLEAVE_FLAG_NONE,
CU_TENSOR_MAP_INTERLEAVE_NONE);
CONVERT(UR_EXP_TENSOR_MAP_INTERLEAVE_FLAG_16B, CU_TENSOR_MAP_INTERLEAVE_16B);
CONVERT(UR_EXP_TENSOR_MAP_INTERLEAVE_FLAG_32B, CU_TENSOR_MAP_INTERLEAVE_32B);
throw "convertUrToCuInterleave failed!";
}

CUtensorMapSwizzle
convertUrToCuSwizzle(ur_exp_tensor_map_swizzle_flags_t UrType) {
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_NONE, CU_TENSOR_MAP_SWIZZLE_NONE);
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_32B, CU_TENSOR_MAP_SWIZZLE_32B);
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_64B, CU_TENSOR_MAP_SWIZZLE_64B);
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_128B, CU_TENSOR_MAP_SWIZZLE_128B);
throw "convertUrToCuSwizzle failed!";
}

CUtensorMapL2promotion
convertUrToL2promotion(ur_exp_tensor_map_l2_promotion_flags_t UrType) {
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE);
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_64B,
CU_TENSOR_MAP_L2_PROMOTION_L2_64B);
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_256B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B);
throw "convertUrToCul2promotion failed!";
}

CUtensorMapFloatOOBfill
convertUrToCuOOBfill(ur_exp_tensor_map_oob_fill_flags_t UrType) {
CONVERT(UR_EXP_TENSOR_MAP_OOB_FILL_FLAG_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
CONVERT(UR_EXP_TENSOR_MAP_OOB_FILL_FLAG_REQUEST_ZERO_FMA,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA);
throw "convertUrToCuDataOOBfill failed!";
}

UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeIm2ColExp(
ur_device_handle_t hDevice,
ur_exp_tensor_map_data_type_flags_t TensorMapType, uint32_t TensorRank,
void *GlobalAddress, const uint64_t *GlobalDim,
const uint64_t *GlobalStrides, const int *PixelBoxLowerCorner,
const int *PixelBoxUpperCorner, uint32_t ChannelsPerPixel,
uint32_t PixelsPerColumn, const uint32_t *ElementStrides,
ur_exp_tensor_map_interleave_flags_t Interleave,
ur_exp_tensor_map_swizzle_flags_t Swizzle,
ur_exp_tensor_map_l2_promotion_flags_t L2Promotion,
ur_exp_tensor_map_oob_fill_flags_t OobFill,
ur_exp_tensor_map_handle_t *hTensorMap) {
ScopedContext Active(hDevice);
try {
UR_CHECK_ERROR(cuTensorMapEncodeIm2col(
&(*hTensorMap)->Map, convertUrToCuDataType(TensorMapType), TensorRank,
GlobalAddress, GlobalDim, GlobalStrides, PixelBoxLowerCorner,
PixelBoxUpperCorner, ChannelsPerPixel, PixelsPerColumn, ElementStrides,
convertUrToCuInterleave(Interleave), convertUrToCuSwizzle(Swizzle),
convertUrToL2promotion(L2Promotion), convertUrToCuOOBfill(OobFill)));
} catch (ur_result_t Err) {
return Err;
}
return UR_RESULT_SUCCESS;
}
UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeTiledExp(
ur_device_handle_t hDevice,
ur_exp_tensor_map_data_type_flags_t TensorMapType, uint32_t TensorRank,
void *GlobalAddress, const uint64_t *GlobalDim,
const uint64_t *GlobalStrides, const uint32_t *BoxDim,
const uint32_t *ElementStrides,
ur_exp_tensor_map_interleave_flags_t Interleave,
ur_exp_tensor_map_swizzle_flags_t Swizzle,
ur_exp_tensor_map_l2_promotion_flags_t L2Promotion,
ur_exp_tensor_map_oob_fill_flags_t OobFill,
ur_exp_tensor_map_handle_t *hTensorMap) {
ScopedContext Active(hDevice);
try {
UR_CHECK_ERROR(cuTensorMapEncodeTiled(
&(*hTensorMap)->Map, convertUrToCuDataType(TensorMapType), TensorRank,
GlobalAddress, GlobalDim, GlobalStrides, BoxDim, ElementStrides,
convertUrToCuInterleave(Interleave), convertUrToCuSwizzle(Swizzle),
convertUrToL2promotion(L2Promotion), convertUrToCuOOBfill(OobFill)));
} catch (ur_result_t Err) {
return Err;
}
return UR_RESULT_SUCCESS;
}

0 comments on commit ae22a1d

Please sign in to comment.