1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <string>
18 #include "common/common_test.h"
19 #include "common/py_func_graph_fetcher.h"
20
21 #include "ir/func_graph_cloner.h"
22 #include "utils/log_adapter.h"
23 #include "pipeline/jit/parse/parse.h"
24 #include "debug/draw.h"
25
26 namespace mindspore {
CheckNoFreeVariables(FuncGraphPtr root)27 void CheckNoFreeVariables(FuncGraphPtr root) {
28 auto mng = Manage(root);
29 for (auto &iter : mng->func_graphs()) {
30 auto g = iter;
31 if (g == nullptr) {
32 continue;
33 }
34 ASSERT_TRUE(g->parent() == nullptr);
35
36 auto nodes = g->nodes();
37 for (auto &node : nodes) {
38 ASSERT_EQ(node->func_graph(), g);
39 auto cnode = node->cast<CNodePtr>();
40 if (cnode != nullptr) {
41 for (auto &inp : cnode->inputs()) {
42 ASSERT_TRUE(inp->func_graph() == nullptr || inp->func_graph() == g);
43 }
44 }
45 }
46 }
47 }
48
CheckCconv(FuncGraphPtr g)49 void CheckCconv(FuncGraphPtr g) {
50 auto mng = Manage(g);
51 auto new_g = LiftingClone(g);
52 CheckNoFreeVariables(new_g);
53 }
54
55 class TestCconv : public UT::Common {
56 public:
TestCconv()57 TestCconv() : getPyFun("gtest_input.optimizer.cconv_test") {}
58
59 virtual void SetUp();
60
61 virtual void TearDown();
62
63 public:
64 UT::PyFuncGraphFetcher getPyFun;
65 };
66
SetUp()67 void TestCconv::SetUp() {}
68
TearDown()69 void TestCconv::TearDown() {}
70
TEST_F(TestCconv,TestStraight)71 TEST_F(TestCconv, TestStraight) {
72 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_straight");
73 ASSERT_TRUE(nullptr != func_graph);
74 CheckCconv(func_graph);
75 }
76
TEST_F(TestCconv,TestSimpleClosure)77 TEST_F(TestCconv, TestSimpleClosure) {
78 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_simple_closure");
79 ASSERT_TRUE(nullptr != func_graph);
80 CheckCconv(func_graph);
81 }
82
TEST_F(TestCconv,TestMax)83 TEST_F(TestCconv, TestMax) {
84 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_max");
85 ASSERT_TRUE(nullptr != func_graph);
86 CheckCconv(func_graph);
87 }
88
TEST_F(TestCconv,TestDeepNesting)89 TEST_F(TestCconv, TestDeepNesting) {
90 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_deep_nesting");
91 ASSERT_TRUE(nullptr != func_graph);
92 CheckCconv(func_graph);
93 }
94
TEST_F(TestCconv,TestReturnInDoubleWhile)95 TEST_F(TestCconv, TestReturnInDoubleWhile) {
96 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_return_in_double_while");
97 ASSERT_TRUE(nullptr != func_graph);
98 CheckCconv(func_graph);
99 }
100
TEST_F(TestCconv,TestPow10)101 TEST_F(TestCconv, TestPow10) {
102 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_pow10");
103 ASSERT_TRUE(nullptr != func_graph);
104 CheckCconv(func_graph);
105 }
106
TEST_F(TestCconv,TestClosureAsSimpleFv)107 TEST_F(TestCconv, TestClosureAsSimpleFv) {
108 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_simple_fv");
109 ASSERT_TRUE(nullptr != func_graph);
110 CheckCconv(func_graph);
111 }
112
TEST_F(TestCconv,TestClosureAsFv)113 TEST_F(TestCconv, TestClosureAsFv) {
114 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_fv");
115 ASSERT_TRUE(nullptr != func_graph);
116 CheckCconv(func_graph);
117 }
118
TEST_F(TestCconv,TestClosureAsDoubleFv)119 TEST_F(TestCconv, TestClosureAsDoubleFv) {
120 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_double_fv");
121 ASSERT_TRUE(nullptr != func_graph);
122 CheckCconv(func_graph);
123 }
124
TEST_F(TestCconv,TestClosureLiftSameParam)125 TEST_F(TestCconv, TestClosureLiftSameParam) {
126 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_lift_same_param");
127 ASSERT_TRUE(nullptr != func_graph);
128 CheckCconv(func_graph);
129 }
130
TEST_F(TestCconv,TestClosureAsLoop)131 TEST_F(TestCconv, TestClosureAsLoop) {
132 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_loop");
133 ASSERT_TRUE(nullptr != func_graph);
134 CheckCconv(func_graph);
135 }
136
TEST_F(TestCconv,TestClosureLiftCNode)137 TEST_F(TestCconv, TestClosureLiftCNode) {
138 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_lift_cnode");
139 ASSERT_TRUE(nullptr != func_graph);
140 CheckCconv(func_graph);
141 }
142
143 } // namespace mindspore
144