1 #include <Python.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/autograd/functions/accumulate_grad.h>
4 #include <torch/csrc/autograd/functions/basic_ops.h>
5 #include <torch/csrc/autograd/functions/pybind.h>
6 #include <torch/csrc/autograd/functions/tensor.h>
7 #include <torch/csrc/autograd/generated/python_functions.h>
8 #include <torch/csrc/autograd/python_cpp_function.h>
9 #include <torch/csrc/autograd/python_variable.h>
10 #ifdef USE_DISTRIBUTED
11 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
12 #endif
13 #include <torch/csrc/jit/python/python_tracer.h>
14 #include <torch/csrc/utils/pybind.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_strings.h>
17 
18 #include <utility>
19 
20 using namespace torch::autograd;
21 
22 struct DelayedErrorCtor {
operator ()DelayedErrorCtor23   DelayedError* operator()(PyObject* args) {
24     TORCH_CHECK(
25         PyTuple_GET_SIZE(args) == 2,
26         "Requires two arguments, got ",
27         PyTuple_GET_SIZE(args));
28     auto arg1 = PyTuple_GET_ITEM(args, 0);
29     TORCH_CHECK(THPUtils_checkString(arg1), "argument 'msg' must be a string");
30     std::string msg = THPUtils_unpackString(arg1);
31     auto arg2 = PyTuple_GET_ITEM(args, 1);
32     TORCH_CHECK(
33         THPUtils_checkLong(arg2), "argument 'num_inputs' must be an int");
34     auto num_inputs = THPUtils_unpackLong(arg2);
35     return new DelayedError(std::move(msg), num_inputs);
36   }
37 };
38 
39 struct UndefinedGradCtor {
operator ()UndefinedGradCtor40   UndefinedGrad* operator()(PyObject* args) {
41     TORCH_CHECK(
42         PyTuple_GET_SIZE(args) == 0,
43         "Requires zero arguments, got ",
44         PyTuple_GET_SIZE(args));
45     return new UndefinedGrad();
46   }
47 };
48 
49 struct NoCtor {
operator ()NoCtor50   Node* operator()(PyObject* args) {
51     throw std::runtime_error("Cannot construct");
52   }
53 };
54 
55 template <typename C, typename T>
addClass(PyObject * module,PyTypeObject & type,const char * name,PyGetSetDef * function_properties=nullptr,PyMethodDef * function_methods=nullptr)56 static void addClass(
57     PyObject* module,
58     PyTypeObject& type,
59     const char* name,
60     PyGetSetDef* function_properties = nullptr,
61     PyMethodDef* function_methods = nullptr) {
62   createForwardFunctionPyTypeObject<T>(
63       type, name, function_properties, function_methods);
64   Py_INCREF(&type);
65   PyModule_AddObject(module, name, (PyObject*)&type);
66   registerCppFunction(typeid(C), &type);
67 }
68 
69 template <
70     typename T,
71     typename ValueT,
72     typename ParamsT,
73     ValueT ParamsT::*ptr,
74     typename ConvertArgT,
75     PyObject* (*Convert)(ConvertArgT)>
getTupleAttr(PyObject * obj,void * _unused)76 PyObject* getTupleAttr(PyObject* obj, void* _unused) {
77   HANDLE_TH_ERRORS
78   THPCppFunction* self = (THPCppFunction*)obj;
79   auto& arr = ((T*)(self->cdata.get()))->*ptr;
80   auto num_elems = arr.size();
81   THPObjectPtr py_tuple(PyTuple_New(num_elems));
82   if (!py_tuple)
83     return nullptr;
84   for (const auto i : c10::irange(num_elems)) {
85     PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i]));
86   }
87   return py_tuple.release();
88   END_HANDLE_TH_ERRORS
89 }
90 
91 template <
92     typename T,
93     typename ValueT,
94     typename ParamsT,
95     ValueT ParamsT::*ptr,
96     typename ConvertArgT,
97     PyObject* (*Convert)(ConvertArgT)>
getValueAttr(PyObject * obj,void * _unused)98 PyObject* getValueAttr(PyObject* obj, void* _unused) {
99   HANDLE_TH_ERRORS
100   THPCppFunction* self = (THPCppFunction*)obj;
101   auto& val = ((T*)(self->cdata.get()))->*ptr;
102   return Convert(val);
103   END_HANDLE_TH_ERRORS
104 }
105 
accumulateGradVar(PyObject * _self,void * _unused)106 static PyObject* accumulateGradVar(PyObject* _self, void* _unused) {
107   THPCppFunction* self = (THPCppFunction*)_self;
108   auto grad_acc = (AccumulateGrad*)self->cdata.get();
109   return THPVariable_Wrap(grad_acc->variable);
110 }
111 
112 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
113 static struct PyGetSetDef accumulate_grad_properties[] = {
114     THP_FUNCTION_DEFAULT_PROPERTIES,
115     {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr},
116     {nullptr}};
117 
THPAutograd_initFunctions()118 void THPAutograd_initFunctions() {
119   THPObjectPtr module(PyModule_New("torch._C._functions"));
120   if (!module)
121     throw python_error();
122 
123   static PyTypeObject AccumulateGradClass;
124   addClass<AccumulateGrad, NoCtor>(
125       module,
126       AccumulateGradClass,
127       "AccumulateGrad",
128       accumulate_grad_properties);
129 
130   static PyTypeObject ErrorClass;
131   addClass<Error, NoCtor>(module, ErrorClass, "Error");
132 
133   static PyTypeObject NotImplementedClass;
134   addClass<NotImplemented, NoCtor>(
135       module, NotImplementedClass, "NotImplemented");
136 
137   static PyTypeObject DelayedErrorClass;
138   addClass<DelayedError, DelayedErrorCtor>(
139       module, DelayedErrorClass, "DelayedError");
140 
141   static PyTypeObject UndefinedGradBackwardClass;
142   addClass<UndefinedGradBackward, NoCtor>(
143       module, UndefinedGradBackwardClass, "UndefinedGradBackward");
144 
145   static PyTypeObject UndefinedGradClass;
146   addClass<UndefinedGrad, UndefinedGradCtor>(
147       module, UndefinedGradClass, "UndefinedGrad");
148 
149   static PyTypeObject CopyBackwardsClass;
150   addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
151 
152 #ifdef USE_DISTRIBUTED
153   static PyTypeObject SendRpcBackwardClass;
154   addClass<torch::distributed::autograd::SendRpcBackward, NoCtor>(
155       module, SendRpcBackwardClass, "SendRpcBackward");
156 #endif
157 
158   static PyTypeObject CopySlicesClass;
159   addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices");
160 
161   generated::initialize_autogenerated_functions(module);
162 
163   auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
164   if (!c_module)
165     throw python_error();
166 
167   Py_INCREF(module.get());
168   if (PyModule_AddObject(c_module, "_functions", module) < 0) {
169     Py_DECREF(module.get());
170     throw python_error();
171   }
172 }
173