Skip to content

Commit

Permalink
DPL: finally use concepts to separate the make method (#13611)
Browse files Browse the repository at this point in the history
  • Loading branch information
ktf authored Oct 28, 2024
1 parent ffc81fd commit 804f7a2
Showing 1 changed file with 104 additions and 85 deletions.
189 changes: 104 additions & 85 deletions Framework/Core/include/Framework/DataAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <TClass.h>
#include <gsl/span>

#include <memory>
#include <vector>
#include <map>
#include <string>
Expand All @@ -57,14 +58,6 @@ namespace o2::framework
{
struct ServiceRegistry;

#define ERROR_STRING \
"data type T not supported by API, " \
"\n specializations available for" \
"\n - trivially copyable, non-polymorphic structures" \
"\n - arrays of those" \
"\n - TObject with additional constructor arguments" \
"\n - std containers of those"

/// Helper to allow framework managed objecs to have a callback
/// when they go out of scope. For example, this could
/// be used to serialize a message into a buffer before the
Expand Down Expand Up @@ -130,6 +123,10 @@ struct LifetimeHolder {
}
};

template <typename T>
concept VectorOfMessageableTypes = is_specialization_v<T, std::vector> &&
is_messageable<typename T::value_type>::value;

/// This allocator is responsible to make sure that the messages created match
/// the provided spec and that depending on how many pipelined reader we
/// have, messages get created on the channel for the reader of the current
Expand All @@ -143,6 +140,7 @@ class DataAllocator
using DataOrigin = o2::header::DataOrigin;
using DataDescription = o2::header::DataDescription;
using SubSpecificationType = o2::header::DataHeader::SubSpecificationType;

template <typename T>
requires std::is_fundamental_v<T>
struct UninitializedVector {
Expand All @@ -163,93 +161,114 @@ class DataAllocator
// and with subspecification 0xdeadbeef.
void cookDeadBeef(const Output& spec);

/// Generic helper to create an object which is owned by the framework and
/// returned as a reference to the own object.
/// Note: decltype(auto) will deduce the return type from the expression and it
/// will be lvalue reference for the framework-owned objects. Instances of local
/// variables like shared_ptr will be returned by value/move/return value optimization.
/// Objects created this way will be sent to the channel specified by @spec
template <typename T, typename... Args>
requires is_specialization_v<T, o2::framework::DataAllocator::UninitializedVector>
decltype(auto) make(const Output& spec, Args... args)
{
auto& timingInfo = mRegistry.get<TimingInfo>();
auto& context = mRegistry.get<MessageContext>();

if constexpr (is_specialization_v<T, UninitializedVector>) {
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
// plain buffer as polymorphic spectator std::vector, which does not run constructors / destructors
using ValueType = typename T::value_type;

// Note: initial payload size is 0 and will be set by the context before sending
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
return context.add<MessageContext::VectorObject<ValueType, MessageContext::ContainerRefObject<std::vector<ValueType, o2::pmr::NoConstructAllocator<ValueType>>>>>(
std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...)
.get();
} else if constexpr (is_specialization_v<T, std::vector> && has_messageable_value_type<T>::value) {
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
// this catches all std::vector objects with messageable value type before checking if is also
// has a root dictionary, so non-serialized transmission is preferred
using ValueType = typename T::value_type;

// Note: initial payload size is 0 and will be set by the context before sending
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
return context.add<MessageContext::VectorObject<ValueType>>(std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...).get();
} else if constexpr (has_root_dictionary<T>::value == true && is_messageable<T>::value == false) {
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
// Extended support for types implementing the Root ClassDef interface, both TObject
// derived types and others
if constexpr (enable_root_serialization<T>::value) {
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodROOT, 0);

return context.add<typename enable_root_serialization<T>::object_type>(std::move(headerMessage), routeIndex, std::forward<Args>(args)...).get();
} else {
static_assert(enable_root_serialization<T>::value, "Please make sure you include RootMessageContext.h");
}
// Note: initial payload size is 0 and will be set by the context before sending
} else if constexpr (std::is_base_of_v<std::string, T>) {
auto* s = new std::string(args...);
adopt(spec, s);
return *s;
} else if constexpr (requires { static_cast<struct TableBuilder>(std::declval<std::decay_t<T>>()); }) {
auto tb = std::move(LifetimeHolder<TableBuilder>(new std::decay_t<T>(args...)));
adopt(spec, tb);
return tb;
} else if constexpr (requires { static_cast<struct TreeToTable>(std::declval<std::decay_t<T>>()); }) {
auto t2t = std::move(LifetimeHolder<TreeToTable>(new std::decay_t<T>(args...)));
adopt(spec, t2t);
return t2t;
} else if constexpr (sizeof...(Args) == 0) {
if constexpr (is_messageable<T>::value == true) {
return *reinterpret_cast<T*>(newChunk(spec, sizeof(T)).data());
} else {
static_assert(always_static_assert_v<T>, ERROR_STRING);
}
} else if constexpr (sizeof...(Args) == 1) {
using FirstArg = typename std::tuple_element<0, std::tuple<Args...>>::type;
if constexpr (std::is_integral_v<FirstArg>) {
if constexpr (is_messageable<T>::value == true) {
auto [nElements] = std::make_tuple(args...);
auto size = nElements * sizeof(T);
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);

fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, size);
return context.add<MessageContext::SpanObject<T>>(std::move(headerMessage), routeIndex, 0, nElements).get();
}
} else if constexpr (std::is_same_v<FirstArg, std::shared_ptr<arrow::Schema>>) {
if constexpr (std::is_base_of_v<arrow::ipc::RecordBatchWriter, T>) {
auto [schema] = std::make_tuple(args...);
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;
create(spec, &writer, schema);
return writer;
}
} else {
static_assert(always_static_assert_v<T>, ERROR_STRING);
}
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
// plain buffer as polymorphic spectator std::vector, which does not run constructors / destructors
using ValueType = typename T::value_type;

// Note: initial payload size is 0 and will be set by the context before sending
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
return context.add<MessageContext::VectorObject<ValueType, MessageContext::ContainerRefObject<std::vector<ValueType, o2::pmr::NoConstructAllocator<ValueType>>>>>(
std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...)
.get();
}

