Skip to content

Commit

Permalink
DPL: use the SendingPolicy for every kind of message sending
Browse files Browse the repository at this point in the history
  • Loading branch information
ktf committed Feb 9, 2024
1 parent 4d5481e commit b868bb6
Show file tree
Hide file tree
Showing 16 changed files with 171 additions and 77 deletions.
4 changes: 4 additions & 0 deletions Framework/Core/include/Framework/ChannelInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ struct InputChannelInfo {
};

struct SendingPolicy;
struct ForwardingPolicy;

/// Output channel information
struct OutputChannelInfo {
std::string name = "invalid";
ChannelAccountingType channelType = ChannelAccountingType::DPL;
fair::mq::Channel& channel;
SendingPolicy const* policy;
ChannelIndex index = {-1};
};

struct OutputChannelState {
Expand All @@ -89,6 +91,8 @@ struct ForwardChannelInfo {
/// Wether or not it's a DPL internal channel.
ChannelAccountingType channelType = ChannelAccountingType::DPL;
fair::mq::Channel& channel;
ForwardingPolicy const* policy;
ChannelIndex index = {-1};
};

struct ForwardChannelState {
Expand Down
22 changes: 11 additions & 11 deletions Framework/Core/include/Framework/DataProcessingHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@
#ifndef O2_FRAMEWORK_DATAPROCESSINGHELPERS_H_
#define O2_FRAMEWORK_DATAPROCESSINGHELPERS_H_

#include "Framework/TimesliceIndex.h"
#include <fairmq/FwdDecls.h>
#include <cstddef>

namespace o2::framework
{

struct ServiceRegistryRef;
struct ForwardChannelInfo;
struct ForwardChannelState;
struct OutputChannelInfo;
struct OutputChannelSpec;
class FairMQDeviceProxy;
struct OutputChannelState;

/// Generic helpers for DataProcessing releated functions.
struct DataProcessingHelpers {
/// Send EndOfStream message to a given channel
/// @param device the fair::mq::Device which needs to send the EndOfStream message
/// @param channel the OutputChannelSpec of the channel which needs to be signaled
/// for EndOfStream
static void sendEndOfStream(fair::mq::Device& device, OutputChannelSpec const& channel);
static void sendEndOfStream(ServiceRegistryRef const& ref, OutputChannelSpec const& channel);
/// @returns true if we did send the oldest possible timeslice message, false otherwise.
static bool sendOldestPossibleTimeframe(ForwardChannelInfo const& info, ForwardChannelState& state, size_t timeslice);
static bool sendOldestPossibleTimeframe(ServiceRegistryRef const& ref, ForwardChannelInfo const& info, ForwardChannelState& state, size_t timeslice);
/// @returns true if we did send the oldest possible timeslice message, false otherwise.
static bool sendOldestPossibleTimeframe(OutputChannelInfo const& info, OutputChannelState& state, size_t timeslice);
static void broadcastOldestPossibleTimeslice(FairMQDeviceProxy& proxy, size_t timeslice);

private:
static void sendOldestPossibleTimeframe(fair::mq::Channel& channel, size_t timeslice);
static bool sendOldestPossibleTimeframe(ServiceRegistryRef const& ref, OutputChannelInfo const& info, OutputChannelState& state, size_t timeslice);
/// Broadcast the oldest possible timeslice to all channels in output
static void broadcastOldestPossibleTimeslice(ServiceRegistryRef const& ref, size_t timeslice);
};

} // namespace o2::framework
Expand Down
3 changes: 3 additions & 0 deletions Framework/Core/include/Framework/DriverInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ struct DriverInfo {
/// These are the policies which can be applied to decide how
/// we send data.
std::vector<SendingPolicy> sendingPolicies;
/// These are the policies which can be applied to decide how
/// we forward data.
std::vector<ForwardingPolicy> forwardingPolicies;
/// The argc with which the driver was started.
int argc;
/// The argv with which the driver was started.
Expand Down
9 changes: 5 additions & 4 deletions Framework/Core/include/Framework/ForwardRoute.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
#include <cstddef>
#include <string>

namespace o2
{
namespace framework
namespace o2::framework
{

struct ForwardingPolicy;

/// This uniquely identifies a route to be forwarded by the device if
/// the InputSpec @a matcher matches an input which should also go to
/// @a channel
Expand All @@ -28,8 +28,9 @@ struct ForwardRoute {
size_t maxTimeslices;
InputSpec matcher;
std::string channel;
// The policy to use to send to on this route.
ForwardingPolicy const* policy;
};

} // namespace framework
} // namespace o2
#endif // FRAMEWORK_FORWARDROUTE_H
8 changes: 8 additions & 0 deletions Framework/Core/include/Framework/SendingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ struct SendingPolicy {
static std::vector<SendingPolicy> createDefaultPolicies();
};

struct ForwardingPolicy {
using ForwardingCallback = std::function<void(fair::mq::Parts&, ChannelIndex channelIndex, ServiceRegistryRef registry)>;
std::string name = "invalid";
EdgeMatcher matcher = nullptr;
ForwardingCallback forward = nullptr;
static std::vector<ForwardingPolicy> createDefaultPolicies();
};

} // namespace o2::framework

#endif // O2_FRAMEWORK_SENDINGPOLICY_H_
10 changes: 5 additions & 5 deletions Framework/Core/src/CommonServices.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ o2::framework::ServiceSpec CommonServices::decongestionSpec()
timesliceIndex.rescan();
}
}
DataProcessingHelpers::broadcastOldestPossibleTimeslice(proxy, oldestPossibleOutput.timeslice.value);
DataProcessingHelpers::broadcastOldestPossibleTimeslice(ctx.services(), oldestPossibleOutput.timeslice.value);

