#include #include #include #include #include #include #include #include #include #ifdef USE_DISTRIBUTED #include #endif #include #include #include #include #include using namespace torch::autograd; struct DelayedErrorCtor { DelayedError* operator()(PyObject* args) { TORCH_CHECK( PyTuple_GET_SIZE(args) == 2, "Requires two arguments, got ", PyTuple_GET_SIZE(args)); auto arg1 = PyTuple_GET_ITEM(args, 0); TORCH_CHECK(THPUtils_checkString(arg1), "argument 'msg' must be a string"); std::string msg = THPUtils_unpackString(arg1); auto arg2 = PyTuple_GET_ITEM(args, 1); TORCH_CHECK( THPUtils_checkLong(arg2), "argument 'num_inputs' must be an int"); auto num_inputs = THPUtils_unpackLong(arg2); return new DelayedError(std::move(msg), num_inputs); } }; struct UndefinedGradCtor { UndefinedGrad* operator()(PyObject* args) { TORCH_CHECK( PyTuple_GET_SIZE(args) == 0, "Requires zero arguments, got ", PyTuple_GET_SIZE(args)); return new UndefinedGrad(); } }; struct NoCtor { Node* operator()(PyObject* args) { throw std::runtime_error("Cannot construct"); } }; template static void addClass( PyObject* module, PyTypeObject& type, const char* name, PyGetSetDef* function_properties = nullptr, PyMethodDef* function_methods = nullptr) { createForwardFunctionPyTypeObject( type, name, function_properties, function_methods); Py_INCREF(&type); PyModule_AddObject(module, name, (PyObject*)&type); registerCppFunction(typeid(C), &type); } template < typename T, typename ValueT, typename ParamsT, ValueT ParamsT::*ptr, typename ConvertArgT, PyObject* (*Convert)(ConvertArgT)> PyObject* getTupleAttr(PyObject* obj, void* _unused) { HANDLE_TH_ERRORS THPCppFunction* self = (THPCppFunction*)obj; auto& arr = ((T*)(self->cdata.get()))->*ptr; auto num_elems = arr.size(); THPObjectPtr py_tuple(PyTuple_New(num_elems)); if (!py_tuple) return nullptr; for (const auto i : c10::irange(num_elems)) { PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i])); } return py_tuple.release(); END_HANDLE_TH_ERRORS } template < typename T, typename ValueT, typename ParamsT, ValueT ParamsT::*ptr, typename ConvertArgT, PyObject* (*Convert)(ConvertArgT)> PyObject* getValueAttr(PyObject* obj, void* _unused) { HANDLE_TH_ERRORS THPCppFunction* self = (THPCppFunction*)obj; auto& val = ((T*)(self->cdata.get()))->*ptr; return Convert(val); END_HANDLE_TH_ERRORS } static PyObject* accumulateGradVar(PyObject* _self, void* _unused) { THPCppFunction* self = (THPCppFunction*)_self; auto grad_acc = (AccumulateGrad*)self->cdata.get(); return THPVariable_Wrap(grad_acc->variable); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static struct PyGetSetDef accumulate_grad_properties[] = { THP_FUNCTION_DEFAULT_PROPERTIES, {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr}, {nullptr}}; void THPAutograd_initFunctions() { THPObjectPtr module(PyModule_New("torch._C._functions")); if (!module) throw python_error(); static PyTypeObject AccumulateGradClass; addClass( module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties); static PyTypeObject ErrorClass; addClass(module, ErrorClass, "Error"); static PyTypeObject NotImplementedClass; addClass( module, NotImplementedClass, "NotImplemented"); static PyTypeObject DelayedErrorClass; addClass( module, DelayedErrorClass, "DelayedError"); static PyTypeObject UndefinedGradBackwardClass; addClass( module, UndefinedGradBackwardClass, "UndefinedGradBackward"); static PyTypeObject UndefinedGradClass; addClass( module, UndefinedGradClass, "UndefinedGrad"); static PyTypeObject CopyBackwardsClass; addClass(module, CopyBackwardsClass, "CopyBackwards"); #ifdef USE_DISTRIBUTED static PyTypeObject SendRpcBackwardClass; addClass( module, SendRpcBackwardClass, "SendRpcBackward"); #endif static PyTypeObject CopySlicesClass; addClass(module, CopySlicesClass, "CopySlices"); generated::initialize_autogenerated_functions(module); auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C")); if (!c_module) throw python_error(); Py_INCREF(module.get()); if (PyModule_AddObject(c_module, "_functions", module) < 0) { Py_DECREF(module.get()); throw python_error(); } }