1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <string> 18 #include "common/common_test.h" 19 #include "common/py_func_graph_fetcher.h" 20 21 #include "ir/func_graph_cloner.h" 22 #include "utils/log_adapter.h" 23 #include "pipeline/jit/parse/parse.h" 24 #include "debug/draw.h" 25 26 namespace mindspore { 27 void CheckNoFreeVariables(FuncGraphPtr root) { 28 auto mng = Manage(root); 29 for (auto &iter : mng->func_graphs()) { 30 auto g = iter; 31 if (g == nullptr) { 32 continue; 33 } 34 ASSERT_TRUE(g->parent() == nullptr); 35 36 auto nodes = g->nodes(); 37 for (auto &node : nodes) { 38 ASSERT_EQ(node->func_graph(), g); 39 auto cnode = node->cast<CNodePtr>(); 40 if (cnode != nullptr) { 41 for (auto &inp : cnode->inputs()) { 42 ASSERT_TRUE(inp->func_graph() == nullptr || inp->func_graph() == g); 43 } 44 } 45 } 46 } 47 } 48 49 void CheckCconv(FuncGraphPtr g) { 50 auto mng = Manage(g); 51 auto new_g = LiftingClone(g); 52 CheckNoFreeVariables(new_g); 53 } 54 55 class TestCconv : public UT::Common { 56 public: 57 TestCconv() : getPyFun("gtest_input.optimizer.cconv_test") {} 58 59 virtual void SetUp(); 60 61 virtual void TearDown(); 62 63 public: 64 UT::PyFuncGraphFetcher getPyFun; 65 }; 66 67 void TestCconv::SetUp() {} 68 69 void TestCconv::TearDown() {} 70 71 TEST_F(TestCconv, TestStraight) { 72 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_straight"); 73 ASSERT_TRUE(nullptr != func_graph); 74 CheckCconv(func_graph); 75 } 76 77 TEST_F(TestCconv, TestSimpleClosure) { 78 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_simple_closure"); 79 ASSERT_TRUE(nullptr != func_graph); 80 CheckCconv(func_graph); 81 } 82 83 TEST_F(TestCconv, TestMax) { 84 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_max"); 85 ASSERT_TRUE(nullptr != func_graph); 86 CheckCconv(func_graph); 87 } 88 89 TEST_F(TestCconv, TestDeepNesting) { 90 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_deep_nesting"); 91 ASSERT_TRUE(nullptr != func_graph); 92 CheckCconv(func_graph); 93 } 94 95 TEST_F(TestCconv, TestReturnInDoubleWhile) { 96 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_return_in_double_while"); 97 ASSERT_TRUE(nullptr != func_graph); 98 CheckCconv(func_graph); 99 } 100 101 TEST_F(TestCconv, TestPow10) { 102 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_pow10"); 103 ASSERT_TRUE(nullptr != func_graph); 104 CheckCconv(func_graph); 105 } 106 107 TEST_F(TestCconv, TestClosureAsSimpleFv) { 108 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_simple_fv"); 109 ASSERT_TRUE(nullptr != func_graph); 110 CheckCconv(func_graph); 111 } 112 113 TEST_F(TestCconv, TestClosureAsFv) { 114 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_fv"); 115 ASSERT_TRUE(nullptr != func_graph); 116 CheckCconv(func_graph); 117 } 118 119 TEST_F(TestCconv, TestClosureAsDoubleFv) { 120 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_double_fv"); 121 ASSERT_TRUE(nullptr != func_graph); 122 CheckCconv(func_graph); 123 } 124 125 TEST_F(TestCconv, TestClosureLiftSameParam) { 126 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_lift_same_param"); 127 ASSERT_TRUE(nullptr != func_graph); 128 CheckCconv(func_graph); 129 } 130 131 TEST_F(TestCconv, TestClosureAsLoop) { 132 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_loop"); 133 ASSERT_TRUE(nullptr != func_graph); 134 CheckCconv(func_graph); 135 } 136 137 TEST_F(TestCconv, TestClosureLiftCNode) { 138 FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_lift_cnode"); 139 ASSERT_TRUE(nullptr != func_graph); 140 CheckCconv(func_graph); 141 } 142 143 } // namespace mindspore 144