for (int fi = 0; fi < proxy.getNumForwardChannels(); fi++) {
auto& info = proxy.getForwardChannelInfo(ChannelIndex{fi});
Expand All @@ -598,7 +598,7 @@ o2::framework::ServiceSpec CommonServices::decongestionSpec()
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice", "Skipping channel %{public}s", info.name.c_str());
continue;
}
if (DataProcessingHelpers::sendOldestPossibleTimeframe(info, state, oldestPossibleOutput.timeslice.value)) {
if (DataProcessingHelpers::sendOldestPossibleTimeframe(ctx.services(), info, state, oldestPossibleOutput.timeslice.value)) {
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice",
"Forwarding to channel %{public}s oldest possible timeslice %" PRIu64 ", priority %d",
info.name.c_str(), (uint64_t)oldestPossibleOutput.timeslice.value, 20);
Expand Down Expand Up @@ -646,7 +646,7 @@ o2::framework::ServiceSpec CommonServices::decongestionSpec()
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice", "Queueing oldest possible timeslice %" PRIu64 " propagation for execution.",
(uint64_t)oldestPossibleOutput.timeslice.value);
AsyncQueueHelpers::post(
queue, decongestion.oldestPossibleTimesliceTask, [oldestPossibleOutput, &decongestion, &proxy, &spec, device, &timesliceIndex]() {
queue, decongestion.oldestPossibleTimesliceTask, [ref = services, oldestPossibleOutput, &decongestion, &proxy, &spec, device, &timesliceIndex]() {
O2_SIGNPOST_ID_FROM_POINTER(cid, data_processor_context, &decongestion);
if (decongestion.lastTimeslice >= oldestPossibleOutput.timeslice.value) {
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice", "Not sending already sent value: %" PRIu64 "> %" PRIu64,
Expand All @@ -655,7 +655,7 @@ o2::framework::ServiceSpec CommonServices::decongestionSpec()
}
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice", "Running oldest possible timeslice %" PRIu64 " propagation.",
(uint64_t)oldestPossibleOutput.timeslice.value);
DataProcessingHelpers::broadcastOldestPossibleTimeslice(proxy, oldestPossibleOutput.timeslice.value);
DataProcessingHelpers::broadcastOldestPossibleTimeslice(ref, oldestPossibleOutput.timeslice.value);

for (int fi = 0; fi < proxy.getNumForwardChannels(); fi++) {
auto& info = proxy.getForwardChannelInfo(ChannelIndex{fi});
Expand All @@ -665,7 +665,7 @@ o2::framework::ServiceSpec CommonServices::decongestionSpec()
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice", "Skipping channel %{public}s", info.name.c_str());
continue;
}
if (DataProcessingHelpers::sendOldestPossibleTimeframe(info, state, oldestPossibleOutput.timeslice.value)) {
if (DataProcessingHelpers::sendOldestPossibleTimeframe(ref, info, state, oldestPossibleOutput.timeslice.value)) {
O2_SIGNPOST_EVENT_EMIT(data_processor_context, cid, "oldest_possible_timeslice",
"Forwarding to channel %{public}s oldest possible timeslice %" PRIu64 ", priority %d",
info.name.c_str(), (uint64_t)oldestPossibleOutput.timeslice.value, 20);
Expand Down
22 changes: 6 additions & 16 deletions Framework/Core/src/DataProcessingDevice.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -680,19 +680,11 @@ static auto forwardInputs = [](ServiceRegistryRef registry, TimesliceSlot slot,
if (forwardedParts[fi].Size() == 0) {
continue;
}
auto channel = proxy.getForwardChannel(ChannelIndex{fi});
LOG(debug) << "Forwarding to " << channel->GetName() << " " << fi;
ForwardChannelInfo info = proxy.getForwardChannelInfo(ChannelIndex{fi});
LOG(debug) << "Forwarding to " << info.name << " " << fi;
// in DPL we are using subchannel 0 only
auto& parts = forwardedParts[fi];
int timeout = 30000;
auto res = channel->Send(parts, timeout);
if (res == (size_t)fair::mq::TransferCode::timeout) {
LOGP(warning, "Timed out sending after {}s. Downstream backpressure detected on {}.", timeout / 1000, channel->GetName());
channel->Send(parts);
LOGP(info, "Downstream backpressure on {} recovered.", channel->GetName());
} else if (res == (size_t)fair::mq::TransferCode::error) {
LOGP(fatal, "Error while sending on channel {}", channel->GetName());
}
info.policy->forward(parts, ChannelIndex{fi}, registry);
}

auto& asyncQueue = registry.get<AsyncQueue>();
Expand All @@ -713,7 +705,7 @@ static auto forwardInputs = [](ServiceRegistryRef registry, TimesliceSlot slot,
LOG(debug) << "Skipping channel";
continue;
}
if (DataProcessingHelpers::sendOldestPossibleTimeframe(info, state, oldestTimeslice.timeslice.value)) {
if (DataProcessingHelpers::sendOldestPossibleTimeframe(registry, info, state, oldestTimeslice.timeslice.value)) {
LOGP(debug, "Forwarding to channel {} oldest possible timeslice {}, prio 20", info.name, oldestTimeslice.timeslice.value);
}
}
Expand Down Expand Up @@ -1678,8 +1670,7 @@ void DataProcessingDevice::doRun(ServiceRegistryRef ref)

for (auto& channel : spec.outputChannels) {
LOGP(detail, "Sending end of stream to {}", channel.name);
auto& rawDevice = ref.get<RawDeviceService>();
DataProcessingHelpers::sendEndOfStream(*rawDevice.device(), channel);
DataProcessingHelpers::sendEndOfStream(ref, channel);
}
// This is needed because the transport is deleted before the device.
relayer.clear();
Expand Down Expand Up @@ -2460,8 +2451,7 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v
if (state.streaming == StreamingState::EndOfStreaming) {
LOGP(detail, "Broadcasting end of stream");
for (auto& channel : spec.outputChannels) {
auto& rawDevice = ref.get<RawDeviceService>();
DataProcessingHelpers::sendEndOfStream(*rawDevice.device(), channel);
DataProcessingHelpers::sendEndOfStream(ref, channel);
}
switchState(StreamingState::Idle);
}
Expand Down
41 changes: 18 additions & 23 deletions Framework/Core/src/DataProcessingHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,79 +18,74 @@
#include "Headers/DataHeader.h"
#include "Headers/Stack.h"
#include "Framework/Logger.h"
#include "Framework/SendingPolicy.h"
#include "Framework/RawDeviceService.h"

#include <fairmq/Device.h>
#include <fairmq/Channel.h>

namespace o2::framework
{
void DataProcessingHelpers::sendEndOfStream(fair::mq::Device& device, OutputChannelSpec const& channel)
void DataProcessingHelpers::sendEndOfStream(ServiceRegistryRef const& ref, OutputChannelSpec const& channel)
{
fair::mq::Device* device = ref.get<RawDeviceService>().device();
fair::mq::Parts parts;
fair::mq::MessagePtr payload(device.NewMessage());
fair::mq::MessagePtr payload(device->NewMessage());
SourceInfoHeader sih;
sih.state = InputChannelState::Completed;
auto channelAlloc = o2::pmr::getTransportAllocator(device.GetChannel(channel.name, 0).Transport());
auto channelAlloc = o2::pmr::getTransportAllocator(device->GetChannel(channel.name, 0).Transport());
auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, sih});
// sigh... See if we can avoid having it const by not
// exposing it to the user in the first place.
parts.AddPart(std::move(header));
parts.AddPart(std::move(payload));
device.Send(parts, channel.name, 0);
device->Send(parts, channel.name, 0);
LOGP(info, "Sending end-of-stream message to channel {}", channel.name);
}

void DataProcessingHelpers::sendOldestPossibleTimeframe(fair::mq::Channel& channel, size_t timeslice)
void doSendOldestPossibleTimeframe(ServiceRegistryRef ref, fair::mq::TransportFactory* transport, ChannelIndex index, SendingPolicy::SendingCallback const& callback, size_t timeslice)
{
fair::mq::Parts parts;
fair::mq::MessagePtr payload(channel.Transport()->CreateMessage());
fair::mq::MessagePtr payload(transport->CreateMessage());
o2::framework::DomainInfoHeader dih;
dih.oldestPossibleTimeslice = timeslice;
auto channelAlloc = o2::pmr::getTransportAllocator(channel.Transport());
auto channelAlloc = o2::pmr::getTransportAllocator(transport);
auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dih});
// sigh... See if we can avoid having it const by not
// exposing it to the user in the first place.
parts.AddPart(std::move(header));
parts.AddPart(std::move(payload));

auto timeout = 1000;
auto res = channel.Send(parts, timeout);
if (res == (size_t)fair::mq::TransferCode::timeout) {
LOGP(warning, "Timed out sending oldest possible timeslice after {}s. Downstream backpressure detected on {}.", timeout / 1000, channel.GetName());
channel.Send(parts);
LOGP(info, "Downstream backpressure on {} recovered.", channel.GetName());
}
if (res < (size_t)fair::mq::TransferCode::success) {
LOGP(fatal, "Error sending oldest possible timeframe {} on channel {} (code {})", timeslice, channel.GetName(), res);
}
callback(parts, index, ref);
}

bool DataProcessingHelpers::sendOldestPossibleTimeframe(ForwardChannelInfo const& info, ForwardChannelState& state, size_t timeslice)
bool DataProcessingHelpers::sendOldestPossibleTimeframe(ServiceRegistryRef const& ref, ForwardChannelInfo const& info, ForwardChannelState& state, size_t timeslice)
{
if (state.oldestForChannel.value >= timeslice) {
return false;
}
sendOldestPossibleTimeframe(info.channel, timeslice);
doSendOldestPossibleTimeframe(ref, info.channel.Transport(), info.index, info.policy->forward, timeslice);
state.oldestForChannel = {timeslice};
return true;
}

bool DataProcessingHelpers::sendOldestPossibleTimeframe(OutputChannelInfo const& info, OutputChannelState& state, size_t timeslice)
bool DataProcessingHelpers::sendOldestPossibleTimeframe(ServiceRegistryRef const& ref, OutputChannelInfo const& info, OutputChannelState& state, size_t timeslice)
{
if (state.oldestForChannel.value >= timeslice) {
return false;
}
sendOldestPossibleTimeframe(info.channel, timeslice);
doSendOldestPossibleTimeframe(ref, info.channel.Transport(), info.index, info.policy->send, timeslice);
state.oldestForChannel = {timeslice};
return true;
}

void DataProcessingHelpers::broadcastOldestPossibleTimeslice(FairMQDeviceProxy& proxy, size_t timeslice)
void DataProcessingHelpers::broadcastOldestPossibleTimeslice(ServiceRegistryRef const& ref, size_t timeslice)
{
auto& proxy = ref.get<FairMQDeviceProxy>();
for (int ci = 0; ci < proxy.getNumOutputChannels(); ++ci) {
auto& info = proxy.getOutputChannelInfo({ci});
auto& state = proxy.getOutputChannelState({ci});
sendOldestPossibleTimeframe(info, state, timeslice);
sendOldestPossibleTimeframe(ref, info, state, timeslice);
}
}

Expand Down
35 changes: 24 additions & 11 deletions Framework/Core/src/DeviceSpecHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ void DeviceSpecHelpers::processOutEdgeActions(ConfigContext const& configContext
const std::vector<OutputSpec>& outputsMatchers,
const std::vector<ChannelConfigurationPolicy>& channelPolicies,
const std::vector<SendingPolicy>& sendingPolicies,
const std::vector<ForwardingPolicy>& forwardingPolicies,
std::string const& channelPrefix,
ComputingOffer const& defaultOffer,
OverrideServiceSpecs const& overrideServices)
Expand Down Expand Up @@ -656,7 +657,7 @@ void DeviceSpecHelpers::processOutEdgeActions(ConfigContext const& configContext
// whether this is a real OutputRoute or if it's a forward from
// a previous consumer device.
// FIXME: where do I find the InputSpec for the forward?
auto appendOutputRouteToSourceDeviceChannel = [&outputsMatchers, &workflow, &devices, &logicalEdges, &sendingPolicies, &configContext](
auto appendOutputRouteToSourceDeviceChannel = [&outputsMatchers, &workflow, &devices, &logicalEdges, &sendingPolicies, &forwardingPolicies, &configContext](
size_t ei, size_t di, size_t ci) {
assert(ei < logicalEdges.size());
assert(di < devices.size());
Expand All @@ -670,29 +671,40 @@ void DeviceSpecHelpers::processOutEdgeActions(ConfigContext const& configContext
assert(edge.outputGlobalIndex < outputsMatchers.size());
// Iterate over all the policies and apply the first one that matches.
SendingPolicy const* policyPtr = nullptr;
ForwardingPolicy const* forwardPolicyPtr = nullptr;
for (auto& policy : sendingPolicies) {
if (policy.matcher(producer, consumer, configContext)) {
policyPtr = &policy;
break;
}
}
assert(forwardingPolicies.empty() == false);
for (auto& policy : forwardingPolicies) {
if (policy.matcher(producer, consumer, configContext)) {
forwardPolicyPtr = &policy;
break;
}
}
assert(policyPtr != nullptr);
assert(forwardPolicyPtr != nullptr);

if (edge.isForward == false) {
OutputRoute route{
edge.timeIndex,
consumer.maxInputTimeslices,
outputsMatchers[edge.outputGlobalIndex],
channel.name,
policyPtr,
.timeslice = edge.timeIndex,
.maxTimeslices = consumer.maxInputTimeslices,
.matcher = outputsMatchers[edge.outputGlobalIndex],
.channel = channel.name,
.policy = policyPtr,
};
device.outputs.emplace_back(route);
} else {
ForwardRoute route{
edge.timeIndex,
consumer.maxInputTimeslices,
workflow[edge.consumer].inputs[edge.consumerInputIndex],
channel.name};
.timeslice = edge.timeIndex,
.maxTimeslices = consumer.maxInputTimeslices,
.matcher = workflow[edge.consumer].inputs[edge.consumerInputIndex],
.channel = channel.name,
.policy = forwardPolicyPtr,
};
device.forwards.emplace_back(route);
}
};
Expand Down Expand Up @@ -1051,6 +1063,7 @@ void DeviceSpecHelpers::dataProcessorSpecs2DeviceSpecs(const WorkflowSpec& workf
std::vector<ResourcePolicy> const& resourcePolicies,
std::vector<CallbacksPolicy> const& callbacksPolicies,
std::vector<SendingPolicy> const& sendingPolicies,
std::vector<ForwardingPolicy> const& forwardingPolicies,
std::vector<DeviceSpec>& devices,
ResourceManager& resourceManager,
std::string const& uniqueWorkflowId,
Expand Down Expand Up @@ -1111,7 +1124,7 @@ void DeviceSpecHelpers::dataProcessorSpecs2DeviceSpecs(const WorkflowSpec& workf
defaultOffer.memory /= deviceCount + 1;

processOutEdgeActions(configContext, devices, deviceIndex, connections, resourceManager, outEdgeIndex, logicalEdges,
outActions, workflow, outputs, channelPolicies, sendingPolicies, channelPrefix, defaultOffer, overrideServices);
outActions, workflow, outputs, channelPolicies, sendingPolicies, forwardingPolicies, channelPrefix, defaultOffer, overrideServices);

// FIXME: is this not the case???
std::sort(connections.begin(), connections.end());
Expand Down
Loading

0 comments on commit b868bb6

Please sign in to comment.