Skip to content

Commit

Permalink
queue: include a spin-lock spsc queue with copies
Browse files Browse the repository at this point in the history
csegarragonz committed Feb 23, 2024
1 parent f54ae94 commit cecf660
Showing 3 changed files with 59 additions and 16 deletions.
4 changes: 2 additions & 2 deletions include/faabric/mpi/MpiWorld.h
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ struct MpiMessage {
// as the broker already has mocking capabilities
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);

typedef faabric::util::Queue<std::unique_ptr<MpiMessage>> InMemoryMpiQueue;
typedef faabric::util::SpinLockQueue<MpiMessage> InMemoryMpiQueue;

class MpiWorld
{
@@ -275,7 +275,7 @@ class MpiWorld
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::unique_ptr<MpiMessage> m,
void doRecv(MpiMessage& m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
37 changes: 35 additions & 2 deletions include/faabric/util/queue.h
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

#include <condition_variable>
#include <queue>
#include <boost/lockfree/spsc_queue.hpp>
#include <readerwriterqueue/readerwritercircularbuffer.h>

#define DEFAULT_QUEUE_TIMEOUT_MS 5000
@@ -23,7 +24,7 @@ class QueueTimeoutException : public faabric::util::FaabricException
template<typename T>
class BaseQueue
{
virtual void enqueue(T value) = 0;
virtual void enqueue(T value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) = 0;

virtual T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) = 0;

@@ -36,7 +37,7 @@ template<typename T>
class Queue : public BaseQueue<T>
{
public:
void enqueue(T value) override
void enqueue(T value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) override
{
UniqueLock lock(mx);

@@ -227,6 +228,38 @@ class FixedCapacityQueue : public BaseQueue<T>
moodycamel::BlockingReaderWriterCircularBuffer<T> mq;
};

// High-performance, spin-lock single-producer, single-consumer queue. This
// queue spin-locks, so use at your own risk!
template<typename T>
class SpinLockQueue
{
public:
void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) {
while (!mq.push(value)) { ; };
}

T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) {
T value;

while (!mq.pop(value)) { ; }

return value;
}

long size() {
throw std::runtime_error("Size for fast queue unimplemented!");
}

void drain() {
while (mq.pop()) { ; }
}

void reset() { ; }

private:
boost::lockfree::spsc_queue<T, boost::lockfree::capacity<1024>> mq;
};

class TokenPool
{
public:
34 changes: 22 additions & 12 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
@@ -518,15 +518,22 @@ void MpiWorld::send(int sendRank,
void* bufferPtr = malloc(count * dataType->size);
std::memcpy(bufferPtr, buffer, count* dataType->size);

/*
auto msg = std::make_unique<MpiMessage>(MpiMessage{
.id = msgId, .worldId = id, .sendRank = sendRank,
.recvRank = recvRank, .type = dataType->id, .count = count,
.buffer = bufferPtr
});
*/
MpiMessage msg = {
.id = msgId, .worldId = id, .sendRank = sendRank,
.recvRank = recvRank, .type = dataType->id, .count = count,
.buffer = bufferPtr
};

SPDLOG_TRACE(
"MPI - send {} -> {} ({})", sendRank, recvRank, messageType);
getLocalQueue(sendRank, recvRank)->enqueue(std::move(msg));
getLocalQueue(sendRank, recvRank)->enqueue(msg);
} else {
// Create the message
auto m = std::make_shared<MPIMessage>();
@@ -586,7 +593,10 @@ void MpiWorld::recv(int sendRank,
bool isLocal = getHostForRank(sendRank) == getHostForRank(recvRank);

if (isLocal) {
std::unique_ptr<MpiMessage> m = getLocalQueue(sendRank, recvRank)->dequeue();
MpiMessage m = getLocalQueue(sendRank, recvRank)->dequeue();

// Do the processing
doRecv(m, buffer, dataType, count, status, messageType);
} else {
// Recv message from underlying transport
std::shared_ptr<MPIMessage> m = recvBatchReturnLast(sendRank, recvRank);
@@ -596,7 +606,7 @@ void MpiWorld::recv(int sendRank,
}
}

void MpiWorld::doRecv(std::unique_ptr<MpiMessage> m,
void MpiWorld::doRecv(MpiMessage& m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
@@ -605,28 +615,28 @@ void MpiWorld::doRecv(std::unique_ptr<MpiMessage> m,
{
// Assert message integrity
// Note - this checks won't happen in Release builds
if (m->type != messageType) {
if (m.type != messageType) {
SPDLOG_ERROR("Different message types (got: {}, expected: {})",
m->type,
m.type,
messageType);
}
assert(m->type == messageType);
assert(m->count <= count);
assert(m.type == messageType);
assert(m.count <= count);

// TODO - avoid copy here
// Copy message data
if (m->count > 0) {
std::memcpy(buffer, (void*)m->buffer, count * dataType->size);
free((void*)m->buffer);
if (m.count > 0) {
std::memcpy(buffer, (void*)m.buffer, count * dataType->size);
free((void*)m.buffer);
}

// Set status values if required
if (status != nullptr) {
status->MPI_SOURCE = m->sendRank;
status->MPI_SOURCE = m.sendRank;
status->MPI_ERROR = MPI_SUCCESS;

// Take the message size here as the receive count may be larger
status->bytesSize = m->count * dataType->size;
status->bytesSize = m.count * dataType->size;

// TODO - thread through tag
status->MPI_TAG = -1;

0 comments on commit cecf660

Please sign in to comment.