Skip to content

Commit

Permalink
Optimize flash attention kernel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Aug 25, 2024
1 parent df32a07 commit 8e4eb3f
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 600 deletions.
3 changes: 2 additions & 1 deletion Seq2SeqSharp/Applications/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public GPT(Seq2SeqOptions options, Vocab tgtVocab = null)
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions, weightsUpdateCount: options.WeightsUpdateCount,
enableTensorCore: options.EnableTensorCore, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, randomSeed: options.RandomSeed,
saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryMode: options.SaveGPUMemoryMode, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted)
saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryMode: options.SaveGPUMemoryMode, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted,
attentionType: options.AttentionType)
{
m_paddingType = options.PaddingType;
m_options = options;
Expand Down
3 changes: 2 additions & 1 deletion Seq2SeqSharp/Applications/Seq2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public Seq2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = n
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions,
weightsUpdateCount: options.WeightsUpdateCount, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32,
saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryMode: options.SaveGPUMemoryMode, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted)
saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryMode: options.SaveGPUMemoryMode, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted,
attentionType: options.AttentionType)
{
m_paddingType = options.PaddingType;
m_options = options;
Expand Down
6 changes: 4 additions & 2 deletions Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ public abstract class BaseSeq2SeqFramework<T> where T : Model
DType m_elementType = DType.Float32;
float m_initLossScaling = 1.0f;
bool m_autoCheckTensorCorruption = false;
AttentionTypeEnums m_attentionType = AttentionTypeEnums.Classic;

public float LossScaling = 1.0f;

public BaseSeq2SeqFramework(string deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f,
string compilerOptions = null, int runValidEveryUpdates = 10000, int primaryTaskId = 0, int updateFreq = 1, int startToRunValidAfterUpdates = 0,
int maxDegressOfParallelism = 1, string mklInstructions = "AVX2", int weightsUpdateCount = 0, bool enableTensorCore = true, CudaMemoryDeviceAllocatorType cudaMemoryAllocatorType = CudaMemoryDeviceAllocatorType.CudaMemoryPool,
DType elementType = DType.Float32, int randomSeed = -1, int saveModelEveryUpdats = 10000, bool saveGPUMemoryMode = false, float initLossScaling = 1.0f, bool autoCheckTensorCorruption = false)
DType elementType = DType.Float32, int randomSeed = -1, int saveModelEveryUpdats = 10000, bool saveGPUMemoryMode = false, float initLossScaling = 1.0f, bool autoCheckTensorCorruption = false, AttentionTypeEnums attentionType = AttentionTypeEnums.Classic)
{
m_deviceIds = deviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
m_compilerOptions = compilerOptions;
Expand All @@ -168,6 +169,7 @@ public BaseSeq2SeqFramework(string deviceIds, ProcessorTypeEnums processorType,
m_saveGPUMemoryMode = saveGPUMemoryMode;
m_initLossScaling = initLossScaling;
m_autoCheckTensorCorruption = autoCheckTensorCorruption;
m_attentionType = attentionType;

InitDevices();

Expand All @@ -181,7 +183,7 @@ public BaseSeq2SeqFramework(string deviceIds, ProcessorTypeEnums processorType,
public void InitDevices()
{
string[] cudaCompilerOptions = m_compilerOptions.IsNullOrEmpty() ? null : Regex.Split(m_compilerOptions, "--").ToList().Where(item => item != "").Select(item => "--" + item).ToArray();
TensorAllocator.InitDevices(m_processorType, m_deviceIds, m_memoryUsageRatio, cudaCompilerOptions, mklInstructions: m_mklInstructions, enableTensorCore: m_enableTensorCore, m_cudaMemoryAllocatorType, m_elementType);
TensorAllocator.InitDevices(m_processorType, m_deviceIds, m_memoryUsageRatio, cudaCompilerOptions, mklInstructions: m_mklInstructions, enableTensorCore: m_enableTensorCore, m_cudaMemoryAllocatorType, m_elementType, attentionTypeEnums: m_attentionType);
}

public virtual List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
Expand Down
5 changes: 3 additions & 2 deletions Seq2SeqSharp/Utils/TensorAllocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public static class TensorAllocator
private static int[] m_deviceIds;
private static ProcessorTypeEnums m_archType;

public static void InitDevices(ProcessorTypeEnums archType, int[] ids, float memoryUsageRatio = 0.9f, string[] compilerOptions = null, string mklInstructions = "AVX2", bool enableTensorCore = true, CudaMemoryDeviceAllocatorType allocatorType = CudaMemoryDeviceAllocatorType.CudaMemoryPool, DType elementType = DType.Float32)
public static void InitDevices(ProcessorTypeEnums archType, int[] ids, float memoryUsageRatio = 0.9f, string[] compilerOptions = null, string mklInstructions = "AVX2", bool enableTensorCore = true, CudaMemoryDeviceAllocatorType allocatorType = CudaMemoryDeviceAllocatorType.CudaMemoryPool,
DType elementType = DType.Float32, AttentionTypeEnums attentionTypeEnums = AttentionTypeEnums.Classic)
{
if (m_allocator != null)
{
Expand All @@ -41,7 +42,7 @@ public static void InitDevices(ProcessorTypeEnums archType, int[] ids, float mem

if (m_archType == ProcessorTypeEnums.GPU)
{
m_cudaContext = new TSCudaContext(m_deviceIds, memoryUsageRatio, compilerOptions, allocatorType, elementType);
m_cudaContext = new TSCudaContext(m_deviceIds, memoryUsageRatio, compilerOptions, allocatorType, elementType, useFlashAttention: attentionTypeEnums == AttentionTypeEnums.FlashAttentionV2);
m_cudaContext.Precompile();
m_cudaContext.CleanUnusedPTX();

Expand Down
Loading

0 comments on commit 8e4eb3f

Please sign in to comment.