Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Feb 9, 2025
1 parent a4de9be commit 10c1fd5
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 23 deletions.
6 changes: 3 additions & 3 deletions lib/models/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {
std::vector<tensor_guid_t> sparse_inputs =
repeat(num_elements(config.embedding_size), [&]() {
return create_input_tensor(
{config.batch_size, config.embedding_bag_size},
DataType::INT64);
{config.batch_size, config.embedding_bag_size}, DataType::INT64);
});

tensor_guid_t dense_input = create_input_tensor(
Expand All @@ -146,7 +145,8 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {

std::vector<tensor_guid_t> emb_outputs = transform(
zip(config.embedding_size, sparse_inputs),
[&](std::pair<nonnegative_int, tensor_guid_t> const &combined_pair) -> tensor_guid_t {
[&](std::pair<nonnegative_int, tensor_guid_t> const &combined_pair)
-> tensor_guid_t {
return create_dlrm_sparse_embedding_network(
/*cgb=*/cgb,
/*config=*/config,
Expand Down
5 changes: 3 additions & 2 deletions lib/op-attrs/include/op-attrs/datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ using real_type_t = typename data_type_enum_to_class<DT>::type;
nonnegative_int size_of_datatype(DataType);

/**
* @brief Maximally semantics-preserving casts, not including identity
* @brief Maximally semantics-preserving casts, not including identity
* casts (e.g., `float -> float` returns `false`)
*/
bool can_strictly_promote_datatype_from_to(DataType from, DataType to);

/**
* @brief Equivalent to [`torch.can_cast`](https://pytorch.org/docs/stable/generated/torch.can_cast.html),
* @brief Equivalent to
* [`torch.can_cast`](https://pytorch.org/docs/stable/generated/torch.can_cast.html),
* except that identity casts (e.g., `float -> float`) return `false`
*/
bool can_torch_strictly_promote_datatype_from_to(DataType from, DataType to);
Expand Down
20 changes: 14 additions & 6 deletions lib/op-attrs/src/op-attrs/datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ bool can_strictly_promote_datatype_from_to(DataType src, DataType dst) {
std::unordered_set<DataType> allowed;
switch (src) {
case DataType::BOOL:
allowed = {
DataType::INT32, DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE};
allowed = {DataType::INT32,
DataType::INT64,
DataType::HALF,
DataType::FLOAT,
DataType::DOUBLE};
break;
case DataType::INT32:
allowed = {DataType::INT64};
Expand All @@ -55,14 +58,19 @@ bool can_torch_strictly_promote_datatype_from_to(DataType src, DataType dst) {
std::unordered_set<DataType> allowed;
switch (src) {
case DataType::BOOL:
allowed = {
DataType::INT32, DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE};
allowed = {DataType::INT32,
DataType::INT64,
DataType::HALF,
DataType::FLOAT,
DataType::DOUBLE};
break;
case DataType::INT32:
allowed = {DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE};
allowed = {
DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE};
break;
case DataType::INT64:
allowed = {DataType::INT32, DataType::HALF, DataType::FLOAT, DataType::DOUBLE};
allowed = {
DataType::INT32, DataType::HALF, DataType::FLOAT, DataType::DOUBLE};
break;
case DataType::HALF:
allowed = {DataType::FLOAT, DataType::DOUBLE};
Expand Down
6 changes: 4 additions & 2 deletions lib/op-attrs/src/op-attrs/ops/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ namespace FlexFlow {
tl::expected<TensorShape, std::string>
get_output_shape(CastAttrs const &attrs, TensorShape const &input) {

if (!can_torch_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) {
if (!can_torch_strictly_promote_datatype_from_to(input.data_type,
attrs.dtype)) {
return tl::unexpected(fmt::format(
"Cast cannot strictly promote input datatype {} to output datatype {}",
input.data_type,
Expand All @@ -21,7 +22,8 @@ tl::expected<TensorShape, std::string>
tl::expected<ParallelTensorShape, std::string>
get_output_shape(CastAttrs const &attrs, ParallelTensorShape const &input) {

if (!can_torch_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) {
if (!can_torch_strictly_promote_datatype_from_to(input.data_type,
attrs.dtype)) {
return tl::unexpected(fmt::format(
"Cast cannot strictly promote input datatype {} to output datatype {}",
input.data_type,
Expand Down
21 changes: 11 additions & 10 deletions lib/op-attrs/test/src/op-attrs/datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ TEST_SUITE(FF_TEST_SUITE) {
}

TEST_CASE("can_torch_strictly_promote_datatype_from_to(DataType, DataType)") {
CHECK(
can_torch_strictly_promote_datatype_from_to(DataType::BOOL, DataType::INT32));
CHECK(can_torch_strictly_promote_datatype_from_to(DataType::BOOL,
DataType::INT32));
CHECK(can_torch_strictly_promote_datatype_from_to(DataType::INT32,
DataType::INT64));
DataType::INT64));
CHECK(can_torch_strictly_promote_datatype_from_to(DataType::FLOAT,
DataType::DOUBLE));
DataType::DOUBLE));

RC_SUBCASE("is strict", [](DataType d) {
RC_ASSERT(!can_torch_strictly_promote_datatype_from_to(d, d));
});

RC_SUBCASE("is transitive if end-points are not the same", [](DataType d1, DataType d2, DataType d3) {
RC_PRE(can_torch_strictly_promote_datatype_from_to(d1, d2));
RC_PRE(can_torch_strictly_promote_datatype_from_to(d2, d3));
RC_PRE(d1 != d3);
RC_ASSERT(can_torch_strictly_promote_datatype_from_to(d1, d3));
});
RC_SUBCASE("is transitive if end-points are not the same",
[](DataType d1, DataType d2, DataType d3) {
RC_PRE(can_torch_strictly_promote_datatype_from_to(d1, d2));
RC_PRE(can_torch_strictly_promote_datatype_from_to(d2, d3));
RC_PRE(d1 != d3);
RC_ASSERT(can_torch_strictly_promote_datatype_from_to(d1, d3));
});
}
}

0 comments on commit 10c1fd5

Please sign in to comment.