forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PyInterpreter.cpp
988 lines (889 loc) · 34.6 KB
/
PyInterpreter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
#include <ATen/core/PythonFallbackKernel.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <string>
using namespace torch;
using namespace at;
using namespace c10;
namespace torch::detail {
namespace {
// NB: This is a macro and not a template function (like it was before)
// because passing in constexpr char* as template argument breaks some
// versions of MSVC that are being used internally at Meta.
// MSVC 14.16.27023 (vs2017_15.9)
#define CONCRETE_GPU_TRACE(device_type, func_name, ...) \
at::impl::MaybeSetTLSOnEntryGuard guard; \
if (Py_IsInitialized()) { \
pybind11::gil_scoped_acquire gil; \
try { \
/* Masquerade hip as cuda because hip uses `torch.cuda` module. */ \
if (device_type == at::kHIP) { \
device_type = at::kCUDA; \
} \
std::string module_name = "torch." + DeviceTypeName(device_type, true); \
py::module mod = py::module::import(module_name.c_str()); \
py::object hook = \
mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks"); \
hook(__VA_ARGS__); \
} catch (const std::exception& e) { \
LOG(ERROR) << device_type \
<< " trace hook execution failed: " << e.what(); \
} \
}
struct ConcretePyInterpreterVTable final
: public c10::impl::PyInterpreterVTable {
std::string name() const override;
void incref(PyObject* pyobj) const override;
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
// operate upon a PyObjectSlot rather than a TensorImpl
c10::intrusive_ptr<c10::TensorImpl> detach(
const c10::TensorImpl* self) const override;
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const override;
void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const override;
// NB: this is defined in python_dispatch.cpp
void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey key,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset,
bool with_op) const override {
torch::impl::dispatch::python_op_registration_trampoline_impl(
op, key, keyset, stack, with_keyset, with_op);
}
void throw_abstract_impl_not_imported_error(
std::string opname,
const char* pymodule,
const char* context) const override {
py::gil_scoped_acquire gil;
pybind11::module::import("torch._utils_internal")
.attr("throw_abstract_impl_not_imported_error")(
opname, pymodule, context);
}
bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
c10::Device device(const c10::TensorImpl* self) const override;
int64_t dim(const c10::TensorImpl* self) const override;
c10::IntArrayRef strides(const c10::TensorImpl* self) const override;
c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
c10::Layout layout(const c10::TensorImpl* self) const override;
int64_t numel(const c10::TensorImpl* self) const override;
c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
}
void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
}
void trace_gpu_event_record(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
}
void trace_gpu_event_wait(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
}
void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
}
void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
}
void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
const override {
CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
}
void trace_gpu_device_synchronization(
at::DeviceType device_type) const override {
CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
}
void trace_gpu_stream_synchronization(
at::DeviceType device_type,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
}
void trace_gpu_event_synchronization(
at::DeviceType device_type,
uintptr_t event) const override {
CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
}
void reset_backward_hooks(const c10::TensorImpl* self) const override;
static ConcretePyInterpreterVTable* instance() {
static ConcretePyInterpreterVTable s;
return &s;
}
};
class PyInterpreterHolder {
public:
PyInterpreterHolder()
: impl_(new c10::impl::PyInterpreter(
ConcretePyInterpreterVTable::instance())),
is_main_interpreter_(
at::impl::PythonOpRegistrationTrampoline::registerInterpreter(
impl_)) {}
// NB: intentionally leaks the PyInterpreter, as there may still be
// references to it that are live, living in objects that aren't being
// destructed while Python is being cleaned up.
~PyInterpreterHolder() {
impl_->disarm();
}
c10::impl::PyInterpreter* get() const noexcept {
return impl_;
}
bool is_main_interpreter() const noexcept {
return is_main_interpreter_;
}
private:
c10::impl::PyInterpreter* impl_;
bool is_main_interpreter_;
};
py::object torchDispatchFromTensorImpl(
const c10::TensorImpl* self,
const char* func_name,
PyObject* torch_api_function,
const char* module_name,
// WARNING: MUST NOT BE TENSOR ARGS
c10::SmallVector<py::object, 1> extra_args = {}) {
if (torch_api_function == nullptr) {
throw python_error();
}
TORCH_CHECK(
PyGILState_Check(),
"GIL must be held before you call parseIValuesToPyArgsKwargs");
std::vector<PyObject*> overloaded_args;
// TODO: there should be a shorter way to spell this
// TODO: fix the constness of target
at::Tensor self_t = at::Tensor(
c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p =
py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
// NB: this may not be a python tensor if you got here from a mode!
// TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
append_overloaded_tensor(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(
PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size())));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
int64_t i = 1;
for (auto& a : extra_args) {
if (a.ptr() == nullptr)
throw python_error();
PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
i++;
}
py::dict kwargs;
return py::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
func_name,
torch_api_function,
module_name,
TorchFunctionName::TorchDispatch));
}
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Before calling PyInterpreter::decref, we must statically know if the
// pyobj has a PyObjectSlot or not.
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
// - If it does not have a PyObjectSlot, we can freely decref
// One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here.
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
const {
// Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned)
// PyObjects stored in them.
if (!Py_IsInitialized())
return;
pybind11::gil_scoped_acquire gil;
// Two possibilities:
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
// Storage. Then we must be careful about PyObject resurrection (see
// THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
if (THPVariable_Check(pyobj)) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject
// so that it fails on subsequent uses. Don't raise an error here;
// you're probably in a destructor.
TORCH_WARN(
"Deallocating Tensor that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"Tensor and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this tensor via the PyObject will now fail.");
((THPVariable*)pyobj)->cdata =
c10::MaybeOwned<torch::autograd::Variable>();
} else if (THPStorage_Check(pyobj)) {
TORCH_WARN(
"Deallocating UntypedStorage that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this storage via the PyObject will now fail.");
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
}
}
Py_DECREF(pyobj);
};
void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
if (!Py_IsInitialized())
return;
pybind11::gil_scoped_acquire gil;
Py_INCREF(pyobj);
};
bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
void ConcretePyInterpreterVTable::reportErrorCallback(
PyObject* callback,
DispatchKey key) const {
py::gil_scoped_acquire g;
auto func = py::reinterpret_borrow<py::object>(callback);
// Not all DispatchKeys are pybind'ed into Python and we do not have infra
// to ensure this, so just pass a string back to Python.
func(c10::toString(key));
}
void ConcretePyInterpreterVTable::dispatch(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) const {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
// The plan: convert all the arguments back into PyObjects,
// extracting out the tensor handles, then call
// handle_torch_function_no_python_arg_parser
// NB: at the point arguments are pushed to the stack, ALL defaults
// are already present
py::gil_scoped_acquire g;
std::vector<PyObject*> overloaded_args;
py::handle torch_api_function_overload = getTorchApiFunction(op);
// Find overloaded tensors
for (const auto idx : c10::irange(arguments.size())) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
const auto& tensor = ivalue.toTensor();
if (isPythonTensor(tensor)) {
append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
}
} else if (ivalue.isList()) {
const auto& list = ivalue.toListRef();
for (const auto jdx : c10::irange(list.size())) {
const auto& nv = list[jdx];
if (nv.isTensor()) {
const auto& tensor = nv.toTensor();
if (isPythonTensor(tensor)) {
append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
}
}
}
}
}
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);
auto kwargs = std::move(args_kwargs.second);
PyObject* obj = handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
nullptr,
torch_api_function_overload.ptr(),
nullptr,
TorchFunctionName::TorchDispatch);
pushPyOutToStack(
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
}
void ConcretePyInterpreterVTable::python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet ks,
torch::jit::Stack* stack) const {
py::gil_scoped_acquire g;
py::handle torch_api_function_overload = getTorchApiFunction(op);
// TODO: if necessary, can optimize to cache the cache lookup
// TODO: if necessary, can optimize OpOverload to have slots
auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache"));
if (cache.ptr() == nullptr) {
throw python_error();
}
c10::DispatchKey k = ks.highestPriorityTypeId();
// TODO: allow this to be non-owning
auto handler = py::reinterpret_borrow<py::object>(
PyDict_GetItem(cache.ptr(), py::cast(k).ptr()));
if (handler.ptr() == nullptr) {
// Slow path
handler = torch_api_function_overload.attr("_get_dispatch")(k);
}
if (py::isinstance<c10::DispatchKey>(handler)) {
// NB: not redispatch, as that will permanently remove the python
// dispatcher for subsequent redispatches
op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack);
return;
}
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);
auto kwargs = std::move(args_kwargs.second);
py::object obj = py::reinterpret_steal<py::object>(
PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
if (obj.ptr() == nullptr) {
throw python_error();
}
pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
}
c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"detach",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("detach")
.attr("default")
.ptr(),
"torch.ops.aten");
TORCH_CHECK(
THPVariable_Check(out.ptr()),
"detach returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected Tensor");
const at::Tensor& res_t = THPVariable_Unpack(out.ptr());
return res_t.getIntrusivePtr();
}
bool ConcretePyInterpreterVTable::is_contiguous(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
py::object out;
if (memory_format == at::MemoryFormat::Contiguous) {
// For backwards compatibility
out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("default")
.ptr(),
"torch.ops.aten");
} else {
out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("memory_format")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
}
if (out.is_none()) {
return self->is_contiguous_default(memory_format);
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_contiguous returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
bool ConcretePyInterpreterVTable::is_strides_like(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"is_strides_like",
py::module::import("torch")
.attr("ops")
.attr("aten")
// NB: intentionally suffixed with _format to avoid
// triggering matches against "_like" suffix
.attr("is_strides_like_format")
.attr("default")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
if (out.is_none()) {
return self->is_strides_like_default(memory_format);
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_strides_like_format returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"is_non_overlapping_and_dense",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_non_overlapping_and_dense")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->is_non_overlapping_and_dense_default();
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_non_overlapping_and_dense returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"dim",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("dim")
.attr("default")
.ptr(),
"torch.ops.aten");
TORCH_CHECK(
PyLong_Check(out.ptr()),
"dim returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected int");
return THPUtils_unpackLong(out.ptr());
}
c10::Device ConcretePyInterpreterVTable::device(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"device",
py::module::import("torch")
.attr("ops")
.attr("prim")
.attr("device")
.attr("default")
.ptr(),
"torch.ops.prim");
return toDevice(out.ptr());
}
static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor,
py::capsule& capsule,
const char* attr_name) {
std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto obj = mb_obj.value();
py::handle(obj).attr(attr_name) = capsule;
}
// Note [Tensor Subclass custom size/stride caching strategy]
// Tensor subclasses can use __torch_dispatch__ to override size/stride calls.
// However, this presents a problem:
// (1) When you return a custom (maybe symbolic) size/stride
// from python, we need to stash this fresh vector of ints/symints
// somewhere so that it has the same lifetime as the tensor.
// (2) If the subclass experiences a metadata mutation,
// this stashed vector is no longer valid, so we need to allocate a fresh
// buffer to store the new sizes the next time someone asks for them.
//
// We handle this in the same way that `TensorImpl::sizes_default()`
// handles its buffer: we simply reallocate the buffer whenever
// the number of dimensions changes due to a resize.
// Notable, we do *not* reallocate the buffer if the values changed,
// but the number of dimensions stayed the same (e.g. `.transpose_()`).
template <typename T>
static c10::ArrayRef<T> get_set_cached_attr(
const c10::TensorImpl* tensor,
const char* base_attr_name,
const py::object& obj) {
std::optional<PyObject*> mb_obj =
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto tensor_obj = mb_obj.value();
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
bool is_buffer_allocated = false;
size_t curr_size = 0;
if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) {
auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str());
curr_size = py::cast<size_t>(len_pyobj);
is_buffer_allocated = true;
}
size_t new_size = py::len(obj);
// We do the smallvector optimization here: any time the new_size is <=5,
// we always allocate our buffer to size 5, so that if the next resize
// is also to <=5 elements, we don't need to reallocate.
// Note: I tried removing this optimization and tripped ASAN
// in a batchnorm kernel here:
// https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D
// We should fix this instead.
bool needs_resize = false;
// We need to resize if:
// (1) we haven't allocated our buffer at all yet
// (2) Our buffer size is different from the new size
// (note: we use the small vector optimization, where our buffer
// is always allocated to at least size 5, and any resizes
// within the <= 5 regime to not require a reallocation).
auto is_smallvector = curr_size <= 5;
needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) ||
(!is_smallvector && curr_size != new_size);
if (needs_resize) {
// If our current buffer is not the right size (either because we haven't
// allocated it yet, or there was a metadata mutation that changed the
// number of dims of the tensor), allocate a fresh buffer. Note that this
// will trash the previous buffer if there already was one, invalidating any
// existing SymIntArrayRef's from an old .sym_size() call.
auto new_buffer_size = new_size;
if (new_size <= 5) {
// This is the smallvector optimization
new_buffer_size = 5;
}
T* ptr = new T[new_buffer_size];
auto capsule =
py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); });
int64_t idx = 0;
for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
ptr[idx] = py::cast<T>(*it);
}
// Set the buffer
set_tensor_attr_with_capsule(tensor, capsule, base_attr_name);
// Set the len buffer
py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size;
} else {
TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name));
auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
void* buffer_pycapsule =
PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
// Overwrite the buffer with our new values, but only if any of them changed
// (due to a metadata mutation).
// This is technically not thread safe, because the update happens lazily.
// The original metadata mutation call on the tensor might have been thread
// safe (e.g. a .resize_() call), but we won't actually mutate the size
// buffer until the first call to .sizes() which the user might not access
// in a thread-safe way. For now we are not explicitly locking, but maybe we
// should.
int64_t idx = 0;
// Quick sanity assert that our buffer size is large enough
// to compare against all the elements in the new buffer.
size_t curr_buffer_size = 5;
if (curr_buffer_size < curr_size) {
curr_buffer_size = curr_size;
}
TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size);
for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
auto actual_val = py::cast<T>(*it);
if constexpr (std::is_same_v<T, c10::SymInt>) {
// if our SymInts are symbolic, we are *not* doing an equality check on
// the symints. we just want to see if the nodes are the same. this is
// because we don't want to introduce any guards here.
if (!curr_buffer[idx].is_same(actual_val)) {
curr_buffer[idx] = actual_val;
}
} else {
if (curr_buffer[idx] != actual_val) {
curr_buffer[idx] = actual_val;
}
}
}
}
// The correct data is now stored at the buffer - read and return it.
auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
void* buffer_pycapsule =
PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
return c10::ArrayRef<T>(curr_buffer, new_size);
}
c10::IntArrayRef ConcretePyInterpreterVTable::strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"stride",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("stride")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call strides on a tensor with symbolic shapes/strides");
return self->strides_default();
}
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"strides must be a list or a tuple");
auto updated_strides =
get_set_cached_attr<int64_t>(self, "_strides_capsule", out);
return updated_strides;
}
c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"size",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("size")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
return self->sizes_default();
}
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sizes must be a list or a tuple");
auto updated_sizes =
get_set_cached_attr<int64_t>(self, "_sizes_capsule", out);
return updated_sizes;
END_HANDLE_TH_ERRORS_PYBIND
}
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"sym_size",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_size")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_sizes_default();
}
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sym_size must be a list or a tuple");
// See Note [Tensor Subclass custom size/stride caching strategy]
auto updated_sym_sizes =
get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out);
return updated_sym_sizes;
END_HANDLE_TH_ERRORS_PYBIND
}
c10::Layout ConcretePyInterpreterVTable::layout(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"layout",
py::module::import("torch")
.attr("ops")
.attr("prim")
.attr("layout")
.attr("default")
.ptr(),
"torch.ops.prim");
TORCH_CHECK(
THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()),
"layout returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected Layout");
if (THPLayout_Check(out.ptr())) {
return toLayout(out.ptr());
} else {
return c10::Layout(py::cast<int64_t>(out));
}
}
int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"numel",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("numel")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
return self->numel_default();
}
return py::cast<int64_t>(out);
}
c10::SymInt ConcretePyInterpreterVTable::sym_numel(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"sym_numel",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_numel")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_numel_default();
}
return torch::is_symint(out) ? out.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"sym_storage_offset",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_storage_offset")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_storage_offset_default();
}
return torch::is_symint(out) ? out.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"sym_stride",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_stride")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_strides_default();
}
// We need to squeeze SymIntNodes and ints into `SymInts`
// since it's a format `sym_strides()` are stored in
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sym_strides must be a list or a tuple");
auto updated_sym_strides =
get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out);
return updated_sym_strides;
END_HANDLE_TH_ERRORS_PYBIND
}
void ConcretePyInterpreterVTable::reset_backward_hooks(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
Tensor self_t =
Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p =
py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
END_HANDLE_TH_ERRORS_PYBIND
}
std::string ConcretePyInterpreterVTable::name() const {
std::stringstream ss;
ss << getPyInterpreter();
return ss.str();
}
PyInterpreterHolder self_interpreter;
} // anonymous namespace
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name.empty()) {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
} // namespace torch::detail
c10::impl::PyInterpreter* getPyInterpreter() {
return torch::detail::self_interpreter.get();
}
bool isMainPyInterpreter() {
return torch::detail::self_interpreter.is_main_interpreter();
}