1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <algorithm> 17 18 #include "common/common_test.h" 19 #include "common/py_func_graph_fetcher.h" 20 #include "ir/manager.h" 21 #include "utils/log_adapter.h" 22 #include "ir/func_graph_cloner.h" 23 #include "pipeline/jit/parse/parse.h" 24 #include "ir/graph_utils.h" 25 #include "pipeline/jit/resource.h" 26 #include "debug/draw.h" 27 #include "frontend/operator/ops.h" 28 #include "vm/segment_runner.h" 29 #include "vm/transform.h" 30 #include "ir/tensor.h" 31 #include "utils/convert_utils.h" 32 #include "utils/convert_utils_py.h" 33 #include "utils/log_adapter.h" 34 #include "base/core_ops.h" 35 36 namespace mindspore { 37 namespace compile { 38 using Tensor = tensor::Tensor; 39 40 class TestCompileSegmentRunner : public UT::Common { 41 public: 42 TestCompileSegmentRunner() : get_py_fun_("gtest_input.vm", true) { UT::InitPythonPath(); } 43 44 protected: 45 UT::PyFuncGraphFetcher get_py_fun_; 46 VM vm_; 47 }; 48 49 TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { 50 FuncGraphPtr g = get_py_fun_(prim::kScalarAdd); 51 // g was managed by local variable manager in get_py_fun_ and that manager will be freed as no reference. 52 // so a new manager should be declared to make get_outputs() in segment_runner.cc happy. 53 std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); 54 55 BackendPtr b = std::make_shared<Backend>("vm"); 56 auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); 57 auto segments = graph_partition->Partition(g); 58 VectorRef args({1.0, 2.0}); 59 60 auto convertResult = MsVmConvert(segments[0], ""); 61 auto runResult = (*(convertResult.run))(args); 62 ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); 63 } 64 65 TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { 66 FuncGraphPtr g = get_py_fun_(prim::kScalarMul); 67 std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); 68 69 BackendPtr b = std::make_shared<Backend>("vm"); 70 auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); 71 auto segments = graph_partition->Partition(g); 72 VectorRef args({1.0, 2.0}); 73 74 auto convertResult = MsVmConvert(segments[0], ""); 75 auto runResult = (*(convertResult.run))(args); 76 ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); 77 } 78 79 TEST_F(TestCompileSegmentRunner, test_if) { 80 FuncGraphPtr g = get_py_fun_("test_if"); 81 std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); 82 83 BackendPtr b = std::make_shared<Backend>("vm"); 84 auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); 85 auto segments = graph_partition->Partition(g); 86 VectorRef args({1.0, 2.0}); 87 88 auto convertResult = MsVmConvert(segments[0], ""); 89 auto runResult = (*(convertResult.run))(args); 90 91 auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); 92 ASSERT_TRUE(runResult.size() == 1 && result == false); 93 } 94 95 TEST_F(TestCompileSegmentRunner, test_RunOperation1) { 96 VectorRef args({1}); 97 auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name())), args); 98 ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1); 99 } 100 101 TEST_F(TestCompileSegmentRunner, test_RunOperation2) { 102 VectorRef args({1, 2}); 103 auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name())), args); 104 ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false); 105 } 106 } // namespace compile 107 } // namespace mindspore 108