Skip to content

Commit

Permalink
allocate buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Feb 26, 2025
1 parent f8d1fa0 commit d7aff2a
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 9 deletions.
10 changes: 10 additions & 0 deletions src/nn/nn-vulkan-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ int main() {
NnVulkanDevice device(&netConfig, &nodeConfig, &execution);
NnFakeNodeSynchronizer synchronizer;

float *x = (float *)execution.pipes[0];
for (NnUint i = 0; i < DIM * N_BATCHES; i++)
x[i] = i;

float rmsNormWeight[DIM];
for (NnUint i = 0; i < DIM; i++)
rmsNormWeight[i] = 0.5 + i / (float)DIM;

NnExecutor executor(&netConfig, &nodeConfig, &device, &execution, &synchronizer);
executor.loadWeight("rms_norm", 0, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight);

execution.setBatchSize(N_BATCHES);
executor.forward();
return 0;
Expand Down
172 changes: 165 additions & 7 deletions src/nn/nn-vulkan.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#include "nn-vulkan.hpp"

#if DEBUG_VULKAN_TRACE
#define VULKAN_TRACE(...) printf("VULKAN_TRACE: "); printf(__VA_ARGS__); printf("\n");
#else
#define VULKAN_TRACE(...)
#endif

static bool hasPortabilityExtension() {
#ifdef __APPLE__
const std::vector<vk::ExtensionProperties> extensionProperties = vk::enumerateInstanceExtensionProperties();
Expand All @@ -22,7 +28,7 @@ static bool hasValidationLayer() {

#define MEMORY_TYPE_INDEX_NOT_FOUND ~0

static uint32_t findMemoryTypeIndex(vk::PhysicalDevice *physicalDevice, vk::MemoryPropertyFlags expectedFlags) {
static uint32_t findMemoryTypeIndex(const vk::PhysicalDevice *physicalDevice, vk::MemoryPropertyFlags expectedFlags) {
vk::PhysicalDeviceMemoryProperties memoryProperties = physicalDevice->getMemoryProperties();
for (uint32_t index = 0; index < memoryProperties.memoryTypeCount; index++) {
auto flags = memoryProperties.memoryTypes[index].propertyFlags;
Expand All @@ -33,7 +39,136 @@ static uint32_t findMemoryTypeIndex(vk::PhysicalDevice *physicalDevice, vk::Memo
return MEMORY_TYPE_INDEX_NOT_FOUND;
}

NnVulkanDevice::NnVulkanDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
static std::pair<vk::Buffer, vk::DeviceMemory> createBuffer(const NnVulkanContext *context, const uint32_t memoryTypeIndex, const vk::DeviceSize bufferSize, const vk::BufferUsageFlags usageFlags) {
vk::BufferCreateInfo bufferCreateInfo {
vk::BufferCreateFlags(),
bufferSize,
usageFlags,
vk::SharingMode::eExclusive,
1,
&context->queueFamilyIndex
};
vk::Buffer buffer = context->device.createBuffer(bufferCreateInfo);

vk::MemoryRequirements memoryRequirements = context->device.getBufferMemoryRequirements(buffer);
vk::MemoryAllocateInfo bufferMemoryAllocateInfo(memoryRequirements.size, memoryTypeIndex);
vk::DeviceMemory bufferMemory = context->device.allocateMemory(bufferMemoryAllocateInfo);

context->device.bindBufferMemory(buffer, bufferMemory, 0);
return std::make_pair(buffer, bufferMemory);
}

NnVulkanStagingCopy::NnVulkanStagingCopy(const NnVulkanContext *context, vk::Buffer& deviceBuffer, const vk::DeviceSize bufferSize, const NnStagingVulkanCopyDirection direction) {
this->deviceBuffer = deviceBuffer;
this->context = context;
this->bufferSize = bufferSize;

uint32_t memoryTypeIndex = findMemoryTypeIndex(&context->physicalDevice, vk::MemoryPropertyFlagBits::eHostVisible);
if (memoryTypeIndex == MEMORY_TYPE_INDEX_NOT_FOUND)
throw std::runtime_error("Cannot find host visible memory type");
auto b = createBuffer(context, memoryTypeIndex, bufferSize, vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst);
hostBuffer = b.first;
hostMemory = b.second;
hostPointer = context->device.mapMemory(hostMemory, 0, bufferSize);
}

NnVulkanStagingCopy::~NnVulkanStagingCopy() {
context->device.unmapMemory(hostMemory);
context->device.freeMemory(hostMemory);
context->device.destroyBuffer(hostBuffer);
}

void NnVulkanStagingCopy::copy(NnByte *data) {
switch (direction) {
case COPY_TO_DEVICE:
std::memcpy(hostPointer, data, bufferSize);
break;
case COPY_FROM_DEVICE:
std::memcpy(data, hostPointer, bufferSize);
break;
}
}

void NnVulkanStagingCopy::addCopyCommand(vk::CommandBuffer commandBuffer) {
VkBufferCopy copyRegion = { 0 };
copyRegion.size = bufferSize;
switch (direction) {
case COPY_TO_DEVICE:
vkCmdCopyBuffer(commandBuffer, hostBuffer, deviceBuffer, 1, &copyRegion);
break;
case COPY_FROM_DEVICE:
vkCmdCopyBuffer(commandBuffer, deviceBuffer, hostBuffer, 1, &copyRegion);
break;
}
}

NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
this->context = context;
this->bufferSize = bufferSize;
this->hostPointer = nullptr;

uint32_t memoryTypeIndex = MEMORY_TYPE_INDEX_NOT_FOUND;
if (fastAccess) {
memoryTypeIndex = findMemoryTypeIndex(&context->physicalDevice, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eDeviceLocal);
if (memoryTypeIndex != MEMORY_TYPE_INDEX_NOT_FOUND)
isHostVisible = true;
}
if (!isHostVisible) {
memoryTypeIndex = findMemoryTypeIndex(&context->physicalDevice, vk::MemoryPropertyFlagBits::eDeviceLocal);
if (memoryTypeIndex == MEMORY_TYPE_INDEX_NOT_FOUND)
throw std::runtime_error("Cannot find host visible memory type");
}

auto b = createBuffer(context, memoryTypeIndex, bufferSize, usageFlags | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst);
deviceBuffer = b.first;
deviceMemory = b.second;
if (isHostVisible)
hostPointer = context->device.mapMemory(deviceMemory, 0, bufferSize);
VULKAN_TRACE("Created buffer of size %lld (fastAccess=%d)", bufferSize, fastAccess);
}

NnVulkanBuffer::~NnVulkanBuffer() {
if (hostPointer != nullptr)
context->device.unmapMemory(deviceMemory);
context->device.freeMemory(deviceMemory);
context->device.destroyBuffer(deviceBuffer);
VULKAN_TRACE("Destroyed buffer of size %lld", bufferSize);
}

void NnVulkanBuffer::write(const NnByte *data) {
if (isHostVisible && hostPointer != nullptr) {
std::memcpy(hostPointer, data, bufferSize);
} else {
NnVulkanStagingCopy copy(context, deviceBuffer, bufferSize, COPY_TO_DEVICE);
copy.copy((NnByte *)data);

vk::CommandBufferAllocateInfo allocInfo(context->commandPool, vk::CommandBufferLevel::ePrimary, 1);
const std::vector<vk::CommandBuffer> cmdBuffers = context->device.allocateCommandBuffers(allocInfo);
vk::CommandBuffer commandBuffer = cmdBuffers.front();
commandBuffer.begin({ vk::CommandBufferUsageFlags{} });
copy.addCopyCommand(commandBuffer);
commandBuffer.end();

vk::Fence fence = context->device.createFence(vk::FenceCreateInfo());
vk::SubmitInfo submitInfo(0, nullptr, nullptr, 1, &commandBuffer);
context->queue.submit({ submitInfo }, fence);
assert(context->device.waitForFences({ fence }, true, uint64_t(-1)) == vk::Result::eSuccess);

context->device.destroyFence(fence);
context->device.freeCommandBuffers(context->commandPool, 1, &commandBuffer);
}
}

NnVulkanData::NnVulkanData(const NnUint nPipes, const NnUint nBuffers)
: pipes(nPipes), buffers(nBuffers) {}

NnVulkanDevice::NnVulkanDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution)
: data(netConfig->nPipes, nodeConfig->nBuffers)
{
this->netConfig = netConfig;
this->nodeConfig = nodeConfig;
this->netExecution = netExecution;

vk::InstanceCreateFlags createInstanceFlags(0);
std::vector<const char*> instanceLayers = {};
std::vector<const char*> instanceExtensions = {};
Expand Down Expand Up @@ -103,11 +238,17 @@ NnVulkanDevice::NnVulkanDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig,

vk::CommandPoolCreateInfo commandPoolCreateInfo(vk::CommandPoolCreateFlags(vk::CommandPoolCreateFlagBits::eTransient), context.queueFamilyIndex);
context.commandPool = context.device.createCommandPool(commandPoolCreateInfo);

context.queue = context.device.getQueue(context.queueFamilyIndex, 0);

for (NnUint i = 0; i < netConfig->nPipes; i++)
data.pipes[i].reset(new NnVulkanBuffer(&context, netConfig->pipes[i].size.nBytes, vk::BufferUsageFlagBits::eUniformBuffer, true));
for (NnUint i = 0; i < nodeConfig->nBuffers; i++)
data.buffers[i].reset(new NnVulkanBuffer(&context, nodeConfig->buffers[i].size.nBytes, vk::BufferUsageFlagBits::eUniformBuffer, false));
}

NnVulkanDevice::~NnVulkanDevice() {
data.pipes.clear();
data.buffers.clear();
context.device.destroyCommandPool(context.commandPool);
context.device.destroy();
context.instance.destroy();
Expand All @@ -118,21 +259,38 @@ NnUint NnVulkanDevice::maxNThreads() {
}

NnDeviceSegment *NnVulkanDevice::createSegment(NnUint segmentIndex) {
return new NnVulkanDeviceSegment(&context);
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
return new NnVulkanDeviceSegment(&context, &data, segmentConfig);
};

void NnVulkanDevice::syncPointers() {
}

NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context) {
this->context = *context;
NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig)
: weightBufferIndex(segmentConfig->nOps) {
this->context = context;
this->data = data;
this->segmentConfig = segmentConfig;

for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
if (opConfig->weightSize.nBytes > 0) {
data->buffers.push_back(std::unique_ptr<NnVulkanBuffer>(
new NnVulkanBuffer(context, opConfig->weightSize.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, false)));
weightBufferIndex[opIndex] = data->buffers.size() - 1;
}
}
}

NnVulkanDeviceSegment::~NnVulkanDeviceSegment() {
}

void NnVulkanDeviceSegment::loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) {

assert(segmentConfig->nOps > opIndex);
assert(segmentConfig->ops[opIndex].weightSize.nBytes == nBytes);
NnUint dataBufferIndex = weightBufferIndex[opIndex];
data->buffers[dataBufferIndex]->write(weight);
VULKAN_TRACE("Loaded %ld bytes to weight buffer %d", nBytes, dataBufferIndex);
}

void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) {
Expand Down
56 changes: 54 additions & 2 deletions src/nn/nn-vulkan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#define NN_VULKAN_HPP

#include <vulkan/vulkan.hpp>
#include <vector>
#include "nn-executor.hpp"
#include "nn-cpu-ops.hpp"

#define DEBUG_VULKAN_TRACE true

typedef struct {
vk::Instance instance;
vk::PhysicalDevice physicalDevice;
Expand All @@ -14,9 +17,55 @@ typedef struct {
vk::Queue queue;
} NnVulkanContext;

enum NnStagingVulkanCopyDirection {
COPY_TO_DEVICE,
COPY_FROM_DEVICE
};

class NnVulkanStagingCopy {
private:
NnStagingVulkanCopyDirection direction;
const NnVulkanContext *context;
vk::DeviceSize bufferSize;
vk::Buffer deviceBuffer;
vk::Buffer hostBuffer;
vk::DeviceMemory hostMemory;
void *hostPointer;
public:
NnVulkanStagingCopy(const NnVulkanContext *context, vk::Buffer& deviceBuffer, const vk::DeviceSize bufferSize, const NnStagingVulkanCopyDirection direction);
~NnVulkanStagingCopy();
void copy(NnByte *data);
void addCopyCommand(vk::CommandBuffer commandBuffer);
};

class NnVulkanBuffer {
private:
bool isHostVisible;
NnVulkanContext *context;
vk::DeviceSize bufferSize;
vk::Buffer deviceBuffer;
vk::DeviceMemory deviceMemory;
void *hostPointer;
public:
NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess);
~NnVulkanBuffer();
void write(const NnByte *data);
};

class NnVulkanData {
public:
std::vector<std::unique_ptr<NnVulkanBuffer>> pipes;
std::vector<std::unique_ptr<NnVulkanBuffer>> buffers;
NnVulkanData(const NnUint nPipes, const NnUint nBuffers);
};

class NnVulkanDevice : public NnDevice {
private:
NnVulkanContext context;
NnVulkanData data;
NnNetConfig *netConfig;
NnNodeConfig *nodeConfig;
NnNetExecution *netExecution;
public:
NnVulkanDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution);
~NnVulkanDevice() override;
Expand All @@ -27,9 +76,12 @@ class NnVulkanDevice : public NnDevice {

class NnVulkanDeviceSegment : public NnDeviceSegment {
private:
NnVulkanContext context;
NnVulkanContext *context;
NnVulkanData *data;
NnSegmentConfig *segmentConfig;
std::vector<NnUint> weightBufferIndex;
public:
NnVulkanDeviceSegment(NnVulkanContext *context);
NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig);
~NnVulkanDeviceSegment() override;
void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) override;
void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override;
Expand Down

0 comments on commit d7aff2a

Please sign in to comment.