• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 
17 #ifndef MINDSPORE_TESTS_UT_CPP_PYNATIVE_COMMON_H_
18 #define MINDSPORE_TESTS_UT_CPP_PYNATIVE_COMMON_H_
19 
20 #include "gtest/gtest.h"
21 #include "mockcpp/mockcpp.hpp"
22 #include "common/mockcpp.h"
23 #include "pybind11/embed.h"
24 #include "pybind11/pybind11.h"
25 
26 #include "ir/tensor.h"
27 #include "include/common/utils/stub_tensor.h"
28 
29 namespace mindspore {
30 class PyCommon : public testing::Test {
31  protected:
SetUp()32   virtual void SetUp() {}
33 
TearDown()34   virtual void TearDown() { GlobalMockObject::verify(); }
35 
SetUpTestCase()36   static void SetUpTestCase() {
37     if (Py_IsInitialized() == 0) {
38       guard_ = std::make_unique<pybind11::scoped_interpreter>();
39     }
40     m_ = pybind11::module::import("mindspore");
41     stub_tensor_module_ = pybind11::module::import("mindspore.common._stub_tensor");
42     tensor_module_ = pybind11::module::import("mindspore.common.tensor");
43   }
44 
TearDownTestCase()45   static void TearDownTestCase() {
46     tensor_module_.release();
47     stub_tensor_module_.release();
48     m_.release();
49     guard_ = nullptr;
50   }
51 
NewPyTensor(const tensor::BaseTensorPtr & tensor)52   pybind11::object NewPyTensor(const tensor::BaseTensorPtr &tensor) {
53     return tensor_module_.attr("Tensor")(tensor);
54   }
55 
NewPyStubTensor(const stub::StubNodePtr & stub_tensor)56   pybind11::object NewPyStubTensor(const stub::StubNodePtr &stub_tensor) {
57     return stub_tensor_module_.attr("_convert_stub")(stub_tensor);
58   }
59 
NewPyStubTensor(const tensor::BaseTensorPtr & tensor)60   pybind11::object NewPyStubTensor(const tensor::BaseTensorPtr &tensor) {
61     auto node = stub::MakeTopNode(kTensorType);
62     node.second->SetValue(tensor);
63     return stub_tensor_module_.attr("_convert_stub")(node.first);
64   }
65 
66  protected:
67   inline static pybind11::module m_;
68   inline static pybind11::module stub_tensor_module_;
69   inline static pybind11::module tensor_module_;
70   inline static std::unique_ptr<pybind11::scoped_interpreter> guard_;
71 };
72 }
73 
74 #endif  // MINDSPORE_TESTS_UT_CPP_PYNATIVE_COMMON_H_
75