Skip to content

Commit

Permalink
fix: nnuint. (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Feb 18, 2025
1 parent 24156d8 commit f8113c1
Show file tree
Hide file tree
Showing 22 changed files with 622 additions and 592 deletions.
16 changes: 8 additions & 8 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {

args.nWorkers = count;
args.workerHosts = new char*[count];
args.workerPorts = new NnSize[count];
args.workerPorts = new NnUint[count];

for (int s = 0; s < count; s++) {
char *v = argv[i + 1 + s];
Expand Down Expand Up @@ -111,7 +111,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {

AppCliArgs::~AppCliArgs() {
if (workerHosts != nullptr) {
for (NnSize i = 0; i < nWorkers; i++)
for (NnUint i = 0; i < nWorkers; i++)
delete[] workerHosts[i];
delete[] workerHosts;
}
Expand All @@ -130,21 +130,21 @@ RootLlmInference::RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution
this->network = network; // May be nullptr!
}

void RootLlmInference::setBatchSize(NnSize batchSize) {
void RootLlmInference::setBatchSize(NnUint batchSize) {
execution->setBatchSize(batchSize);
controlPacket.batchSize = batchSize;
}

void RootLlmInference::setPosition(NnSize position) {
void RootLlmInference::setPosition(NnUint position) {
assert(position >= 0);
assert(position + execution->batchSize - 1 < header->seqLen);

controlPacket.position = position;
for (NnSize i = 0; i < execution->batchSize; i++)
for (NnUint i = 0; i < execution->batchSize; i++)
positionPipe[i] = (float)(position + i);
}

void RootLlmInference::setToken(NnSize batchIndex, NnSize token) {
void RootLlmInference::setToken(NnUint batchIndex, NnUint token) {
assert(batchIndex >= 0 && batchIndex < execution->batchSize);
tokenPipe[batchIndex] = (float)token;
}
Expand Down Expand Up @@ -179,14 +179,14 @@ bool WorkerLlmInference::tryReadControlPacket() {
isFinished = true;
return true;
}
for (NnSize i = 0; i < controlPacket.batchSize; i++)
for (NnUint i = 0; i < controlPacket.batchSize; i++)
positionPipe[i] = (float)(controlPacket.position + i);
execution->setBatchSize(controlPacket.batchSize);
return true;
}

void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context)) {
NnSize nNodes = args->nWorkers + 1;
NnUint nNodes = args->nWorkers + 1;

LlmHeader header = loadLlmHeader(args->modelPath, args->maxSeqLen, args->syncType);
if (nNodes > header.nKvHeads)
Expand Down
24 changes: 12 additions & 12 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,36 @@
class AppCliArgs {
public:
char *mode;
NnSize nThreads;
NnSize nBatches;
NnUint nThreads;
NnUint nBatches;
bool help;

// inference
char *modelPath;
char *tokenizerPath;
char *prompt;
NnFloatType syncType;
NnSize nWorkers;
NnUint nWorkers;
char **workerHosts;
NnSize *workerPorts;
NnUint *workerPorts;
float temperature;
float topp;
NnSize steps;
NnUint steps;
bool benchmark;
unsigned long long seed;
ChatTemplateType chatTemplateType;
NnSize maxSeqLen;
NnUint maxSeqLen;

// worker
NnSize port;
NnUint port;

static AppCliArgs parse(int argc, char **argv, bool hasMode);
~AppCliArgs();
};

typedef struct {
NnSize position;
NnSize batchSize; // 0 = stop signal
NnUint position;
NnUint batchSize; // 0 = stop signal
} LlmControlPacket;

class RootLlmInference {
Expand All @@ -56,9 +56,9 @@ class RootLlmInference {
LlmControlPacket controlPacket;
public:
RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network);
void setBatchSize(NnSize batchSize);
void setPosition(NnSize position);
void setToken(NnSize batchIndex, NnSize token);
void setBatchSize(NnUint batchSize);
void setPosition(NnUint position);
void setToken(NnUint batchIndex, NnUint token);
void forward();
void finish();
};
Expand Down
8 changes: 4 additions & 4 deletions src/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,20 +380,20 @@ class ApiServer {
buffer += inputPrompt.publicPrompt;
}

NnSize pos = startPos;
NnUint pos = startPos;
int token;
for (NnSize i = 0; ;) {
for (NnUint i = 0; ;) {
long remainingTokens = promptEndPos - pos;
if (remainingTokens <= 0)
break;

NnSize batchSize = remainingTokens < args->nBatches
NnUint batchSize = remainingTokens < args->nBatches
? remainingTokens
: args->nBatches;

inference->setBatchSize(batchSize);
inference->setPosition(pos);
for (NnSize j = 0; j < batchSize; j++)
for (NnUint j = 0; j < batchSize; j++)
inference->setToken(j, promptTokens[i + j]);

inference->forward();
Expand Down
40 changes: 20 additions & 20 deletions src/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static void inference(AppInferenceContext *context) {
std::vector<int> inputTokensVec(std::strlen(context->args->prompt) + 3);
int *inputTokens = inputTokensVec.data();

NnSize pos = 0;
NnUint pos = 0;
int token;
int nInputTokens;
context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, false);
Expand All @@ -27,21 +27,21 @@ static void inference(AppInferenceContext *context) {
throw std::runtime_error("The number of prompt tokens is greater than the number of steps");

Timer evalTimer;
size_t sentBytes = 0;
size_t recvBytes = 0;
NnSize sentBytes = 0;
NnSize recvBytes = 0;
printf("%s\n", context->args->prompt);
for (;;) {
Timer batchTimer;
long remainingTokens = nInputTokens - 1 - (long)pos;
if (remainingTokens <= 0)
break;
NnSize batchSize = remainingTokens < context->args->nBatches
NnUint batchSize = remainingTokens < context->args->nBatches
? remainingTokens
: context->args->nBatches;

context->inference->setBatchSize(batchSize);
context->inference->setPosition(pos);
for (NnSize i = 0; i < batchSize; i++)
for (NnUint i = 0; i < batchSize; i++)
context->inference->setToken(i, inputTokens[pos + i]);

context->inference->forward();
Expand All @@ -57,15 +57,15 @@ static void inference(AppInferenceContext *context) {
recvBytes / 1024,
batchSize);
}
NnSize evalTime = evalTimer.elapsedMiliseconds();
NnUint evalTime = evalTimer.elapsedMiliseconds();

fflush(stdout);

context->inference->setBatchSize(1);
context->tokenizer->resetDecoder();

Timer predTimer;
const NnSize maxPos = std::min(context->header->seqLen, context->args->steps);
const NnUint maxPos = std::min(context->header->seqLen, context->args->steps);
for (; pos < maxPos; pos++) {
Timer tokenTimer;
context->inference->setPosition(pos);
Expand All @@ -86,10 +86,10 @@ static void inference(AppInferenceContext *context) {
piece == nullptr ? "~" : piece);
fflush(stdout);
}
NnSize predTime = predTimer.elapsedMiliseconds();
NnUint predTime = predTimer.elapsedMiliseconds();

NnSize nEvalTokens = nInputTokens - 1;
NnSize nPredTokens = pos - nEvalTokens;
NnUint nEvalTokens = nInputTokens - 1;
NnUint nPredTokens = pos - nEvalTokens;
printf("\n");
printf("Evaluation\n");
printf(" nBatches: %d\n", context->args->nBatches);
Expand All @@ -104,11 +104,11 @@ static void inference(AppInferenceContext *context) {
predTime / ((float) nPredTokens));
}

static size_t readStdin(const char *guide, char *buffer, size_t size) {
static NnUint readStdin(const char *guide, char *buffer, NnUint size) {
std::fflush(stdin);
std::printf("%s", guide);
if (std::fgets(buffer, size, stdin) != NULL) {
size_t length = std::strlen(buffer);
NnUint length = std::strlen(buffer);
if (length > 0 && buffer[length - 1] == '\n') {
buffer[length - 1] = '\0';
length--;
Expand All @@ -119,20 +119,20 @@ static size_t readStdin(const char *guide, char *buffer, size_t size) {
}

static void chat(AppInferenceContext *context) {
const NnSize seqLen = context->header->seqLen;
const NnUint seqLen = context->header->seqLen;
char prompt[2048];

TokenizerChatStops stops(context->tokenizer);
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);

const size_t sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
const NnUint sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
std::vector<ChatItem> deltaItems;
if (sysPromptLength > 0)
deltaItems.push_back(ChatItem{"system", prompt});

NnSize pos = 0;
size_t userPromptLength;
NnUint pos = 0;
NnUint userPromptLength;
int token;
int nInputTokens;
do {
Expand All @@ -149,18 +149,18 @@ static void chat(AppInferenceContext *context) {
bool addBos = pos == 0;
context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, addBos, true);

NnSize userPromptEndPos = (NnSize)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
for (NnSize i = 0; ;) {
NnUint userPromptEndPos = (NnUint)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
for (NnUint i = 0; ;) {
int remainingTokens = userPromptEndPos - pos;
if (remainingTokens <= 0)
break;
NnSize batchSize = remainingTokens < context->args->nBatches
NnUint batchSize = remainingTokens < context->args->nBatches
? remainingTokens
: context->args->nBatches;

context->inference->setBatchSize(batchSize);
context->inference->setPosition(pos);
for (NnSize j = 0; j < batchSize; j++)
for (NnUint j = 0; j < batchSize; j++)
context->inference->setToken(j, inputTokens[i + j]);

context->inference->forward();
Expand Down
Loading

0 comments on commit f8113c1

Please sign in to comment.