template <typename T, typename... Args>
requires VectorOfMessageableTypes<T>
decltype(auto) make(const Output& spec, Args... args)
{
auto& timingInfo = mRegistry.get<TimingInfo>();
auto& context = mRegistry.get<MessageContext>();

auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
// this catches all std::vector objects with messageable value type before checking if is also
// has a root dictionary, so non-serialized transmission is preferred
using ValueType = typename T::value_type;

// Note: initial payload size is 0 and will be set by the context before sending
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
return context.add<MessageContext::VectorObject<ValueType>>(std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...).get();
}

template <typename T, typename... Args>
requires(!VectorOfMessageableTypes<T> && has_root_dictionary<T>::value == true && is_messageable<T>::value == false)
decltype(auto) make(const Output& spec, Args... args)
{
auto& timingInfo = mRegistry.get<TimingInfo>();
auto& context = mRegistry.get<MessageContext>();

auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
// Extended support for types implementing the Root ClassDef interface, both TObject
// derived types and others
if constexpr (enable_root_serialization<T>::value) {
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodROOT, 0);

return context.add<typename enable_root_serialization<T>::object_type>(std::move(headerMessage), routeIndex, std::forward<Args>(args)...).get();
} else {
static_assert(always_static_assert_v<T>, ERROR_STRING);
static_assert(enable_root_serialization<T>::value, "Please make sure you include RootMessageContext.h");
}
}

template <typename T, typename... Args>
requires std::is_base_of_v<std::string, T>
decltype(auto) make(const Output& spec, Args... args)
{
auto* s = new std::string(args...);
adopt(spec, s);
return *s;
}

template <typename T, typename... Args>
requires(requires { static_cast<struct TableBuilder>(std::declval<std::decay_t<T>>()); })
decltype(auto) make(const Output& spec, Args... args)
{
auto tb = std::move(LifetimeHolder<TableBuilder>(new std::decay_t<T>(args...)));
adopt(spec, tb);
return tb;
}

template <typename T, typename... Args>
requires(requires { static_cast<struct TreeToTable>(std::declval<std::decay_t<T>>()); })
decltype(auto) make(const Output& spec, Args... args)
{
auto t2t = std::move(LifetimeHolder<TreeToTable>(new std::decay_t<T>(args...)));
adopt(spec, t2t);
return t2t;
}

template <typename T>
requires is_messageable<T>::value && (!is_specialization_v<T, UninitializedVector>)
decltype(auto) make(const Output& spec)
{
return *reinterpret_cast<T*>(newChunk(spec, sizeof(T)).data());
}

template <typename T>
requires is_messageable<T>::value && (!is_specialization_v<T, UninitializedVector>)
decltype(auto) make(const Output& spec, std::integral auto nElements)
{
auto& timingInfo = mRegistry.get<TimingInfo>();
auto& context = mRegistry.get<MessageContext>();
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);

fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, nElements * sizeof(T));
return context.add<MessageContext::SpanObject<T>>(std::move(headerMessage), routeIndex, 0, nElements).get();
}

template <typename T, typename Arg>
decltype(auto) make(const Output& spec, std::same_as<std::shared_ptr<arrow::Schema>> auto schema)
{
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;
create(spec, &writer, schema);
return writer;
}

/// Adopt a string in the framework and serialize / send
/// it to the consumers of @a spec once done.
void
Expand Down

0 comments on commit 804f7a2

Please sign in to comment.