1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_TESTS_UT_CPP_PYNATIVE_COMMON_H_ 18 #define MINDSPORE_TESTS_UT_CPP_PYNATIVE_COMMON_H_ 19 20 #include "gtest/gtest.h" 21 #include "mockcpp/mockcpp.hpp" 22 #include "common/mockcpp.h" 23 #include "pybind11/embed.h" 24 #include "pybind11/pybind11.h" 25 26 #include "ir/tensor.h" 27 #include "include/common/utils/stub_tensor.h" 28 29 namespace mindspore { 30 class PyCommon : public testing::Test { 31 protected: SetUp()32 virtual void SetUp() {} 33 TearDown()34 virtual void TearDown() { GlobalMockObject::verify(); } 35 SetUpTestCase()36 static void SetUpTestCase() { 37 if (Py_IsInitialized() == 0) { 38 guard_ = std::make_unique<pybind11::scoped_interpreter>(); 39 } 40 m_ = pybind11::module::import("mindspore"); 41 stub_tensor_module_ = pybind11::module::import("mindspore.common._stub_tensor"); 42 tensor_module_ = pybind11::module::import("mindspore.common.tensor"); 43 } 44 TearDownTestCase()45 static void TearDownTestCase() { 46 tensor_module_.release(); 47 stub_tensor_module_.release(); 48 m_.release(); 49 guard_ = nullptr; 50 } 51 NewPyTensor(const tensor::BaseTensorPtr & tensor)52 pybind11::object NewPyTensor(const tensor::BaseTensorPtr &tensor) { 53 return tensor_module_.attr("Tensor")(tensor); 54 } 55 NewPyStubTensor(const stub::StubNodePtr & stub_tensor)56 pybind11::object NewPyStubTensor(const stub::StubNodePtr &stub_tensor) { 57 return stub_tensor_module_.attr("_convert_stub")(stub_tensor); 58 } 59 NewPyStubTensor(const tensor::BaseTensorPtr & tensor)60 pybind11::object NewPyStubTensor(const tensor::BaseTensorPtr &tensor) { 61 auto node = stub::MakeTopNode(kTensorType); 62 node.second->SetValue(tensor); 63 return stub_tensor_module_.attr("_convert_stub")(node.first); 64 } 65 66 protected: 67 inline static pybind11::module m_; 68 inline static pybind11::module stub_tensor_module_; 69 inline static pybind11::module tensor_module_; 70 inline static std::unique_ptr<pybind11::scoped_interpreter> guard_; 71 }; 72 } 73 74 #endif // MINDSPORE_TESTS_UT_CPP_PYNATIVE_COMMON_H_ 75