1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // This source code is licensed under the BSD-style license found in the 5 // LICENSE file in the root directory of this source tree. 6 7 #pragma once 8 // note: pytorch's python variable simple includes pybind which conflicts with minpybind 9 // so this file just reproduces the minimial API needed to extract Tensors from python objects. 10 11 #include <torch/csrc/python_headers.h> 12 #include <ATen/core/Tensor.h> 13 #include <torch/csrc/Export.h> 14 15 // Python object that backs torch.autograd.Variable 16 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 17 struct THPVariable { 18 PyObject_HEAD; 19 // Payload 20 c10::MaybeOwned<at::Tensor> cdata; 21 // Hooks to be run on backwards pass (corresponds to Python attr 22 // '_backwards_hooks', set by 'register_hook') 23 PyObject* backward_hooks = nullptr; 24 }; 25 26 TORCH_PYTHON_API extern PyObject *THPVariableClass; 27 TORCH_PYTHON_API extern PyObject *ParameterClass; 28 29 TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var); 30 THPVariable_Check(PyObject * obj)31inline bool THPVariable_Check(PyObject *obj) 32 { 33 if (!THPVariableClass) 34 return false; 35 36 const auto result = PyObject_IsInstance(obj, THPVariableClass); 37 AT_ASSERT(result != -1); 38 return result; 39 } 40 THPVariable_Unpack(THPVariable * var)41inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { 42 return *var->cdata; 43 } 44 THPVariable_Unpack(PyObject * obj)45inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { 46 return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj)); 47 } 48 49 TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); 50