// Copyright (c) Facebook, Inc. and its affiliates. // All rights reserved. // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include // Many APIs have changed/don't exist anymore #if IS_PYTHON_3_12_PLUS #include "dim.h" // Re-enable this some day PyObject* Dim_init() { PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); return nullptr; } #else #include "minpybind.h" #include #include #include #include #include #include //#include #include #include #include #include #include #include "arena.h" #include "dim.h" #include "python_variable_simple.h" #if IS_PYTHON_3_11_PLUS #define Py_BUILD_CORE #include "internal/pycore_opcode.h" #undef Py_BUILD_CORE #endif // C++ API functions for objects to // * construct the object, returning a ref-counted handle // * The actual API, with methods that take/return C-typed values // extend minpybind.h to include // * typed handles so that -> can get to their raw API // * object/handle distinction for the typed handles // class Dim: --------------- mpy::handle torch_Tensor___mul__; mpy::handle _Tensor; mpy::handle _Tensor_sum; mpy::handle NamedTuple; mpy::dict_view pointwise; mpy::handle torch_Tensor_expand; binaryfunc THPVariable_getitem; objobjargproc THPVariable_setitem; mpy::handle no_slice; PyTypeObject* torch_Tensor; mpy::handle torch_Tensor_copy_; mpy::handle torch_Tensor_split; bool pointwise_optimize = true; PyTypeObject* DimType = nullptr; PyObject* Tensor_getitem(PyObject* self, PyObject* index); int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); namespace{ void maybeInitializeGlobals() { // globals that depend on the python dim library, // which we can't lookup until we finish initializing the _C module if (_Tensor.ptr()) { return; } auto dim = mpy::import("functorch.dim"); _Tensor = dim.attr("_Tensor"); pointwise = dim.attr("pointwise"); _Tensor_sum = _Tensor.attr("sum"); DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr(); } void replaceMappingIfMatches(mpy::handle tp) { auto T = (PyTypeObject*) tp.ptr(); bool recurse = false; if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { T->tp_as_mapping->mp_subscript = Tensor_getitem; recurse = true; } if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; recurse = true; } if (recurse) { auto result = tp.attr("__subclasses__").call(); mpy::list_view lv(result); for (auto i : lv.enumerate()) { replaceMappingIfMatches(lv[i]); } } } void initializeGlobals(Arena & A) { auto torch = mpy::import("torch"); torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr(); torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); auto py_TensorBase = torch.attr("_C").attr("TensorBase"); auto TensorBase = (PyTypeObject*) py_TensorBase.ptr(); THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; NamedTuple = mpy::import("typing").attr("NamedTuple"); no_slice = PySlice_New(NULL, NULL, NULL); } mpy::handle DimensionBindError_; mpy::handle DimensionBindError() { if(!DimensionBindError_.ptr()) { DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError"); } return DimensionBindError_; } static int64_t n_dims_created = 65; struct Dim : public mpy::base { int64_t level_; // for stable comparisons in prototype mpy::object name_; Dim() : level_(n_dims_created++) {} void init(mpy::object name, int64_t s = -1) { name_ = std::move(name); size_ = s; } static bool check_exact(mpy::handle v) { return Py_TYPE(v.ptr()) == DimType; } int64_t size() const { if (size_ == -1) { mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr()); } return size_; } void set_size(int64_t v) { if (size_ == -1) { size_ = v; } else if(size_ != v) { mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v); } } bool is_bound() const { return size_ != -1; } static mpy::obj create(mpy::object name, int64_t s = -1) { if (!DimType) { maybeInitializeGlobals(); } auto r = Dim::alloc(DimType); r->init(std::move(name), s); return r; } static PyTypeObject Type; const at::Tensor& range() { if (!range_.defined()) { range_ = at::arange(size()); } return range_; } const at::Tensor& batchtensor() { if (!batchtensor_.defined()) { batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); } return batchtensor_; } private: int64_t size_{-1}; at::Tensor range_; at::Tensor batchtensor_; }; struct DimEntry { // union of either a negative number indicating which dimension this is from the rhs, // or a pointer to a first-class dimension. // pointers do not have their highest bit set, so checking the number is negative tells us // that it is not a dim. bool is_positional() const { return data_ < 0; } bool is_none() const { return data_ == 0; } int64_t position() const { return data_; } mpy::hdl dim() const { Dim* result; std::memcpy(&result, &data_, sizeof(Dim*)); return mpy::hdl(result); } DimEntry() : data_(0) {} DimEntry(int64_t pos) : data_(pos) { AT_ASSERT(pos < 0); } DimEntry(mpy::hdl d) { std::memcpy(&data_, &d, sizeof(int64_t)); } bool operator==(const DimEntry& rhs) const { return data_ == rhs.data_; } private: int64_t data_; }; // Dim wrapper methods DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) { if (Dim::check(d)) { if (keepdim) { mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True"); } return Dim::unchecked_wrap(d); } else if (mpy::is_int(d)) { auto i = mpy::to_int(d); while (i >= 0) { i -= N; } return i; } else { return DimEntry(); } } int Dim_init(mpy::hdl self, PyObject *args, PyObject *kwds) { PY_BEGIN static constexpr const char* kwlist[] = {"name", "size", nullptr}; mpy::handle name; mpy::handle size = nullptr; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast(kwlist), &name, &size)) { return -1; } self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); return 0; PY_END(-1) } PyObject* Dim_repr(Dim* self) { PY_BEGIN mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string(""); return name.release(); PY_END(nullptr) } PyObject* Dim_getsize(Dim* self, void*) { PY_BEGIN return mpy::from_int(self->size()).release(); PY_END(nullptr) } int Dim_setsize(Dim* self, PyObject* size, void*) { PY_BEGIN self->set_size(mpy::to_int(size)); return 0; PY_END(-1) } PyObject* Dim_getis_bound(Dim* self, void*) { return PyBool_FromLong(self->is_bound()); } PyObject* Dim_getlevel(Dim* self, void*) { return PyLong_FromLong(self->level_); } PyObject* Dim_get_levels(Dim* self, void*) { mpy::tuple t(1); t.set(0, mpy::object::borrow(self->ptr())); return t.release(); } PyObject* Dim_get_has_device(Dim* self, void*) { Py_RETURN_FALSE; } PyObject* Dim_get_tensor(Dim* self, void*) { return THPVariable_Wrap(self->range()); } PyObject* Dim_get_batchtensor(Dim* self, void*) { return THPVariable_Wrap(self->batchtensor()); } PyGetSetDef Dim_getsetters[] = { {"size", (getter) Dim_getsize, (setter) Dim_setsize, "Dimension size", NULL}, {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL}, {"_level", (getter) Dim_getlevel, NULL, "_level", NULL}, {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL}, {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL}, {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL}, {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL}, {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL}, {NULL} /* Sentinel */ }; } PyTypeObject Dim::Type = { PyVarObject_HEAD_INIT(NULL, 0) "_C.Dim", /* tp_name */ sizeof(Dim), /* tp_basicsize */ 0, /* tp_itemsize */ Dim::dealloc_stub, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_as_async */ (reprfunc)Dim_repr, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ "Dim Object", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ 0, /* tp_methods */ 0, /* tp_members */ Dim_getsetters, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)(void*)static_cast,PyObject*,PyObject*)>(Dim_init), /* tp_init */ 0, /* tp_alloc */ Dim::new_stub, /* tp_new */ }; // class DimList ------------ struct DimList : public mpy::base { mpy::object name_; std::vector> dims_; static PyTypeObject Type; void init(mpy::object name) { name_ = std::move(name); } void set_dims(std::vector> dims) { bound_ = true; dims_ = std::move(dims); } bool is_bound() { return bound_; } void bind_len(int64_t size) { if (bound_) { int64_t b_size = dims_.size(); if (b_size != size) { mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size); } } else { bound_ = true; dims_.resize(size); for (Py_ssize_t i = 0; i < size; ++i) { dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); } } } int64_t size() const { if (!bound_) { mpy::raise_error(DimensionBindError(), "DimList not bound"); } return dims_.size(); } void set_bound(bool b) { bound_ = b; } private: bool bound_ = false; }; static int DimList_init(DimList *self, PyObject *args, PyObject *kwds); static PyObject* DimList_repr(DimList* self) { PY_BEGIN if (self->is_bound()) { size_t size = self->dims_.size(); mpy::tuple t(size); for(size_t i = 0; i < size; ++i) { t.set(i, self->dims_[i]); } return mpy::repr(t).release(); } else if(!mpy::is_none(self->name_)) { return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); } else { return mpy::unicode_from_string("").release(); } PY_END(nullptr) } static PyObject* DimList_bind(DimList *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN mpy::handle sizes; static const char * const _keywords[] = {"sizes", nullptr}; static _PyArg_Parser parser = {"O", _keywords, 0}; if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { return nullptr; } if (!mpy::is_sequence(sizes)) { mpy::raise_error(PyExc_ValueError, "expected a sequence"); } mpy::sequence_view seq = sizes; auto size = seq.size(); self->bind_len(size); for (Py_ssize_t i = 0; i < size; ++i) { self->dims_[i]->set_size(mpy::to_int(seq[i])); } Py_RETURN_NONE; PY_END(nullptr) } static PyObject* DimList_bind_len(DimList *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN int size; static const char * const _keywords[] = {"N", nullptr}; static _PyArg_Parser parser = {"i", _keywords, 0}; if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { return nullptr; } self->bind_len(size); Py_RETURN_NONE; PY_END(nullptr) } static PyMethodDef DimList_methods[] = { {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS}, {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS}, {NULL, NULL, 0, NULL} /* Sentinel */ }; static Py_ssize_t DimList_len(DimList* self) { PY_BEGIN return self->size(); PY_END(-1) } static PyObject * DimList_item(DimList* self, Py_ssize_t idx) { PY_BEGIN if (!self->is_bound()) { mpy::raise_error(DimensionBindError(), "DimList not bound"); } if (idx < 0 || (size_t) idx >= self->dims_.size()) { mpy::raise_error(PyExc_IndexError, "index out of bounds"); } mpy::object r = self->dims_[idx]; return r.release(); PY_END(nullptr) } PySequenceMethods DimList_seq { (lenfunc) DimList_len, //lenfunc sq_length; 0, //binaryfunc sq_concat; 0, //ssizeargfunc sq_repeat; (ssizeargfunc) DimList_item, //ssizeargfunc sq_item; 0, //void *was_sq_slice; 0, //ssizeobjargproc sq_ass_item; 0, //void *was_sq_ass_slice; 0, //objobjproc sq_contains; 0, //binaryfunc sq_inplace_concat; 0, //ssizeargfunc sq_inplace_repeat; }; static PyObject* DimList_getis_bound(DimList* self, void*) { return PyBool_FromLong(self->is_bound()); } static PyGetSetDef DimList_getsetters[] = { {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL}, {NULL} /* Sentinel */ }; static PyObject* DimList_subscript(DimList* self, mpy::handle idx) { PY_BEGIN if (mpy::is_int(idx)) { return DimList_item(self, mpy::to_int(idx)); } else if (mpy::is_slice(idx)) { if (!self->is_bound()) { mpy::raise_error(DimensionBindError(), "DimList not bound"); } mpy::slice_view s(idx, self->dims_.size()); mpy::tuple r(s.slicelength); for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { r.set(j++, self->dims_[i]); } return r.release(); } else { mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); return nullptr; } PY_END(nullptr) } PyMappingMethods DimList_mapping = { 0, //lenfunc mp_length; (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript; 0, //objobjargproc mp_ass_subscript; }; PyTypeObject DimList::Type = { PyVarObject_HEAD_INIT(NULL, 0) "_C.DimList", /* tp_name */ sizeof(DimList), /* tp_basicsize */ 0, /* tp_itemsize */ DimList::dealloc_stub, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_as_async */ (reprfunc)DimList_repr, /* tp_repr */ 0, /* tp_as_number */ &DimList_seq, /* tp_as_sequence */ &DimList_mapping, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ 0, /* tp_flags */ "DimList Object", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ DimList_methods, /* tp_methods */ 0, /* tp_members */ DimList_getsetters, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc) DimList_init, /* tp_init */ 0, /* tp_alloc */ DimList::new_stub, /* tp_new */ }; static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) { PY_BEGIN static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; mpy::handle len_or_dims = nullptr; PyObject* name = nullptr; if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { return -1; } self->init(mpy::object::borrow(name ? name : Py_None)); if (len_or_dims.ptr()) { if(mpy::is_int(len_or_dims)) { self->bind_len(mpy::to_int(len_or_dims)); } else if (mpy::is_sequence(len_or_dims)) { mpy::sequence_view s(len_or_dims); std::vector> dims; size_t size = s.size(); dims.reserve(size); for (size_t i = 0; i < size; ++i) { auto r = s[i]; if (mpy::is_int(r)) { dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), mpy::to_int(r))); } else { dims.emplace_back(Dim::wrap(r)); } } self->set_dims(std::move(dims)); } else { PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions"); return -1; } return 0; } return 0; PY_END(-1); } // Tensor ----------------------------- PyTypeObject* TensorType = nullptr; // the python wrapper type. mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise); namespace{ at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice levels_) { auto levels = Slice(); levels.extend(A, levels_); while (true) { int64_t min_real_index = -1; int64_t min_index = -1; int64_t min_value = INT_MAX; int64_t i = 0; int64_t r = 0; for (auto l : levels) { if (!l.is_none()) { if (!l.is_positional() && l.dim()->level_ < min_value) { min_value = l.dim()->level_; min_index = i; min_real_index = r; } ++i; } ++r; } if (min_index == -1) { return t; } auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); t = std::move(t2); levels[min_real_index] = DimEntry(); } } struct DelayedOperator { DelayedOperator(mpy::object o, mpy::vector_args a) : orig(std::move(o)), args(a) { auto all = a.size(); // this will outlive the call so // take ownership of temporaries // in vector args auto buf = new mpy::handle[all]; memcpy(buf, args.args, sizeof(mpy::handle)*all); args.args = buf; for (auto i : args.enumerate_all()) { Py_INCREF(args.args[i].ptr()); } Py_XINCREF(args.kwnames.ptr()); } ~DelayedOperator() { for (auto i : args.enumerate_all()) { Py_DECREF(args[i].ptr()); } if (args.has_keywords()) { Py_XDECREF(args.kwnames.ptr()); } delete [] args.args; } mpy::object orig; mpy::vector_args args; }; void free_levels_dims(Slice levels) { for(auto e : levels) { if (!e.is_positional()) { mpy::object::steal(e.dim()); } } } } struct Tensor : public mpy::base { private: at::Tensor tensor_; at::Tensor batchtensor_; OwnedSlice levels_; bool has_device_; std::unique_ptr delayed_; public: at::Tensor& tensor(Arena& A) { if (C10_UNLIKELY(!tensor_.defined())) { AT_ASSERT(delayed_); auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); tensor_ = t->tensor(A); delayed_.reset(); // don't force creation of batch tensor if it wasn't alreay provided. batchtensor_ = t->batchtensor_; AT_ASSERT(levels() == t->levels()); } return tensor_; } at::Tensor& batchtensor(Arena& A) { if (C10_UNLIKELY(!batchtensor_.defined())) { batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); } return batchtensor_; } Slice levels() { return levels_.slice(); } bool has_device() { return has_device_; } DelayedOperator* delayed() { return delayed_.get(); } static PyTypeObject Type; static bool check_exact(mpy::handle v) { return Py_TYPE(v.ptr()) == TensorType; } static mpy::obj create() { if (!TensorType) { TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr(); } return Tensor::alloc(TensorType); } void capture_levels(Slice levels) { // grab ownership of the dims inside levels for (auto l : levels) { if (!l.is_positional()) { mpy::object::borrow(l.dim()).release(); } } levels_.set(levels, free_levels_dims); } static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device); static mpy::obj create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device); friend struct EnableAllLayers; }; namespace{ // version in header does a unnecessary refcount +/- at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) { if (at::functorch::isBatchedTensor(tensor)) { return static_cast(tensor.unsafeGetTensorImpl()); } return nullptr; } TensorRef unchecked_tensor_from(mpy::handle p) { auto v = (THPVariable*) p.ptr(); return TensorRef(*v->cdata); } static int64_t ndim_of_levels(Slice levels) { int64_t r = 0; for (auto l : levels) { if (l.is_positional()) { ++r; } } return r; } struct TensorInfo { TensorRef tensor; Slice levels; bool has_device; TensorRef batchedtensor; int64_t ndim() const { return ndim_of_levels(levels); } operator bool() const { return tensor; } static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) { if (Tensor::check_exact(h)) { auto t = Tensor::unchecked_wrap(h); return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()}; } else if (Dim::check_exact(h)) { auto d = Dim::unchecked_wrap(h); return TensorInfo {d->range(), Slice(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()}; } else if (THPVariable_Check(h.ptr())) { TensorRef t = unchecked_tensor_from(h); Slice levels; for (auto i : irange(-t->dim(), 0)) { levels.append(A, i); } return TensorInfo {t, levels, true, t}; } else { if (ensure_present) { mpy::raise_error(PyExc_ValueError, "expected a tensor object"); } return TensorInfo {}; } } }; static PyObject* py_Tensor_from_positional(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) #undef ARGS if (!THPVariable_Check(tensor.ptr())) { mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); } Slice levels; mpy::sequence_view sq(py_levels); for (auto i : sq.enumerate()) { mpy::object v = sq[i]; if (mpy::is_int(v)) { auto vi = mpy::to_int(v); levels.append(A, vi); } else { auto dim = Dim::wrap(std::move(v)); mpy::hdl hdim = dim; levels.append(A, hdim); } } return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release(); PY_END(nullptr) } } mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device) { size_t seen_dims = 0; int last = 0; //auto sz = tensor.sizes(); for (auto i : levels.enumerate()) { auto l = levels[i]; if (l.is_positional()) { AT_ASSERT(last == 0 || last + 1 == l.position()); last = l.position(); } else { mpy::object::borrow(l.dim()).release(); //AT_ASSERT(sz[i] == l.dim()->size()); ++seen_dims; } } AT_ASSERT(last == 0 || last == -1); if (!seen_dims) { return mpy::object::steal(THPVariable_Wrap(std::move(tensor))); } mpy::obj self = Tensor::create(); self->tensor_ = std::move(tensor); AT_ASSERT(self->tensor_.dim() == levels.size()); self->levels_.set(levels, free_levels_dims); self->has_device_ = has_device; mpy::object r = std::move(self); return r; } mpy::obj Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device) { mpy::obj self = Tensor::create(); self->capture_levels(levels); self->has_device_ = has_device; self->delayed_ = std::make_unique(std::move(op), args); return self; } namespace{ mpy::list slice_to_list(Slice h) { mpy::list lst(h.size()); for (auto i : h.enumerate()) { lst.set(i, mpy::object::borrow(h[i])); } return lst; } mpy::tuple slice_to_tuple(Slice h) { mpy::tuple lst(h.size()); for (auto i : h.enumerate()) { lst.set(i, mpy::object::borrow(h[i])); } return lst; } enum UType { U_ELEM, U_TUPLE_LIKE, U_DICT, }; struct Unflatten { mpy::object operator()(Slice& elements) { mpy::object r; switch (type) { case U_ELEM: { r = mpy::object::borrow(elements[0]); elements = elements.slice(1); } break; case U_TUPLE_LIKE: { mpy::tuple tup(children.size()); for (auto i : children.enumerate()) { tup.set(i, children[i](elements)); } r = obj.call(tup); } break; case U_DICT: { r = mpy::object::checked_steal(PyDict_New()); mpy::dict_view rv(r); mpy::dict_view d(obj); Py_ssize_t pos = 0; mpy::handle k, v; for (int i = 0; d.next(&pos, &k, &v); ++i) { rv.set(k, children[i](elements)); } } break; } return r; } UType type; mpy::handle obj; Slice children; }; Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice& flat_elements) { Slice c; UType utype; mpy::handle obj; if (mpy::list_view::check(agg)) { obj = agg.type(); utype = U_TUPLE_LIKE; mpy::list_view l(agg); for (auto i : l.enumerate()) { c.append(A, tree_flatten(A, l[i], flat_elements)); } } else if (mpy::tuple_view::check(agg)) { obj = agg.type(); utype = U_TUPLE_LIKE; // includes named tuples mpy::tuple_view l(agg); for (auto i : l.enumerate()) { c.append(A, tree_flatten(A, l[i], flat_elements)); } } else if (mpy::dict_view::check(agg)) { utype = U_DICT; mpy::dict_view d(agg); obj = agg; Py_ssize_t pos = 0; mpy::handle k, v; while (d.next(&pos, &k, &v)) { c.append(A, tree_flatten(A, v, flat_elements)); } } else { utype = U_ELEM; flat_elements.append(A, agg); } return Unflatten {utype, obj, c}; } struct UnflattenVectorArgs { mpy::vector_args operator()(Arena& A, Slice& elements) { if (!had_nested) { auto args = elements.begin(); elements = Slice(); return mpy::vector_args(args, nargs, kwnames); } Slice args; for (auto u : children) { args.append(A, A.autorelease(u(elements))); } return mpy::vector_args(args.begin(), nargs, kwnames); } Slice children; Py_ssize_t nargs; mpy::handle kwnames; bool had_nested; }; UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice& flat_elements) { UnflattenVectorArgs r; r.kwnames = args.kwnames; r.nargs = args.nargs; r.had_nested = false; auto N = args.size(); for(auto i : irange(N)) { auto typ = Py_TYPE(args[i].ptr()); // fast checks that this thing isn't something that is nested. bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType; if (!is_element) { flat_elements.extend(A, args.args, args.args + i); for (auto j : irange(i)) { (void)j; r.children.append(A, Unflatten {U_ELEM}); } for (auto j : irange(i, N)) { r.children.append(A, tree_flatten(A, args[j], flat_elements)); if (r.children.back().type != U_ELEM) { r.had_nested = true; } } return r; } } flat_elements.extend(A, args.args, args.args + N); return r; } struct UnflattenArena { Arena A; Unflatten unflatten; }; PyObject* py_unflatten(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN #define ARGS(_) _(mpy::handle, ns) MPY_PARSE_ARGS_KWNAMES("O", ARGS) #undef ARGS mpy::sequence_view sv(ns); // because we do not have a autorelase pool yet... Arena A; Slice slice; mpy::handle Tuple = (PyObject*) &PyTuple_Type; auto inputs = Tuple.call(ns); mpy::tuple_view tv(inputs); for (auto i : tv.enumerate()) { slice.append(A, tv[i]); } auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena"); auto r = AA->unflatten(slice).release(); AT_ASSERT(r != nullptr); return r; PY_END(nullptr) } PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS}; void free_unflatten_arena(PyObject * pc) { delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena"); } PyObject* py_tree_flatten(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN #define ARGS(_) _(mpy::handle, tree) MPY_PARSE_ARGS_KWNAMES("O", ARGS) #undef ARGS auto A = new UnflattenArena; Slice elements; A->unflatten = tree_flatten(A->A, tree, elements); auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena)); auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release())); mpy::tuple r(2); r.set(0, slice_to_list(elements)); r.set(1, std::move(unflatten)); return r.release(); PY_END(nullptr) } mpy::object tree_map(Arena& A, const std::function& fn, mpy::handle agg) { Slice elements; auto unflatten = tree_flatten(A, agg, elements); for (auto i : elements.enumerate()) { elements[i] = fn(elements[i]); } return unflatten(elements); } // prereq: isinstance(h, _Tensor) int64_t _Tensor_ndim(mpy::handle h) { if (Tensor::check(h)) { int64_t r = 0; for (auto l : Tensor::unchecked_wrap(h)->levels()) { if (l.is_positional()) { ++r; } } return r; } // Dim or DelayedMulTensor return 0; } mpy::handle handle_from_tensor(Arena& A, TensorRef t) { // fast case: tensor is live in python std::optional mb_obj = t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { return *mb_obj; } return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); } } struct EnableAllLayers { EnableAllLayers(Arena& A, Slice levels) { std::vector> layers; layers.reserve(levels.size()); for (auto l : levels) { if (!l.is_positional()) { auto d = l.dim(); levels_to_dim_.append(A, d); } } std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl lhs, mpy::hdl rhs) { return lhs->level_ < rhs->level_;}); for (auto i : levels_to_dim_.enumerate()) { auto batch_size = levels_to_dim_[i]->size(); auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); if (i == 0) { levels_start_ = level; } } } ~EnableAllLayers() { auto to_remove = levels_start_ + levels_to_dim_.size() - 1; for (auto i : levels_to_dim_.enumerate()) { AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i); } } mpy::obj from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) { Slice levels; for (auto i : irange(-batchedtensor.dim(), 0)) { levels.append(A, i); } TensorRef tensor; at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor); while(true) { auto level = impl->level(); AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size()); mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); levels.insert(A, impl->bdim(), dim); at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value()); if (!nimpl) { tensor = impl->value(); break; } impl = nimpl; } mpy::obj self = Tensor::create(); // grab ownership of the tensors self->tensor_ = *tensor; self->batchtensor_ = std::move(batchedtensor); self->has_device_ = has_device; self->capture_levels(levels); return self; } void inplace_update_layers(TensorRef batchtensor, Slice levels) { // XXX - requires a patch to functorch to att set_level auto impl = maybeGetBatchedImpl(*batchtensor); for (auto i : levels_to_dim_.reversed_enumerate()) { if (!impl) { break; } if (levels.contains(levels_to_dim_[i])) { impl->_unsafe_set_level(levels_start_ + i); impl = maybeGetBatchedImpl(impl->value()); } } } private: int64_t levels_start_{}; Slice> levels_to_dim_; }; namespace{ TensorRef _match_levels(Arena& A, TensorRef v, Slice from_levels, Slice to_levels, bool drop_levels=false) { if (from_levels == to_levels) { return v; } // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0. at::IntArrayRef sz = v->sizes(); at::IntArrayRef sd = v->strides(); AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); Slice nsz; Slice nsd; for (auto l : to_levels) { auto oidx = from_levels.index(l); if (!oidx) { nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); nsd.append(A, 0); } else { auto idx = *oidx; nsz.append(A, sz[idx]); nsd.append(A, sd[idx]); } } return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset())); } } mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { if (!pointwise_optimize) { is_pointwise = false; } // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n"; Slice> all_dims; Slice flat_args; auto unflatten_args = tree_flatten(A, args, flat_args); TensorRef device_holding_tensor; Slice infos; Slice result_levels; for (auto f : flat_args) { infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); if (infos.back()) { TensorInfo& info = infos.back(); AT_ASSERT(is_pointwise || info.batchedtensor); if (!device_holding_tensor && info.has_device) { device_holding_tensor = infos.back().tensor; } for (auto l : info.levels) { if (!result_levels.contains(l)) { result_levels.append(A, l); } } } } if (is_pointwise) { for (auto i : flat_args.enumerate()) { if (infos[i]) { TensorRef tensor = infos[i].tensor; if (device_holding_tensor && !infos[i].has_device) { tensor = A.autorelease(tensor->to(device_holding_tensor->device())); } auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); flat_args[i] = handle_from_tensor(A, std::move(ml)); } } Slice flat_it = flat_args; mpy::vector_args uargs = unflatten_args(A, flat_it); mpy::object result = orig.call_vector(uargs); // fast wrap for normal case where operator just returns a tensor. if (THPVariable_Check(result.ptr())) { return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor); } auto wrap = [&](mpy::handle h) { if (THPVariable_Check(h.ptr())){ return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor)); } return h; }; return tree_map(A, wrap, result); } else { // std::cout << orig << " calling functorch...\n"; // std::cout << "rl: " << result_levels << "\n"; EnableAllLayers guard(A, result_levels); for (auto i : flat_args.enumerate()) { if (infos[i]) { TensorRef batched = infos[i].batchedtensor; if (device_holding_tensor && !infos[i].has_device) { batched = A.autorelease(batched->to(device_holding_tensor->device())); } guard.inplace_update_layers(batched, infos[i].levels); flat_args[i] = handle_from_tensor(A, batched); } } Slice flat_it = flat_args; mpy::vector_args uargs = unflatten_args(A, flat_it); AT_ASSERT(flat_it.size() == 0); mpy::object result = orig.call_vector(uargs); auto wrap = [&](mpy::handle h) { if (THPVariable_Check(h.ptr())) { return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); } return h; }; if (THPVariable_Check(result.ptr())) { return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor); } return tree_map(A, wrap, result); } } namespace{ mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { if (orig == torch_Tensor___mul__) { AT_ASSERT(args.nargs == 2 && !args.has_keywords()); auto lhs = args[0]; auto rhs = args[1]; if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { bool has_device = false; Slice levels; for (auto i : args.enumerate_positional()) { auto t = TensorInfo::create(A, args[i], false); // something like a mask * rhs, which matrix multiplies don't correctly promote if (!t.tensor->is_floating_point()) { return run_torch_function(A, orig, args, is_pointwise); } has_device = has_device || t.has_device; for (auto l : t.levels) { if (!levels.contains(l)) { levels.append(A, l); } } } // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device); } } return run_torch_function(A, orig, args, is_pointwise); } mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) { auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0); auto pos_n = PyTuple_GET_SIZE(args.ptr()); if (!kwargs.ptr()) { return mpy::vector_args(pos_args, pos_n, nullptr); } Slice all_args; Slice kwnames; all_args.extend(A, pos_args, pos_args + pos_n); mpy::dict_view dv(kwargs); Py_ssize_t pos = 0; mpy::handle key, value; while (dv.next(&pos, &key, &value)) { all_args.append(A, value); kwnames.append(A, key); } return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); } PyObject* py___torch_function__(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN maybeInitializeGlobals(); AT_ASSERT(nargs == 4 || nargs == 5); auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); bool is_pointwise = pointwise.contains(args[1]); return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); PY_END(nullptr) } mpy::object levels_to_tuple(Slice slice) { mpy::tuple t(slice.size()); for (auto i : slice.enumerate()) { t.set(i, slice[i].is_positional() ? mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim())); } mpy::object r = std::move(t); return r; } PyObject* Tensor_ndim(Tensor* self, void*) { Py_ssize_t i = 0; for (auto l : self->levels()) { if (l.is_positional()) { ++i; } } return mpy::from_int(i).release(); } PyGetSetDef Tensor_getsetters[] = { {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL}, {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* { Arena A; return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL}, {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* { Arena A; return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL}, {"_levels", (getter) [](PyObject* self, void*) -> PyObject* { PY_BEGIN return levels_to_tuple(((Tensor*)self)->levels()).release(); PY_END(nullptr) }}, {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL}, {NULL} /* Sentinel */ }; PyMethodDef Tensor_methods[] = { {NULL, NULL, 0, NULL} /* Sentinel */ }; } PyTypeObject Tensor::Type = { PyVarObject_HEAD_INIT(NULL, 0) "_C.Tensor", /* tp_name */ sizeof(Tensor), /* tp_basicsize */ 0, /* tp_itemsize */ Tensor::dealloc_stub, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_as_async */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */ "Tensor Object", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ Tensor_methods, /* tp_methods */ 0, /* tp_members */ Tensor_getsetters, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ 0, /* tp_init */ 0, /* tp_alloc */ Tensor::new_stub, /* tp_new */ }; // dim() -------------------- static bool relevant_op(_Py_CODEUNIT c) { switch(c) { case STORE_NAME: case STORE_GLOBAL: case STORE_FAST: case STORE_DEREF: return true; default: return false; } } static mpy::object create_dim(mpy::object name, mpy::handle size) { auto d = Dim::create(std::move(name)); if (!mpy::is_none(size)) { d->set_size(mpy::to_int(size)); } return std::move(d); } static mpy::object create_dimlist(mpy::object name, mpy::handle size) { auto d = DimList::create(std::move(name)); if (!mpy::is_none(size)) { if (mpy::is_int(size)) { d->bind_len(mpy::to_int(size)); } else { mpy::sequence_view s(size); d->bind_len(s.size()); for (auto i : irange(d->size())) { d->dims_[i]->set_size(mpy::to_int(s[i])); } } } return std::move(d); } // Python wrappers that make new reflection primitives available for older runtimes #if !(IS_PYTHON_3_11_PLUS) #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) #endif namespace{ struct PyInstDecoder { PyInstDecoder(PyCodeObject* code_object, int lasti) : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols // See https://github.com/pytorch/pytorch/issues/93854 void next() { #if IS_PYTHON_3_11_PLUS offset_ += _PyOpcode_Caches[opcode()]; #endif offset_ += 1; } int opcode() { auto r = _Py_OPCODE(code_[offset_]); #if IS_PYTHON_3_11_PLUS r = _PyOpcode_Deopt[r]; #endif return r; } int oparg() { return _Py_OPARG(code_[offset_]); } mpy::object name() { mpy::object names; switch(opcode()) { case STORE_NAME: case STORE_GLOBAL: names = mpy::object::borrow(code_object_->co_names); break; case STORE_FAST: names = mpy::object::steal(PyCode_GetVarnames(code_object_)); break; case STORE_DEREF: names = mpy::object::steal(PyCode_GetCellvars(code_object_)); break; default: return mpy::object(); } return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); } private: PyCodeObject* code_object_; _Py_CODEUNIT* code_; int offset_; }; template static PyObject* _dims(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN Py_ssize_t specified_ndims = -1; Py_ssize_t found_ndims = 0; Py_ssize_t sizes = -1; mpy::handle n = Py_None; mpy::handle py_sizes = Py_None; if (nargs || kwnames) { mpy::vector_args va(args, nargs, kwnames); va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); if (!mpy::is_none(py_sizes)) { sizes = mpy::sequence_view(py_sizes).size(); specified_ndims = sizes; } if (!mpy::is_none(n)) { specified_ndims = mpy::to_int(n); } } PyThreadState* state = PyThreadState_GET(); auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); auto lasti = PyFrame_GetLasti(f.ptr()); auto decoder = PyInstDecoder(c.ptr(), lasti); #if IS_PYTHON_3_11_PLUS // When py3.11 adapts bytecode lasti points to the precall // rather than the call instruction after it if (decoder.opcode() == PRECALL) { decoder.next(); } #endif decoder.next(); if (relevant_op(decoder.opcode())) { found_ndims = 1; } else if (decoder.opcode() == UNPACK_SEQUENCE) { found_ndims = decoder.oparg(); decoder.next(); } if (specified_ndims == -1) { if (found_ndims == 0) { mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified"); } specified_ndims = found_ndims; } if (found_ndims != specified_ndims) { found_ndims = 0; // avoid taking the wrong names for dimensions } auto genobject = [&](int i) -> mpy::object { mpy::object name; if (i < found_ndims) { name = decoder.name(); } if (!name.ptr()) { name = mpy::unicode_from_format("d%d", i); found_ndims = 0; // once we fail at finding a name, we can find any more } else { decoder.next(); } return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); }; if (sizes != -1 && sizes != specified_ndims) { mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes)); } if (specified_ndims == 1) { return genobject(0).release(); } mpy::tuple result(specified_ndims); for (int i = 0; i < specified_ndims; ++i) { result.set(i, genobject(i)); } return result.release(); PY_END(nullptr) } struct DotPart { Slice dims; size_t total_size = 1; void append(Arena& A, mpy::hdl d) { total_size *= d->size(); dims.append(A, d); } }; template static at::ArrayRef as_array_ref(Slice t) { return at::ArrayRef(t.begin(), t.end()); } static TensorRef dot_prepare(Arena& A, std::initializer_list parts, const TensorInfo& t) { Slice new_levels; bool needs_reshape = false; for (auto p : parts) { if (p.dims.size() != 1) { needs_reshape = true; } new_levels.extend(A, p.dims); } auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); if (!needs_reshape) { return r; } Slice view; for (auto p : parts) { view.append(A, p.total_size); } return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); } static mpy::object dot_finish(Arena& A, std::initializer_list parts, at::Tensor r) { Slice result_levels; bool needs_reshape = false; for (auto p : parts) { if (p.dims.size() != 1) { needs_reshape = true; } result_levels.extend(A, p.dims); } if (needs_reshape) { Slice new_size; for (auto l : result_levels) { new_size.append(A, l.dim()->size()); } r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); } return Tensor::from_positional(A, std::move(r), result_levels, true); } static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice sum) { auto lhs_strides = lhs.tensor->strides(); auto rhs_strides = rhs.tensor->strides(); DotPart lro_dims; DotPart lo_dims; DotPart ro_dims; DotPart lr_dims; auto insert_dim = [&] (mpy::hdl d, std::optional lhs_idx, std::optional rhs_idx) { bool reduced = sum.contains(d); int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; if (reduced) { // lr lr_dims.append(A, d); } else { if ((lhs_stride == 0) == (rhs_stride == 0)) { // lro lro_dims.append(A, d); } else if (lhs_stride != 0) { // lo lo_dims.append(A, d); } else { AT_ASSERT(rhs_stride != 0); ro_dims.append(A, d); } } }; auto rhs_seen = A.allocate(rhs.levels.size()); std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); for (auto i : lhs.levels.enumerate()) { auto d = lhs.levels[i]; auto rhs_idx = rhs.levels.index(d); if (rhs_idx) { rhs_seen[*rhs_idx] = true; } insert_dim(d.dim(), i, rhs_idx); } for (auto i : rhs.levels.enumerate()) { if (rhs_seen[i]) { continue; } auto d = rhs.levels[i]; insert_dim(d.dim(), std::nullopt, i); } if (lr_dims.dims.size() != sum.size()) { for (auto & d : sum) { if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { mpy::raise_error(DimensionBindError(), "summing over non-existant dimension %S", d.dim().ptr()); } } } // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n"; // no batch, just call mm if (lro_dims.dims.size() != 0) { auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); } else { auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); } } static PyObject* test_c(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN Arena A; Slice s(A, 3, 4, 5); AT_ASSERT(s.size() == 3 && s.capacity() == 8); AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); s.append(A, 6); AT_ASSERT(s[3] == 6); for(int i : irange(10)) { s.append(A, i); } AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); Slice s2(A, -1, -2, -3); AT_ASSERT(s2[1] == -2 && s[0] == 3); auto ss = s.slice(1,2); AT_ASSERT(ss.size() == 1); AT_ASSERT(ss[0] == 4); AT_ASSERT(ss.capacity() == 1); ss.append(A, -4); AT_ASSERT(ss.size() == 2 && ss[1] == -4); ss[0] = 3; AT_ASSERT(s[1] == 4); s.insert(A, s.slice(1, 4), ss); AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); auto sz = s.size(); s.insert(A, s.slice(1, 1), 4); AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); Slice d(A, 0, 1, 2, 3, 4); Slice b(A, 0, 1, 2, 3, 4); b.insert(A, b.slice(1,1), d); AT_ASSERT(b.size() == 10); AT_ASSERT(b[1] == 0); AT_ASSERT(b[5] == 4); AT_ASSERT(b.back() == 4); Py_RETURN_NONE; PY_END(nullptr); } static PyObject* order(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN if (kwnames) { mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames); } AT_ASSERT(nargs-- > 0); Slice orig_levels; Slice levels; TensorRef data; mpy::handle self = args++[0]; bool has_device; if (Tensor::check_exact(self)) { auto t = Tensor::unchecked_wrap(self); orig_levels = t->levels(); data = t->tensor(A); has_device = t->has_device(); } else { auto d = Dim::unchecked_wrap(self); orig_levels.append(A, d); data = d->range(); has_device = false; } Slice flat_positional_dims; Slice> to_flatten; levels.extend(A, orig_levels); int orig_ndim = ndim_of_levels(levels); auto append = [&](DimEntry d) { auto midx = levels.index(d); if (!midx) { if (d.is_positional()) { mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim)); } else { mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr()); } } levels[*midx] = DimEntry(); flat_positional_dims.append(A, d); }; int n_new_positional = 0; for (auto i :irange(nargs)) { mpy::handle arg = args[i]; DimEntry entry = _wrap_dim(arg, orig_ndim, false); if (!entry.is_none()) { append(entry); ++n_new_positional; } else if (DimList::check(arg)) { auto dl = DimList::unchecked_wrap(arg); for (mpy::obj & d : dl->dims_) { append(mpy::hdl(d)); ++n_new_positional; } } else { ++n_new_positional; if (!mpy::is_sequence(arg)) { mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); } mpy::sequence_view sq(arg); auto N = sq.size(); to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); for (auto j : irange(N)) { DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); if (e.is_none()) { mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); } append(e); } } } int ndim = 0; int insert_point = -1; Slice new_levels; for (auto l : levels) { if (l.is_none()) { continue; } if (l.is_positional()) { ndim++; if (insert_point == -1) { insert_point = new_levels.size(); new_levels.extend(A, flat_positional_dims); } } new_levels.append(A, l); } if (insert_point == -1) { insert_point = new_levels.size(); new_levels.extend(A, flat_positional_dims); } at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); if (to_flatten.size()) { Slice view; auto sz = ndata.sizes(); // before the new positional dims for (auto i : irange(0, insert_point)) { view.append(A, sz[i]); } int i = 0; for (auto to_flat : to_flatten) { for (;i < to_flat.first; ++i) { view.append(A, sz[insert_point + i]); } int64_t new_size = 1; int last = i + to_flat.second; for (; i < last; ++i) { new_size *= sz[insert_point + i]; } view.append(A, new_size); } for (; i < flat_positional_dims.size(); ++i) { view.append(A, sz[insert_point + i]); } // after the new positional dims for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) { view.append(A, sz[i]); } // we shorted the number of dimension, so remove them from new levels // we will renumber them later auto n_to_remove = flat_positional_dims.size() - n_new_positional; new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice()); ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); } // renumber the positional dimension int seen = 0; for (auto i : new_levels.reversed_enumerate()) { if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) { new_levels[i] = --seen; } } return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release(); PY_END(nullptr) } static PyObject* expand(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN AT_ASSERT(nargs-- > 0); auto info = TensorInfo::create(A, args++[0], false); for (auto i : irange(nargs)) { if (!Dim::check(args[i])) { maybeInitializeGlobals(); mpy::vector_args vargs(args - 1, nargs + 1, kwnames); if (THPVariable_Check(args[-1])) { return torch_Tensor_expand.call_vector(vargs).release(); } else { return __torch_function__(A, torch_Tensor_expand, vargs, false).release(); } } } const at::Tensor& data = *info.tensor; auto levels = info.levels; Slice new_levels; Slice sz; Slice sd; for (auto i : irange(nargs)) { auto d = Dim::unchecked_wrap(args[i]); if (levels.contains(d) || new_levels.contains(d)) { mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr()); } new_levels.append(A, d); sz.append(A, d->size()); sd.append(A, 0); } new_levels.extend(A, levels); at::IntArrayRef osz = data.sizes(); at::IntArrayRef osd = data.strides(); sz.extend(A, osz.begin(), osz.end()); sd.extend(A, osd.begin(), osd.end()); at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset()); return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release(); PY_END(nullptr) } static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd, Slice> dims, Slice& nsz, Slice& nsd) { int64_t rhs_prod = 1; for (auto i : dims.enumerate()) { if (!dims[i]->is_bound()) { for (auto j : irange(i + 1, dims.size())) { if (!dims[j]->is_bound()) { mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr()); } rhs_prod *= dims[j]->size(); } if (sz % rhs_prod != 0) { mpy::tuple tup(dims.size()); for (auto j : dims.enumerate()) { tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?")); } mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr()); } int64_t inferred_size = sz / rhs_prod; dims[i]->set_size(inferred_size); rhs_prod = sz; break; } rhs_prod *= dims[i]->size(); } if (rhs_prod != sz) { mpy::tuple tup(dims.size()); for (auto j : dims.enumerate()) { tup.set(j, mpy::object::borrow(dims[j])); } mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr()); } auto new_strides = A.allocate(dims.size()); auto prev_stride = sd; for (auto i : dims.reversed_enumerate()) { new_strides[i] = prev_stride; prev_stride = dims[i]->size()*prev_stride; } for (auto i : dims.enumerate()) { nsd.append(A, new_strides[i]); nsz.append(A, dims[i]->size()); } } static bool has_dims(mpy::handle d) { return Dim::check_exact(d) || Tensor::check_exact(d); } struct IndexingInfo { bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling bool advanced_indexing; // requires actual lookup TensorRef self; Slice flat_inputs; Slice result_levels; bool has_device; }; } IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none); namespace{ Slice as_slice(mpy::tuple_view tv) { PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0); return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); } Slice as_slice(mpy::list_view tv) { PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0); return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); } bool maybe_dimpack(Slice& elements, mpy::handle s, bool check_first=true) { // can we avoid rechecking? if (mpy::list_view::check(s)) { mpy::list_view tv(s); if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { elements = as_slice(tv); return true; } } // can we avoid rechecking? if (mpy::tuple_view::check(s)) { mpy::tuple_view tv(s); if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { elements = as_slice(tv); return true; } } return false; }; bool is_dimpack(mpy::handle s) { Slice e; return maybe_dimpack(e, s); } mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { at::Tensor rtensor; if (iinfo.advanced_indexing) { auto self_hdl = handle_from_tensor(A, iinfo.self); auto tup = slice_to_tuple(iinfo.flat_inputs); // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n"; auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr())); rtensor = THPVariable_Unpack(pytensor.ptr()); } else { // std::cout << "skipping original getindex\n"; rtensor = *iinfo.self; } // std::cout << "returning (from_positional)\n"; return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); } mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) { maybeInitializeGlobals(); Slice dims_list; Slice indices_list; // we allow for matching single dims to multiple dims, // so we first have to normalize everything into the case where there is a list on lhs and the rhs bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices); if (lhs_list && rhs_list) { mpy::sequence_view dv(dims); mpy::sequence_view ind(indices); Py_ssize_t N = dv.size(); if (N != ind.size()) { mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size())); } for (auto i : irange(N)) { dims_list.append(A, A.autorelease(dv[i])); indices_list.append(A, A.autorelease(ind[i])); } } else { dims_list.append(A, dims); indices_list.append(A, indices); } // dims being indexed can be grouped together into a single index space, and we have to // flatten them int a single dimension before we can index them... auto self_info = TensorInfo::create(A, self, false); auto ndim = self_info.ndim(); Slice new_levels; Slice to_flatten; Slice dims_list_flat; auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { auto d = _wrap_dim(s, ndim, false); if (d.is_none()) { mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr()); } return d; }; auto dim_not_present = [&](DimEntry d) { if (d.is_positional()) { mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim); } else { mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); } }; for (auto i : dims_list.enumerate()) { Slice m; if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { if (m.size() == 0) { // plausible semantics work for this to have 0 elements (e.g. the index will always be 0) dims_list_flat.append(A, DimEntry()); // value is just dropped } auto first = parse_dim_entry(m[0]); dims_list_flat.append(A, first); if (m.size() == 1) { continue; } if (to_flatten.size() == 0) { new_levels.extend(A, self_info.levels); } Slice rest; for (auto i : irange(1, m.size())) { auto d = parse_dim_entry(m[i]); if (!new_levels.remove(A, d)) { dim_not_present(d); } rest.append(A, d); } auto first_idx = new_levels.index(first); if (!first_idx) { dim_not_present(first); } new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); to_flatten.extend(A, rest); } else { dims_list_flat.append(A, parse_dim_entry(dims_list[i])); } } if (to_flatten.size() > 0) { TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels); at::IntArrayRef sizes = rearranged->sizes(); Slice new_sizes; Slice reshape_levels; for (auto i : new_levels.enumerate()) { if (to_flatten.contains(new_levels[i])) { new_sizes.back() *= sizes[i]; } else { new_sizes.append(A, sizes[i]); reshape_levels.append(A, new_levels[i]); } } self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op // we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group } bool has_dimpacks = false; for (auto idx : indices_list) { if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { has_dimpacks = true; break; } } IndexingInfo info = getsetitem_flat(A, self_info, Slice(), dims_list_flat, indices_list, has_dimpacks); return invoke_getitem(A, info); } // true -- the indices were flattend out of a tuple, list or sequence... Slice slice_from_sequence(Arena& A, mpy::handle value) { if (mpy::tuple_view::check(value)) { return as_slice(mpy::tuple_view(value)); } else if (mpy::list_view::check(value)) { return as_slice(mpy::list_view(value)); } else { mpy::sequence_view sv(value); Slice r; for (auto i : sv.enumerate()) { r.append(A, A.autorelease(sv[i])); } return r; } } bool extractIndices(Arena& A, mpy::handle index, Slice& indices) { if (mpy::tuple_view::check(index)) { indices.extend(A, as_slice(mpy::tuple_view(index))); return true; } else if (THPVariable_Check(index.ptr())) { indices.append(A, index); return false; } else if (!mpy::is_sequence(index)) { indices.append(A, index); return false; } // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors.. mpy::sequence_view sv(index); if (sv.size() >= 32) { indices.extend(A, slice_from_sequence(A, index)); return true; } for (auto i : sv.enumerate()) { mpy::handle item; try { item = sv[i]; } catch (mpy::exception_set & e) { PyErr_Clear(); indices.append(A, index); return false; } if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) { indices.extend(A, slice_from_sequence(A, index)); return true; } } indices.append(A, index); return false; } IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) { bool can_call_original_getitem = !tensors_have_dims; Slice input; if (has_dims(index)) { input.append(A, index); } else { bool is_sequence = extractIndices(A, index, input); // nothing about first class dims here, fallback to getitem if (can_call_original_getitem && !is_sequence) { return { true }; } } int64_t dims_indexed = 0; int64_t expanding_object = -1; DimList* unbound_dim_list = nullptr; auto check_expanding = [&](int64_t i) { if (expanding_object != -1) { mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i); } expanding_object = i; }; Slice dimlists; // calculate how many dimensioned have been indexed in order to compute the size of ... // or expand a potentially unbound dimension list. bool has_dimpacks_or_none = false; for (auto i : input.enumerate()) { mpy::handle s = input[i]; if (Dim::check_exact(s) || Tensor::check_exact(s)) { can_call_original_getitem = false; ++dims_indexed; } else if (s.ptr() == Py_Ellipsis) { check_expanding(i); } else if (DimList::check(s)) { can_call_original_getitem = false; auto dl = DimList::unchecked_wrap(s); if (!dl->is_bound()) { check_expanding(i); unbound_dim_list = dl.ptr(); } else { dims_indexed += dl->dims_.size(); } dimlists.append(A, i); } else if (mpy::is_none(s)) { has_dimpacks_or_none = true; } else if (is_dimpack(s)) { can_call_original_getitem = false; has_dimpacks_or_none = true; ++dims_indexed; } else { ++dims_indexed; } } // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem. if (can_call_original_getitem) { return {true}; } // std::cout << "__getitem__ " << self << " " << index << "\n"; TensorInfo self_info = TensorInfo::create(A, self, false, true); auto ndim = self_info.ndim(); if (dims_indexed > ndim) { mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim); } // expand any unbound dimension list, or expand ... into individual : slices. auto expanding_dims = ndim - dims_indexed; if (expanding_object != -1) { if (unbound_dim_list) { unbound_dim_list->bind_len(expanding_dims); } else { // ... Slice no_slices; for (auto i : irange(expanding_dims)) { (void) i; no_slices.append(A, no_slice); } input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices); } } // flatten out any dimensions stored in dimlist elements directly into the inputs // std::cout << dimlists << " <- dim lists!\n"; for (int64_t i = dimlists.size() - 1; i >=0; --i) { auto idx = dimlists[i]; // we added more elements to input because of ... // so we need to also adjust the index to get back to where the // dimlist existed if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { idx += expanding_dims; } auto dl = DimList::unchecked_wrap(input[idx]); // XXX would be better if we used an OwnedSlice in DimList Slice more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end()); input.insert(A, input.slice(idx, idx + 1), more_dims); } return getsetitem_flat(A, self_info, input, Slice(), Slice(), has_dimpacks_or_none); } } IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none) { // At this point: // ..., DimList have been eliminated // Dim, Tensor, Tuple[Dim,...], int, slice still remain // we have to count how many times we see a dimension. // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing. Slice> seen_dims; Slice seen_dims_nuses; auto add_dim = [&](mpy::hdl entry) { auto midx = seen_dims.index(entry); if (!midx) { seen_dims.append(A, entry); seen_dims_nuses.append(A, 1); } else { ++seen_dims_nuses[*midx]; } }; Slice input_it = input; Slice flat_inputs; // flat inputs will start with an empty mpy::handle if the // actual value is in the tensor-like object in the tensor info Slice tensor_inputs; auto append_flat_handle = [&](mpy::handle h) { flat_inputs.append(A, h); tensor_inputs.append(A, TensorInfo()); }; TensorRef device_holding_tensor; auto append_tensor_input = [&](TensorInfo ti) { flat_inputs.append(A, mpy::handle()); tensor_inputs.append(A, ti); if (ti.has_device && !device_holding_tensor) { device_holding_tensor = ti.tensor; } }; Slice nsz; Slice nsd; at::IntArrayRef sz = self_info.tensor->sizes(); at::IntArrayRef sd = self_info.tensor->strides(); auto append_size = [&](int i) { if (has_dimpacks_or_none) { nsz.append(A, sz[i]); nsd.append(A, sd[i]); } }; // std::cout << "self levels: " << self_info.levels << "\n"; auto parse_nones = [&]() { while (input_it.size() && mpy::is_none(input_it[0])) { append_flat_handle(no_slice); nsz.append(A, 1); nsd.append(A, 0); input_it = input_it.slice(1); } }; auto append_item = [&](int i, mpy::handle arg) { if (Dim::check_exact(arg)) { auto d = Dim::unchecked_wrap(arg); d->set_size(sz[i]); add_dim(d); append_size(i); append_flat_handle(arg); return; } auto info = TensorInfo::create(A, arg, false, false); if (info) { append_size(i); append_tensor_input(info); for (auto il : info.levels) { if (!il.is_positional()) { add_dim(il.dim()); } } return; } if (has_dimpacks_or_none) { Slice mp; if (maybe_dimpack(mp, arg)) { // dim pack Slice> dim_pack; for (auto d : mp) { dim_pack.append(A, Dim::wrap(d)); add_dim(dim_pack.back()); append_flat_handle(dim_pack.back()); } _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); return; } } append_size(i); append_flat_handle(arg); }; // pair up the indexing expressions with dimension of self it indexes // self may have first-class dims, which do not participate the indexing. for (auto i : self_info.levels.enumerate()) { auto l = self_info.levels[i]; auto idx = keys.index(l); if (idx) { append_item(i, values[*idx]); } else if (l.is_positional()) { // grab and index from the positional list parse_nones(); if (!input_it.size()) { // we might have fewer indices than tensor dimensions, // which implicitly indexes the remaining dimensions with : append_flat_handle(no_slice); append_size(i); } else { mpy::handle arg = input_it[0]; input_it = input_it.slice(1); append_item(i, arg); } } else { add_dim(l.dim()); append_flat_handle(l.dim()); append_size(i); } } // any training Nones may have no existing dimension associated with them in self. parse_nones(); // we have to restride the tensor to collapse dimension packs and introduce our none dimensions. if (has_dimpacks_or_none) { self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset())); } // figure out what the shape of the indexing tensors will be // and what the shape of the resulting tensor will be Slice result_levels; Slice index_levels; int64_t tensor_insert_point = -1; bool requires_getindex = false; auto mark_tensor_index = [&] { if (tensor_insert_point == -1) { tensor_insert_point = result_levels.size(); } else if (tensor_insert_point != result_levels.size()) { tensor_insert_point = 0; } }; for (auto i : flat_inputs.enumerate()) { auto inp = flat_inputs[i]; if(tensor_inputs[i]) { requires_getindex = true; mark_tensor_index(); for (auto l : tensor_inputs[i].levels) { // std::cout << "Consider to add " << l << "\n"; if (!index_levels.contains(l)) { index_levels.append(A, l); } } } else if (Dim::check_exact(inp)) { auto d = Dim::unchecked_wrap(inp); // dimesions used once are just binding operations if (1 == seen_dims_nuses[*seen_dims.index(d)]) { flat_inputs[i] = no_slice; result_levels.append(A, d); } else { requires_getindex = true; flat_inputs[i] = mpy::handle(); tensor_inputs[i] = TensorInfo {d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; if (!index_levels.contains(d)) { index_levels.append(A, d); } mark_tensor_index(); } } else { if (inp.ptr() != no_slice.ptr()) { requires_getindex = true; } if (!mpy::is_int(inp)) { // note: actual positional indexes are accurately computed later result_levels.append(A, -1); } } } // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert // the indexing leveles into the result klevels at this spot if (tensor_insert_point != -1) { result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels); } // std::cout << "flat inputs: " << flat_inputs << "\n"; // std::cout << "result_levels: " << result_levels << "\n"; // std::cout << "index_levels: " << index_levels << "\n"; // get all the tensors to be the right shape for indexing if (requires_getindex) { for (auto i : flat_inputs.enumerate()) { if (tensor_inputs[i]) { AT_ASSERT(!flat_inputs[i].ptr()); // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n"; TensorRef t = tensor_inputs[i].tensor; if (!tensor_inputs[i].has_device && device_holding_tensor) { t = A.autorelease(t->to(device_holding_tensor->device())); } flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); } } } // previously we didn't know how many positional dimensions there would be so we couldn't number them right // so fill it in now. auto seen_positionals = 0; for (auto i : result_levels.reversed_enumerate()) { if (result_levels[i].is_positional()) { result_levels[i] = -(++seen_positionals); } } return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device}; } namespace{ mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) { maybeInitializeGlobals(); auto iinfo = getsetitem(A, self, index, has_dims(self)); if (iinfo.can_call_original) { return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr())); } return invoke_getitem(A, iinfo); } void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) { maybeInitializeGlobals(); auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); if (iinfo.can_call_original) { if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { throw mpy::exception_set(); } return; } auto rhs_info = TensorInfo::create(A, rhs, false, false); if (rhs_info) { // otherwise rhs can be a scalar... for (auto l : rhs_info.levels) { if (!iinfo.result_levels.contains(l)) { if (l.is_positional()) { mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim()); } else { auto tup = levels_to_tuple(iinfo.result_levels); mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr()); } } } auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); rhs = handle_from_tensor(A, rhs_matched); } self = handle_from_tensor(A, iinfo.self); if (iinfo.advanced_indexing) { auto tup = slice_to_tuple(iinfo.flat_inputs); if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { throw mpy::exception_set(); } } else { torch_Tensor_copy_.call(self, rhs); } } } PyObject* Tensor_getitem(PyObject* self, PyObject* index) { Arena A; PY_BEGIN return __getitem__(A, self, index).release(); PY_END(nullptr); } int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { Arena A; PY_BEGIN __setitem__(A, self, index, value); return 0; PY_END(-1); } namespace{ PyObject* py___getitem__(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN AT_ASSERT(nargs == 2); return __getitem__(A, args[0], args[1]).release(); PY_END(nullptr) } PyObject* py___setitem__(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN AT_ASSERT(nargs == 3); __setitem__(A, args[0], args[1], args[2]); Py_RETURN_NONE; PY_END(nullptr) } PyObject* py_index(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN mpy::vector_args va(args, nargs, kwnames); mpy::handle self, dims, indices; va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); return index(A, self, dims, indices).release(); PY_END(nullptr) } PyObject* py_stack(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN mpy::vector_args va(args, nargs, kwnames); mpy::handle tensors, new_dim, dim; va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); Slice result_levels; Slice infos; mpy::sequence_view sv(tensors); auto new_dim_d = Dim::wrap(new_dim); for (auto i : sv.enumerate()) { infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); for (auto l : infos.back().levels) { if (!result_levels.contains(l)) { result_levels.append(A, l); } } } new_dim_d->set_size(infos.size()); std::vector inputs; inputs.reserve(infos.size()); for (auto in : infos) { inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); } auto ndim = ndim_of_levels(result_levels); int64_t rawdim = 0; if (dim.ptr()) { auto d = _wrap_dim(dim, ndim, false); auto idx = result_levels.index(d); if (!idx) { mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); } rawdim = *idx; } auto result = at::stack(inputs, rawdim); result_levels.insert(A, rawdim, new_dim_d); return Tensor::from_positional(A, std::move(result), result_levels, true).release(); PY_END(nullptr) } PyObject* py_split(PyObject *_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN maybeInitializeGlobals(); mpy::vector_args va(args, nargs, kwnames); mpy::handle self, split_size_or_sections, dim; va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2); bool dim_is_object = dim.ptr() && Dim::check_exact(dim); Slice sizes; bool all_dims = true; bool all_ints = true; if (!mpy::is_int(split_size_or_sections)) { mpy::sequence_view sv(split_size_or_sections); for (auto i : sv.enumerate()) { sizes.append(A, A.autorelease(sv[i])); if (Dim::check_exact(sizes.back())) { all_ints = false; } else { all_dims = false; } } } if (all_ints) { if (dim_is_object) { mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions."); } // call original split (if self has dimensions this will use torch function to do the split) return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release(); } if (!all_dims) { mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix"); } auto self_info = TensorInfo::create(A, self, false); auto ndim = self_info.ndim(); if (!dim_is_object&& ndim == 0) { mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor"); } DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; auto idx = self_info.levels.index(dim_l); if (!idx) { if (!dim.ptr()) { dim = A.autorelease(mpy::from_int(0)); } mpy::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R", dim.ptr()); } Slice indices; int64_t total_size = 0; Slice unbound; for (auto i : sizes.enumerate()) { auto d = Dim::unchecked_wrap(sizes[i]); if (d->is_bound()) { indices.append(A, d->size()); total_size += indices.back(); } else { indices.append(A, 0); unbound.append(A, i); } } auto tensor_size = self_info.tensor->sizes()[*idx]; if (unbound.size()) { if (total_size > tensor_size) { mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size)); } auto remaining_size = tensor_size - total_size; auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); for (auto u : unbound) { auto sz = std::min(chunk_size, remaining_size); Dim::unchecked_wrap(sizes[u])->set_size(sz); indices[u] = sz; remaining_size -= sz; } } else if (tensor_size != total_size) { mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size)); } auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx); mpy::tuple result(result_tensors.size()); Slice new_levels; new_levels.extend(A, self_info.levels); for (auto i : sizes.enumerate()) { new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true)); } return result.release(); PY_END(nullptr) } Slice _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) { auto de = _wrap_dim(d, N, keepdim); Slice r; if (!de.is_none()) { r.append(A, de); } else { mpy::sequence_view sq(d); for (auto i : sq.enumerate()) { r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); } } return r; } struct WrappedOperator : public mpy::base { mpy::object orig; PyMethodDef method_def; mpy::object name, doc; bool is_pointwise = false; int64_t dim_offset = 0; int64_t keepdim_offset = 1; std::string dim_name; bool single_dim = false; bool reduce = true; static PyTypeObject Type; void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") { orig = std::move(orig_); method_def.ml_meth = wrapper_implementation; name = orig.attr("__name__"); doc = orig.attr("__doc__"); dim_name = std::move(dim_name_); if (!mpy::is_none(doc) && !dim_name.empty()) { doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str()); } method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; } mpy::object function() { return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); } }; } PyTypeObject WrappedOperator::Type = { PyVarObject_HEAD_INIT(NULL, 0) "_C.WrappedOperator", /* tp_name */ sizeof(WrappedOperator), /* tp_basicsize */ 0, /* tp_itemsize */ WrappedOperator::dealloc_stub, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_as_async */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ "Wrapped Object Holder", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ 0, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ 0, /* tp_init */ 0, /* tp_alloc */ WrappedOperator::new_stub, /* tp_new */ }; namespace{ PyObject* patched_dim_method(PyObject * self_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; auto self = WrappedOperator::unchecked_wrap(self_); PY_BEGIN mpy::vector_args va(args, nargs, kwnames); auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { auto offset = offset_ + 1; // do not include self auto idx = va.index(name, offset); return idx == -1 ? mpy::handle() : va[idx]; }; Slice patched_args; patched_args.extend(A, va.begin(), va.end()); auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { auto offset = offset_ + 1; // do not include self auto idx = va.index(name, offset); if (idx == -1) { mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); } patched_args[idx] = value; }; auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); if (!dim.ptr()) { auto info = TensorInfo::create(A, args[0], true); EnableAllLayers l(A, info.levels); l.inplace_update_layers(info.batchedtensor, info.levels); patched_args[0] = handle_from_tensor(A, info.batchedtensor); auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release(); } auto info = TensorInfo::create(A, args[0]); auto keepdim = false; if (self->reduce) { auto py_keepdim = _getarg("keepdim", self->keepdim_offset); if (py_keepdim.ptr()) { keepdim = mpy::to_bool(py_keepdim); } } auto ndim = info.ndim(); auto dims = _wrap_dims(A, dim, ndim, keepdim); Slice dim_indices; auto seen = A.allocate(info.levels.size()); std::fill(seen, seen + info.levels.size(), false); for (auto d : dims) { auto midx = info.levels.index(d); if (!midx) { auto tup = levels_to_tuple(info.levels); mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr()); } seen[*midx] = true; dim_indices.append(A, *midx); } Slice new_levels; if (self->reduce && !keepdim) { for (auto i : info.levels.enumerate()) { if (!seen[i]) { new_levels.append(A, info.levels[i]); } } } else { new_levels = info.levels; } mpy::object py_indices; if (dim_indices.size() == 1) { py_indices = mpy::from_int(dim_indices[0]); } else { mpy::tuple tup(dim_indices.size()); for (auto i : dim_indices.enumerate()) { tup.set(i, mpy::from_int(dim_indices[i])); } py_indices = std::move(tup); } _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); patched_args[0] = handle_from_tensor(A, info.tensor); auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); auto wrap = [&](mpy::handle h) { if (THPVariable_Check(h.ptr())) { return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); } return h; }; return tree_map(A, wrap, r).release(); PY_END(nullptr) } PyObject* _wrap(PyObject * self_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce) MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) std::string dim_name_str; if (dim_name.ptr()) { dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); } else { dim_name_str = "dim"; } auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str)); if (dim_offset.ptr()) { info->dim_offset = mpy::to_int(dim_offset); } if (keepdim_offset.ptr()) { info->keepdim_offset = mpy::to_int(keepdim_offset); } if (single_dim.ptr()) { info->single_dim = mpy::to_bool(single_dim); } if (reduce.ptr()) { info->reduce = mpy::to_bool(reduce); } return info->function().release(); #undef ARGS PY_END(nullptr) } PyObject* call_torch_function(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN Arena A; maybeInitializeGlobals(); auto info = WrappedOperator::unchecked_wrap(self); return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release(); PY_END(nullptr) } PyObject* _wrap_method(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN AT_ASSERT(nargs == 2); // XXX - ignore python function wrapped, we will call torch function directly mpy::handle orig = args[0]; if (!pointwise.ptr()) { auto dim = mpy::import("functorch.dim"); pointwise = dim.attr("pointwise"); } auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function); info->is_pointwise = pointwise.contains(orig); return PyInstanceMethod_New(info->function().release()); PY_END(nullptr); } PyObject* Tensor_sum(PyObject * self_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { Arena A; PY_BEGIN maybeInitializeGlobals(); mpy::vector_args va(args, nargs, kwnames); auto self_ = Tensor::unchecked_wrap(args[0]); auto d = self_->delayed(); if (!d) { return _Tensor_sum.call_vector(va).release(); } mpy::handle self, dim, keepdim, dtype; va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1); if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; return _Tensor_sum.call_vector(va).release(); } auto levels = self_->levels(); auto N = ndim_of_levels(levels); auto reduced_dims = _wrap_dims(A, dim, N, false); return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release(); PY_END(nullptr) } PyObject* _parse_test(PyObject * self_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN maybeInitializeGlobals(); int required = mpy::to_int(args[0]); int kwonly = mpy::to_int(args[1]); mpy::vector_args va(args + 2, nargs - 2, kwnames); mpy::handle a, b, c, d; va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); mpy::tuple r(4); r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); return r.release(); PY_END(nullptr) } PyObject* _set_pointwise_optimize(PyObject * self_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN mpy::handle value; mpy::vector_args va(args, nargs, kwnames); va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); pointwise_optimize = mpy::to_bool(value); Py_RETURN_NONE; PY_END(nullptr) } PyObject* _patch_tensor_class(PyObject * self_, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PY_BEGIN auto torch = mpy::import("torch"); auto py_TensorBase = torch.attr("_C").attr("TensorBase"); replaceMappingIfMatches(py_TensorBase); Py_RETURN_NONE; PY_END(nullptr) } const char* dims_doc = R"""( dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] Creates and returns one or more Dim objects. Arg: n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified. sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be created, specifying each dimensions size, or None to leave the size unset. Example:: >>> batch, channel, width, height = dims(4) >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224]) )"""; PyMethodDef methods[] = { {"dims", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS, dims_doc}, {"dimlists", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS}, {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS}, {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS}, {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS}, {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS}, {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS}, {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS}, {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS}, {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS}, {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS}, {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS}, {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS}, {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS}, {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS}, {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS}, {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS}, {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS}, {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, {NULL, NULL, 0, NULL} /* Sentinel */ }; struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, "_C", /* name of module */ NULL, /* module documentation, may be NULL */ -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ methods }; } PyObject* Dim_init() { Arena A; try { mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); Dim::ready(mod, "Dim"); DimList::ready(mod, "DimList"); Tensor::ready(mod, "Tensor"); WrappedOperator::ready(mod, "_WrappedOperator"); Py_INCREF(&PyInstanceMethod_Type); PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type); initializeGlobals(A); return mod.release(); } catch(mpy::exception_set& err) { return nullptr; } } #endif