diff --git a/src/app.cpp b/src/app.cpp index 4847f14..7f9eead 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -70,7 +70,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) { args.nWorkers = count; args.workerHosts = new char*[count]; - args.workerPorts = new NnSize[count]; + args.workerPorts = new NnUint[count]; for (int s = 0; s < count; s++) { char *v = argv[i + 1 + s]; @@ -111,7 +111,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) { AppCliArgs::~AppCliArgs() { if (workerHosts != nullptr) { - for (NnSize i = 0; i < nWorkers; i++) + for (NnUint i = 0; i < nWorkers; i++) delete[] workerHosts[i]; delete[] workerHosts; } @@ -130,21 +130,21 @@ RootLlmInference::RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution this->network = network; // May be nullptr! } -void RootLlmInference::setBatchSize(NnSize batchSize) { +void RootLlmInference::setBatchSize(NnUint batchSize) { execution->setBatchSize(batchSize); controlPacket.batchSize = batchSize; } -void RootLlmInference::setPosition(NnSize position) { +void RootLlmInference::setPosition(NnUint position) { assert(position >= 0); assert(position + execution->batchSize - 1 < header->seqLen); controlPacket.position = position; - for (NnSize i = 0; i < execution->batchSize; i++) + for (NnUint i = 0; i < execution->batchSize; i++) positionPipe[i] = (float)(position + i); } -void RootLlmInference::setToken(NnSize batchIndex, NnSize token) { +void RootLlmInference::setToken(NnUint batchIndex, NnUint token) { assert(batchIndex >= 0 && batchIndex < execution->batchSize); tokenPipe[batchIndex] = (float)token; } @@ -179,14 +179,14 @@ bool WorkerLlmInference::tryReadControlPacket() { isFinished = true; return true; } - for (NnSize i = 0; i < controlPacket.batchSize; i++) + for (NnUint i = 0; i < controlPacket.batchSize; i++) positionPipe[i] = (float)(controlPacket.position + i); execution->setBatchSize(controlPacket.batchSize); return true; } void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context)) { - NnSize nNodes = args->nWorkers + 1; + NnUint nNodes = args->nWorkers + 1; LlmHeader header = loadLlmHeader(args->modelPath, args->maxSeqLen, args->syncType); if (nNodes > header.nKvHeads) diff --git a/src/app.hpp b/src/app.hpp index d809bb1..07d2470 100644 --- a/src/app.hpp +++ b/src/app.hpp @@ -10,8 +10,8 @@ class AppCliArgs { public: char *mode; - NnSize nThreads; - NnSize nBatches; + NnUint nThreads; + NnUint nBatches; bool help; // inference @@ -19,27 +19,27 @@ class AppCliArgs { char *tokenizerPath; char *prompt; NnFloatType syncType; - NnSize nWorkers; + NnUint nWorkers; char **workerHosts; - NnSize *workerPorts; + NnUint *workerPorts; float temperature; float topp; - NnSize steps; + NnUint steps; bool benchmark; unsigned long long seed; ChatTemplateType chatTemplateType; - NnSize maxSeqLen; + NnUint maxSeqLen; // worker - NnSize port; + NnUint port; static AppCliArgs parse(int argc, char **argv, bool hasMode); ~AppCliArgs(); }; typedef struct { - NnSize position; - NnSize batchSize; // 0 = stop signal + NnUint position; + NnUint batchSize; // 0 = stop signal } LlmControlPacket; class RootLlmInference { @@ -56,9 +56,9 @@ class RootLlmInference { LlmControlPacket controlPacket; public: RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network); - void setBatchSize(NnSize batchSize); - void setPosition(NnSize position); - void setToken(NnSize batchIndex, NnSize token); + void setBatchSize(NnUint batchSize); + void setPosition(NnUint position); + void setToken(NnUint batchIndex, NnUint token); void forward(); void finish(); }; diff --git a/src/dllama-api.cpp b/src/dllama-api.cpp index 8130826..39e7eb9 100644 --- a/src/dllama-api.cpp +++ b/src/dllama-api.cpp @@ -380,20 +380,20 @@ class ApiServer { buffer += inputPrompt.publicPrompt; } - NnSize pos = startPos; + NnUint pos = startPos; int token; - for (NnSize i = 0; ;) { + for (NnUint i = 0; ;) { long remainingTokens = promptEndPos - pos; if (remainingTokens <= 0) break; - NnSize batchSize = remainingTokens < args->nBatches + NnUint batchSize = remainingTokens < args->nBatches ? remainingTokens : args->nBatches; inference->setBatchSize(batchSize); inference->setPosition(pos); - for (NnSize j = 0; j < batchSize; j++) + for (NnUint j = 0; j < batchSize; j++) inference->setToken(j, promptTokens[i + j]); inference->forward(); diff --git a/src/dllama.cpp b/src/dllama.cpp index 2237482..0dda75a 100644 --- a/src/dllama.cpp +++ b/src/dllama.cpp @@ -16,7 +16,7 @@ static void inference(AppInferenceContext *context) { std::vector inputTokensVec(std::strlen(context->args->prompt) + 3); int *inputTokens = inputTokensVec.data(); - NnSize pos = 0; + NnUint pos = 0; int token; int nInputTokens; context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, false); @@ -27,21 +27,21 @@ static void inference(AppInferenceContext *context) { throw std::runtime_error("The number of prompt tokens is greater than the number of steps"); Timer evalTimer; - size_t sentBytes = 0; - size_t recvBytes = 0; + NnSize sentBytes = 0; + NnSize recvBytes = 0; printf("%s\n", context->args->prompt); for (;;) { Timer batchTimer; long remainingTokens = nInputTokens - 1 - (long)pos; if (remainingTokens <= 0) break; - NnSize batchSize = remainingTokens < context->args->nBatches + NnUint batchSize = remainingTokens < context->args->nBatches ? remainingTokens : context->args->nBatches; context->inference->setBatchSize(batchSize); context->inference->setPosition(pos); - for (NnSize i = 0; i < batchSize; i++) + for (NnUint i = 0; i < batchSize; i++) context->inference->setToken(i, inputTokens[pos + i]); context->inference->forward(); @@ -57,7 +57,7 @@ static void inference(AppInferenceContext *context) { recvBytes / 1024, batchSize); } - NnSize evalTime = evalTimer.elapsedMiliseconds(); + NnUint evalTime = evalTimer.elapsedMiliseconds(); fflush(stdout); @@ -65,7 +65,7 @@ static void inference(AppInferenceContext *context) { context->tokenizer->resetDecoder(); Timer predTimer; - const NnSize maxPos = std::min(context->header->seqLen, context->args->steps); + const NnUint maxPos = std::min(context->header->seqLen, context->args->steps); for (; pos < maxPos; pos++) { Timer tokenTimer; context->inference->setPosition(pos); @@ -86,10 +86,10 @@ static void inference(AppInferenceContext *context) { piece == nullptr ? "~" : piece); fflush(stdout); } - NnSize predTime = predTimer.elapsedMiliseconds(); + NnUint predTime = predTimer.elapsedMiliseconds(); - NnSize nEvalTokens = nInputTokens - 1; - NnSize nPredTokens = pos - nEvalTokens; + NnUint nEvalTokens = nInputTokens - 1; + NnUint nPredTokens = pos - nEvalTokens; printf("\n"); printf("Evaluation\n"); printf(" nBatches: %d\n", context->args->nBatches); @@ -104,11 +104,11 @@ static void inference(AppInferenceContext *context) { predTime / ((float) nPredTokens)); } -static size_t readStdin(const char *guide, char *buffer, size_t size) { +static NnUint readStdin(const char *guide, char *buffer, NnUint size) { std::fflush(stdin); std::printf("%s", guide); if (std::fgets(buffer, size, stdin) != NULL) { - size_t length = std::strlen(buffer); + NnUint length = std::strlen(buffer); if (length > 0 && buffer[length - 1] == '\n') { buffer[length - 1] = '\0'; length--; @@ -119,20 +119,20 @@ static size_t readStdin(const char *guide, char *buffer, size_t size) { } static void chat(AppInferenceContext *context) { - const NnSize seqLen = context->header->seqLen; + const NnUint seqLen = context->header->seqLen; char prompt[2048]; TokenizerChatStops stops(context->tokenizer); ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]); EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength); - const size_t sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt)); + const NnUint sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt)); std::vector deltaItems; if (sysPromptLength > 0) deltaItems.push_back(ChatItem{"system", prompt}); - NnSize pos = 0; - size_t userPromptLength; + NnUint pos = 0; + NnUint userPromptLength; int token; int nInputTokens; do { @@ -149,18 +149,18 @@ static void chat(AppInferenceContext *context) { bool addBos = pos == 0; context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, addBos, true); - NnSize userPromptEndPos = (NnSize)std::min(seqLen, pos + nInputTokens - 1); - for (NnSize i = 0; ;) { + NnUint userPromptEndPos = (NnUint)std::min(seqLen, pos + nInputTokens - 1); + for (NnUint i = 0; ;) { int remainingTokens = userPromptEndPos - pos; if (remainingTokens <= 0) break; - NnSize batchSize = remainingTokens < context->args->nBatches + NnUint batchSize = remainingTokens < context->args->nBatches ? remainingTokens : context->args->nBatches; context->inference->setBatchSize(batchSize); context->inference->setPosition(pos); - for (NnSize j = 0; j < batchSize; j++) + for (NnUint j = 0; j < batchSize; j++) context->inference->setToken(j, inputTokens[i + j]); context->inference->forward(); diff --git a/src/llm.cpp b/src/llm.cpp index 39aa61e..6a1ceeb 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -23,7 +23,7 @@ static const char *archTypeToString(LlmArchType type) { throw std::runtime_error("Unsupported architecture"); } -LlmHeader loadLlmHeader(const char *path, const NnSize maxSeqLen, NnFloatType syncType) { +LlmHeader loadLlmHeader(const char *path, const NnUint maxSeqLen, NnFloatType syncType) { LlmHeader header; std::memset(&header, 0, sizeof(LlmHeader)); header.weightType = F_UNK; @@ -93,7 +93,7 @@ LlmHeader loadLlmHeader(const char *path, const NnSize maxSeqLen, NnFloatType sy header.headSize = header.dim / header.nHeads; header.kvDim = (header.dim *header.nKvHeads) / header.nHeads; header.syncType = syncType; - header.fileSize = (size_t)seekToEnd(fd); + header.fileSize = (NnSize)seekToEnd(fd); return header; } @@ -122,7 +122,7 @@ void printLlmHeader(LlmHeader *header) { } } -LlmNet buildLlmNet(LlmHeader *h, NnSize nNodes, NnSize nBatches) { +LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches) { LlmNet n; n.tokenEmbeddingSize = size2D(F_32, h->vocabSize, h->dim); n.rmsNormSize = size1D(F_32, h->dim); @@ -146,37 +146,37 @@ LlmNet buildLlmNet(LlmHeader *h, NnSize nNodes, NnSize nBatches) { n.tokenPipeIndex = netBuilder.addPipe("TOK", size2D(F_32, nBatches, 1)); n.xPipeIndex = netBuilder.addPipe("X", size2D(F_32, nBatches, h->dim)); n.logitsPipeIndex = netBuilder.addPipe("LG", size2D(F_32, nBatches, h->vocabSize)); - const NnSize zqPipeIndex = netBuilder.addPipe("ZQ", size2D(h->syncType, nBatches, h->dim * nNodes)); + const NnUint zqPipeIndex = netBuilder.addPipe("ZQ", size2D(h->syncType, nBatches, h->dim * nNodes)); n.header = h; n.netConfig = netBuilder.build(); n.nodeConfigs = new NnNodeConfig[nNodes]; - for (NnSize nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { + for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { NnRopeSlice ropeSlice = sliceRope(h->dim, h->kvDim, h->nKvHeads, nNodes, h->seqLen, h->headSize, h->ropeTheta, nodeIndex); NnNodeConfigBuilder nodeBuilder(nodeIndex); - const NnSize xBufferIndex = nodeBuilder.addBuffer("x", size2D(F_32, nBatches, h->dim)); + const NnUint xBufferIndex = nodeBuilder.addBuffer("x", size2D(F_32, nBatches, h->dim)); - const NnSize yBufferIndex = nodeBuilder.addBuffer("y", size2D(F_32, nBatches, h->dim)); - const NnSize yqBufferIndex = h->syncType == F_32 + const NnUint yBufferIndex = nodeBuilder.addBuffer("y", size2D(F_32, nBatches, h->dim)); + const NnUint yqBufferIndex = h->syncType == F_32 ? yBufferIndex : nodeBuilder.addBuffer("yq", size2D(h->syncType, nBatches, h->dim)); - const NnSize yqSliceIndex = nodeBuilder.addBuffer("yq_slice", size2D(h->syncType, nBatches, h->dim / nNodes)); + const NnUint yqSliceIndex = nodeBuilder.addBuffer("yq_slice", size2D(h->syncType, nBatches, h->dim / nNodes)); - const NnSize qBufferIndex = nodeBuilder.addBuffer("q", size2D(F_32, nBatches, n.qSlice.d0)); - const NnSize kTempBufferIndex = nodeBuilder.addBuffer("k_temp", size2D(F_32, nBatches, n.kSlice.d0)); - const NnSize vTempBufferIndex = nodeBuilder.addBuffer("v_temp", size2D(F_32, nBatches, n.vSlice.d0)); + const NnUint qBufferIndex = nodeBuilder.addBuffer("q", size2D(F_32, nBatches, n.qSlice.d0)); + const NnUint kTempBufferIndex = nodeBuilder.addBuffer("k_temp", size2D(F_32, nBatches, n.kSlice.d0)); + const NnUint vTempBufferIndex = nodeBuilder.addBuffer("v_temp", size2D(F_32, nBatches, n.vSlice.d0)); - const NnSize dBufferIndex = nodeBuilder.addBuffer("d", size2D(F_32, nBatches, n.w1Slice.d0)); - const NnSize dqBufferIndex = h->syncType == F_32 + const NnUint dBufferIndex = nodeBuilder.addBuffer("d", size2D(F_32, nBatches, n.w1Slice.d0)); + const NnUint dqBufferIndex = h->syncType == F_32 ? dBufferIndex : nodeBuilder.addBuffer("d", size2D(h->syncType, nBatches, n.w1Slice.d0)); - const NnSize lBufferIndex = nodeBuilder.addBuffer("l", size2D(F_32, nBatches, n.w3Slice.d0)); - const NnSize invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, nBatches, 1)); - const NnSize ropeCacheBufferIndex = nodeBuilder.addBuffer("rope_cache", ropeSlice.cacheSize); - const NnSize attBufferIndex = nodeBuilder.addBuffer("att", multiHeadAttSlice.attSize); - const NnSize logitsSliceBufferIndex = nodeBuilder.addBuffer("lg", size2D(F_32, nBatches, h->vocabSize / nNodes)); + const NnUint lBufferIndex = nodeBuilder.addBuffer("l", size2D(F_32, nBatches, n.w3Slice.d0)); + const NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, nBatches, 1)); + const NnUint ropeCacheBufferIndex = nodeBuilder.addBuffer("rope_cache", ropeSlice.cacheSize); + const NnUint attBufferIndex = nodeBuilder.addBuffer("att", multiHeadAttSlice.attSize); + const NnUint logitsSliceBufferIndex = nodeBuilder.addBuffer("lg", size2D(F_32, nBatches, h->vocabSize / nNodes)); NnSegmentConfigBuilder start; if (nodeIndex == 0) { @@ -191,9 +191,9 @@ LlmNet buildLlmNet(LlmHeader *h, NnSize nNodes, NnSize nBatches) { start.setSyncPointers(true); nodeBuilder.addSegment(start.build()); - for (NnSize layerIndex = 0; layerIndex < h->nLayers; layerIndex++) { - const NnSize kBufferIndex = nodeBuilder.addBuffer("k", kvCacheSlice.keySize); - const NnSize vBufferIndex = nodeBuilder.addBuffer("v", kvCacheSlice.valueSize); + for (NnUint layerIndex = 0; layerIndex < h->nLayers; layerIndex++) { + const NnUint kBufferIndex = nodeBuilder.addBuffer("k", kvCacheSlice.keySize); + const NnUint vBufferIndex = nodeBuilder.addBuffer("v", kvCacheSlice.valueSize); NnSegmentConfigBuilder att; NnSegmentConfigBuilder ff; @@ -436,7 +436,7 @@ LlmNet buildLlmNet(LlmHeader *h, NnSize nNodes, NnSize nBatches) { } void releaseLlmNet(LlmNet *net) { - for (NnSize nodeIndex = 0; nodeIndex < net->netConfig.nNodes; nodeIndex++) + for (NnUint nodeIndex = 0; nodeIndex < net->netConfig.nNodes; nodeIndex++) releaseNodeConfig(&net->nodeConfigs[nodeIndex]); releaseNetConfig(&net->netConfig); delete[] net->nodeConfigs; @@ -445,14 +445,19 @@ void releaseLlmNet(LlmNet *net) { void loadLlmNetWeight(const char *path, LlmNet *net, NnRootWeightLoader *loader) { MmapFile file; openMmapFile(&file, path, net->header->fileSize); +#if DEBUG_USE_MMAP_FOR_WEIGHTS + assert(net->netConfig.nNodes == 1); +#else std::unique_ptr fdPtr(&file, closeMmapFile); + printf("💿 Loading weights...\n"); +#endif NnByte *data = (NnByte *)file.data; NnByte *b = &data[net->header->headerSize]; - NnSize nodeIndex = 0; + NnUint nodeIndex = 0; b += loader->loadRoot("embedding", 0, net->tokenEmbeddingSize.nBytes, b); - for (NnSize layerIndex = 0; layerIndex < net->header->nLayers; layerIndex++) { + for (NnUint layerIndex = 0; layerIndex < net->header->nLayers; layerIndex++) { b += loader->loadRowMatmulSlices("block_matmul_q", layerIndex, &net->qSlice, b); b += loader->loadRowMatmulSlices("block_matmul_k", layerIndex, &net->kSlice, b); b += loader->loadRowMatmulSlices("block_matmul_v", layerIndex, &net->vSlice, b); @@ -467,8 +472,9 @@ void loadLlmNetWeight(const char *path, LlmNet *net, NnRootWeightLoader *loader) b += loader->loadAll("final_rms_norm", 0, net->rmsNormSize.nBytes, b); b += loader->loadRowMatmulSlices("final_matmul_logits", 0, &net->wclsSlice, b); - unsigned long missingBytes = (b - data) - net->header->fileSize; - assert(missingBytes == 0); + long long missingBytes = (long long)(b - data) - net->header->fileSize; + if (missingBytes != 0) + throw std::runtime_error("Missing bytes in weight file: " + std::to_string(missingBytes)); printf("💿 Weights loaded\n"); loader->finish(); diff --git a/src/llm.hpp b/src/llm.hpp index 67918f7..fbec5db 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -37,29 +37,29 @@ enum LlmArchType { }; typedef struct { - size_t headerSize; - size_t fileSize; + NnSize headerSize; + NnSize fileSize; int version; LlmArchType archType; - NnSize dim; - NnSize nLayers; - NnSize nHeads; - NnSize headSize; - NnSize nKvHeads; - NnSize nExperts; - NnSize nActiveExperts; - NnSize origSeqLen; // Original model context length - NnSize seqLen; // Limited context length by the `--max-seq-len` argument - NnSize hiddenDim; + NnUint dim; + NnUint nLayers; + NnUint nHeads; + NnUint headSize; + NnUint nKvHeads; + NnUint nExperts; + NnUint nActiveExperts; + NnUint origSeqLen; // Original model context length + NnUint seqLen; // Limited context length by the `--max-seq-len` argument + NnUint hiddenDim; LlmHiddenAct hiddenAct; - NnSize kvDim; - NnSize vocabSize; + NnUint kvDim; + NnUint vocabSize; float ropeTheta; NnRopeType ropeType; float ropeScalingFactor; float ropeScalingLowFreqFactor; float ropeScalingHighFreqFactory; - NnSize ropeScalingOrigMaxSeqLen; + NnUint ropeScalingOrigMaxSeqLen; float normEpsilon; NnFloatType weightType; @@ -78,17 +78,17 @@ typedef struct { NnColMatmulSlice w2Slice; NnRowMatmulSlice w3Slice; NnRowMatmulSlice wclsSlice; - NnSize positionPipeIndex; - NnSize tokenPipeIndex; - NnSize xPipeIndex; - NnSize logitsPipeIndex; + NnUint positionPipeIndex; + NnUint tokenPipeIndex; + NnUint xPipeIndex; + NnUint logitsPipeIndex; NnSize2D tokenEmbeddingSize; NnSize2D rmsNormSize; } LlmNet; LlmHeader loadLlmHeader(const char* path, const unsigned int maxSeqLen, NnFloatType syncType); void printLlmHeader(LlmHeader *header); -LlmNet buildLlmNet(LlmHeader *h, NnSize nNodes, NnSize nBatches); +LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches); void releaseLlmNet(LlmNet *net); void loadLlmNetWeight(const char* path, LlmNet *net, NnRootWeightLoader *loader); diff --git a/src/nn/nn-config-builder.hpp b/src/nn/nn-config-builder.hpp index 4c0f71b..9709ded 100644 --- a/src/nn/nn-config-builder.hpp +++ b/src/nn/nn-config-builder.hpp @@ -6,7 +6,7 @@ #include static char *cloneString(const char *str) { - NnSize len = std::strlen(str); + NnUint len = std::strlen(str); char *copy = new char[len + 1]; std::memcpy(copy, str, len + 1); return copy; @@ -14,17 +14,17 @@ static char *cloneString(const char *str) { class NnNetConfigBuilder { public: - NnSize nNodes; - NnSize nBatches; + NnUint nNodes; + NnUint nBatches; std::list pipes; - NnNetConfigBuilder(NnSize nNodes, NnSize nBatches) { + NnNetConfigBuilder(NnUint nNodes, NnUint nBatches) { this->nNodes = nNodes; this->nBatches = nBatches; } - NnSize addPipe(const char *name, NnSize2D size) { - NnSize pipeIndex = pipes.size(); + NnUint addPipe(const char *name, NnSize2D size) { + NnUint pipeIndex = pipes.size(); pipes.push_back({ cloneString(name), size }); return pipeIndex; } @@ -42,16 +42,16 @@ class NnNetConfigBuilder { class NnNodeConfigBuilder { public: - NnSize nodeIndex; + NnUint nodeIndex; std::list buffers; std::list segments; - NnNodeConfigBuilder(NnSize nodeIndex) { + NnNodeConfigBuilder(NnUint nodeIndex) { this->nodeIndex = nodeIndex; } - NnSize addBuffer(const char *name, NnSize2D size) { - NnSize bufferIndex = buffers.size(); + NnUint addBuffer(const char *name, NnSize2D size) { + NnUint bufferIndex = buffers.size(); buffers.push_back({ cloneString(name), size }); return bufferIndex; } @@ -84,8 +84,8 @@ class NnSegmentConfigBuilder { public: template - void addOp(NnOpCode code, const char *name, NnSize index, NnPointerConfig input, NnPointerConfig output, NnSize2D weightSize, T config) { - NnSize configSize = sizeof(T); + void addOp(NnOpCode code, const char *name, NnUint index, NnPointerConfig input, NnPointerConfig output, NnSize2D weightSize, T config) { + NnUint configSize = sizeof(T); NnByte *configCopy = new NnByte[configSize]; std::memcpy(configCopy, &config, configSize); ops.push_back({ @@ -100,7 +100,7 @@ class NnSegmentConfigBuilder { }); }; - void addSync(NnSize pipeIndex, NnSyncType syncType) { + void addSync(NnUint pipeIndex, NnSyncType syncType) { syncs.push_back({ pipeIndex, syncType }); } diff --git a/src/nn/nn-core.cpp b/src/nn/nn-core.cpp index ee662c7..5136d55 100644 --- a/src/nn/nn-core.cpp +++ b/src/nn/nn-core.cpp @@ -97,24 +97,24 @@ NnSize2D size0() { return { F_UNK, 0, 0, 0, 0 }; } -NnSize2D size1D(NnFloatType floatType, NnSize x) { +NnSize2D size1D(NnFloatType floatType, NnUint x) { return size2D(floatType, 1, x); } -NnSize2D size2D(NnFloatType floatType, NnSize y, NnSize x) { +NnSize2D size2D(NnFloatType floatType, NnUint y, NnUint x) { NnSize length = y * x; return { floatType, y, x, length, getBytes(floatType, length) }; } -NnPointerConfig pointerConfig(NnPointerType type, NnSize index) { +NnPointerConfig pointerConfig(NnPointerType type, NnUint index) { return { type, index, SLICE_NONE, PNTR_BATCH_DEFAULT, 0 /* not used*/ }; } -NnPointerConfig pointerConfigWithPipedBatch(NnPointerType type, NnSize index, NnSize pipeIndex) { +NnPointerConfig pointerConfigWithPipedBatch(NnPointerType type, NnUint index, NnUint pipeIndex) { return { type, index, SLICE_NONE, PNTR_BATCH_PIPE, pipeIndex }; } -NnPointerConfig slicedPointerConfig(NnPointerType type, NnSize index) { +NnPointerConfig slicedPointerConfig(NnPointerType type, NnUint index) { return { type, index, SLICE_NODE_PART, PNTR_BATCH_DEFAULT, 0 /* not used*/ }; } @@ -123,17 +123,17 @@ bool hasPointerContinuousMemory(NnPointerConfig *config) { } void releaseNetConfig(NnNetConfig *netConfig) { - for (NnSize pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) { + for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) { delete[] netConfig->pipes[pipeIndex].name; } delete[] netConfig->pipes; } void releaseNodeConfig(NnNodeConfig *nodeConfig) { - for (NnSize segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { + for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex]; if (segment->nOps > 0) { - for (NnSize opIndex = 0; opIndex < segment->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) { NnOpConfig *op = &segment->ops[opIndex]; delete[] op->name; delete[] op->config; @@ -143,7 +143,7 @@ void releaseNodeConfig(NnNodeConfig *nodeConfig) { if (segment->nSyncs > 0) delete[] segment->syncs; } - for (NnSize bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++) + for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++) delete[] nodeConfig->buffers[bufferIndex].name; delete[] nodeConfig->buffers; delete[] nodeConfig->segments; @@ -151,13 +151,13 @@ void releaseNodeConfig(NnNodeConfig *nodeConfig) { void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) { unsigned long total = 0; - for (NnSize pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) + for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) total += netConfig->pipes[pipeIndex].size.nBytes; - for (NnSize bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++) + for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++) total += nodeConfig->buffers[bufferIndex].size.nBytes; - for (NnSize segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { + for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex]; - for (NnSize opIndex = 0; opIndex < segment->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) { total += segment->ops[opIndex].weightSize.nBytes; total += segment->ops[opIndex].configSize; } @@ -169,19 +169,19 @@ Timer::Timer() { startTime = std::chrono::high_resolution_clock::now(); } -NnSize Timer::elapsedMiliseconds() { +NnUint Timer::elapsedMiliseconds() { auto endTime = std::chrono::high_resolution_clock::now(); - return (NnSize)std::chrono::duration_cast(endTime - startTime).count(); + return (NnUint)std::chrono::duration_cast(endTime - startTime).count(); } -NnSize Timer::elapsedMicroseconds() { +NnUint Timer::elapsedMicroseconds() { auto endTime = std::chrono::high_resolution_clock::now(); - return (NnSize)std::chrono::duration_cast(endTime - startTime).count(); + return (NnUint)std::chrono::duration_cast(endTime - startTime).count(); } // slicers -NnKvCacheSlice sliceKvCache(NnSize kvDim, NnSize seqLen, NnSize nNodes) { +NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes) { NnKvCacheSlice s; assert(kvDim % nNodes == 0); s.kvDim0 = kvDim / nNodes; @@ -190,7 +190,7 @@ NnKvCacheSlice sliceKvCache(NnSize kvDim, NnSize seqLen, NnSize nNodes) { return s; } -NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnSize nNodes, NnSize n, NnSize d) { +NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) { NnRowMatmulSlice s; assert(d % nNodes == 0); s.type = type; @@ -202,7 +202,7 @@ NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnSize nNodes, NnSize n, NnSiz return s; } -NnColMatmulSlice sliceColMatmul(NnFloatType type, NnSize nNodes, NnSize n, NnSize d) { +NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) { NnColMatmulSlice s; assert(n % nNodes == 0); s.type = type; @@ -215,7 +215,7 @@ NnColMatmulSlice sliceColMatmul(NnFloatType type, NnSize nNodes, NnSize n, NnSiz return s; } -NnRopeSlice sliceRope(NnSize dim, NnSize kvDim, NnSize nKvHeads, NnSize nNodes, NnSize seqLen, NnSize headSize, float ropeTheta, NnSize nodeIndex) { +NnRopeSlice sliceRope(NnUint dim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headSize, float ropeTheta, NnUint nodeIndex) { NnRopeSlice s; assert(dim >= kvDim); assert(dim % nNodes == 0); @@ -242,7 +242,7 @@ NnRopeSlice sliceRope(NnSize dim, NnSize kvDim, NnSize nKvHeads, NnSize nNodes, return s; } -NnMultiHeadAttSlice sliceMultiHeadAtt(NnSize nHeads, NnSize seqLen, NnSize nNodes) { +NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes) { NnMultiHeadAttSlice s; assert(nHeads % nNodes == 0); s.nHeads = nHeads; @@ -253,7 +253,7 @@ NnMultiHeadAttSlice sliceMultiHeadAtt(NnSize nHeads, NnSize seqLen, NnSize nNode // splitters -NnSize splitRowMatmulWeight(NnRowMatmulSlice *slice, NnSize nodeIndex, NnByte *weight, NnByte *weight0) { +NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) { NnSize blockSize = getBlockSize(slice->type); NnSize batchBytes = getBytes(slice->type, blockSize); assert(slice->n % blockSize == 0); @@ -261,8 +261,8 @@ NnSize splitRowMatmulWeight(NnRowMatmulSlice *slice, NnSize nodeIndex, NnByte *w NnSize n = slice->n / blockSize; NnSize offset = slice->d0 * nodeIndex * n * batchBytes; NnSize copiedBytes = 0; - for (NnSize d = 0; d < slice->d0; d++) { - for (NnSize j = 0; j < n; j++) { + for (NnUint d = 0; d < slice->d0; d++) { + for (NnUint j = 0; j < n; j++) { NnSize o = (d * n + j) * batchBytes; std::memcpy(weight0 + o, weight + offset + o, batchBytes); copiedBytes += batchBytes; @@ -271,7 +271,7 @@ NnSize splitRowMatmulWeight(NnRowMatmulSlice *slice, NnSize nodeIndex, NnByte *w return copiedBytes; } -NnSize splitColMatmulWeight(NnColMatmulSlice *slice, NnSize nodeIndex, NnByte *weight, NnByte *weight0) { +NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) { NnSize blockSize = getBlockSize(slice->type); NnSize batchBytes = getBytes(slice->type, blockSize); assert(slice->n0 % blockSize == 0); @@ -281,7 +281,7 @@ NnSize splitColMatmulWeight(NnColMatmulSlice *slice, NnSize nodeIndex, NnByte *w NnSize row0Bytes = (slice->n0 / blockSize) * batchBytes; NnSize rowOffsetBytes = nodeIndex * row0Bytes; NnSize copiedBytes = 0; - for (NnSize d = 0; d < slice->d; d++) { + for (NnUint d = 0; d < slice->d; d++) { std::memcpy(&weight0[row0Bytes * d], &weight[rowBytes * d + rowOffsetBytes], row0Bytes); copiedBytes += row0Bytes; } diff --git a/src/nn/nn-core.hpp b/src/nn/nn-core.hpp index 49d075f..256ea9b 100644 --- a/src/nn/nn-core.hpp +++ b/src/nn/nn-core.hpp @@ -11,8 +11,8 @@ typedef struct { NnFloatType floatType; - NnSize y; - NnSize x; + NnUint y; + NnUint x; NnSize length; NnSize nBytes; } NnSize2D; @@ -20,49 +20,49 @@ typedef struct { // slices typedef struct { - NnSize kvDim0; + NnUint kvDim0; NnSize2D keySize; NnSize2D valueSize; } NnKvCacheSlice; typedef struct { NnFloatType type; - NnSize nNodes; - NnSize d0; - NnSize n; + NnUint nNodes; + NnUint d0; + NnUint n; NnSize2D size; NnSize2D sliceSize; } NnRowMatmulSlice; typedef struct { NnFloatType type; - NnSize nNodes; - NnSize n; - NnSize n0; - NnSize d; + NnUint nNodes; + NnUint n; + NnUint n0; + NnUint d; NnSize2D size; NnSize2D sliceSize; } NnColMatmulSlice; typedef struct { - NnSize qDim0; - NnSize qDimStart; - NnSize qDimEnd; - NnSize qShift; - NnSize kvDim; - NnSize kvDim0; - NnSize kvDimStart; - NnSize sliceDim; - NnSize seqLen; - NnSize headSize; - NnSize nKvHeads; + NnUint qDim0; + NnUint qDimStart; + NnUint qDimEnd; + NnUint qShift; + NnUint kvDim; + NnUint kvDim0; + NnUint kvDimStart; + NnUint sliceDim; + NnUint seqLen; + NnUint headSize; + NnUint nKvHeads; float ropeTheta; NnSize2D cacheSize; } NnRopeSlice; typedef struct { - NnSize nHeads; - NnSize nHeads0; + NnUint nHeads; + NnUint nHeads0; NnSize2D attSize; } NnMultiHeadAttSlice; @@ -138,48 +138,48 @@ typedef struct { typedef struct { NnPointerType pointerType; - NnSize pointerIndex; + NnUint pointerIndex; NnPointerSliceType sliceType; NnPointerBatchType batchType; - NnSize batchArg0; + NnUint batchArg0; } NnPointerConfig; typedef struct { NnOpCode code; char *name; - NnSize index; + NnUint index; NnPointerConfig input; NnPointerConfig output; NnSize2D weightSize; NnByte *config; - NnSize configSize; + NnUint configSize; } NnOpConfig; typedef struct { - NnSize pipeIndex; + NnUint pipeIndex; NnSyncType syncType; } NnSyncConfig; typedef struct { - NnSize nOps; + NnUint nOps; NnOpConfig *ops; - NnSize nSyncs; + NnUint nSyncs; NnSyncConfig *syncs; bool syncPointers; } NnSegmentConfig; typedef struct { - NnSize nBatches; - NnSize nNodes; - NnSize nPipes; + NnUint nBatches; + NnUint nNodes; + NnUint nPipes; NnPipeConfig *pipes; } NnNetConfig; typedef struct { - NnSize nodeIndex; - NnSize nBuffers; + NnUint nodeIndex; + NnUint nBuffers; NnBufferConfig *buffers; - NnSize nSegments; + NnUint nSegments; NnSegmentConfig *segments; } NnNodeConfig; @@ -194,7 +194,7 @@ typedef struct { } NnInvRmsOpConfig; typedef struct { - NnSize invRmsBufferIndex; + NnUint invRmsBufferIndex; } NnRmsNormOpConfig; typedef struct { @@ -203,24 +203,24 @@ typedef struct { typedef struct { bool isQ; - NnSize positionPipeIndex; - NnSize ropeCacheBufferIndex; + NnUint positionPipeIndex; + NnUint ropeCacheBufferIndex; float ropeScalingFactor; float ropeScalingLowFreqFactor; float ropeScalingHighFreqFactory; - NnSize ropeScalingOrigMaxSeqLen; + NnUint ropeScalingOrigMaxSeqLen; NnRopeSlice slice; } NnRopeLlamaOpConfig; typedef struct { - NnSize nKvHeads; - NnSize headSize; - NnSize seqLen; - NnSize positionPipeIndex; - NnSize queryBufferIndex; - NnSize keyCacheBufferIndex; - NnSize valueCacheBufferIndex; - NnSize attBufferIndex; + NnUint nKvHeads; + NnUint headSize; + NnUint seqLen; + NnUint positionPipeIndex; + NnUint queryBufferIndex; + NnUint keyCacheBufferIndex; + NnUint valueCacheBufferIndex; + NnUint attBufferIndex; NnRowMatmulSlice qSlice; NnKvCacheSlice kvCacheSlice; NnMultiHeadAttSlice multiHeadAttSlice; @@ -251,11 +251,11 @@ NnSize getBytes(NnFloatType floatType, NnSize n); NnSize getBlockSize(NnFloatType floatType); NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output); NnSize2D size0(); -NnSize2D size1D(NnFloatType floatType, NnSize x); -NnSize2D size2D(NnFloatType floatType, NnSize y, NnSize x); -NnPointerConfig pointerConfig(NnPointerType type, NnSize index); -NnPointerConfig pointerConfigWithPipedBatch(NnPointerType type, NnSize index, NnSize pipeIndex); -NnPointerConfig slicedPointerConfig(NnPointerType type, NnSize index); +NnSize2D size1D(NnFloatType floatType, NnUint x); +NnSize2D size2D(NnFloatType floatType, NnUint y, NnUint x); +NnPointerConfig pointerConfig(NnPointerType type, NnUint index); +NnPointerConfig pointerConfigWithPipedBatch(NnPointerType type, NnUint index, NnUint pipeIndex); +NnPointerConfig slicedPointerConfig(NnPointerType type, NnUint index); bool hasPointerContinuousMemory(NnPointerConfig *config); void releaseNetConfig(NnNetConfig *netConfig); @@ -268,21 +268,21 @@ class Timer { std::chrono::time_point startTime; public: Timer(); - NnSize elapsedMiliseconds(); - NnSize elapsedMicroseconds(); + NnUint elapsedMiliseconds(); + NnUint elapsedMicroseconds(); }; // slicers -NnKvCacheSlice sliceKvCache(NnSize kvDim, NnSize seqLen, NnSize nNodes); -NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnSize nNodes, NnSize n, NnSize d); -NnColMatmulSlice sliceColMatmul(NnFloatType type, NnSize nNodes, NnSize n, NnSize d); -NnRopeSlice sliceRope(NnSize dim, NnSize kvDim, NnSize nKvHeads, NnSize nNodes, NnSize seqLen, NnSize headSize, float ropeTheta, NnSize nodeIndex); -NnMultiHeadAttSlice sliceMultiHeadAtt(NnSize nHeads, NnSize seqLen, NnSize nNodes); +NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes); +NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d); +NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d); +NnRopeSlice sliceRope(NnUint dim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headSize, float ropeTheta, NnUint nodeIndex); +NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes); // splitters -NnSize splitRowMatmulWeight(NnRowMatmulSlice *slice, NnSize nodeIndex, NnByte *weight, NnByte *weight0); -NnSize splitColMatmulWeight(NnColMatmulSlice *slice, NnSize nodeIndex, NnByte *weight, NnByte *weight0); +NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0); +NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0); #endif diff --git a/src/nn/nn-cpu-ops-test.cpp b/src/nn/nn-cpu-ops-test.cpp index 8af47b6..f0a990b 100644 --- a/src/nn/nn-cpu-ops-test.cpp +++ b/src/nn/nn-cpu-ops-test.cpp @@ -3,20 +3,20 @@ // framework -void rand(float *o, const NnSize n, const NnSize seed) { +void rand(float *o, const NnUint n, const NnUint seed) { srand(seed + 123456); - for (NnSize i = 0; i < n; i++) { + for (NnUint i = 0; i < n; i++) { float v = (float)(rand() / RAND_MAX); o[i] = v * 2.0f - 1.0f; } } -void compare_F32(const char *name, const float *a, const float *b, const NnSize n, const float epsilon) { - for (NnSize i = 0; i < n; i++) { +void compare_F32(const char *name, const float *a, const float *b, const NnUint n, const float epsilon) { + for (NnUint i = 0; i < n; i++) { float error = fabs(a[i] - b[i]); if (error > epsilon) { printf("❌ %s failed\n", name); - for (NnSize j = i; j < i + 16 && j < n; j++) + for (NnUint j = i; j < i + 16 && j < n; j++) printf(" [%3d] %f != %f\n", j, a[j], b[j]); exit(1); } @@ -71,7 +71,7 @@ void testSplitThreads() { void testConvertF32toF16() { float x[] = {0.0f, 0.25f, 0.3456f, 1.0f}; - for (NnSize i = 0; i < sizeof(x) / sizeof(float); i++) { + for (NnUint i = 0; i < sizeof(x) / sizeof(float); i++) { NnFp16 f16 = CONVERT_F32_TO_F16(x[i]); float f32 = CONVERT_F16_TO_F32(f16); compare_F32("convertF32toF16", &x[i], &f32, 1, 0.0005); @@ -79,7 +79,7 @@ void testConvertF32toF16() { } // quantization -void testQuantization(const NnSize m) { +void testQuantization(const NnUint m) { std::vector a(m * Q40_BLOCK_SIZE); std::vector aTemp(m * Q40_BLOCK_SIZE); std::vector aQ40(m); @@ -118,7 +118,7 @@ void testInvRms() { } // rmsNorm -void testRmsNorm(const NnSize m) { +void testRmsNorm(const NnUint m) { std::vector x(m); std::vector xQ80(m / Q80_BLOCK_SIZE); std::vector w(m); @@ -137,8 +137,8 @@ void testRmsNorm(const NnSize m) { } // a *= b -void testMul(const NnSize m) { - const NnSize n = Q80_BLOCK_SIZE * m; +void testMul(const NnUint m) { + const NnUint n = Q80_BLOCK_SIZE * m; std::vector a0(n); std::vector b0(n); @@ -158,8 +158,8 @@ void testMul(const NnSize m) { } // y += x -void testAdd(const NnSize m) { - const NnSize n = Q80_BLOCK_SIZE * m; +void testAdd(const NnUint m) { + const NnUint n = Q80_BLOCK_SIZE * m; std::vector y(n); std::vector yTemp(n); @@ -179,7 +179,7 @@ void testAdd(const NnSize m) { void testSoftmax() { std::vector y(8); - for (NnSize i = 0; i < 8; i++) + for (NnUint i = 0; i < 8; i++) y[i] = i / 8.0f; softmax_F32(y.data(), 8); @@ -199,7 +199,7 @@ void testSoftmax() { void testSilu() { std::vector y(8); - for (NnSize i = 0; i < 8; i++) + for (NnUint i = 0; i < 8; i++) y[i] = i / 8.0f; silu_F32(y.data(), 8, 1, 0); @@ -218,9 +218,9 @@ void testSilu() { } // matmul -void testMatmul_F32_Q40_F32(const NnSize m = 2) { - const NnSize n = Q80_BLOCK_SIZE * m; - const NnSize d = Q80_BLOCK_SIZE * m; +void testMatmul_F32_Q40_F32(const NnUint m = 2) { + const NnUint n = Q80_BLOCK_SIZE * m; + const NnUint d = Q80_BLOCK_SIZE * m; std::vector x(n); std::vector w(n * d); @@ -241,9 +241,9 @@ void testMatmul_F32_Q40_F32(const NnSize m = 2) { } void testLlamafileSgemm() { - const NnSize batchSize = 8; - const NnSize n = 256; - const NnSize d = 128; + const NnUint batchSize = 8; + const NnUint n = 256; + const NnUint d = 128; std::vector x(n * batchSize); std::vector xQ((n * batchSize) / Q80_BLOCK_SIZE); @@ -260,7 +260,7 @@ void testLlamafileSgemm() { // f32 - for (NnSize i = 0; i < batchSize; i++) { + for (NnUint i = 0; i < batchSize; i++) { matmul_F32_F32_F32(o.data() + i * d, x.data() + i * n, w.data(), n, d, 1, 0); } diff --git a/src/nn/nn-cpu-ops.cpp b/src/nn/nn-cpu-ops.cpp index 52e84ef..5f8b65f 100644 --- a/src/nn/nn-cpu-ops.cpp +++ b/src/nn/nn-cpu-ops.cpp @@ -136,7 +136,7 @@ static float invRms_F32(const float *x, const unsigned int size, const float eps return 1.0f / sqrtf(sum); } -static void rmsNorm_F32(float *output, const float *x, const float invRms, const float *w, const NnSize size, const NnSize nThreads, const NnSize threadIndex) { +static void rmsNorm_F32(float *output, const float *x, const float invRms, const float *w, const NnUint size, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, size, nThreads, threadIndex); unsigned int i = start; #if defined(__ARM_NEON) @@ -168,21 +168,21 @@ static void rmsNorm_F32(float *output, const float *x, const float invRms, const output[i] = w[i] * (invRms * x[i]); } -static void rmsNorm_Q80_F32_F32(float *output, const NnBlockQ80 *x, const float invRms, const float *w, const NnSize size, const NnSize nThreads, const NnSize threadIndex) { +static void rmsNorm_Q80_F32_F32(float *output, const NnBlockQ80 *x, const float invRms, const float *w, const NnUint size, const NnUint nThreads, const NnUint threadIndex) { assert(size % Q80_BLOCK_SIZE == 0); - const NnSize nBlocks = size / Q80_BLOCK_SIZE; + const NnUint nBlocks = size / Q80_BLOCK_SIZE; SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { float d = CONVERT_F16_TO_F32(x[i].d); - for (NnSize j = 0; j < Q80_BLOCK_SIZE; j++) { - NnSize k = i * Q80_BLOCK_SIZE + j; + for (NnUint j = 0; j < Q80_BLOCK_SIZE; j++) { + NnUint k = i * Q80_BLOCK_SIZE + j; output[k] = w[k] * (invRms * d * x[i].qs[j]); } } } -static void matmul_F32_F32_F32(float *output, const float *x, const float *w, const NnSize n, const NnSize d, const NnSize nThreads, const NnSize threadIndex) { +static void matmul_F32_F32_F32(float *output, const float *x, const float *w, const NnUint n, const NnUint d, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, d, nThreads, threadIndex); unsigned int i, j; #if defined(__ARM_NEON) @@ -222,7 +222,7 @@ static void matmul_F32_F32_F32(float *output, const float *x, const float *w, co #endif } -static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlockQ40 *w, const NnSize n, const NnSize d, const NnSize nThreads, const NnSize threadIndex) { +static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlockQ40 *w, const NnUint n, const NnUint d, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, d, nThreads, threadIndex); assert(n % Q40_BLOCK_SIZE == 0); const unsigned int nBlocks = n / Q40_BLOCK_SIZE; @@ -352,9 +352,9 @@ static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlock output[di] = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); } #elif defined(__AVX512F__) - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { float sum = 0.0f; - for (NnSize j = 0; j < nBlocks; j++) { + for (NnUint j = 0; j < nBlocks; j++) { const NnBlockQ40 *wb = &w[i * nBlocks + j]; const NnBlockQ80 *xb = &x[j]; const float s = CONVERT_F16_TO_F32(wb->d) * CONVERT_F16_TO_F32(xb->d); @@ -379,9 +379,9 @@ static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlock output[i] = sum; } #elif defined(__AVX2__) - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { float sum = 0.0f; - for (NnSize j = 0; j < nBlocks; j++) { + for (NnUint j = 0; j < nBlocks; j++) { const NnBlockQ40 *wb = &w[i * nBlocks + j]; const NnBlockQ80 *xb = &x[j]; const float s = CONVERT_F16_TO_F32(wb->d) * CONVERT_F16_TO_F32(xb->d); @@ -423,13 +423,13 @@ static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlock output[i] = sum; } #else - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { float sum = 0.0; - for (NnSize j = 0; j < nBlocks; j++) { + for (NnUint j = 0; j < nBlocks; j++) { const NnBlockQ40 *wb = &w[i * nBlocks + j]; const NnBlockQ80 *xb = &x[j]; const float s = CONVERT_F16_TO_F32(wb->d) * CONVERT_F16_TO_F32(xb->d); - for (NnSize k = 0; k < Q40_BLOCK_SIZE / 2; k++) { + for (NnUint k = 0; k < Q40_BLOCK_SIZE / 2; k++) { const int w0 = (wb->qs[k] & 0x0F) - 8; const int w1 = (wb->qs[k] >> 4) - 8; const int i1 = xb->qs[k]; @@ -445,7 +445,7 @@ static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlock #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f #define GELU_COEF_A 0.044715f -static void gelu_F32(float *output, const unsigned int n, const NnSize nThreads, const NnSize threadIndex) { +static void gelu_F32(float *output, const unsigned int n, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, n, nThreads, threadIndex); for (unsigned int i = start; i < end; i++) { float x = output[i]; @@ -453,7 +453,7 @@ static void gelu_F32(float *output, const unsigned int n, const NnSize nThreads, } } -static void silu_F32(float *output, const unsigned int n, const NnSize nThreads, const NnSize threadIndex) { +static void silu_F32(float *output, const unsigned int n, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, n, nThreads, threadIndex); unsigned int i = start; #if defined(__ARM_NEON) @@ -493,15 +493,15 @@ static void silu_F32(float *output, const unsigned int n, const NnSize nThreads, } } -static void add_F32(float *output, const float *x, const unsigned int n, const NnSize nThreads, const NnSize threadIndex) { +static void add_F32(float *output, const float *x, const unsigned int n, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, n, nThreads, threadIndex); for (unsigned int i = start; i < end; i++) { output[i] += x[i]; } } -static void add_Q80_F32(float *y, const NnBlockQ80 *x, const NnSize n, const NnSize nThreads, const NnSize threadIndex) { - const NnSize nBlocks = n / Q80_BLOCK_SIZE; +static void add_Q80_F32(float *y, const NnBlockQ80 *x, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { + const NnUint nBlocks = n / Q80_BLOCK_SIZE; SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); #if defined(__ARM_NEON) @@ -586,18 +586,18 @@ static void add_Q80_F32(float *y, const NnBlockQ80 *x, const NnSize n, const NnS #endif } -void softmax_F32(float *x, const NnSize size) { +void softmax_F32(float *x, const NnUint size) { if (size == 0) return; #if defined(__ARM_NEON) - NnSize j; + NnUint j; float maxVal; if (size >= 4) { float32x4_t fs; float32x4_t fmaxv = vld1q_f32(&x[0]); j = size - (size % 4); - for (NnSize i = 4; i < j; i += 4) { + for (NnUint i = 4; i < j; i += 4) { fs = vld1q_f32(&x[i]); fmaxv = vmaxq_f32(fmaxv, fs); } @@ -611,7 +611,7 @@ void softmax_F32(float *x, const NnSize size) { const float32x4_t maxVal_vec = vdupq_n_f32(maxVal); float32x4_t sumv = vdupq_n_f32(0.0f); - NnSize i = 0; + NnUint i = 0; for (; i + 4 <= size; i += 4) { float32x4_t val = vld1q_f32(x + i); val = vsubq_f32(val, maxVal_vec); @@ -645,7 +645,7 @@ void softmax_F32(float *x, const NnSize size) { #elif defined(__AVX2__) float maxVal; const unsigned avxEnd = size - (size % 8); - NnSize i = 0; + NnUint i = 0; if (avxEnd >= 8) { __m256 max_vec = _mm256_loadu_ps(x); @@ -702,18 +702,18 @@ void softmax_F32(float *x, const NnSize size) { x[i] *= inv_sum; #else float maxVal = x[0]; - for (NnSize i = 1; i < size; i++) { + for (NnUint i = 1; i < size; i++) { if (x[i] > maxVal) maxVal = x[i]; } float sum = 0.0f; - for (NnSize i = 0; i < size; i++) { + for (NnUint i = 0; i < size; i++) { x[i] = expf(x[i] - maxVal); sum += x[i]; } if (sum == 0.0) sum = 0.000001; - for (NnSize i = 0; i < size; i++) + for (NnUint i = 0; i < size; i++) x[i] /= sum; #endif } @@ -751,21 +751,21 @@ static float dotProduct_F32(const float *a, const float *b, const unsigned int s static void multiheadAtt_F32( float *x, const float *q, float *att, float *keyCache, float *valueCache, - const unsigned pos, const NnSize nHeads, const NnSize nHeads0, const NnSize nKvHeads, const NnSize kvDim0, const NnSize headSize, const NnSize seqLen, - const NnSize nThreads, const NnSize threadIndex) + const unsigned pos, const NnUint nHeads, const NnUint nHeads0, const NnUint nKvHeads, const NnUint kvDim0, const NnUint headSize, const NnUint seqLen, + const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(h0Start, h0End, nHeads0, nThreads, threadIndex); - const NnSize kvMul = nHeads / nKvHeads; + const NnUint kvMul = nHeads / nKvHeads; const float headSizeRoot = sqrtf(headSize); - for (NnSize h0 = h0Start; h0 < h0End; h0++) { + for (NnUint h0 = h0Start; h0 < h0End; h0++) { const float *hQ = &q[h0 * headSize]; - const NnSize headIndex = h0 / kvMul; + const NnUint headIndex = h0 / kvMul; const float *hKc = &keyCache[headIndex * headSize]; const float *hVc = &valueCache[headIndex * headSize]; float *hAtt = &att[h0 * seqLen]; - for (NnSize t = 0; t <= pos; t++) { + for (NnUint t = 0; t <= pos; t++) { const float *posK = &hKc[t * kvDim0]; const float score = dotProduct_F32(hQ, posK, headSize) / headSizeRoot; hAtt[t] = score; @@ -776,7 +776,7 @@ static void multiheadAtt_F32( float *hX = &x[h0 * headSize]; std::memset(hX, 0, headSize * sizeof(float)); - for (NnSize t = 0; t <= pos; t++) { + for (NnUint t = 0; t <= pos; t++) { const float *posV = &hVc[t * kvDim0]; const float posA = hAtt[t]; for (int i = 0; i < headSize; i++) { @@ -786,7 +786,7 @@ static void multiheadAtt_F32( } } -static void mul_F32(float *output, const float *x, const NnSize n, const NnSize nThreads, const NnSize threadIndex) { +static void mul_F32(float *output, const float *x, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, n, nThreads, threadIndex); unsigned int i = start; @@ -813,35 +813,35 @@ static void mul_F32(float *output, const float *x, const NnSize n, const NnSize output[i] *= x[i]; } -static void mul_Q80_F32(float *output, const NnBlockQ80 *x, const NnSize n, const NnSize nThreads, const NnSize threadIndex) { - const NnSize nBlocks = n / Q80_BLOCK_SIZE; +static void mul_Q80_F32(float *output, const NnBlockQ80 *x, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { + const NnUint nBlocks = n / Q80_BLOCK_SIZE; SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { const NnBlockQ80 *b = &x[i]; float d = CONVERT_F16_TO_F32(b->d); - for (NnSize j = 0; j < Q80_BLOCK_SIZE; j++) { - NnSize k = i * Q80_BLOCK_SIZE + j; + for (NnUint j = 0; j < Q80_BLOCK_SIZE; j++) { + NnUint k = i * Q80_BLOCK_SIZE + j; output[k] *= d * b->qs[j]; } } } -static void copy_UNK(NnByte *output, const NnByte *x, NnSize size, const NnSize nThreads, const NnSize threadIndex) { +static void copy_UNK(NnByte *output, const NnByte *x, NnUint size, const NnUint nThreads, const NnUint threadIndex) { SPLIT_THREADS(start, end, size, nThreads, threadIndex); - NnSize s = end - start; + NnUint s = end - start; if (s != 0) std::memcpy(&output[start], &x[start], s); } // -static void mergeAddForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - NnSize nSlices = context->inputSize.x / context->outputSize.x; +static void mergeAddForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + NnUint nSlices = context->inputSize.x / context->outputSize.x; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *output = (float *)context->output[batchIndex]; float *input = (float *)context->input[batchIndex]; - for (NnSize sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) { + for (NnUint sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) { float *i = &input[sliceIndex * context->outputSize.x]; DEBUG_VECTOR(context, "input", i); add_F32( @@ -854,16 +854,16 @@ static void mergeAddForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize } } -static void mergeAddForward_Q80_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void mergeAddForward_Q80_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { assert(context->inputSize.floatType == F_Q80); assert(context->outputSize.floatType == F_32); - NnSize nSlices = context->inputSize.x / context->outputSize.x; - NnSize xSize = context->outputSize.x / Q80_BLOCK_SIZE; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + NnUint nSlices = context->inputSize.x / context->outputSize.x; + NnUint xSize = context->outputSize.x / Q80_BLOCK_SIZE; + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *output = (float *)context->output[batchIndex]; NnBlockQ80 *input = (NnBlockQ80 *)context->input[batchIndex]; - for (NnSize sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) { + for (NnUint sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) { add_Q80_F32( output, &input[sliceIndex * xSize], @@ -880,11 +880,11 @@ static void initEmbeddingForward(NnCpuOpContext *context) { ASSERT_EQ(context->weightSize.x, context->outputSize.x); } -static void embeddingForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - NnSize dimSize = getBytes(F_32, context->outputSize.x); +static void embeddingForward_F32_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + NnUint dimSize = getBytes(F_32, context->outputSize.x); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { - NnSize token = (NnSize)*((float *)context->input[batchIndex]); + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { + NnUint token = (NnUint)*((float *)context->input[batchIndex]); copy_UNK( context->output[batchIndex], &context->weight[token * dimSize], @@ -894,11 +894,11 @@ static void embeddingForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, Nn } } -static void embeddingForward_F32_F32_Q80(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - NnSize dimSize = getBytes(F_32, context->outputSize.x); +static void embeddingForward_F32_F32_Q80(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + NnUint dimSize = getBytes(F_32, context->outputSize.x); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { - NnSize token = (NnSize)*((float *)context->input[batchIndex]); + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { + NnUint token = (NnUint)*((float *)context->input[batchIndex]); quantizeF32toQ80( (float *)&context->weight[token * dimSize], (NnBlockQ80 *)context->output[batchIndex], @@ -908,14 +908,14 @@ static void embeddingForward_F32_F32_Q80(NnSize nThreads, NnSize threadIndex, Nn } } -static void invRmsForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void invRmsForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { if (threadIndex == 0) { ASSERT_EQ(context->inputSize.y, context->nBatches); ASSERT_EQ(context->outputSize.x, 1); ASSERT_EQ(context->outputSize.y, context->nBatches); const NnInvRmsOpConfig *config = (NnInvRmsOpConfig *)context->opConfig; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *input = (float *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; DEBUG_VECTOR(context, "input", input); @@ -944,14 +944,14 @@ static void initRmsNormForward_ANY_F32_F32(NnCpuOpContext *context) { ASSERT_EQ(rmsBufferConfig->size.y, context->nBatches); } -static void rmsNormForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void rmsNormForward_F32_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { ASSERT_EQ(context->inputSize.floatType, F_32); NnRmsNormOpConfig *config = (NnRmsNormOpConfig *)context->opConfig; const float *weight = (float *)context->weight; const float *invRms = (float *)context->buffers[config->invRmsBufferIndex]; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *input = (float *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; rmsNorm_F32( @@ -965,14 +965,14 @@ static void rmsNormForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, NnSi } } -static void rmsNormForward_Q80_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void rmsNormForward_Q80_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { ASSERT_EQ(context->inputSize.floatType, F_Q80); NnRmsNormOpConfig *config = (NnRmsNormOpConfig *)context->opConfig; const float *weight = (float *)context->weight; const float *invRms = (float *)context->buffers[config->invRmsBufferIndex]; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { NnBlockQ80 *input = (NnBlockQ80 *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; rmsNorm_Q80_F32_F32( @@ -1000,12 +1000,12 @@ static void initMatmulForward(NnCpuOpContext *context) { } -static bool matmulForward_llamafile(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static bool matmulForward_llamafile(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { if (batchSize == 1 || !context->hasInputContinuousMemory || !context->hasOutputContinuousMemory) return false; - const NnSize n = context->weightSize.y / getBlockSize(context->inputSize.floatType); - const NnSize d = context->weightSize.x; + const NnUint n = context->weightSize.y / getBlockSize(context->inputSize.floatType); + const NnUint d = context->weightSize.x; return llamafile_sgemm( d, batchSize, n, context->weight, n, @@ -1018,12 +1018,12 @@ static bool matmulForward_llamafile(NnSize nThreads, NnSize threadIndex, NnSize ); } -static void matmulForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void matmulForward_F32_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { if (matmulForward_llamafile(nThreads, threadIndex, batchSize, context)) return; const float *weight = (float *)context->weight; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *input = (float *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; DEBUG_VECTOR(context, "input", input); @@ -1038,12 +1038,12 @@ static void matmulForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, NnSiz } } -static void matmulForward_Q80_Q40_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void matmulForward_Q80_Q40_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { if (matmulForward_llamafile(nThreads, threadIndex, batchSize, context)) return; const NnBlockQ40 *weight = (NnBlockQ40 *)context->weight; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { NnBlockQ80 *input = (NnBlockQ80 *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; matmul_Q80_Q40_F32( @@ -1057,23 +1057,23 @@ static void matmulForward_Q80_Q40_F32(NnSize nThreads, NnSize threadIndex, NnSiz } } -static void siluForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - ASSERT_EQ(context->weightSize.nBytes, 0); +static void siluForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + assert(context->weightSize.nBytes == 0); ASSERT_EQ(context->inputSize.x, context->outputSize.x); ASSERT_EQ(context->inputSize.y, context->outputSize.y); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *output = (float *)context->output[batchIndex]; silu_F32(output, context->outputSize.x, nThreads, threadIndex); } } -static void geluForward_F32_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - ASSERT_EQ(context->weightSize.nBytes, 0); +static void geluForward_F32_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + assert(context->weightSize.nBytes == 0); ASSERT_EQ(context->inputSize.x, context->outputSize.x); ASSERT_EQ(context->inputSize.y, context->outputSize.y); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *output = (float *)context->output[batchIndex]; gelu_F32(output, context->outputSize.x, nThreads, threadIndex); } @@ -1088,9 +1088,9 @@ static void initRopeLlama31Forward(NnCpuOpContext *context) { const NnRopeSlice *slice = &config->slice; float *cache = (float *)context->buffers[config->ropeCacheBufferIndex]; - for (NnSize pos = 0; pos < slice->seqLen; pos++) { - for (NnSize i = slice->kvDimStart; i < slice->qDimEnd; i += 2) { - const NnSize headDim = i % slice->headSize; + for (NnUint pos = 0; pos < slice->seqLen; pos++) { + for (NnUint i = slice->kvDimStart; i < slice->qDimEnd; i += 2) { + const NnUint headDim = i % slice->headSize; const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize); const float val = pos * freq; const float fcr = cosf(val); @@ -1117,26 +1117,26 @@ static inline float ropeLlama31Scale(const float freq, const NnRopeLlamaOpConfig return (1 - smooth) * freq / config->ropeScalingFactor + smooth * freq; } -static void ropeLlamaForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void ropeLlamaForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { const NnRopeLlamaOpConfig *config = (NnRopeLlamaOpConfig *)context->opConfig; const NnRopeSlice *slice = &config->slice; - const NnSize dim0Half = (config->isQ ? slice->qDim0 : slice->kvDim0) / 2; - const NnSize shift = config->isQ ? slice->qShift : 0; + const NnUint dim0Half = (config->isQ ? slice->qDim0 : slice->kvDim0) / 2; + const NnUint shift = config->isQ ? slice->qShift : 0; SPLIT_THREADS(s, e, dim0Half, nThreads, threadIndex); - const NnSize iStart = s * 2; - const NnSize iEnd = e * 2; + const NnUint iStart = s * 2; + const NnUint iEnd = e * 2; const bool applyScale = config->ropeScalingFactor != 1.0f; const float *cache = (float *)context->buffers[config->ropeCacheBufferIndex]; const float *positions = (float *)context->pipes[config->positionPipeIndex]; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *x = (float *)context->input[batchIndex]; - const NnSize pos = (NnSize)positions[batchIndex]; + const NnUint pos = (NnUint)positions[batchIndex]; const float *posCache = &cache[pos * slice->sliceDim + shift]; - for (NnSize i = iStart; i < iEnd; i += 2) { + for (NnUint i = iStart; i < iEnd; i += 2) { const float fcr = posCache[i]; const float fci = posCache[i + 1]; const float v0 = x[i]; @@ -1157,7 +1157,7 @@ static void ropeLlamaForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize static void initMultiHeadAttForward(NnCpuOpContext *context) { const NnMultiHeadAttOpConfig *config = (NnMultiHeadAttOpConfig *)context->opConfig; - ASSERT_EQ(context->weightSize.nBytes, 0); + assert(context->weightSize.nBytes == 0); ASSERT_EQ(context->inputSize.x, config->qSlice.d0); ASSERT_EQ(context->inputSize.y, context->nBatches); NnSize2D *querySize = &context->bufferConfigs[config->queryBufferIndex].size; @@ -1167,7 +1167,7 @@ static void initMultiHeadAttForward(NnCpuOpContext *context) { ASSERT_EQ(posSize->y, context->nBatches); } -static void multiHeadAttForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void multiHeadAttForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { const NnMultiHeadAttOpConfig *config = (NnMultiHeadAttOpConfig *)context->opConfig; float *query = (float *)context->buffers[config->queryBufferIndex]; @@ -1176,10 +1176,10 @@ static void multiHeadAttForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnS float *att = (float *)context->buffers[config->attBufferIndex]; const float *positions = (float *)context->pipes[config->positionPipeIndex]; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *i = (float *)context->input[batchIndex]; float *q = &query[batchIndex * config->qSlice.d0]; - NnSize pos = (NnSize)positions[batchIndex]; + NnUint pos = (NnUint)positions[batchIndex]; assert(pos < config->seqLen); DEBUG_VECTOR(context, "input", i); @@ -1191,12 +1191,12 @@ static void multiHeadAttForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnS } } -static void mulForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - ASSERT_EQ(context->weightSize.nBytes, 0); +static void mulForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + assert(context->weightSize.nBytes == 0); ASSERT_EQ(context->inputSize.x, context->outputSize.x); ASSERT_EQ(context->inputSize.y, context->outputSize.y); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *input = (float *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; mul_F32( @@ -1208,8 +1208,8 @@ static void mulForward_F32_F32(NnSize nThreads, NnSize threadIndex, NnSize batch } } -static void mulForward_Q80_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { +static void mulForward_Q80_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { NnBlockQ80 *input = (NnBlockQ80 *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; mul_Q80_F32( @@ -1226,10 +1226,10 @@ static void initCastForward(NnCpuOpContext *context) { ASSERT_EQ(context->inputSize.y, context->outputSize.y); } -static void castForward_ANY(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { - const NnSize rowBytes = context->outputSize.nBytes / context->outputSize.y; +static void castForward_ANY(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { + const NnUint rowBytes = context->outputSize.nBytes / context->outputSize.y; - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { copy_UNK( context->output[batchIndex], context->input[batchIndex], @@ -1239,11 +1239,11 @@ static void castForward_ANY(NnSize nThreads, NnSize threadIndex, NnSize batchSiz } } -static void castForward_F32_Q80(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void castForward_F32_Q80(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { ASSERT_EQ(context->inputSize.floatType, F_32); ASSERT_EQ(context->outputSize.floatType, F_Q80); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { float *input = (float *)context->input[batchIndex]; NnBlockQ80 *output = (NnBlockQ80 *)context->output[batchIndex]; quantizeF32toQ80( @@ -1255,11 +1255,11 @@ static void castForward_F32_Q80(NnSize nThreads, NnSize threadIndex, NnSize batc } } -static void castForward_Q80_F32(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context) { +static void castForward_Q80_F32(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context) { ASSERT_EQ(context->inputSize.floatType, F_Q80); ASSERT_EQ(context->outputSize.floatType, F_32); - for (NnSize batchIndex = 0; batchIndex < batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) { NnBlockQ80 *input = (NnBlockQ80 *)context->input[batchIndex]; float *output = (float *)context->output[batchIndex]; dequantizeQ80toF32( diff --git a/src/nn/nn-cpu-ops.hpp b/src/nn/nn-cpu-ops.hpp index c0d0f1b..bd242d0 100644 --- a/src/nn/nn-cpu-ops.hpp +++ b/src/nn/nn-cpu-ops.hpp @@ -32,12 +32,12 @@ typedef struct { } NnCpuOpContext; typedef void (*NnCpuOpForwardInit)(NnCpuOpContext *context); -typedef void (*NnCpuOpForward)(NnSize nThreads, NnSize threadIndex, NnSize batchSize, NnCpuOpContext *context); +typedef void (*NnCpuOpForward)(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context); void printCpuInstructionSet(); NnCpuOpForwardInit getCpuOpForwardInit(NnOpCode code, NnOpQuantType quantType); NnCpuOpForward getCpuOpForward(NnOpCode code, NnOpQuantType quantType); -void softmax_F32(float *x, const NnSize size); +void softmax_F32(float *x, const NnUint size); #endif \ No newline at end of file diff --git a/src/nn/nn-cpu-test.cpp b/src/nn/nn-cpu-test.cpp index 52ede4e..df335a4 100644 --- a/src/nn/nn-cpu-test.cpp +++ b/src/nn/nn-cpu-test.cpp @@ -7,12 +7,12 @@ #define N_BATCHES 2 void buildConfig(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) { - NnSize nNodes = 1; + NnUint nNodes = 1; NnNetConfigBuilder netBuilder(nNodes, N_BATCHES); - NnSize xPipeIndex = netBuilder.addPipe("X", size2D(F_32, N_BATCHES, DIM)); + NnUint xPipeIndex = netBuilder.addPipe("X", size2D(F_32, N_BATCHES, DIM)); NnNodeConfigBuilder nodeBuilder(0); - NnSize invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1)); + NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1)); NnSegmentConfigBuilder segmentBuilder; segmentBuilder.addSync(xPipeIndex, SYNC_NODE_SLICES_EXCEPT_ROOT); @@ -34,10 +34,10 @@ void buildConfig(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) { *nodeConfig = nodeBuilder.build(); } -void print2D(const char *name, NnSize x, NnSize y, float *w) { - for (NnSize i = 0; i < y; i++) { +void print2D(const char *name, NnUint x, NnUint y, float *w) { + for (NnUint i = 0; i < y; i++) { printf("%s[%d] = ", name, i); - for (NnSize j = 0; j < x; j++) + for (NnUint j = 0; j < x; j++) printf("%f ", w[i * x + j]); printf("\n"); } @@ -46,22 +46,22 @@ void print2D(const char *name, NnSize x, NnSize y, float *w) { int main() { initQuants(); - NnSize nThreads = 2; + NnUint nThreads = 2; NnNetConfig netConfig; NnNodeConfig nodeConfig; buildConfig(&netConfig, &nodeConfig); NnNetExecution execution(nThreads, &netConfig); float *x = (float *)execution.pipes[0]; - for (NnSize b = 0; b < N_BATCHES; b++) { - for (NnSize i = 0; i < DIM; i++) + for (NnUint b = 0; b < N_BATCHES; b++) { + for (NnUint i = 0; i < DIM; i++) x[b * DIM + i] = i / (float)DIM + (float)b; } print2D("x", DIM, N_BATCHES, x); float rmsNormWeight[DIM]; - for (NnSize i = 0; i < DIM; i++) + for (NnUint i = 0; i < DIM; i++) rmsNormWeight[i] = 0.5 + i / (float)DIM; NnCpuDevice device(&netConfig, &nodeConfig, &execution); diff --git a/src/nn/nn-cpu.cpp b/src/nn/nn-cpu.cpp index 768f949..b9e0305 100644 --- a/src/nn/nn-cpu.cpp +++ b/src/nn/nn-cpu.cpp @@ -16,7 +16,7 @@ #define BUFFER_ALIGNMENT 64 -static NnByte *allocAlignedBuffer(size_t size) { +static NnByte *allocAlignedBuffer(NnSize size) { NnByte *buffer; #ifdef _WIN32 buffer = (NnByte *)_aligned_malloc(size, BUFFER_ALIGNMENT); @@ -47,7 +47,7 @@ NnCpuDevice::NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNet nBuffers = nodeConfig->nBuffers; buffers = new NnByte *[nBuffers]; - for (NnSize bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) { + for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) { NnBufferConfig *config = &nodeConfig->buffers[bufferIndex]; NnByte *buffer = allocAlignedBuffer(config->size.nBytes); buffers[bufferIndex] = buffer; @@ -58,17 +58,17 @@ NnCpuDevice::NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNet } NnCpuDevice::~NnCpuDevice() { - for (NnSize bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) + for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) releaseAlignedBuffer(buffers[bufferIndex]); delete[] buffers; delete[] bufferFlags; } -NnSize NnCpuDevice::maxNThreads() { +NnUint NnCpuDevice::maxNThreads() { return std::thread::hardware_concurrency(); } -NnDeviceSegment *NnCpuDevice::createSegment(NnSize segmentIndex) { +NnDeviceSegment *NnCpuDevice::createSegment(NnUint segmentIndex) { NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; assert(segmentConfig->nOps > 0); @@ -82,7 +82,7 @@ NnDeviceSegment *NnCpuDevice::createSegment(NnSize segmentIndex) { NnByte **inputs = inputsPtr.get(); NnByte **outputs = outputsPtr.get(); - for (NnSize opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; NnSize2D inputSize; NnSize2D outputSize; @@ -92,9 +92,9 @@ NnDeviceSegment *NnCpuDevice::createSegment(NnSize segmentIndex) { inputSize.floatType, opConfig->weightSize.floatType, outputSize.floatType); - #if DEBUG_CPU_OP_QUANTS +#if DEBUG_CPU_OP_QUANTS printf("%20s %2d: %s\n", opConfig->name, opConfig->index, opQuantTypeToString(opQuant)); - #endif +#endif NnCpuOpForward forward = getCpuOpForward(opConfig->code, opQuant); if (forward == nullptr) { throw std::invalid_argument( @@ -114,7 +114,7 @@ NnDeviceSegment *NnCpuDevice::createSegment(NnSize segmentIndex) { NnCpuOpForward *opForward = new NnCpuOpForward[segmentConfig->nOps]; NnCpuOpContext *opContexts = new NnCpuOpContext[segmentConfig->nOps]; - for (NnSize opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; NnCpuOpContext *opContext = &opContexts[opIndex]; NnCpuOpForwardInit opInit = getCpuOpForwardInit(opConfig->code, opQuants[opIndex]); @@ -136,10 +136,12 @@ NnDeviceSegment *NnCpuDevice::createSegment(NnSize segmentIndex) { opContext->outputSize = outputSizes[opIndex]; opContext->hasOutputContinuousMemory = hasPointerContinuousMemory(&opConfig->output); +#if not(DEBUG_USE_MMAP_FOR_WEIGHTS) if (opContext->weightSize.nBytes > 0) opContext->weight = allocAlignedBuffer(opContext->weightSize.nBytes); else opContext->weight = nullptr; +#endif if (opInit != nullptr) opInit(opContext); @@ -149,14 +151,16 @@ NnDeviceSegment *NnCpuDevice::createSegment(NnSize segmentIndex) { } NnCpuDeviceSegment::~NnCpuDeviceSegment() { - for (NnSize opIndex = 0; opIndex < nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < nOps; opIndex++) { NnCpuOpContext *context = &opContexts[opIndex]; if (opIndex == 0) { delete[] context->input; delete[] context->output; } +#if not(DEBUG_USE_MMAP_FOR_WEIGHTS) if (context->weightSize.nBytes > 0) releaseAlignedBuffer(context->weight); +#endif } delete[] opForward; delete[] opContexts; @@ -178,8 +182,8 @@ void NnCpuDevice::resolvePointer(NnByte **pntr, NnSize2D *pntrSize, NnPointerCon if (pointerConfig->batchType == PNTR_BATCH_DEFAULT) { ASSERT_EQ(sourceSize->y, netConfig->nBatches); - NnSize batchBytes = getBytes(sourceSize->floatType, sourceSize->x); - for (NnSize batchIndex = 0; batchIndex < netConfig->nBatches; batchIndex++) + NnUint batchBytes = getBytes(sourceSize->floatType, sourceSize->x); + for (NnUint batchIndex = 0; batchIndex < netConfig->nBatches; batchIndex++) pntr[batchIndex] = &source[batchIndex * batchBytes]; *pntrSize = *sourceSize; @@ -187,9 +191,9 @@ void NnCpuDevice::resolvePointer(NnByte **pntr, NnSize2D *pntrSize, NnPointerCon return; if (pointerConfig->sliceType == SLICE_NODE_PART) { assert(sourceSize->x % netConfig->nNodes == 0); - NnSize xSlice = sourceSize->x / netConfig->nNodes; - NnSize xSliceBytes = getBytes(sourceSize->floatType, xSlice); - for (NnSize batchIndex = 0; batchIndex < netConfig->nBatches; batchIndex++) + NnUint xSlice = sourceSize->x / netConfig->nNodes; + NnUint xSliceBytes = getBytes(sourceSize->floatType, xSlice); + for (NnUint batchIndex = 0; batchIndex < netConfig->nBatches; batchIndex++) pntr[batchIndex] = &pntr[batchIndex][xSliceBytes * nodeConfig->nodeIndex]; *pntrSize = size2D(sourceSize->floatType, sourceSize->y, xSlice); return; @@ -206,30 +210,34 @@ void NnCpuDevice::resolvePointer(NnByte **pntr, NnSize2D *pntrSize, NnPointerCon } void NnCpuDevice::syncPointers() { - NnSize nDynamicPointers = dynamicPointers.size(); - for (NnSize dynamicPointerIndex = 0; dynamicPointerIndex < nDynamicPointers; dynamicPointerIndex++) { + NnUint nDynamicPointers = dynamicPointers.size(); + for (NnUint dynamicPointerIndex = 0; dynamicPointerIndex < nDynamicPointers; dynamicPointerIndex++) { NnCpuDynamicPointer *dp = &dynamicPointers[dynamicPointerIndex]; assert(dp->pointerConfig->batchType == PNTR_BATCH_PIPE); float *pipe = (float *)netExecution->pipes[dp->pointerConfig->batchArg0]; - for (NnSize batchIndex = 0; batchIndex < netExecution->batchSize; batchIndex++) { - NnSize index = (NnSize)pipe[batchIndex]; + for (NnUint batchIndex = 0; batchIndex < netExecution->batchSize; batchIndex++) { + NnUint index = (NnUint)pipe[batchIndex]; assert(index < dp->sourceSize->y); - NnSize nBytes = dp->sourceSize->nBytes / dp->sourceSize->y; + NnUint nBytes = dp->sourceSize->nBytes / dp->sourceSize->y; dp->pntr[batchIndex] = &dp->source[index * nBytes]; } } } -void NnCpuDeviceSegment::loadWeight(NnSize opIndex, NnSize nBytes, NnByte *weight) { +void NnCpuDeviceSegment::loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) { assert(opIndex >= 0); assert(opIndex < nOps); NnCpuOpContext *context = &opContexts[opIndex]; - ASSERT_EQ(context->weightSize.nBytes, nBytes); + assert(context->weightSize.nBytes == nBytes); +#if DEBUG_USE_MMAP_FOR_WEIGHTS + context->weight = weight; +#else std::memcpy(context->weight, weight, nBytes); +#endif } -void NnCpuDeviceSegment::forward(NnSize opIndex, NnSize nThreads, NnSize threadIndex, NnSize batchSize) { +void NnCpuDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) { NnCpuOpContext *context = &opContexts[opIndex]; // printf("forward: %d %s (%d/%d)\n", opIndex, context->name, threadIndex + 1, nThreads); fflush(stdout); opForward[opIndex](nThreads, threadIndex, batchSize, context); diff --git a/src/nn/nn-cpu.hpp b/src/nn/nn-cpu.hpp index e0662ce..b09fc59 100644 --- a/src/nn/nn-cpu.hpp +++ b/src/nn/nn-cpu.hpp @@ -5,6 +5,8 @@ #include "nn-executor.hpp" #include "nn-cpu-ops.hpp" +#define DEBUG_USE_MMAP_FOR_WEIGHTS false + typedef struct { NnByte *source; NnSize2D *sourceSize; @@ -19,28 +21,28 @@ class NnCpuDevice : public NnDevice { NnNetConfig *netConfig; NnNodeConfig *nodeConfig; NnNetExecution *netExecution; - NnSize nBuffers; + NnUint nBuffers; NnByte *bufferFlags; std::vector dynamicPointers; public: NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution); ~NnCpuDevice(); - NnSize maxNThreads() override; - NnDeviceSegment *createSegment(NnSize segmentIndex) override; + NnUint maxNThreads() override; + NnDeviceSegment *createSegment(NnUint segmentIndex) override; void syncPointers() override; void resolvePointer(NnByte **pntr, NnSize2D *pntrSize, NnPointerConfig *pointerConfig); }; class NnCpuDeviceSegment : public NnDeviceSegment { public: - NnSize nOps; + NnUint nOps; NnCpuOpForward *opForward; NnCpuOpContext *opContexts; - NnCpuDeviceSegment(NnCpuOpForward *opForward, NnCpuOpContext *opContexts, NnSize nOps) + NnCpuDeviceSegment(NnCpuOpForward *opForward, NnCpuOpContext *opContexts, NnUint nOps) : opForward(opForward), opContexts(opContexts), nOps(nOps) {} ~NnCpuDeviceSegment() override; - void loadWeight(NnSize opIndex, NnSize nBytes, NnByte *weight) override; - void forward(NnSize opIndex, NnSize nThreads, NnSize threadIndex, NnSize batchSize) override; + void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) override; + void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override; }; #endif \ No newline at end of file diff --git a/src/nn/nn-executor.cpp b/src/nn/nn-executor.cpp index 8f96e41..2860852 100644 --- a/src/nn/nn-executor.cpp +++ b/src/nn/nn-executor.cpp @@ -5,18 +5,18 @@ #define DEBUG_EXECUTOR_BENCHMARK false -void NnFakeNodeSynchronizer::sync(NnSize segmentIndex, NnSize nThreads, NnSize threadIndex) { +void NnFakeNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) { // Nothing } -NnNetExecution::NnNetExecution(NnSize nThreads, NnNetConfig *netConfig) { +NnNetExecution::NnNetExecution(NnUint nThreads, NnNetConfig *netConfig) { this->nThreads = nThreads; this->nBatches = netConfig->nBatches; this->nPipes = netConfig->nPipes; this->batchSize = 0; // This value must be overwritten before calling forward pipes = new NnByte *[netConfig->nPipes]; - for (NnSize pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) { + for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) { NnPipeConfig *pipeConfig = &netConfig->pipes[pipeIndex]; NnByte *pipe = new NnByte[pipeConfig->size.nBytes]; std::memset(pipe, 0, pipeConfig->size.nBytes); @@ -25,12 +25,12 @@ NnNetExecution::NnNetExecution(NnSize nThreads, NnNetConfig *netConfig) { } NnNetExecution::~NnNetExecution() { - for (NnSize pipeIndex = 0; pipeIndex < nPipes; pipeIndex++) + for (NnUint pipeIndex = 0; pipeIndex < nPipes; pipeIndex++) delete[] pipes[pipeIndex]; delete[] pipes; } -void NnNetExecution::setBatchSize(NnSize batchSize) { +void NnNetExecution::setBatchSize(NnUint batchSize) { assert(batchSize > 0 && batchSize <= nBatches); this->batchSize = batchSize; } @@ -38,20 +38,20 @@ void NnNetExecution::setBatchSize(NnSize batchSize) { NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer) : segments(nodeConfig->nSegments), steps() { - NnSize maxNThreads = device->maxNThreads(); + NnUint maxNThreads = device->maxNThreads(); if (netExecution->nThreads > maxNThreads) throw std::invalid_argument("This CPU supports max " + std::to_string(maxNThreads) + " threads"); this->netExecution = netExecution; this->nodeConfig = nodeConfig; bool useSynchronizer = netConfig->nNodes > 1; - for (NnSize segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { + for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; if (segmentConfig->nOps > 0) { NnDeviceSegment *segment = device->createSegment(segmentIndex); segments[segmentIndex] = std::unique_ptr(segment); - for (NnSize opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) + for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) steps.push_back(NnExecutorStep{ STEP_EXECUTE_OP, segment, opIndex, &segmentConfig->ops[opIndex] }); } if (useSynchronizer && segmentConfig->nSyncs > 0) @@ -65,11 +65,11 @@ NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevic context.nThreads = netExecution->nThreads; context.synchronizer = synchronizer; context.device = device; - context.nSteps = (NnSize)steps.size(); + context.nSteps = (NnUint)steps.size(); context.steps = steps.data(); threads = new NnExecutorThread[netExecution->nThreads]; - for (NnSize threadIndex = 0; threadIndex < netExecution->nThreads; threadIndex++) { + for (NnUint threadIndex = 0; threadIndex < netExecution->nThreads; threadIndex++) { NnExecutorThread *thread = &threads[threadIndex]; thread->threadIndex = threadIndex; thread->context = &context; @@ -80,10 +80,10 @@ NnExecutor::~NnExecutor() { delete[] threads; } -void NnExecutor::loadWeight(const char *name, NnSize index, NnSize nBytes, NnByte *weight) { - for (NnSize segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { +void NnExecutor::loadWeight(const char *name, NnUint index, NnSize nBytes, NnByte *weight) { + for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; - for (NnSize opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; if (opConfig->index == index && std::strcmp(opConfig->name, name) == 0) { NnDeviceSegment *segment = segments[segmentIndex].get(); @@ -96,7 +96,7 @@ void NnExecutor::loadWeight(const char *name, NnSize index, NnSize nBytes, NnByt throw std::invalid_argument("Cannot locate op by name: " + std::string(name)); } -inline void executeStep(NnExecutorStep *step, NnSize nThreads, NnExecutorThread *thread, NnExecutorContext *context) { +inline void executeStep(NnExecutorStep *step, NnUint nThreads, NnExecutorThread *thread, NnExecutorContext *context) { #if DEBUG_EXECUTOR_BENCHMARK assert(nThreads == 1); Timer startTime; @@ -114,7 +114,7 @@ inline void executeStep(NnExecutorStep *step, NnSize nThreads, NnExecutorThread } #if DEBUG_EXECUTOR_BENCHMARK - NnSize duration = startTime.elapsedMicroseconds(); + NnUint duration = startTime.elapsedMicroseconds(); if (step->type == STEP_EXECUTE_OP) printf("🕒 [OP %16s %2d] %u μs\n", opCodeToString(step->opConfig->code), step->opConfig->index, duration); else if (step->type == STEP_SYNC_NODES) @@ -125,8 +125,8 @@ inline void executeStep(NnExecutorStep *step, NnSize nThreads, NnExecutorThread static inline void *executorThreadHandler(void *arg) { NnExecutorThread *thread = (NnExecutorThread *)arg; NnExecutorContext *context = thread->context; - NnSize nThreads = context->nThreads; - NnSize doneCount = nThreads - 1; + NnUint nThreads = context->nThreads; + NnUint doneCount = nThreads - 1; while (true) { const unsigned int currentStepIndex = context->currentStepIndex.load(); @@ -136,7 +136,7 @@ static inline void *executorThreadHandler(void *arg) { NnExecutorStep *step = &context->steps[currentStepIndex]; executeStep(step, nThreads, thread, context); - NnSize currentCount = context->doneThreadCount.fetch_add(1); + NnUint currentCount = context->doneThreadCount.fetch_add(1); if (currentCount == doneCount) { context->doneThreadCount.store(0); context->currentStepIndex.fetch_add(1); @@ -150,12 +150,12 @@ static inline void *executorThreadHandler(void *arg) { void NnExecutor::forward() { assert(netExecution->batchSize > 0); - NnSize nThreads = netExecution->nThreads; + NnUint nThreads = netExecution->nThreads; context.currentStepIndex.exchange(0); context.doneThreadCount.exchange(0); context.batchSize = netExecution->batchSize; - NnSize threadIndex; + NnUint threadIndex; for (threadIndex = 1; threadIndex < nThreads; threadIndex++) { int result = pthread_create(&threads[threadIndex].handler, NULL, (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]); if (result != 0) diff --git a/src/nn/nn-executor.hpp b/src/nn/nn-executor.hpp index a8625aa..854cbb5 100644 --- a/src/nn/nn-executor.hpp +++ b/src/nn/nn-executor.hpp @@ -9,41 +9,41 @@ class NnDeviceSegment { public: virtual ~NnDeviceSegment() {}; - virtual void loadWeight(NnSize opIndex, NnSize nBytes, NnByte *weight) = 0; - virtual void forward(NnSize opIndex, NnSize nThreads, NnSize threadIndex, NnSize batchSize) = 0; + virtual void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) = 0; + virtual void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) = 0; }; class NnDevice { public: - virtual NnSize maxNThreads() = 0; - virtual NnDeviceSegment *createSegment(NnSize segmentIndex) = 0; + virtual NnUint maxNThreads() = 0; + virtual NnDeviceSegment *createSegment(NnUint segmentIndex) = 0; virtual void syncPointers() = 0; }; class NnNodeSynchronizer { public: virtual ~NnNodeSynchronizer() {}; - virtual void sync(NnSize segmentIndex, NnSize nThreads, NnSize threadIndex) = 0; + virtual void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) = 0; }; class NnFakeNodeSynchronizer : public NnNodeSynchronizer { public: ~NnFakeNodeSynchronizer() override {}; - void sync(NnSize segmentIndex, NnSize nThreads, NnSize threadIndex) override; + void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override; }; class NnNetExecution { public: - NnSize nThreads; + NnUint nThreads; NnByte **pipes; - NnSize batchSize; + NnUint batchSize; private: - NnSize nBatches; - NnSize nPipes; + NnUint nBatches; + NnUint nPipes; public: - NnNetExecution(NnSize nThreads, NnNetConfig *netConfig); + NnNetExecution(NnUint nThreads, NnNetConfig *netConfig); ~NnNetExecution(); - void setBatchSize(NnSize batchSize); + void setBatchSize(NnUint batchSize); }; enum NnExecutorStepType { @@ -55,23 +55,23 @@ enum NnExecutorStepType { typedef struct { NnExecutorStepType type; NnDeviceSegment *segment; - NnSize arg0; + NnUint arg0; NnOpConfig *opConfig; } NnExecutorStep; typedef struct { - NnSize nThreads; - NnSize nSteps; + NnUint nThreads; + NnUint nSteps; NnExecutorStep *steps; NnNodeSynchronizer *synchronizer; NnDevice *device; std::atomic_uint currentStepIndex; std::atomic_uint doneThreadCount; - NnSize batchSize; + NnUint batchSize; } NnExecutorContext; typedef struct { - NnSize threadIndex; + NnUint threadIndex; NnExecutorContext *context; PthreadHandler handler; } NnExecutorThread; @@ -86,7 +86,7 @@ class NnExecutor { NnExecutorContext context; NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer); ~NnExecutor(); - void loadWeight(const char *name, NnSize index, NnSize nBytes, NnByte *weight); + void loadWeight(const char *name, NnUint index, NnSize nBytes, NnByte *weight); void forward(); }; diff --git a/src/nn/nn-network.cpp b/src/nn/nn-network.cpp index 113a78c..250aa51 100644 --- a/src/nn/nn-network.cpp +++ b/src/nn/nn-network.cpp @@ -20,7 +20,7 @@ typedef SSIZE_T ssize_t; #define SOCKET_LAST_ERRCODE errno #define SOCKET_LAST_ERROR strerror(errno) -#define ACK 23571113 +#define ACK 23571114 #define MAX_CHUNK_SIZE 4096 static inline bool isEagainError() { @@ -81,7 +81,7 @@ void setReuseAddr(int socket) { #endif } -void writeSocket(int socket, const void *data, size_t size) { +void writeSocket(int socket, const void *data, NnSize size) { while (size > 0) { ssize_t s = send(socket, (const char*)data, size, 0); if (s < 0) { @@ -97,9 +97,9 @@ void writeSocket(int socket, const void *data, size_t size) { } } -static inline bool tryReadSocket(int socket, void *data, size_t size, unsigned long maxAttempts) { +static inline bool tryReadSocket(int socket, void *data, NnSize size, unsigned long maxAttempts) { // maxAttempts = 0 means infinite attempts - size_t s = size; + NnSize s = size; while (s > 0) { ssize_t r = recv(socket, (char*)data, s, 0); if (r < 0) { @@ -122,21 +122,21 @@ static inline bool tryReadSocket(int socket, void *data, size_t size, unsigned l return true; } -void readSocket(int socket, void *data, size_t size) { +void readSocket(int socket, void *data, NnSize size) { if (!tryReadSocket(socket, data, size, 0)) { throw std::runtime_error("Error reading from socket"); } } static void readAckPacket(int socket) { - NnSize packet; + NnUint packet; readSocket(socket, &packet, sizeof(packet)); if (packet != ACK) throw std::runtime_error("Invalid ack packet"); } static void writeAckPacket(int socket) { - NnSize packet = ACK; + NnUint packet = ACK; writeSocket(socket, &packet, sizeof(packet)); } @@ -254,13 +254,13 @@ NnWriteNetworkException::NnWriteNetworkException(int code, const char *message) std::unique_ptr NnNetwork::serve(int port) { int serverSocket = createServerSocket(port); - NnSize nSockets; - NnSize nodeIndex; + NnUint nSockets; + NnUint nodeIndex; int rootSocket = acceptSocket(serverSocket); printf("⭕ The root node has connected\n"); readSocket(rootSocket, &nSockets, sizeof(nSockets)); - NnSize nNodes = nSockets - 1; // nSockets - 1 root node + NnUint nNodes = nSockets - 1; // nSockets - 1 root node printf("⭕ nNodes: %d\n", nNodes); readSocket(rootSocket, &nodeIndex, sizeof(nodeIndex)); printf("⭕ NodeIndex: %d\n", nodeIndex); @@ -271,8 +271,8 @@ std::unique_ptr NnNetwork::serve(int port) { int *ports = new int[nNodes]; printf("⭕ Socket[0]: accepted root node\n"); - size_t hostLen; - for (NnSize i = 0; i < nNodes; i++) { + NnUint hostLen; + for (NnUint i = 0; i < nNodes; i++) { readSocket(rootSocket, &hostLen, sizeof(hostLen)); hosts[i] = new char[hostLen]; readSocket(rootSocket, hosts[i], hostLen); @@ -284,8 +284,8 @@ std::unique_ptr NnNetwork::serve(int port) { // We need to wait here until the root node will send a "root is ready" packet readAckPacket(rootSocket); - for (NnSize i = 0; i < nNodes; i++) { - NnSize socketIndex = i + 1; + for (NnUint i = 0; i < nNodes; i++) { + NnUint socketIndex = i + 1; if (i >= nodeIndex) { printf("⭕ Socket[%d]: connecting to %s:%d worker\n", socketIndex, hosts[i], ports[i]); sockets[socketIndex] = connectSocket(hosts[i], ports[i]); @@ -297,7 +297,7 @@ std::unique_ptr NnNetwork::serve(int port) { } } - for (NnSize i = 0; i < nNodes; i++) + for (NnUint i = 0; i < nNodes; i++) delete[] hosts[i]; delete[] hosts; delete[] ports; @@ -308,22 +308,21 @@ std::unique_ptr NnNetwork::serve(int port) { return std::unique_ptr(new NnNetwork(nSockets, sockets)); } -std::unique_ptr NnNetwork::connect(NnSize nSockets, char **hosts, NnSize *ports) { +std::unique_ptr NnNetwork::connect(NnUint nSockets, char **hosts, NnUint *ports) { assert(nSockets > 0); int *sockets = new int[nSockets]; struct sockaddr_in addr; - NnSize confirmPacket; - for (NnSize i = 0; i < nSockets; i++) { + for (NnUint i = 0; i < nSockets; i++) { printf("⭕ Socket[%d]: connecting to %s:%d worker\n", i, hosts[i], ports[i]); int socket = connectSocket(hosts[i], ports[i]); sockets[i] = socket; writeSocket(socket, &nSockets, sizeof(nSockets)); writeSocket(socket, &i, sizeof(i)); // send node index - for (NnSize j = 0; j < nSockets; j++) { + for (NnUint j = 0; j < nSockets; j++) { if (j == i) continue; - size_t hostLen = strlen(hosts[j]) + 1; + NnUint hostLen = strlen(hosts[j]) + 1; writeSocket(socket, &hostLen, sizeof(hostLen)); writeSocket(socket, hosts[j], hostLen); writeSocket(socket, &ports[j], sizeof(ports[j])); @@ -331,22 +330,24 @@ std::unique_ptr NnNetwork::connect(NnSize nSockets, char **hosts, NnS readAckPacket(socket); printf("⭕ Socket[%d]: connected\n", i); } - for (NnSize i = 0; i < nSockets; i++) { + for (NnUint i = 0; i < nSockets; i++) { writeAckPacket(sockets[i]); } printf("⭕ Network is initialized\n"); return std::unique_ptr(new NnNetwork(nSockets, sockets)); } -NnNetwork::NnNetwork(NnSize nSockets, int *sockets) - : sentBytes(0), recvBytes(0) -{ +NnNetwork::NnNetwork(NnUint nSockets, int *sockets) { this->nSockets = nSockets; this->sockets = sockets; + this->sentBytes = new NnSize[nSockets]; + this->recvBytes = new NnSize[nSockets]; } NnNetwork::~NnNetwork() { - for (NnSize i = 0; i < nSockets; i++) { + delete[] sentBytes; + delete[] recvBytes; + for (NnUint i = 0; i < nSockets; i++) { shutdown(sockets[i], 2); close(sockets[i]); } @@ -355,67 +356,67 @@ NnNetwork::~NnNetwork() { } void NnNetwork::setTurbo(bool enabled) { - for (NnSize i = 0; i < nSockets; i++) { + for (NnUint i = 0; i < nSockets; i++) { ::setNonBlocking(sockets[i], enabled); } } -void NnNetwork::write(NnSize socketIndex, const void *data, size_t size) { - assert(socketIndex >= 0 && socketIndex < nSockets); - sentBytes.fetch_add(size); +void NnNetwork::write(const NnUint socketIndex, const void *data, const NnSize size) { + assert(socketIndex < nSockets); - char *current = (char*)data; + NnByte *current = (NnByte *)data; int s = sockets[socketIndex]; - for (size_t chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) { - size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk; + for (NnSize chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) { + NnSize chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk; writeSocket(s, current, chunkSize); current += chunkSize; } + sentBytes[socketIndex] += size; } -void NnNetwork::read(NnSize socketIndex, void *data, size_t size) { - assert(socketIndex >= 0 && socketIndex < nSockets); - recvBytes.fetch_add(size); +void NnNetwork::read(const NnUint socketIndex, void *data, const NnSize size) { + assert(socketIndex < nSockets); - char *current = (char*)data; + NnByte *current = (NnByte *)data; int s = sockets[socketIndex]; - for (size_t chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) { - size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk; + for (NnSize chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) { + NnSize chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk; readSocket(s, current, chunkSize); current += chunkSize; } + recvBytes[socketIndex] += size; } -void NnNetwork::writeAck(NnSize socketIndex) { +void NnNetwork::writeAck(const NnUint socketIndex) { assert(socketIndex >= 0 && socketIndex < nSockets); writeAckPacket(sockets[socketIndex]); } -void NnNetwork::readAck(NnSize socketIndex) { +void NnNetwork::readAck(const NnUint socketIndex) { assert(socketIndex >= 0 && socketIndex < nSockets); readAckPacket(sockets[socketIndex]); } -bool NnNetwork::tryReadWithMaxAttempts(NnSize socketIndex, void *data, size_t size, unsigned long maxAttempts) { +bool NnNetwork::tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts) { assert(socketIndex >= 0 && socketIndex < nSockets); if (tryReadSocket(sockets[socketIndex], data, size, maxAttempts)) { - recvBytes.fetch_add(size); + recvBytes[socketIndex] += size; return true; } return false; } -void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) { +void NnNetwork::writeMany(NnUint n, NnSocketIo *ios) { bool isWriting; - size_t nBytes = 0; - for (NnSize i = 0; i < n; i++) { + NnSize nBytes = 0; + for (NnUint i = 0; i < n; i++) { NnSocketIo *io = &ios[i]; - assert(io->socketIndex >= 0 && io->socketIndex < nSockets); - nBytes += io->size; + assert(io->socketIndex < nSockets); + sentBytes[io->socketIndex] += io->size; } do { isWriting = false; - for (NnSize i = 0; i < n; i++) { + for (NnUint i = 0; i < n; i++) { NnSocketIo *io = &ios[i]; if (io->size > 0) { isWriting = true; @@ -435,12 +436,11 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) { } } } while (isWriting); - sentBytes.fetch_add(nBytes); } -void NnNetwork::writeAll(void *data, size_t size) { +void NnNetwork::writeAll(void *data, NnSize size) { std::vector ios(nSockets); - for (NnSize i = 0; i < nSockets; i++) { + for (NnUint i = 0; i < nSockets; i++) { NnSocketIo *io = &ios[i]; io->socketIndex = i; io->data = data; @@ -449,17 +449,17 @@ void NnNetwork::writeAll(void *data, size_t size) { writeMany(nSockets, &ios[0]); } -void NnNetwork::readMany(NnSize n, NnSocketIo *ios) { +void NnNetwork::readMany(NnUint n, NnSocketIo *ios) { bool isReading; - size_t nBytes = 0; - for (NnSize i = 0; i < n; i++) { + NnSize nBytes = 0; + for (NnUint i = 0; i < n; i++) { NnSocketIo *io = &ios[i]; - assert(io->socketIndex >= 0 && io->socketIndex < nSockets); - nBytes += io->size; + assert(io->socketIndex < nSockets); + recvBytes[io->socketIndex] += io->size; } do { isReading = false; - for (NnSize i = 0; i < n; i++) { + for (NnUint i = 0; i < n; i++) { NnSocketIo *io = &ios[i]; if (io->size > 0) { isReading = true; @@ -478,29 +478,34 @@ void NnNetwork::readMany(NnSize n, NnSocketIo *ios) { } } } while (isReading); - recvBytes.fetch_add(nBytes); } -void NnNetwork::getStats(size_t *sentBytes, size_t *recvBytes) { - *sentBytes = this->sentBytes.load(); - *recvBytes = this->recvBytes.load(); +void NnNetwork::getStats(NnSize *sentBytes, NnSize *recvBytes) { + *sentBytes = 0; + *recvBytes = 0; + for (NnUint i = 0; i < nSockets; i++) { + *sentBytes += this->sentBytes[i]; + *recvBytes += this->recvBytes[i]; + } resetStats(); } void NnNetwork::resetStats() { - sentBytes.exchange(0); - recvBytes.exchange(0); + for (NnUint i = 0; i < nSockets; i++) { + sentBytes[i] = 0; + recvBytes[i] = 0; + } } -static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnSize nThreads, NnSize threadIndex) { +static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnUint nThreads, NnUint threadIndex) { if (nodeIndex == 0) { // root - unsigned int nSocketsPerThread = network->nSockets / nThreads + (network->nSockets % nThreads > threadIndex ? 1 : 0); + NnUint nSocketsPerThread = network->nSockets / nThreads + (network->nSockets % nThreads > threadIndex ? 1 : 0); if (nSocketsPerThread == 0) return; std::vector ios(nSocketsPerThread); - for (int i = 0; i < nSocketsPerThread; i++) { + for (NnUint i = 0; i < nSocketsPerThread; i++) { ios[i].socketIndex = threadIndex + i * nThreads; ios[i].data = buffer; ios[i].size = nBytes; @@ -519,10 +524,10 @@ static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, N } } -static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize nodeIndex, NnSize nNodes, NnByte *buffer, NnSize nBytes, NnSize nThreads, NnSize threadIndex) { +static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnUint nodeIndex, NnUint nNodes, NnByte *buffer, NnSize nBytes, NnUint nThreads, NnUint threadIndex) { bool isWorker = nodeIndex != 0; - NnSize nSockets = onlyFromWorkerToRoot && isWorker ? 1 : network->nSockets; - NnSize nSocketsPerThread = nSockets / nThreads + (nSockets % nThreads > threadIndex ? 1 : 0); + NnUint nSockets = onlyFromWorkerToRoot && isWorker ? 1 : network->nSockets; + NnUint nSocketsPerThread = nSockets / nThreads + (nSockets % nThreads > threadIndex ? 1 : 0); if (nSocketsPerThread == 0) return; NnSize sliceBytes = nBytes / nNodes; @@ -531,8 +536,8 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize if (!onlyFromWorkerToRoot || isWorker) { NnByte *mySliceData = &buffer[sliceBytes * nodeIndex]; - for (unsigned int i = 0; i < nSocketsPerThread; i++) { - unsigned int socketIndex = threadIndex + i * nThreads; + for (NnUint i = 0; i < nSocketsPerThread; i++) { + NnUint socketIndex = threadIndex + i * nThreads; ios[i].socketIndex = socketIndex; ios[i].data = mySliceData; ios[i].size = sliceBytes; @@ -541,9 +546,9 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize } if (!onlyFromWorkerToRoot || !isWorker) { - for (unsigned int i = 0; i < nSocketsPerThread; i++) { - unsigned int socketIndex = threadIndex + i * nThreads; - int sliceIndex = socketIndex >= nodeIndex ? socketIndex + 1 : socketIndex; + for (NnUint i = 0; i < nSocketsPerThread; i++) { + NnUint socketIndex = threadIndex + i * nThreads; + NnUint sliceIndex = socketIndex >= nodeIndex ? socketIndex + 1 : socketIndex; NnByte *sliceData = &buffer[sliceBytes * sliceIndex]; ios[i].socketIndex = socketIndex; ios[i].data = sliceData; @@ -560,16 +565,16 @@ NnNetworkNodeSynchronizer::NnNetworkNodeSynchronizer(NnNetwork *network, NnNetEx this->nodeConfig = nodeConfig; } -void NnNetworkNodeSynchronizer::sync(NnSize segmentIndex, NnSize nThreads, NnSize threadIndex) { +void NnNetworkNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) { NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; - for (NnSize syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) { + for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) { NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex]; NnByte *pipe = execution->pipes[syncConfig->pipeIndex]; NnPipeConfig *pipeConfig = &netConfig->pipes[syncConfig->pipeIndex]; NnSize batchBytes = getBytes(pipeConfig->size.floatType, pipeConfig->size.x); - for (NnSize batchIndex = 0; batchIndex < execution->batchSize; batchIndex++) { + for (NnUint batchIndex = 0; batchIndex < execution->batchSize; batchIndex++) { NnByte *pipeBatch = &pipe[batchIndex * batchBytes]; if (syncConfig->syncType == SYNC_WITH_ROOT) { @@ -585,15 +590,15 @@ void NnNetworkNodeSynchronizer::sync(NnSize segmentIndex, NnSize nThreads, NnSiz } } -static void writeString(NnNetwork *network, NnSize socketIndex, char *str) { - NnSize bytes = std::strlen(str) + 1; - network->write(socketIndex, &bytes, sizeof(NnSize)); +static void writeString(NnNetwork *network, NnUint socketIndex, char *str) { + NnUint bytes = std::strlen(str) + 1; + network->write(socketIndex, &bytes, sizeof(NnUint)); network->write(socketIndex, str, bytes); } -static char *readString(NnNetwork *network, NnSize socketIndex) { - NnSize bytes; - network->read(socketIndex, &bytes, sizeof(NnSize)); +static char *readString(NnNetwork *network, NnUint socketIndex) { + NnUint bytes; + network->read(socketIndex, &bytes, sizeof(NnUint)); char *str = new char[bytes]; network->read(socketIndex, str, bytes); return str; @@ -603,12 +608,12 @@ NnRootConfigWriter::NnRootConfigWriter(NnNetwork *network) { this->network = network; } -void NnRootConfigWriter::writeNet(NnSize socketIndex, NnNetConfig *config) { +void NnRootConfigWriter::writeNet(NnUint socketIndex, NnNetConfig *config) { network->writeAck(socketIndex); network->write(socketIndex, &config->nBatches, sizeof(config->nBatches)); network->write(socketIndex, &config->nNodes, sizeof(config->nNodes)); network->write(socketIndex, &config->nPipes, sizeof(config->nPipes)); - for (NnSize pipeIndex = 0; pipeIndex < config->nPipes; pipeIndex++) { + for (NnUint pipeIndex = 0; pipeIndex < config->nPipes; pipeIndex++) { NnPipeConfig *pipeConfig = &config->pipes[pipeIndex]; network->write(socketIndex, &pipeConfig->size, sizeof(pipeConfig->size)); writeString(network, socketIndex, pipeConfig->name); @@ -616,30 +621,30 @@ void NnRootConfigWriter::writeNet(NnSize socketIndex, NnNetConfig *config) { network->readAck(socketIndex); } -void NnRootConfigWriter::writeNode(NnSize socketIndex, NnNodeConfig *config) { +void NnRootConfigWriter::writeNode(NnUint socketIndex, NnNodeConfig *config) { network->writeAck(socketIndex); network->write(socketIndex, &config->nodeIndex, sizeof(config->nodeIndex)); network->write(socketIndex, &config->nBuffers, sizeof(config->nBuffers)); network->write(socketIndex, &config->nSegments, sizeof(config->nSegments)); - for (NnSize bufferIndex = 0; bufferIndex < config->nBuffers; bufferIndex++) { + for (NnUint bufferIndex = 0; bufferIndex < config->nBuffers; bufferIndex++) { NnBufferConfig *bufferConfig = &config->buffers[bufferIndex]; network->write(socketIndex, &bufferConfig->size, sizeof(bufferConfig->size)); writeString(network, socketIndex, bufferConfig->name); } - for (NnSize segmentIndex = 0; segmentIndex < config->nSegments; segmentIndex++) { + for (NnUint segmentIndex = 0; segmentIndex < config->nSegments; segmentIndex++) { NnSegmentConfig *segmentConfig = &config->segments[segmentIndex]; network->write(socketIndex, &segmentConfig->nSyncs, sizeof(segmentConfig->nSyncs)); network->write(socketIndex, &segmentConfig->nOps, sizeof(segmentConfig->nOps)); network->write(socketIndex, &segmentConfig->syncPointers, sizeof(segmentConfig->syncPointers)); - for (NnSize syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) { + for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) { NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex]; network->write(socketIndex, &syncConfig->pipeIndex, sizeof(syncConfig->pipeIndex)); network->write(socketIndex, &syncConfig->syncType, sizeof(syncConfig->syncType)); } - for (NnSize opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; network->write(socketIndex, &opConfig->code, sizeof(opConfig->code)); network->write(socketIndex, &opConfig->index, sizeof(opConfig->index)); @@ -656,8 +661,8 @@ void NnRootConfigWriter::writeNode(NnSize socketIndex, NnNodeConfig *config) { } void NnRootConfigWriter::writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs) { - for (NnSize nodeIndex = 1; nodeIndex < netConfig->nNodes; nodeIndex++) { - NnSize socketIndex = nodeIndex - 1; + for (NnUint nodeIndex = 1; nodeIndex < netConfig->nNodes; nodeIndex++) { + NnUint socketIndex = nodeIndex - 1; writeNet(socketIndex, netConfig); writeNode(socketIndex, &nodeConfigs[nodeIndex]); } @@ -674,7 +679,7 @@ NnNetConfig NnWorkerConfigReader::readNet() { network->read(ROOT_SOCKET_INDEX, &config.nNodes, sizeof(config.nNodes)); network->read(ROOT_SOCKET_INDEX, &config.nPipes, sizeof(config.nPipes)); config.pipes = new NnPipeConfig[config.nPipes]; - for (NnSize pipeIndex = 0; pipeIndex < config.nPipes; pipeIndex++) { + for (NnUint pipeIndex = 0; pipeIndex < config.nPipes; pipeIndex++) { NnPipeConfig *pipeConfig = &config.pipes[pipeIndex]; network->read(ROOT_SOCKET_INDEX, &pipeConfig->size, sizeof(pipeConfig->size)); pipeConfig->name = readString(network, ROOT_SOCKET_INDEX); @@ -694,13 +699,13 @@ NnNodeConfig NnWorkerConfigReader::readNode() { config.buffers = new NnBufferConfig[config.nBuffers]; config.segments = new NnSegmentConfig[config.nSegments]; - for (NnSize bufferIndex = 0; bufferIndex < config.nBuffers; bufferIndex++) { + for (NnUint bufferIndex = 0; bufferIndex < config.nBuffers; bufferIndex++) { NnBufferConfig *bufferConfig = &config.buffers[bufferIndex]; network->read(ROOT_SOCKET_INDEX, &bufferConfig->size, sizeof(bufferConfig->size)); bufferConfig->name = readString(network, ROOT_SOCKET_INDEX); } - for (NnSize segmentIndex = 0; segmentIndex < config.nSegments; segmentIndex++) { + for (NnUint segmentIndex = 0; segmentIndex < config.nSegments; segmentIndex++) { NnSegmentConfig *segmentConfig = &config.segments[segmentIndex]; network->read(ROOT_SOCKET_INDEX, &segmentConfig->nSyncs, sizeof(segmentConfig->nSyncs)); network->read(ROOT_SOCKET_INDEX, &segmentConfig->nOps, sizeof(segmentConfig->nOps)); @@ -709,7 +714,7 @@ NnNodeConfig NnWorkerConfigReader::readNode() { if (segmentConfig->nSyncs > 0) { segmentConfig->syncs = new NnSyncConfig[segmentConfig->nSyncs]; - for (NnSize syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) { + for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) { NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex]; network->read(ROOT_SOCKET_INDEX, &syncConfig->pipeIndex, sizeof(syncConfig->pipeIndex)); network->read(ROOT_SOCKET_INDEX, &syncConfig->syncType, sizeof(syncConfig->syncType)); @@ -719,7 +724,7 @@ NnNodeConfig NnWorkerConfigReader::readNode() { if (segmentConfig->nOps > 0) { segmentConfig->ops = new NnOpConfig[segmentConfig->nOps]; - for (NnSize opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { + for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; network->read(ROOT_SOCKET_INDEX, &opConfig->code, sizeof(opConfig->code)); network->read(ROOT_SOCKET_INDEX, &opConfig->index, sizeof(opConfig->index)); @@ -739,7 +744,7 @@ NnNodeConfig NnWorkerConfigReader::readNode() { return config; } -NnRootWeightLoader::NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnSize nNodes) { +NnRootWeightLoader::NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes) { this->executor = executor; this->network = network; this->nNodes = nNodes; @@ -752,8 +757,8 @@ NnRootWeightLoader::~NnRootWeightLoader() { } void NnRootWeightLoader::finish() { - NnSize zeroSize = 0; - for (NnSize socketIndex = 0; socketIndex < nNodes - 1; socketIndex++) { + NnUint zeroSize = 0; + for (NnUint socketIndex = 0; socketIndex < nNodes - 1; socketIndex++) { network->write(socketIndex, &zeroSize, sizeof(zeroSize)); network->readAck(socketIndex); } @@ -772,9 +777,9 @@ void NnRootWeightLoader::allocate(NnSize size) { } } -void NnRootWeightLoader::writeWeight(NnSize nodeIndex, const char *opName, NnSize opIndex, NnSize nBytes, NnByte *weight) { - NnSize nameSize = std::strlen(opName) + 1; - NnSize socketIndex = nodeIndex - 1; +void NnRootWeightLoader::writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) { + NnUint nameSize = std::strlen(opName) + 1; + NnUint socketIndex = nodeIndex - 1; network->write(socketIndex, &nameSize, sizeof(nameSize)); network->write(socketIndex, opName, nameSize); network->write(socketIndex, &opIndex, sizeof(opIndex)); @@ -782,41 +787,49 @@ void NnRootWeightLoader::writeWeight(NnSize nodeIndex, const char *opName, NnSiz network->write(socketIndex, weight, nBytes); } -NnSize NnRootWeightLoader::loadRoot(const char *opName, NnSize opIndex, NnSize nBytes, NnByte *weight) { +NnSize NnRootWeightLoader::loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) { executor->loadWeight(opName, opIndex, nBytes, weight); return nBytes; } -NnSize NnRootWeightLoader::loadAll(const char *opName, NnSize opIndex, NnSize nBytes, NnByte *weight) { - for (NnSize nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { - if (nodeIndex == 0) - executor->loadWeight(opName, opIndex, nBytes, weight); - else +NnSize NnRootWeightLoader::loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) { + executor->loadWeight(opName, opIndex, nBytes, weight); + + if (nNodes > 1) { + for (NnUint nodeIndex = 1; nodeIndex < nNodes; nodeIndex++) writeWeight(nodeIndex, opName, opIndex, nBytes, weight); } return nBytes; } -NnSize NnRootWeightLoader::loadRowMatmulSlices(const char *opName, NnSize opIndex, NnRowMatmulSlice *slice, NnByte *weight) { - allocate(slice->sliceSize.nBytes); - for (NnSize nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { - splitRowMatmulWeight(slice, nodeIndex, weight, temp); - if (nodeIndex == 0) - executor->loadWeight(opName, opIndex, slice->sliceSize.nBytes, temp); - else - writeWeight(nodeIndex, opName, opIndex, slice->sliceSize.nBytes, temp); +NnSize NnRootWeightLoader::loadRowMatmulSlices(const char *opName, NnUint opIndex, NnRowMatmulSlice *slice, NnByte *weight) { + if (nNodes == 1) { + executor->loadWeight(opName, opIndex, slice->sliceSize.nBytes, weight); + } else { + allocate(slice->sliceSize.nBytes); + for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { + splitRowMatmulWeight(slice, nodeIndex, weight, temp); + if (nodeIndex == 0) + executor->loadWeight(opName, opIndex, slice->sliceSize.nBytes, temp); + else + writeWeight(nodeIndex, opName, opIndex, slice->sliceSize.nBytes, temp); + } } return slice->size.nBytes; } -NnSize NnRootWeightLoader::loadColMatmulSlices(const char *opName, NnSize opIndex, NnColMatmulSlice *slice, NnByte *weight) { - allocate(slice->sliceSize.nBytes); - for (NnSize nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { - splitColMatmulWeight(slice, nodeIndex, weight, temp); - if (nodeIndex == 0) - executor->loadWeight(opName, opIndex, slice->sliceSize.nBytes, temp); - else - writeWeight(nodeIndex, opName, opIndex, slice->sliceSize.nBytes, temp); +NnSize NnRootWeightLoader::loadColMatmulSlices(const char *opName, NnUint opIndex, NnColMatmulSlice *slice, NnByte *weight) { + if (nNodes == 1) { + executor->loadWeight(opName, opIndex, slice->sliceSize.nBytes, weight); + } else { + allocate(slice->sliceSize.nBytes); + for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) { + splitColMatmulWeight(slice, nodeIndex, weight, temp); + if (nodeIndex == 0) + executor->loadWeight(opName, opIndex, slice->sliceSize.nBytes, temp); + else + writeWeight(nodeIndex, opName, opIndex, slice->sliceSize.nBytes, temp); + } } return slice->size.nBytes; } @@ -832,7 +845,7 @@ NnWorkerWeightReader::~NnWorkerWeightReader() { delete[] temp; } -void NnWorkerWeightReader::allocate(NnSize size) { +void NnWorkerWeightReader::allocate(NnUint size) { if (tempSize < size) { if (tempSize > 0) delete[] temp; @@ -842,8 +855,8 @@ void NnWorkerWeightReader::allocate(NnSize size) { } void NnWorkerWeightReader::read() { - NnSize nameSize; - NnSize opIndex; + NnUint nameSize; + NnUint opIndex; NnSize nBytes; while (true) { network->read(0, &nameSize, sizeof(nameSize)); @@ -855,15 +868,15 @@ void NnWorkerWeightReader::read() { } break; } - char *opName = new char[nameSize]; + std::unique_ptr opNamePtr(new char[nameSize]); + char *opName = opNamePtr.get(); network->read(ROOT_SOCKET_INDEX, opName, nameSize); network->read(ROOT_SOCKET_INDEX, &opIndex, sizeof(opIndex)); network->read(ROOT_SOCKET_INDEX, &nBytes, sizeof(nBytes)); allocate(nBytes); network->read(0, temp, nBytes); executor->loadWeight(opName, opIndex, nBytes, temp); - printf("💿 Loaded %22s %3d, %12d kB\n", opName, opIndex, nBytes / 1024); - delete[] opName; + printf("💿 Loaded %22s %3d, %12zu kB\n", opName, opIndex, nBytes / 1024); } printf("💿 Weights loaded\n"); } diff --git a/src/nn/nn-network.hpp b/src/nn/nn-network.hpp index 8e24159..c490aa0 100644 --- a/src/nn/nn-network.hpp +++ b/src/nn/nn-network.hpp @@ -9,8 +9,8 @@ void initSockets(); void cleanupSockets(); int acceptSocket(int serverSocket); void setReuseAddr(int socket); -void writeSocket(int socket, const void* data, size_t size); -void readSocket(int socket, void* data, size_t size); +void writeSocket(int socket, const void* data, NnSize size); +void readSocket(int socket, void* data, NnSize size); int createServerSocket(int port); void closeServerSocket(int serverSocket); @@ -29,36 +29,36 @@ class NnWriteNetworkException : public std::exception { }; struct NnSocketIo { - NnSize socketIndex; + NnUint socketIndex; const void *data; - size_t size; + NnSize size; }; class NnNetwork { private: int *sockets; - std::atomic_uint sentBytes; - std::atomic_uint recvBytes; + NnSize *sentBytes; + NnSize *recvBytes; public: static std::unique_ptr serve(int port); - static std::unique_ptr connect(NnSize nSockets, char **hosts, NnSize *ports); + static std::unique_ptr connect(NnUint nSockets, char **hosts, NnUint *ports); - NnSize nSockets; + NnUint nSockets; - NnNetwork(NnSize nSockets, int *sockets); + NnNetwork(NnUint nSockets, int *sockets); ~NnNetwork(); void setTurbo(bool enabled); - void write(NnSize socketIndex, const void *data, size_t size); - void read(NnSize socketIndex, void *data, size_t size); - void writeAck(NnSize socketIndex); - void readAck(NnSize socketIndex); - bool tryReadWithMaxAttempts(NnSize socketIndex, void *data, size_t size, unsigned long maxAttempts); - void writeMany(NnSize n, NnSocketIo *ios); - void writeAll(void *data, size_t size); - void readMany(NnSize n, NnSocketIo *ios); - void getStats(size_t *sentBytes, size_t *recvBytes); + void write(const NnUint socketIndex, const void *data, const NnSize size); + void read(const NnUint socketIndex, void *data, const NnSize size); + void writeAck(const NnUint socketIndex); + void readAck(const NnUint socketIndex); + bool tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts); + void writeMany(NnUint n, NnSocketIo *ios); + void writeAll(void *data, NnSize size); + void readMany(NnUint n, NnSocketIo *ios); + void getStats(NnSize *sentBytes, NnSize *recvBytes); void resetStats(); }; @@ -71,7 +71,7 @@ class NnNetworkNodeSynchronizer : public NnNodeSynchronizer { public: NnNetworkNodeSynchronizer(NnNetwork *network, NnNetExecution *execution, NnNetConfig *netConfig, NnNodeConfig *nodeConfig); ~NnNetworkNodeSynchronizer() override {}; - void sync(NnSize segmentIndex, NnSize nThreads, NnSize threadIndex) override; + void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override; }; class NnRootConfigWriter { @@ -79,8 +79,8 @@ class NnRootConfigWriter { NnNetwork *network; public: NnRootConfigWriter(NnNetwork *network); - void writeNet(NnSize socketIndex, NnNetConfig *config); - void writeNode(NnSize socketIndex, NnNodeConfig *config); + void writeNet(NnUint socketIndex, NnNetConfig *config); + void writeNode(NnUint socketIndex, NnNodeConfig *config); void writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs); }; @@ -97,17 +97,17 @@ class NnRootWeightLoader { private: NnExecutor *executor; NnNetwork *network; - NnSize nNodes; + NnUint nNodes; NnByte *temp; NnSize tempSize; public: - NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnSize nNodes); + NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes); ~NnRootWeightLoader(); - void writeWeight(NnSize nodeIndex, const char *opName, NnSize opIndex, NnSize nBytes, NnByte *weight); - NnSize loadRoot(const char *opName, NnSize opIndex, NnSize nBytes, NnByte *weight); - NnSize loadAll(const char *opName, NnSize opIndex, NnSize nBytes, NnByte *weight); - NnSize loadRowMatmulSlices(const char *opName, NnSize opIndex, NnRowMatmulSlice *slice, NnByte *weight); - NnSize loadColMatmulSlices(const char *opName, NnSize opIndex, NnColMatmulSlice *slice, NnByte *weight); + void writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight); + NnSize loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight); + NnSize loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight); + NnSize loadRowMatmulSlices(const char *opName, NnUint opIndex, NnRowMatmulSlice *slice, NnByte *weight); + NnSize loadColMatmulSlices(const char *opName, NnUint opIndex, NnColMatmulSlice *slice, NnByte *weight); void finish(); private: void allocate(NnSize size);}; @@ -117,13 +117,13 @@ class NnWorkerWeightReader { NnExecutor *executor; NnNetwork *network; NnByte *temp; - NnSize tempSize; + NnUint tempSize; public: NnWorkerWeightReader(NnExecutor *executor, NnNetwork *network); ~NnWorkerWeightReader(); void read(); private: - void allocate(NnSize size); + void allocate(NnUint size); }; #endif diff --git a/src/nn/nn-quants.cpp b/src/nn/nn-quants.cpp index 6a8b67b..f239193 100644 --- a/src/nn/nn-quants.cpp +++ b/src/nn/nn-quants.cpp @@ -11,7 +11,7 @@ float f16ToF32Lookup[65536]; void initQuants() { #if defined(CONVERT_F16_TO_F32_LOOKUP) - for (NnSize i = 0; i < 65536; i++) + for (NnUint i = 0; i < 65536; i++) f16ToF32Lookup[i] = convertF16toF32Impl((NnFp16)i); #endif } @@ -64,18 +64,18 @@ NnFp16 convertF32ToF16Impl(const float x) { return s | (e << 10) | (m >> 13); } -void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnSize n, const NnSize nThreads, const NnSize threadIndex) { +void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { assert(n % Q80_BLOCK_SIZE == 0); - const NnSize nBlocks = n / Q80_BLOCK_SIZE; + const NnUint nBlocks = n / Q80_BLOCK_SIZE; SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); #if defined(__ARM_NEON) - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { const float *x = &input[i * Q80_BLOCK_SIZE]; NnBlockQ80 *y = &output[i]; float32x4_t amaxVec = vdupq_n_f32(0.0f); - for (NnSize j = 0; j < Q80_BLOCK_SIZE; j += 4) { + for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) { const float32x4_t vec = vld1q_f32(&x[j]); const float32x4_t abs_vec = vabsq_f32(vec); amaxVec = vmaxq_f32(amaxVec, abs_vec); @@ -90,7 +90,7 @@ void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnSize n, co const float32x4_t vid_vec = vdupq_n_f32(id); - for (NnSize j = 0; j < Q80_BLOCK_SIZE; j += 4) { + for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) { float32x4_t vec = vld1q_f32(&x[j]); vec = vmulq_f32(vec, vid_vec); @@ -106,7 +106,7 @@ void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnSize n, co } } #elif defined(__AVX2__) - for (NnSize i = start; i < end; ++i) { + for (NnUint i = start; i < end; ++i) { const float *x = input + i * Q80_BLOCK_SIZE; NnBlockQ80 *y = output + i; @@ -152,12 +152,12 @@ void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnSize n, co } } #else - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { const float *x = &input[i * Q80_BLOCK_SIZE]; NnBlockQ80 *y = &output[i]; float amax = 0.0f; - for (NnSize j = 0; j < Q80_BLOCK_SIZE; j++) { + for (NnUint j = 0; j < Q80_BLOCK_SIZE; j++) { const float v = fabsf(x[j]); amax = amax > v ? amax : v; } @@ -165,14 +165,14 @@ void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnSize n, co const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f / d : 0.0f; y->d = CONVERT_F32_TO_F16(d); - for (NnSize j = 0; j < Q80_BLOCK_SIZE; ++j) { + for (NnUint j = 0; j < Q80_BLOCK_SIZE; ++j) { y->qs[j] = roundf(x[j] * id); } } #endif } -void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnSize k, const NnSize nThreads, const NnSize threadIndex) { +void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex) { assert(k % Q80_BLOCK_SIZE == 0); const int nBlocks = k / Q80_BLOCK_SIZE; const int blocksPerThread = nBlocks / nThreads; @@ -190,16 +190,16 @@ void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnSize k, } } -void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnSize n, const NnSize nThreads, const NnSize threadIndex) { +void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { assert(n % Q40_BLOCK_SIZE == 0); - const NnSize nBlocks = n / Q40_BLOCK_SIZE; - const NnSize halfSize = Q40_BLOCK_SIZE / 2; + const NnUint nBlocks = n / Q40_BLOCK_SIZE; + const NnUint halfSize = Q40_BLOCK_SIZE / 2; SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { float amax = 0.0f; float max = 0.0f; - for (NnSize j = 0; j < Q40_BLOCK_SIZE; j++) { + for (NnUint j = 0; j < Q40_BLOCK_SIZE; j++) { float v = x[i * Q40_BLOCK_SIZE + j]; if (amax < fabsf(v)) { amax = fabsf(v); @@ -212,7 +212,7 @@ void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnSize n, const NnBlockQ40 *o = &output[i]; o->d = CONVERT_F32_TO_F16(d); - for (NnSize j = 0; j < halfSize; j++) { + for (NnUint j = 0; j < halfSize; j++) { const float x0 = x[i * Q40_BLOCK_SIZE + j] * id; const float x1 = x[i * Q40_BLOCK_SIZE + halfSize + j] * id; @@ -226,12 +226,12 @@ void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnSize n, const } } -void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnSize n, const NnSize nThreads, const NnSize threadIndex) { +void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { assert(n % Q40_BLOCK_SIZE == 0); - const NnSize nBlocks = n / Q40_BLOCK_SIZE; + const NnUint nBlocks = n / Q40_BLOCK_SIZE; SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); - for (NnSize i = start; i < end; i++) { + for (NnUint i = start; i < end; i++) { const NnBlockQ40 *b = &x[i]; const float d = CONVERT_F16_TO_F32(b->d); diff --git a/src/nn/nn-quants.hpp b/src/nn/nn-quants.hpp index fe3cae0..8f0b86e 100644 --- a/src/nn/nn-quants.hpp +++ b/src/nn/nn-quants.hpp @@ -10,7 +10,8 @@ #endif typedef std::uint8_t NnByte; -typedef std::uint32_t NnSize; +typedef std::uint32_t NnUint; +typedef std::size_t NnSize; typedef std::uint16_t NnFp16; float convertF16toF32Impl(const NnFp16 value); @@ -71,17 +72,17 @@ typedef struct { } NnBlockQ80; void initQuants(); -void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnSize k, const NnSize nThreads, const NnSize threadIndex); -void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnSize k, const NnSize nThreads, const NnSize threadIndex); -void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnSize n, const NnSize nThreads, const NnSize threadIndex); -void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnSize n, const NnSize nThreads, const NnSize threadIndex); +void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint k, const NnUint nThreads, const NnUint threadIndex); +void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex); +void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex); +void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex); const char *floatTypeToString(NnFloatType type); #define SPLIT_THREADS(varStart, varEnd, rangeLen, nThreads, threadIndex) \ - const NnSize rangeSlice = rangeLen / nThreads; \ - const NnSize rangeRest = rangeLen % nThreads; \ - const NnSize varStart = threadIndex * rangeSlice + (threadIndex < rangeRest ? threadIndex : rangeRest); \ - const NnSize varEnd = varStart + rangeSlice + (threadIndex < rangeRest ? 1 : 0); + const NnUint rangeSlice = rangeLen / nThreads; \ + const NnUint rangeRest = rangeLen % nThreads; \ + const NnUint varStart = threadIndex * rangeSlice + (threadIndex < rangeRest ? threadIndex : rangeRest); \ + const NnUint varEnd = varStart + rangeSlice + (threadIndex < rangeRest ? 1 : 0); #endif \ No newline at end of file diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 14b8e49..ee7e0a4 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -293,7 +293,7 @@ void Tokenizer::encode(char *text, int *tokens, int *nTokens, bool addBos, bool } #if DEBUG_TOKENIZER_BENCHMARK - NnSize duration = startTime.elapsedMicroseconds(); + NnUint duration = startTime.elapsedMicroseconds(); printf("🕒 [%22s] %u μs\n", "ENCODER", duration); #endif #if DEBUG_TOKENIZER_ENCODER @@ -420,7 +420,7 @@ int Sampler::sample(float* logits) { } } #if DEBUG_SAMPLER_BENCHMARK - NnSize duration = startTime.elapsedMicroseconds(); + NnUint duration = startTime.elapsedMicroseconds(); printf("🕒 [%22s] %u μs\n", "SAMPLER", duration); #endif return next;