Skip to content

Commit

Permalink
HIP: add option to embed static blockDim
Browse files Browse the repository at this point in the history
Using blockDim in hip kernels unfortunately incurs a large overhead,
because this (dynamic) information is stored in the dispatch packet
located in a host-coherent memory region. Since vkFFT always knows the
work group size its going to use, just replace uses of blockDim with
these values.
This means the load from non-cached memory is avoided, the dispatch
pointer doesn't have to be loaded which frees up 2 SGPRs, and some
indexing calculations might constant fold better.

The added option `useStaticWorkGroupSize` has three possible values:
- -1: Disable embedding blockDim sizes, effectively the old behavior
-  0: Automatically enable embedding when profitable (always except for RDNA2)
-  1: Always enable

RDNA is disabled by default because this can actually decrease performance
sometimes with the reason not fully known, details at [1]

[1]: ROCm/hipamd#53

Co-authored-by: [email protected]
  • Loading branch information
Gergely Meszaros committed Feb 14, 2023
1 parent 41a4808 commit d9fb2cf
Showing 1 changed file with 40 additions and 31 deletions.
71 changes: 40 additions & 31 deletions vkFFT/vkFFT.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ typedef struct {
uint64_t streamCounter;//Filled at app creation
uint64_t streamID;//Filled at app creation
int64_t useStrict32BitAddress; // guarantee 32 bit addresses in bytes instead of number of elements. This results in fewer instructions generated. -1: Disable, 0: Infer based on size, 1: enable. Has no effect with useUint64.
int64_t useStaticWorkGroupSize; // Embed the compile time known block dimensions into kernels instead of using blockDim, for potentially better performance. -1: Disable, 0: Automatically enable where beneficial, 1: Always enable.
#elif(VKFFT_BACKEND==3)
cl_command_queue* commandQueue;
#elif(VKFFT_BACKEND==4)
Expand Down Expand Up @@ -813,7 +814,8 @@ typedef struct {
uint64_t performBufferSetUpdate;
uint64_t useUint64;
#if(VKFFT_BACKEND==2)
int64_t useStrict32BitAddress;
int64_t useStrict32BitAddress;
int64_t useStaticWorkGroupSize;
#endif
uint64_t disableSetLocale;

Expand Down Expand Up @@ -25942,38 +25944,39 @@ static inline VkFFTResult shaderGenVkFFT_R2C_decomposition(char* output, VkFFTSp
if (!strcmp(floatTypeOutputMemory, "half")) sprintf(vecTypeOutput, "f16vec2");
if (!strcmp(floatTypeOutputMemory, "float")) sprintf(vecTypeOutput, "float2");
if (!strcmp(floatTypeOutputMemory, "double")) sprintf(vecTypeOutput, "double2");
sprintf(sc->gl_LocalInvocationID_x, "threadIdx.x");
sprintf(sc->gl_LocalInvocationID_y, "threadIdx.y");
sprintf(sc->gl_LocalInvocationID_z, "threadIdx.z");
sprintf(sc->gl_LocalInvocationID_x, sc->localSize[0] > 1 ? "threadIdx.x" : "0u");
sprintf(sc->gl_LocalInvocationID_y, sc->localSize[1] > 1 ? "threadIdx.y" : "0u");
sprintf(sc->gl_LocalInvocationID_z, sc->localSize[2] > 1 ? "threadIdx.z" : "0u");
switch (sc->swapComputeWorkGroupID) {
case 0:
sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.x * blockDim.x)");
sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)");
sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)");
sprintf(sc->gl_WorkGroupID_x, "blockIdx.x");
sprintf(sc->gl_WorkGroupID_y, "blockIdx.y");
sprintf(sc->gl_WorkGroupID_z, "blockIdx.z");
break;
case 1:
sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.y * blockDim.x)");
sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.x * blockDim.y)");
sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)");
sprintf(sc->gl_WorkGroupID_x, "blockIdx.y");
sprintf(sc->gl_WorkGroupID_y, "blockIdx.x");
sprintf(sc->gl_WorkGroupID_z, "blockIdx.z");
break;
case 2:
sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.z * blockDim.x)");
sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)");
sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.x * blockDim.z)");
sprintf(sc->gl_WorkGroupID_x, "blockIdx.z");
sprintf(sc->gl_WorkGroupID_y, "blockIdx.y");
sprintf(sc->gl_WorkGroupID_z, "blockIdx.x");
break;
}
sprintf(sc->gl_WorkGroupSize_x, "blockDim.x");
sprintf(sc->gl_WorkGroupSize_y, "blockDim.y");
sprintf(sc->gl_WorkGroupSize_z, "blockDim.z");
if(sc->useStaticWorkGroupSize > 0) {
sprintf(sc->gl_WorkGroupSize_x, "%" PRIu64 "u", sc->localSize[0]);
sprintf(sc->gl_WorkGroupSize_y, "%" PRIu64 "u", sc->localSize[1]);
sprintf(sc->gl_WorkGroupSize_z, "%" PRIu64 "u", sc->localSize[2]);
}
else {
sprintf(sc->gl_WorkGroupSize_x, "blockDim.x");
sprintf(sc->gl_WorkGroupSize_y, "blockDim.y");
sprintf(sc->gl_WorkGroupSize_z, "blockDim.z");
}
sprintf(sc->gl_GlobalInvocationID_x, "(%s + %s * %s)", sc->gl_LocalInvocationID_x, sc->gl_WorkGroupID_x, sc->gl_WorkGroupSize_x);
sprintf(sc->gl_GlobalInvocationID_y, "(%s + %s * %s)", sc->gl_LocalInvocationID_y, sc->gl_WorkGroupID_y, sc->gl_WorkGroupSize_y);
sprintf(sc->gl_GlobalInvocationID_z, "(%s + %s * %s)", sc->gl_LocalInvocationID_z, sc->gl_WorkGroupID_z, sc->gl_WorkGroupSize_z);
sprintf(sc->gl_SubgroupInvocationID, "(threadIdx.x %% %" PRIu64 ")", sc->warpSize);
sprintf(sc->gl_SubgroupID, "(threadIdx.x / %" PRIu64 ")", sc->warpSize);
if (!strcmp(floatType, "double")) sprintf(LFending, "l");
Expand Down Expand Up @@ -26833,38 +26836,39 @@ static inline VkFFTResult shaderGenVkFFT(char* output, VkFFTSpecializationConsta
if (!strcmp(floatTypeOutputMemory, "half")) sprintf(vecTypeOutput, "f16vec2");
if (!strcmp(floatTypeOutputMemory, "float")) sprintf(vecTypeOutput, "float2");
if (!strcmp(floatTypeOutputMemory, "double")) sprintf(vecTypeOutput, "double2");
sprintf(sc->gl_LocalInvocationID_x, "threadIdx.x");
sprintf(sc->gl_LocalInvocationID_y, "threadIdx.y");
sprintf(sc->gl_LocalInvocationID_z, "threadIdx.z");
sprintf(sc->gl_LocalInvocationID_x, sc->localSize[0] > 1 ? "threadIdx.x" : "0u");
sprintf(sc->gl_LocalInvocationID_y, sc->localSize[1] > 1 ? "threadIdx.y" : "0u");
sprintf(sc->gl_LocalInvocationID_z, sc->localSize[2] > 1 ? "threadIdx.z" : "0u");
switch (sc->swapComputeWorkGroupID) {
case 0:
sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.x * blockDim.x)");
sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)");
sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)");
sprintf(sc->gl_WorkGroupID_x, "blockIdx.x");
sprintf(sc->gl_WorkGroupID_y, "blockIdx.y");
sprintf(sc->gl_WorkGroupID_z, "blockIdx.z");
break;
case 1:
sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.y * blockDim.x)");
sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.x * blockDim.y)");
sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)");
sprintf(sc->gl_WorkGroupID_x, "blockIdx.y");
sprintf(sc->gl_WorkGroupID_y, "blockIdx.x");
sprintf(sc->gl_WorkGroupID_z, "blockIdx.z");
break;
case 2:
sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.z * blockDim.x)");
sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)");
sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.x * blockDim.z)");
sprintf(sc->gl_WorkGroupID_x, "blockIdx.z");
sprintf(sc->gl_WorkGroupID_y, "blockIdx.y");
sprintf(sc->gl_WorkGroupID_z, "blockIdx.x");
break;
}
sprintf(sc->gl_WorkGroupSize_x, "blockDim.x");
sprintf(sc->gl_WorkGroupSize_y, "blockDim.y");
sprintf(sc->gl_WorkGroupSize_z, "blockDim.z");
if(sc->useStaticWorkGroupSize > 0) {
sprintf(sc->gl_WorkGroupSize_x, "%" PRIu64 "u", sc->localSize[0]);
sprintf(sc->gl_WorkGroupSize_y, "%" PRIu64 "u", sc->localSize[1]);
sprintf(sc->gl_WorkGroupSize_z, "%" PRIu64 "u", sc->localSize[2]);
}
else {
sprintf(sc->gl_WorkGroupSize_x, "blockDim.x");
sprintf(sc->gl_WorkGroupSize_y, "blockDim.y");
sprintf(sc->gl_WorkGroupSize_z, "blockDim.z");
}
sprintf(sc->gl_GlobalInvocationID_x, "(%s + %s * %s)", sc->gl_LocalInvocationID_x, sc->gl_WorkGroupID_x, sc->gl_WorkGroupSize_x);
sprintf(sc->gl_GlobalInvocationID_y, "(%s + %s * %s)", sc->gl_LocalInvocationID_y, sc->gl_WorkGroupID_y, sc->gl_WorkGroupSize_y);
sprintf(sc->gl_GlobalInvocationID_z, "(%s + %s * %s)", sc->gl_LocalInvocationID_z, sc->gl_WorkGroupID_z, sc->gl_WorkGroupSize_z);
sprintf(sc->gl_SubgroupInvocationID, "(threadIdx.x %% %" PRIu64 ")", sc->warpSize);
sprintf(sc->gl_SubgroupID, "(threadIdx.x / %" PRIu64 ")", sc->warpSize);
#elif((VKFFT_BACKEND==3)||(VKFFT_BACKEND==4))
Expand Down Expand Up @@ -33973,6 +33977,7 @@ static inline VkFFTResult VkFFTPlanR2CMultiUploadDecomposition(VkFFTApplication*
axis->specializationConstants.useUint64 = app->configuration.useUint64;
#if(VKFFT_BACKEND==2)
axis->specializationConstants.useStrict32BitAddress = app->configuration.useStrict32BitAddress;
axis->specializationConstants.useStaticWorkGroupSize = app->configuration.useStaticWorkGroupSize;
#endif
axis->specializationConstants.disableSetLocale = app->configuration.disableSetLocale;

Expand Down Expand Up @@ -35735,6 +35740,7 @@ static inline VkFFTResult VkFFTPlanAxis(VkFFTApplication* app, VkFFTPlan* FFTPla
axis->specializationConstants.useUint64 = app->configuration.useUint64;
#if(VKFFT_BACKEND==2)
axis->specializationConstants.useStrict32BitAddress = app->configuration.useStrict32BitAddress;
axis->specializationConstants.useStaticWorkGroupSize = app->configuration.useStaticWorkGroupSize;
#endif
axis->specializationConstants.disableSetLocale = app->configuration.disableSetLocale;

Expand Down Expand Up @@ -39681,6 +39687,9 @@ static inline VkFFTResult initializeVkFFT(VkFFTApplication* app, VkFFTConfigurat
return VKFFT_ERROR_FAILED_TO_GET_ATTRIBUTE;
}
app->configuration.warpSize = value;
if(inputLaunchConfiguration.useStaticWorkGroupSize != 0) app->configuration.useStaticWorkGroupSize = inputLaunchConfiguration.useStaticWorkGroupSize;
else if (app->configuration.warpSize == 32) app->configuration.useStaticWorkGroupSize = -1; // Embedding the work group size slows down kernels on RDNA
else app->configuration.useStaticWorkGroupSize = 1;
app->configuration.sharedMemorySizePow2 = (uint64_t)pow(2, (uint64_t)log2(app->configuration.sharedMemorySize));
app->configuration.useRaderUintLUT = 0;
if (app->configuration.num_streams > 1) {
Expand Down

0 comments on commit d9fb2cf

Please sign in to comment.