• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 // note: pytorch's python variable simple includes pybind which conflicts with minpybind
9 // so this file just reproduces the minimial API needed to extract Tensors from python objects.
10 
11 #include <torch/csrc/python_headers.h>
12 #include <ATen/core/Tensor.h>
13 #include <torch/csrc/Export.h>
14 
15 // Python object that backs torch.autograd.Variable
16 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
17 struct THPVariable {
18   PyObject_HEAD;
19   // Payload
20   c10::MaybeOwned<at::Tensor> cdata;
21   // Hooks to be run on backwards pass (corresponds to Python attr
22   // '_backwards_hooks', set by 'register_hook')
23   PyObject* backward_hooks = nullptr;
24 };
25 
26 TORCH_PYTHON_API extern PyObject *THPVariableClass;
27 TORCH_PYTHON_API extern PyObject *ParameterClass;
28 
29 TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var);
30 
THPVariable_Check(PyObject * obj)31 inline bool THPVariable_Check(PyObject *obj)
32 {
33   if (!THPVariableClass)
34       return false;
35 
36   const auto result = PyObject_IsInstance(obj, THPVariableClass);
37   AT_ASSERT(result != -1);
38   return result;
39 }
40 
THPVariable_Unpack(THPVariable * var)41 inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
42   return *var->cdata;
43 }
44 
THPVariable_Unpack(PyObject * obj)45 inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
46   return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
47 }
48 
49 TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
50