-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace List Mesh to Tensor #17667
base: main
Are you sure you want to change the base?
Replace List Mesh to Tensor #17667
Conversation
b4bdc28
to
1edac8f
Compare
ba03eea
to
9d1c27c
Compare
1a2d6ec
to
3476824
Compare
2250279
to
8474bfa
Compare
fc27d2c
to
9dc5bb4
Compare
The PR looks good! Make sure to kick off post commits and T3K unit tests to double check nothing is broken. |
7f12318
to
58b94cc
Compare
…types forward, fix bfloat16 type error
…or ttnn.multidevice, fix python hook name recursion error
module.def( "get_mesh_device_core_grid", [](MeshDevice& mesh_device) { CoreCoord coords = mesh_device.compute_with_storage_grid_size(); new CoreGrid(coords.x, coords.y); }, py::arg("mesh_device"));
…tnn/operations/core.py
…iner array generator
58b94cc
to
1721aec
Compare
@@ -9,8 +9,10 @@ | |||
import pathlib | |||
import re | |||
from types import ModuleType | |||
from typing import List |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove all added lines here.
#include <pybind11/operators.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need these here. Please remove.
#include <pybind11/pytypes.h> | ||
#include <cstdint> | ||
#include <memory> | ||
#include <utility> | ||
|
||
#include "tt-metalium/assert.hpp" | ||
#include "tt-metalium/bfloat16.hpp" | ||
#include "tt-metalium/buffer.hpp" | ||
#include "tt-metalium/core_coord.hpp" | ||
#include "tt-metalium/overloaded.hpp" | ||
#include "tt-metalium/core_coord.hpp" | ||
#include "ttnn/common/constants.hpp" | ||
#include "ttnn/distributed/api.hpp" | ||
#include "ttnn/tensor/host_buffer/borrowed_buffer.hpp" | ||
#include "ttnn/tensor/host_buffer/functions.hpp" | ||
#include "ttnn/tensor/layout/page_config.hpp" | ||
#include "ttnn/tensor/storage.hpp" | ||
#include "ttnn/tensor/tensor_impl.hpp" | ||
#include "ttnn/tensor/tensor_utils.hpp" | ||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/types.hpp" | ||
#include <tt-metalium/command_queue.hpp> | ||
#include "pybind11/stl.h" | ||
|
||
#include "umd/device/tt_xy_pair.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've added a lot of new includes. Can you please make sure you're only including the ones you need?
Thanks, @jjiangTT! Just need to clean up some of the includes and then we can merge this in. |
Ticket
Link to Github Issue
#15061
Problem description
ListMeshToTensor was a python class in distributed.py that didn't match the xtensor->torch.tensor convention and instead output a list[torch.tensor]. I've added the utility method ttnn.shardedtensor_to_tensorlist(ttnn.tensor)->list[torch.tensor] instead.
What's changed
-ListMeshToTensor removed, all usages replaced with ttnn.sharded_tensor_to_torch_tensor_list (tensor: ttnn.tensor) hook
-Added temporary python hook in operations/core.py to convert from ttnn.tensor lists to torch.tensor lists
Checklist