Skip to content

Commit

Permalink
[executorch] Migrate most of extension/... to new namespace
Browse files Browse the repository at this point in the history
Pull Request resolved: #4617

Migrate these headers to the new `::executorch::extension` namespace. Add temporary aliases from the old `::torch::executor` namespace so we can migrate users incrementally.

ghstack-source-id: 239152036
@exported-using-ghexport

Differential Revision: [D60938936](https://our.internmc.facebook.com/intern/diff/D60938936/)
  • Loading branch information
dbort committed Aug 21, 2024
1 parent b66d62a commit 48e2d57
Show file tree
Hide file tree
Showing 49 changed files with 715 additions and 429 deletions.
22 changes: 13 additions & 9 deletions extension/aten_util/aten_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#include <executorch/runtime/platform/assert.h>
#include <cstring>

namespace torch {
namespace util {
namespace executorch {
namespace extension {

namespace {
void check_tensor_meta(const at::Tensor& a, const exec_aten::Tensor& b) {
Expand Down Expand Up @@ -55,14 +55,15 @@ ET_CHECK_MSG(
}
// check dtype
ET_CHECK_MSG(
b.scalar_type() == torchToExecuTorchScalarType(a.options().dtype()),
b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()),
"dtypes dont match a %hhd vs. b %hhd",
torchToExecuTorchScalarType(a.options().dtype()),
torch_to_executorch_scalar_type(a.options().dtype()),
b.scalar_type());
}
} // namespace

torch::executor::ScalarType torchToExecuTorchScalarType(caffe2::TypeMeta type) {
torch::executor::ScalarType torch_to_executorch_scalar_type(
caffe2::TypeMeta type) {
switch (c10::typeMetaToScalarType(type)) {
case c10::ScalarType::Byte:
return torch::executor::ScalarType::Byte;
Expand Down Expand Up @@ -91,7 +92,8 @@ torch::executor::ScalarType torchToExecuTorchScalarType(caffe2::TypeMeta type) {
}
}

c10::ScalarType execuTorchtoTorchScalarType(torch::executor::ScalarType type) {
c10::ScalarType executorch_to_torch_scalar_type(
torch::executor::ScalarType type) {
switch (type) {
case torch::executor::ScalarType::Byte:
return c10::ScalarType::Byte;
Expand Down Expand Up @@ -147,7 +149,8 @@ void alias_etensor_to_attensor(
}

at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
c10::ScalarType dtype = execuTorchtoTorchScalarType(etensor.scalar_type());
c10::ScalarType dtype =
executorch_to_torch_scalar_type(etensor.scalar_type());
std::vector<int64_t> at_tensor_sizes(
etensor.sizes().begin(), etensor.sizes().end());
std::vector<int64_t> at_tensor_strides(
Expand All @@ -162,5 +165,6 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
check_tensor_meta(t, etensor);
return t;
}
} // namespace util
} // namespace torch

} // namespace extension
} // namespace executorch
42 changes: 38 additions & 4 deletions extension/aten_util/aten_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
#include <memory>
#include <vector>

namespace torch {
namespace util {
namespace executorch {
namespace extension {

torch::executor::ScalarType torchToExecuTorchScalarType(caffe2::TypeMeta type);
torch::executor::ScalarType torch_to_executorch_scalar_type(
caffe2::TypeMeta type);

c10::ScalarType execuTorchtoTorchScalarType(torch::executor::ScalarType type);
c10::ScalarType executorch_to_torch_scalar_type(
torch::executor::ScalarType type);

/*
* @param[in] aten_tensor Input at::Tensor
Expand All @@ -45,5 +47,37 @@ void alias_etensor_to_attensor(at::Tensor& at, torch::executor::Tensor& et);
* cloned.
*/
at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& et);

} // namespace extension
} // namespace executorch

namespace torch {
namespace executor {
namespace util {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::alias_attensor_to_etensor;
using ::executorch::extension::alias_etensor_to_attensor;
inline torch::executor::ScalarType torchToExecuTorchScalarType(
caffe2::TypeMeta type) {
return ::executorch::extension::torch_to_executorch_scalar_type(type);
}
inline c10::ScalarType execuTorchtoTorchScalarType(
torch::executor::ScalarType type) {
return ::executorch::extension::executorch_to_torch_scalar_type(type);
}
} // namespace util
} // namespace executor
} // namespace torch

// Some users refer to these as `torch::util::`.
namespace torch {
namespace util {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::torch::executor::util::alias_attensor_to_etensor;
using ::torch::executor::util::alias_etensor_to_attensor;
using ::torch::executor::util::execuTorchtoTorchScalarType;
using ::torch::executor::util::torchToExecuTorchScalarType;
} // namespace util
} // namespace torch
61 changes: 39 additions & 22 deletions extension/aten_util/make_aten_functor_from_et_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <torch/torch.h>

namespace torch {
namespace executor {
namespace executorch {
namespace extension {
namespace internal {

// Map types from ETen to ATen.
// This is used to convert ETen arguments into ATen.
Expand Down Expand Up @@ -105,29 +106,35 @@ struct type_convert<
torch::executor::Tensor>>>
final {
explicit type_convert(ATensor value) : value_(value) {
auto sizes = std::make_shared<std::vector<Tensor::SizesType>>(
value_.sizes().begin(), value_.sizes().end());
auto sizes =
std::make_shared<std::vector<torch::executor::Tensor::SizesType>>(
value_.sizes().begin(), value_.sizes().end());
const ssize_t dim = sizes->size();
auto dim_order = std::make_shared<std::vector<Tensor::DimOrderType>>(dim);
auto strides = std::make_shared<std::vector<Tensor::StridesType>>(dim);
auto dim_order =
std::make_shared<std::vector<torch::executor::Tensor::DimOrderType>>(
dim);
auto strides =
std::make_shared<std::vector<torch::executor::Tensor::StridesType>>(
dim);

std::iota(dim_order->begin(), dim_order->end(), 0);
dim_order_to_stride_nocheck(
::executorch::runtime::dim_order_to_stride_nocheck(
sizes->data(), dim_order->data(), dim, strides->data());

auto tensor_impl = std::make_shared<TensorImpl>(
auto tensor_impl = std::make_shared<torch::executor::TensorImpl>(
static_cast<torch::executor::ScalarType>(value_.scalar_type()),
sizes->size(),
sizes->data(),
value_.mutable_data_ptr(),
dim_order->data(),
strides->data());

converted_ = std::unique_ptr<Tensor, std::function<void(Tensor*)>>(
new Tensor(tensor_impl.get()),
[sizes, dim_order, strides, tensor_impl](Tensor* pointer) {
delete pointer;
});
converted_ = std::unique_ptr<
torch::executor::Tensor,
std::function<void(torch::executor::Tensor*)>>(
new torch::executor::Tensor(tensor_impl.get()),
[sizes, dim_order, strides, tensor_impl](
torch::executor::Tensor* pointer) { delete pointer; });
}

ETensor call() {
Expand All @@ -136,7 +143,10 @@ struct type_convert<

private:
ATensor value_;
std::unique_ptr<Tensor, std::function<void(Tensor*)>> converted_;
std::unique_ptr<
torch::executor::Tensor,
std::function<void(torch::executor::Tensor*)>>
converted_;
};

// Tensors: ETen to ATen.
Expand Down Expand Up @@ -258,7 +268,12 @@ struct wrapper_impl<R (*)(Args...), f, int, N> {
using TupleArgsType = std::tuple<typename type_map<Args>::type...>;
static constexpr size_t num_args = sizeof...(Args);
static_assert(
(N < num_args && std::is_same_v<element_t<N, typelist<Args...>>, R>) ||
(N < num_args &&
std::is_same_v<
executorch::extension::kernel_util_internal::element_t<
N,
executorch::extension::kernel_util_internal::typelist<Args...>>,
R>) ||
N == -1,
"The index of the out tensor can't be greater or equal to num_args and "
"the Nth argument type has to be the same as the return type.");
Expand Down Expand Up @@ -298,16 +313,18 @@ struct wrapper_impl<R (*)(Args...), f, int, N> {
}
};

} // namespace executor
} // namespace torch
} // namespace internal
} // namespace extension
} // namespace executorch

// Wrapper macro for out variant function. N is the index of the out tensor.
// We need N to know how to preserve the semantics of modifying out tensor and
// return the reference without allocating a new memory buffer for out tensor.
#define _WRAP_2(func, N) \
::torch::executor::wrapper_impl<decltype(&func), func, decltype(N), N>::wrap
#define _WRAP_2(func, N) \
::executorch::extension::internal:: \
wrapper_impl<decltype(&func), func, decltype(N), N>::wrap
#define _WRAP_1(func) \
::torch::executor::wrapper_impl<decltype(&func), func>::wrap
::executorch::extension::internal::wrapper_impl<decltype(&func), func>::wrap

#define GET_MACRO(_1, _2, NAME, ...) NAME
#define WRAP_TO_ATEN(...) GET_MACRO(__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__)
#define _GET_MACRO(_1, _2, NAME, ...) NAME
#define WRAP_TO_ATEN(...) _GET_MACRO(__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__)
2 changes: 1 addition & 1 deletion extension/aten_util/test/aten_bridge_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include <gtest/gtest.h>

using namespace ::testing;
using namespace torch::util;
using namespace torch::executor;
using namespace torch::executor::util;

namespace {
at::Tensor generate_at_tensor() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
#include <gtest/gtest.h>
#include <torch/library.h>

namespace torch {
namespace executor {

using namespace ::testing;
using ::executorch::extension::internal::type_convert;
using ::executorch::extension::internal::type_map;
using ::torch::executor::ScalarType;
using ::torch::executor::Tensor;

Tensor& my_op_out(const Tensor& a, Tensor& out) {
(void)a;
Expand Down Expand Up @@ -420,6 +421,3 @@ TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) {
EXPECT_EQ(stack.size(), 1);
EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 4);
}

} // namespace executor
} // namespace torch
28 changes: 19 additions & 9 deletions extension/data_loader/buffer_data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
#include <executorch/runtime/platform/log.h>
#include <cstring>

namespace torch {
namespace executor {
namespace util {
namespace executorch {
namespace extension {

/**
* A DataLoader that wraps a pre-allocated buffer. The FreeableBuffers
Expand All @@ -25,12 +24,13 @@ namespace util {
* This can be used to wrap data that is directly embedded into the firmware
* image, or to wrap data that was allocated elsewhere.
*/
class BufferDataLoader final : public DataLoader {
class BufferDataLoader final : public executorch::runtime::DataLoader {
public:
BufferDataLoader(const void* data, size_t size)
: data_(reinterpret_cast<const uint8_t*>(data)), size_(size) {}

ET_NODISCARD Result<FreeableBuffer> load(
ET_NODISCARD
executorch::runtime::Result<executorch::runtime::FreeableBuffer> load(
size_t offset,
size_t size,
ET_UNUSED const DataLoader::SegmentInfo& segment_info) const override {
Expand All @@ -41,14 +41,15 @@ class BufferDataLoader final : public DataLoader {
offset,
size,
size_);
return FreeableBuffer(data_ + offset, size, /*free_fn=*/nullptr);
return executorch::runtime::FreeableBuffer(
data_ + offset, size, /*free_fn=*/nullptr);
}

ET_NODISCARD Result<size_t> size() const override {
ET_NODISCARD executorch::runtime::Result<size_t> size() const override {
return size_;
}

ET_NODISCARD Error load_into(
ET_NODISCARD executorch::runtime::Error load_into(
size_t offset,
size_t size,
ET_UNUSED const SegmentInfo& segment_info,
Expand All @@ -63,14 +64,23 @@ class BufferDataLoader final : public DataLoader {
return result.error();
}
std::memcpy(buffer, result->data(), size);
return Error::Ok;
return executorch::runtime::Error::Ok;
}

private:
const uint8_t* const data_; // uint8 is easier to index into.
const size_t size_;
};

} // namespace extension
} // namespace executorch

namespace torch {
namespace executor {
namespace util {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::BufferDataLoader;
} // namespace util
} // namespace executor
} // namespace torch
15 changes: 9 additions & 6 deletions extension/data_loader/file_data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@
#define ET_HAVE_PREAD 1
#endif // !ET_HAVE_PREAD

namespace torch {
namespace executor {
namespace util {
using executorch::runtime::Error;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::Result;

namespace executorch {
namespace extension {

namespace {

/**
Expand Down Expand Up @@ -287,6 +291,5 @@ ET_NODISCARD Error FileDataLoader::load_into(
return Error::Ok;
}

} // namespace util
} // namespace executor
} // namespace torch
} // namespace extension
} // namespace executorch
Loading

0 comments on commit 48e2d57

Please sign in to comment.