1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 19 #include "common/common_test.h" 20 21 #include "common/py_func_graph_fetcher.h" 22 #include "ir/anf.h" 23 #include "ir/func_graph.h" 24 #include "ir/func_graph_cloner.h" 25 #include "ir/manager.h" 26 #include "ir/value.h" 27 #include "frontend/operator/ops.h" 28 #include "frontend/optimizer/irpass.h" 29 #include "pipeline/jit/resource.h" 30 #include "debug/draw.h" 31 #include "pipeline/jit/parse/data_converter.h" 32 33 namespace mindspore { 34 namespace opt { 35 using abstract::AnalysisResult; 36 37 class TestOptLib : public UT::Common { 38 public: 39 TestOptLib() : getPyFun("gtest_input.optimizer.opt_test", true), irpass() {} 40 void SetUp() { 41 UT::InitPythonPath(); 42 parse::data_converter::ClearObjectCache(); 43 auto ms_context = MsContext::GetInstance(); 44 MS_EXCEPTION_IF_NULL(ms_context); 45 ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); 46 } 47 FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { 48 equiv_node.clear(); 49 equiv_graph.clear(); 50 51 FuncGraphPtr gbefore_clone = BasicClone(gbefore); 52 OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>()); 53 transform(gbefore_clone, optimizer); 54 return gbefore_clone; 55 } 56 FuncGraphPtr RunSubs(FuncGraphPtr before, std::vector<SubstitutionPtr> opts = {}) { 57 SubstitutionList eq(opts); 58 return RunTransform(before, eq); 59 } 60 bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform, 61 bool save_graphs = false) { 62 equiv_node.clear(); 63 equiv_graph.clear(); 64 65 FuncGraphPtr gbefore_clone = BasicClone(gbefore); 66 OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>()); 67 transform(gbefore_clone, optimizer); 68 return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node); 69 } 70 bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}, 71 bool save_graphs = false) { 72 if (nullptr == before || nullptr == after) { 73 return false; 74 } 75 SubstitutionList eq(opts); 76 return CheckTransform(before, after, eq, save_graphs); 77 } 78 79 public: 80 UT::PyFuncGraphFetcher getPyFun; 81 FuncGraphPairMapEquiv equiv_graph; 82 NodeMapEquiv equiv_node; 83 irpass::OptimizeIRPassLib irpass; 84 }; 85 86 TEST_F(TestOptLib, test_simplify_always_true_false) { 87 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1"); 88 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2"); 89 FuncGraphPtr after = getPyFun.CallAndParseRet("test_simplify_always_true_false", "after"); 90 auto patterns = std::vector<SubstitutionPtr>({irpass.switch_simplify_}); 91 ASSERT_TRUE(CheckOpt(before1, after, patterns)); 92 ASSERT_TRUE(CheckOpt(before2, after, patterns)); 93 } 94 95 TEST_F(TestOptLib, test_inline) { 96 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_inline", "before"); 97 FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline", "after"); 98 // add infer and renormalize 99 std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>(); 100 AbstractBasePtrList args_spec_list; 101 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3}); 102 tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3}); 103 104 AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true); 105 AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true); 106 args_spec_list.push_back(abstract_v1); 107 args_spec_list.push_back(abstract_v2); 108 AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); 109 FuncGraphPtr new_graph = pipeline::ProgramSpecialize(res, before1, result.context); 110 auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_, irpass.switch_simplify_, irpass.inline_}); 111 ASSERT_TRUE(CheckOpt(new_graph, after, patterns)); 112 } 113 114 TEST_F(TestOptLib, test_inline_successively) { 115 FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_successively", "before"); 116 FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_successively", "after"); 117 auto patterns = std::vector<SubstitutionPtr>({irpass.inline_}); 118 ASSERT_TRUE(CheckOpt(before, after, patterns)); 119 } 120 121 TEST_F(TestOptLib, test_inline_closure) { 122 FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_closure", "before"); 123 FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_closure", "after"); 124 auto patterns = std::vector<SubstitutionPtr>({irpass.inline_}); 125 ASSERT_TRUE(CheckOpt(before, after, patterns)); 126 } 127 128 TEST_F(TestOptLib, test_inline_deep_closure) { 129 FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_deep_closure", "before"); 130 FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_deep_closure", "after"); 131 auto patterns = std::vector<SubstitutionPtr>({irpass.inline_}); 132 ASSERT_TRUE(CheckOpt(before, after, patterns)); 133 } 134 135 TEST_F(TestOptLib, test_inline_new_closure) { 136 FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_new_closure", "before"); 137 FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_new_closure", "after"); 138 auto patterns = std::vector<SubstitutionPtr>({irpass.inline_}); 139 ASSERT_TRUE(CheckOpt(before, after, patterns)); 140 } 141 142 TEST_F(TestOptLib, test_inline_while) { 143 FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_while", "before"); 144 auto patterns = std::vector<SubstitutionPtr>({irpass.inline_}); 145 FuncGraphPtr after = RunSubs(before, patterns); 146 ASSERT_TRUE(CheckOpt(before, after, patterns, true)); 147 } 148 149 TEST_F(TestOptLib, test_arithmetic) { 150 FuncGraphPtr b1_0 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_zero_l"); 151 FuncGraphPtr b2_0 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_zero_r"); 152 FuncGraphPtr b1 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_one_l"); 153 FuncGraphPtr b2 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_one_r"); 154 FuncGraphPtr b3 = getPyFun.CallAndParseRet("test_arithmetic", "add_zero_l"); 155 FuncGraphPtr b4 = getPyFun.CallAndParseRet("test_arithmetic", "add_zero_r"); 156 FuncGraphPtr b5 = getPyFun.CallAndParseRet("test_arithmetic", "elim_identity"); 157 FuncGraphPtr after = getPyFun.CallAndParseRet("test_arithmetic", "after"); 158 FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_arithmetic", "after_0"); 159 160 auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_}); 161 162 ASSERT_TRUE(CheckOpt(b1_0, after_0, patterns)); 163 ASSERT_TRUE(CheckOpt(b2_0, after_0, patterns)); 164 ASSERT_TRUE(CheckOpt(b1, after, patterns)); 165 ASSERT_TRUE(CheckOpt(b2, after, patterns)); 166 ASSERT_TRUE(CheckOpt(b3, after, patterns)); 167 ASSERT_TRUE(CheckOpt(b4, after, patterns)); 168 ASSERT_TRUE(CheckOpt(b5, after, patterns)); 169 } 170 171 TEST_F(TestOptLib, test_elim_cast_same_dtype) { 172 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_cast_same_dtype", "fp32_cast_fp32"); 173 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_cast_same_dtype", "after"); 174 // construct such case that cast srcT equal dstT 175 auto &inputs = before->output()->cast<CNodePtr>()->inputs(); 176 if (inputs.size() > 2) { 177 auto cast_node = inputs[0]; 178 auto cast_py = cast_node->cast<ValueNodePtr>()->value()->cast<PrimitivePyPtr>(); 179 cast_py->set_attr("SrcT", TypeIdToType(kNumberTypeFloat32)); 180 cast_py->set_attr("DstT", TypeIdToType(kNumberTypeFloat32)); 181 182 auto x_node = inputs[1]; 183 std::vector<int64_t> shp = {2, 3}; 184 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 185 auto x_abstract = x_tensor->ToAbstract(); 186 x_node->set_abstract(x_abstract); 187 188 TypePtr t = std::make_shared<TensorType>(std::make_shared<Float>(32)); 189 ValueNodePtr val = std::make_shared<ValueNode>(t); 190 auto t_abstract = t->ToAbstract(); 191 val->set_abstract(t_abstract); 192 before->output()->cast<CNodePtr>()->set_input(2, val); 193 } 194 FuncGraphPtr gbefore_clone = BasicClone(before); 195 auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_}); 196 ASSERT_TRUE(CheckOpt(before, after, patterns)); 197 198 TypePtr t = std::make_shared<Float>(32); 199 ValueNodePtr val = std::make_shared<ValueNode>(t); 200 auto t_abstract = t->ToAbstract(); 201 val->set_abstract(t_abstract); 202 gbefore_clone->output()->cast<CNodePtr>()->set_input(2, val); 203 ASSERT_TRUE(CheckOpt(gbefore_clone, after, patterns)); 204 } 205 206 TEST_F(TestOptLib, test_elim_reshape_same_shape) { 207 FuncGraphPtr before = getPyFun.CallAndParseRet("elim_reshape_same_shape", "reshape_to_2_3"); 208 FuncGraphPtr after = getPyFun.CallAndParseRet("elim_reshape_same_shape", "after"); 209 // construct such case that shape is equal to reshape target 210 auto &inputs = before->output()->cast<CNodePtr>()->inputs(); 211 if (inputs.size() > 1) { 212 auto x_node = inputs[1]; 213 std::vector<int64_t> shp = {2, 3}; 214 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 215 auto x_abstract = x_tensor->ToAbstract(); 216 x_node->set_abstract(x_abstract); 217 before->output()->set_abstract(x_abstract); 218 } 219 auto patterns = std::vector<SubstitutionPtr>({irpass.reshape_eliminate_}); 220 ASSERT_TRUE(CheckOpt(before, after, patterns)); 221 if (inputs.size() > 1) { 222 auto x_node = inputs[1]; 223 std::vector<int64_t> shp = {3, 2}; 224 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 225 auto x_abstract = x_tensor->ToAbstract(); 226 x_node->set_abstract(x_abstract); 227 } 228 ASSERT_FALSE(CheckOpt(before, after, patterns)); 229 } 230 231 TEST_F(TestOptLib, elim_two_reshape) { 232 FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_reshape", "before"); 233 FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_reshape", "after"); 234 235 auto patterns = std::vector<SubstitutionPtr>({irpass.reshape_eliminate_}); 236 ASSERT_TRUE(CheckOpt(before, after, patterns)); 237 } 238 239 TEST_F(TestOptLib, elim_two_cast) { 240 FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_cast", "before"); 241 FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_cast", "after"); 242 243 auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_}); 244 ASSERT_TRUE(CheckOpt(before, after, patterns)); 245 } 246 247 TEST_F(TestOptLib, test_elim_transpose) { 248 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before"); 249 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after"); 250 251 auto patterns = std::vector<SubstitutionPtr>({irpass.transpose_eliminate_}); 252 ASSERT_TRUE(CheckOpt(before, after, patterns)); 253 } 254 255 TEST_F(TestOptLib, test_elim_depend_value) { 256 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before"); 257 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after"); 258 259 auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_}); 260 ASSERT_TRUE(CheckOpt(before, after, patterns)); 261 } 262 263 TEST_F(TestOptLib, test_elim_tile_multiply_one) { 264 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before"); 265 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after"); 266 267 auto patterns = std::vector<SubstitutionPtr>({irpass.tile_eliminate_}); 268 ASSERT_TRUE(CheckOpt(before, after, patterns, true)); 269 } 270 271 TEST_F(TestOptLib, test_elim_reduce_mean_shape_one) { 272 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_reduce_mean_shape_one", "before"); 273 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_reduce_mean_shape_one", "after"); 274 275 // construct such case that input x shape is (1), keepdims is true 276 auto inputs = before->output()->cast<CNodePtr>()->inputs(); 277 if (inputs.size() > 2) { 278 auto x_node = inputs[1]; 279 std::vector<int64_t> shp = {1}; 280 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 281 auto x_abstract = x_tensor->ToAbstract(); 282 x_node->set_abstract(x_abstract); 283 284 auto reduce_node = inputs[0]; 285 auto reduce = reduce_node->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 286 reduce->set_attr("keep_dims", std::make_shared<BoolImm>(true)); 287 } 288 289 auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_}); 290 ASSERT_TRUE(CheckOpt(before, after, patterns)); 291 } 292 293 TEST_F(TestOptLib, test_elim_all_shape_one) { 294 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_all_shape_one", "before"); 295 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_all_shape_one", "after"); 296 297 // construct such case that input x shape is (1) keep_dims is true 298 auto inputs = before->output()->cast<CNodePtr>()->inputs(); 299 if (inputs.size() > 2) { 300 auto x_node = inputs[1]; 301 std::vector<int64_t> shp = {1}; 302 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 303 auto x_abstract = x_tensor->ToAbstract(); 304 x_node->set_abstract(x_abstract); 305 306 auto reduce_node = inputs[0]; 307 auto reduce = reduce_node->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 308 reduce->set_attr("keep_dims", std::make_shared<BoolImm>(true)); 309 } 310 auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_}); 311 ASSERT_TRUE(CheckOpt(before, after, patterns)); 312 } 313 314 TEST_F(TestOptLib, test_elim_sum_shape_one) { 315 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_sum_shape_one", "before"); 316 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_sum_shape_one", "after"); 317 318 // construct such case that input x shape is (1) keepdims is true 319 auto inputs = before->output()->cast<CNodePtr>()->inputs(); 320 if (inputs.size() > 2) { 321 auto x_node = inputs[1]; 322 std::vector<int64_t> shp = {1}; 323 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 324 auto x_abstract = x_tensor->ToAbstract(); 325 x_node->set_abstract(x_abstract); 326 327 auto reduce_node = inputs[0]; 328 auto reduce = reduce_node->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>(); 329 reduce->set_attr("keep_dims", std::make_shared<BoolImm>(true)); 330 } 331 auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_}); 332 ASSERT_TRUE(CheckOpt(before, after, patterns)); 333 } 334 335 TEST_F(TestOptLib, test_tuple_getitem) { 336 FuncGraphPtr make_get_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "make_get_0"); 337 FuncGraphPtr make_get_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "make_get_1"); 338 FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0"); 339 FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1"); 340 341 FuncGraphPtr make_get_const = std::make_shared<FuncGraph>(); 342 auto value_node_1 = NewValueNode(static_cast<int64_t>(1)); 343 auto value_node_2 = NewValueNode(static_cast<int64_t>(2)); 344 std::vector<int64_t> vec{1, 2}; 345 auto value_node_tuple = NewValueNode(MakeValue(vec)); 346 std::vector<AnfNodePtr> node_list{NewValueNode(prim::kPrimTupleGetItem), value_node_tuple, value_node_1}; 347 auto get_item = make_get_const->NewCNode(node_list); 348 make_get_const->set_output(get_item); 349 350 FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); 351 after_2->set_output(value_node_2); 352 353 auto patterns = std::vector<SubstitutionPtr>( 354 {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_, 355 irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_, 356 irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_}); 357 ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); 358 ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); 359 ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); 360 } 361 362 TEST_F(TestOptLib, test_tuple_setitem) { 363 FuncGraphPtr before_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "before_0"); 364 FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "before_1"); 365 FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); 366 FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); 367 368 auto patterns = std::vector<SubstitutionPtr>( 369 {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_, 370 irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_, 371 irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_}); 372 373 ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); 374 ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); 375 } 376 377 TEST_F(TestOptLib, test_tuple_get_set_item) { 378 FuncGraphPtr before_0 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); 379 FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); 380 FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); 381 FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); 382 383 auto patterns = std::vector<SubstitutionPtr>( 384 {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_, 385 irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_, 386 irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_}); 387 388 ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); 389 ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); 390 } 391 392 TEST_F(TestOptLib, test_partial) { 393 FuncGraphPtr before = getPyFun.CallAndParseRet("test_partial", "before"); 394 FuncGraphPtr after = getPyFun.CallAndParseRet("test_partial", "after"); 395 396 auto patterns = std::vector<SubstitutionPtr>({irpass.partial_eliminate_}); 397 398 ASSERT_TRUE(CheckOpt(before, after, patterns)); 399 } 400 401 TEST_F(TestOptLib, test_replace_applicator) { 402 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_replace_applicator", "before1"); 403 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_replace_applicator", "before2"); 404 FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_replace_applicator", "before3"); 405 FuncGraphPtr after = getPyFun.CallAndParseRet("test_replace_applicator", "after"); 406 407 auto patterns = std::vector<SubstitutionPtr>({irpass.replace_applicator_}); 408 409 ASSERT_TRUE(CheckOpt(before1, after, patterns)); 410 ASSERT_TRUE(CheckOpt(before2, after, patterns)); 411 ASSERT_TRUE(CheckOpt(before3, before3, patterns)); 412 } 413 414 TEST_F(TestOptLib, test_specialize_on_graph_arguments) { 415 FuncGraphPtr before = getPyFun.CallAndParseRet("test_specialize_on_graph_arguments", "before"); 416 FuncGraphPtr after = getPyFun.CallAndParseRet("test_specialize_on_graph_arguments", "after"); 417 418 auto patterns = std::vector<SubstitutionPtr>({irpass.specialize_transform_}); 419 420 ASSERT_TRUE(CheckOpt(before, after, patterns)); 421 } 422 423 TEST_F(TestOptLib, test_incorporate_getitem) { 424 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "before1"); 425 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "before2"); 426 FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after1"); 427 FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after2"); 428 429 auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_}); 430 431 ASSERT_TRUE(CheckOpt(before1, after1, patterns)); 432 ASSERT_TRUE(CheckOpt(before2, after2, patterns)); 433 } 434 435 TEST_F(TestOptLib, test_incorporate_getitem_through_switch) { 436 FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "before"); 437 FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "after"); 438 439 auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_}); 440 ASSERT_TRUE(CheckOpt(before, after, patterns)); 441 } 442 443 TEST_F(TestOptLib, test_incorporate_call) { 444 FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_call", "before"); 445 FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_call", "after"); 446 447 auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_call_}); 448 ASSERT_TRUE(CheckOpt(before, after, patterns)); 449 } 450 451 TEST_F(TestOptLib, test_incorporate_call_through_switch) { 452 FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_call_through_switch", "before"); 453 FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_call_through_switch", "after"); 454 auto patterns = std::vector<SubstitutionPtr>({ 455 irpass.incorporate_call_switch_, 456 irpass.incorporate_call_, 457 irpass.arithmetic_simplify_, 458 }); 459 ASSERT_TRUE(CheckOpt(before, after, patterns)); 460 } 461 462 TEST_F(TestOptLib, test_float_tuple_getitem_through_switch) { 463 FuncGraphPtr before = getPyFun.CallAndParseRet("test_float_tuple_getitem_through_switch", "before"); 464 FuncGraphPtr after = getPyFun.CallAndParseRet("test_float_tuple_getitem_through_switch", "after"); 465 466 auto patterns = std::vector<SubstitutionPtr>({irpass.float_tuple_getitem_switch_}); 467 ASSERT_TRUE(CheckOpt(before, after, patterns)); 468 } 469 470 TEST_F(TestOptLib, test_merge_addn) { 471 FuncGraphPtr before = getPyFun.CallAndParseRet("test_merge_addn", "before"); 472 FuncGraphPtr after = getPyFun.CallAndParseRet("test_merge_addn", "after"); 473 474 auto patterns = std::vector<SubstitutionPtr>({irpass.merge_addn_}); 475 ASSERT_TRUE(CheckOpt(before, after, patterns)); 476 } 477 478 TEST_F(TestOptLib, test_filter_addn_zero) { 479 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_addn_zero", "before_1"); 480 FuncGraphPtr after = getPyFun.CallAndParseRet("test_addn_zero", "after"); 481 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_addn_zero", "before_2"); 482 FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_addn_zero", "before_3"); 483 FuncGraphPtr before4 = getPyFun.CallAndParseRet("test_addn_zero", "before_4"); 484 auto patterns = std::vector<SubstitutionPtr>({irpass.addn_zero_filter_}); 485 ASSERT_TRUE(CheckOpt(before1, after, patterns)); 486 ASSERT_TRUE(CheckOpt(before2, after, patterns)); 487 ASSERT_TRUE(CheckOpt(before3, after, patterns)); 488 ASSERT_TRUE(CheckOpt(before4, before4, patterns)); 489 } 490 491 TEST_F(TestOptLib, test_minmax_grad) { 492 FuncGraphPtr before11 = getPyFun.CallAndParseRet("test_minmax_grad", "before_11"); 493 FuncGraphPtr before12 = getPyFun.CallAndParseRet("test_minmax_grad", "before_12"); 494 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_minmax_grad", "before_2"); 495 FuncGraphPtr before31 = getPyFun.CallAndParseRet("test_minmax_grad", "before_31"); 496 FuncGraphPtr before32 = getPyFun.CallAndParseRet("test_minmax_grad", "before_32"); 497 FuncGraphPtr before4 = getPyFun.CallAndParseRet("test_minmax_grad", "before_4"); 498 auto patterns = std::vector<SubstitutionPtr>({irpass.minmaximum_grad_}); 499 ASSERT_TRUE(CheckOpt(before11, before11, patterns)); 500 ASSERT_TRUE(CheckOpt(before12, before12, patterns)); 501 ASSERT_TRUE(CheckOpt(before2, before2, patterns)); 502 ASSERT_TRUE(CheckOpt(before31, before31, patterns)); 503 ASSERT_TRUE(CheckOpt(before32, before32, patterns)); 504 ASSERT_TRUE(CheckOpt(before4, before4, patterns)); 505 } 506 507 TEST_F(TestOptLib, test_reducesum_one) { 508 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_reducesum_one", "before_1"); 509 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_reducesum_one", "before_2"); 510 FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_reducesum_one", "before_3"); 511 FuncGraphPtr before4 = getPyFun.CallAndParseRet("test_reducesum_one", "before_4"); 512 FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_reducesum_one", "after_1"); 513 FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_reducesum_one", "after_2"); 514 FuncGraphPtr after3 = getPyFun.CallAndParseRet("test_reducesum_one", "after_3"); 515 auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_}); 516 517 std::vector<int64_t> shp = {3, 2, 2, 1}; 518 tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); 519 auto x_abstract = x_tensor->ToAbstract(); 520 521 std::vector<int64_t> shp2 = {3, 2, 1, 1}; 522 tensor::TensorPtr x_tensor2 = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp2); 523 auto x_abstract2 = x_tensor2->ToAbstract(); 524 525 auto inputs = before1->output()->cast<CNodePtr>()->inputs(); 526 if (inputs.size() > 1) { 527 auto x_node = inputs[1]; 528 x_node->set_abstract(x_abstract); 529 } 530 ASSERT_TRUE(CheckOpt(before1, after1, patterns)); 531 532 auto inputs2 = before2->output()->cast<CNodePtr>()->inputs(); 533 if (inputs2.size() > 1) { 534 auto x_node2 = inputs2[1]; 535 x_node2->set_abstract(x_abstract2); 536 } 537 ASSERT_TRUE(CheckOpt(before2, after1, patterns)); 538 539 auto inputs3 = before2->output()->cast<CNodePtr>()->inputs(); 540 if (inputs3.size() > 1) { 541 auto x_node3 = inputs3[1]; 542 x_node3->set_abstract(x_abstract); 543 } 544 ASSERT_TRUE(CheckOpt(before2, before2, patterns)); 545 546 auto inputs4 = before3->output()->cast<CNodePtr>()->inputs(); 547 if (inputs4.size() > 1) { 548 auto x_node4 = inputs4[1]; 549 x_node4->set_abstract(x_abstract); 550 } 551 ASSERT_TRUE(CheckOpt(before3, after2, patterns)); 552 553 auto inputs5 = before4->output()->cast<CNodePtr>()->inputs(); 554 if (inputs5.size() > 1) { 555 auto x_node5 = inputs5[1]; 556 x_node5->set_abstract(x_abstract2); 557 } 558 ASSERT_TRUE(CheckOpt(before4, after3, patterns)); 559 } 560 561 #ifndef ENABLE_SECURITY 562 TEST_F(TestOptLib, test_print_tuple_wrapper) { 563 FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "before1"); 564 FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "before2"); 565 FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "before3"); 566 FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "after1"); 567 FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "after2"); 568 auto patterns = std::vector<SubstitutionPtr>({irpass.print_tuple_wrapper_}); 569 ASSERT_TRUE(CheckOpt(before1, after1, patterns)); 570 ASSERT_TRUE(CheckOpt(before2, after2, patterns)); 571 ASSERT_TRUE(CheckOpt(before3, before3, patterns)); 572 } 573 #endif 574 575 TEST_F(TestOptLib, test_constant_duplicate_mul) { 576 FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell"); 577 FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr"); 578 FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl"); 579 FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr"); 580 FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after"); 581 auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_}); 582 ASSERT_TRUE(CheckOpt(beforell, after, patterns)); 583 ASSERT_TRUE(CheckOpt(beforelr, after, patterns)); 584 ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); 585 ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); 586 } 587 588 TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { 589 FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); 590 FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); 591 FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); 592 FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); 593 FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); 594 FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); 595 FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); 596 FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); 597 auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_}); 598 ASSERT_TRUE(CheckOpt(beforell, after1, patterns, true)); 599 ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); 600 ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); 601 ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); 602 } 603 604 TEST_F(TestOptLib, test_row_tensor) { 605 FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "before_get_indices"); 606 FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "after_get_indices"); 607 FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_row_tensor", "before_get_values"); 608 FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_row_tensor", "after_get_values"); 609 FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "before_get_dense_shape"); 610 FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "after_get_dense_shape"); 611 auto patterns = std::vector<SubstitutionPtr>({irpass.row_tensor_eliminate_}); 612 ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); 613 ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); 614 ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); 615 } 616 617 TEST_F(TestOptLib, test_sparse_tensor) { 618 FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices"); 619 FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices"); 620 FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values"); 621 FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values"); 622 FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape"); 623 FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape"); 624 auto patterns = std::vector<SubstitutionPtr>({irpass.sparse_tensor_eliminate_}); 625 ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); 626 ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); 627 ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); 628 } 629 } // namespace opt 630 } // namespace mindspore 631