• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 
19 #include "common/common_test.h"
20 
21 #include "common/py_func_graph_fetcher.h"
22 #include "ir/anf.h"
23 #include "ir/func_graph.h"
24 #include "ir/func_graph_cloner.h"
25 #include "ir/manager.h"
26 #include "ir/value.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/optimizer/irpass.h"
29 #include "pipeline/jit/resource.h"
30 #include "debug/draw.h"
31 #include "pipeline/jit/parse/data_converter.h"
32 
33 namespace mindspore {
34 namespace opt {
35 using abstract::AnalysisResult;
36 
37 class TestOptLib : public UT::Common {
38  public:
TestOptLib()39   TestOptLib() : getPyFun("gtest_input.optimizer.opt_test", true), irpass() {}
SetUp()40   void SetUp() {
41     UT::InitPythonPath();
42     parse::data_converter::ClearObjectCache();
43     auto ms_context = MsContext::GetInstance();
44     MS_EXCEPTION_IF_NULL(ms_context);
45     ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
46   }
RunTransform(FuncGraphPtr gbefore,const SubstitutionList & transform)47   FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) {
48     equiv_node.clear();
49     equiv_graph.clear();
50 
51     FuncGraphPtr gbefore_clone = BasicClone(gbefore);
52     OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
53     transform(gbefore_clone, optimizer);
54     return gbefore_clone;
55   }
RunSubs(FuncGraphPtr before,std::vector<SubstitutionPtr> opts={})56   FuncGraphPtr RunSubs(FuncGraphPtr before, std::vector<SubstitutionPtr> opts = {}) {
57     SubstitutionList eq(opts);
58     return RunTransform(before, eq);
59   }
CheckTransform(FuncGraphPtr gbefore,FuncGraphPtr gafter,const SubstitutionList & transform,bool save_graphs=false)60   bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform,
61                       bool save_graphs = false) {
62     equiv_node.clear();
63     equiv_graph.clear();
64 
65     FuncGraphPtr gbefore_clone = BasicClone(gbefore);
66     OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
67     transform(gbefore_clone, optimizer);
68     return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
69   }
CheckOpt(FuncGraphPtr before,FuncGraphPtr after,std::vector<SubstitutionPtr> opts={},bool save_graphs=false)70   bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {},
71                 bool save_graphs = false) {
72     if (nullptr == before || nullptr == after) {
73       return false;
74     }
75     SubstitutionList eq(opts);
76     return CheckTransform(before, after, eq, save_graphs);
77   }
78 
79  public:
80   UT::PyFuncGraphFetcher getPyFun;
81   FuncGraphPairMapEquiv equiv_graph;
82   NodeMapEquiv equiv_node;
83   irpass::OptimizeIRPassLib irpass;
84 };
85 
TEST_F(TestOptLib,test_simplify_always_true_false)86 TEST_F(TestOptLib, test_simplify_always_true_false) {
87   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1");
88   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2");
89   FuncGraphPtr after = getPyFun.CallAndParseRet("test_simplify_always_true_false", "after");
90   auto patterns = std::vector<SubstitutionPtr>({irpass.switch_simplify_});
91   ASSERT_TRUE(CheckOpt(before1, after, patterns));
92   ASSERT_TRUE(CheckOpt(before2, after, patterns));
93 }
94 
TEST_F(TestOptLib,test_inline)95 TEST_F(TestOptLib, test_inline) {
96   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_inline", "before");
97   FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline", "after");
98   // add infer and renormalize
99   std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
100   AbstractBasePtrList args_spec_list;
101   tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
102   tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
103 
104   AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
105   AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
106   args_spec_list.push_back(abstract_v1);
107   args_spec_list.push_back(abstract_v2);
108   AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);
109   FuncGraphPtr new_graph = pipeline::ProgramSpecialize(res, before1, result.context);
110   auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_, irpass.switch_simplify_, irpass.inline_});
111   ASSERT_TRUE(CheckOpt(new_graph, after, patterns));
112 }
113 
TEST_F(TestOptLib,test_inline_successively)114 TEST_F(TestOptLib, test_inline_successively) {
115   FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_successively", "before");
116   FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_successively", "after");
117   auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
118   ASSERT_TRUE(CheckOpt(before, after, patterns));
119 }
120 
TEST_F(TestOptLib,test_inline_closure)121 TEST_F(TestOptLib, test_inline_closure) {
122   FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_closure", "before");
123   FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_closure", "after");
124   auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
125   ASSERT_TRUE(CheckOpt(before, after, patterns));
126 }
127 
TEST_F(TestOptLib,test_inline_deep_closure)128 TEST_F(TestOptLib, test_inline_deep_closure) {
129   FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_deep_closure", "before");
130   FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_deep_closure", "after");
131   auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
132   ASSERT_TRUE(CheckOpt(before, after, patterns));
133 }
134 
TEST_F(TestOptLib,test_inline_new_closure)135 TEST_F(TestOptLib, test_inline_new_closure) {
136   FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_new_closure", "before");
137   FuncGraphPtr after = getPyFun.CallAndParseRet("test_inline_new_closure", "after");
138   auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
139   ASSERT_TRUE(CheckOpt(before, after, patterns));
140 }
141 
TEST_F(TestOptLib,test_inline_while)142 TEST_F(TestOptLib, test_inline_while) {
143   FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_while", "before");
144   auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
145   FuncGraphPtr after = RunSubs(before, patterns);
146   ASSERT_TRUE(CheckOpt(before, after, patterns, true));
147 }
148 
TEST_F(TestOptLib,test_arithmetic)149 TEST_F(TestOptLib, test_arithmetic) {
150   FuncGraphPtr b1_0 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_zero_l");
151   FuncGraphPtr b2_0 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_zero_r");
152   FuncGraphPtr b1 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_one_l");
153   FuncGraphPtr b2 = getPyFun.CallAndParseRet("test_arithmetic", "multiply_by_one_r");
154   FuncGraphPtr b3 = getPyFun.CallAndParseRet("test_arithmetic", "add_zero_l");
155   FuncGraphPtr b4 = getPyFun.CallAndParseRet("test_arithmetic", "add_zero_r");
156   FuncGraphPtr b5 = getPyFun.CallAndParseRet("test_arithmetic", "elim_identity");
157   FuncGraphPtr after = getPyFun.CallAndParseRet("test_arithmetic", "after");
158   FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_arithmetic", "after_0");
159 
160   auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
161 
162   ASSERT_TRUE(CheckOpt(b1_0, after_0, patterns));
163   ASSERT_TRUE(CheckOpt(b2_0, after_0, patterns));
164   ASSERT_TRUE(CheckOpt(b1, after, patterns));
165   ASSERT_TRUE(CheckOpt(b2, after, patterns));
166   ASSERT_TRUE(CheckOpt(b3, after, patterns));
167   ASSERT_TRUE(CheckOpt(b4, after, patterns));
168   ASSERT_TRUE(CheckOpt(b5, after, patterns));
169 }
170 
TEST_F(TestOptLib,test_elim_cast_same_dtype)171 TEST_F(TestOptLib, test_elim_cast_same_dtype) {
172   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_cast_same_dtype", "fp32_cast_fp32");
173   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_cast_same_dtype", "after");
174   // construct such case that cast srcT equal dstT
175   auto &inputs = before->output()->cast<CNodePtr>()->inputs();
176   if (inputs.size() > 2) {
177     auto cast_node = inputs[0];
178     auto cast_py = cast_node->cast<ValueNodePtr>()->value()->cast<PrimitivePyPtr>();
179     cast_py->set_attr("SrcT", TypeIdToType(kNumberTypeFloat32));
180     cast_py->set_attr("DstT", TypeIdToType(kNumberTypeFloat32));
181 
182     auto x_node = inputs[1];
183     std::vector<int64_t> shp = {2, 3};
184     tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
185     auto x_abstract = x_tensor->ToAbstract();
186     x_node->set_abstract(x_abstract);
187 
188     TypePtr t = std::make_shared<TensorType>(std::make_shared<Float>(32));
189     ValueNodePtr val = std::make_shared<ValueNode>(t);
190     auto t_abstract = t->ToAbstract();
191     val->set_abstract(t_abstract);
192     before->output()->cast<CNodePtr>()->set_input(2, val);
193   }
194   FuncGraphPtr gbefore_clone = BasicClone(before);
195   auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_});
196   ASSERT_TRUE(CheckOpt(before, after, patterns));
197 
198   TypePtr t = std::make_shared<Float>(32);
199   ValueNodePtr val = std::make_shared<ValueNode>(t);
200   auto t_abstract = t->ToAbstract();
201   val->set_abstract(t_abstract);
202   gbefore_clone->output()->cast<CNodePtr>()->set_input(2, val);
203   ASSERT_TRUE(CheckOpt(gbefore_clone, after, patterns));
204 }
205 
TEST_F(TestOptLib,test_elim_reshape_same_shape)206 TEST_F(TestOptLib, test_elim_reshape_same_shape) {
207   FuncGraphPtr before = getPyFun.CallAndParseRet("elim_reshape_same_shape", "reshape_to_2_3");
208   FuncGraphPtr after = getPyFun.CallAndParseRet("elim_reshape_same_shape", "after");
209   // construct such case that shape is equal to reshape target
210   auto &inputs = before->output()->cast<CNodePtr>()->inputs();
211   if (inputs.size() > 1) {
212     auto x_node = inputs[1];
213     std::vector<int64_t> shp = {2, 3};
214     tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
215     auto x_abstract = x_tensor->ToAbstract();
216     x_node->set_abstract(x_abstract);
217     before->output()->set_abstract(x_abstract);
218   }
219   auto patterns = std::vector<SubstitutionPtr>({irpass.reshape_eliminate_});
220   ASSERT_TRUE(CheckOpt(before, after, patterns));
221   if (inputs.size() > 1) {
222     auto x_node = inputs[1];
223     std::vector<int64_t> shp = {3, 2};
224     tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
225     auto x_abstract = x_tensor->ToAbstract();
226     x_node->set_abstract(x_abstract);
227   }
228   ASSERT_FALSE(CheckOpt(before, after, patterns));
229 }
230 
TEST_F(TestOptLib,elim_two_reshape)231 TEST_F(TestOptLib, elim_two_reshape) {
232   FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_reshape", "before");
233   FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_reshape", "after");
234 
235   auto patterns = std::vector<SubstitutionPtr>({irpass.reshape_eliminate_});
236   ASSERT_TRUE(CheckOpt(before, after, patterns));
237 }
238 
TEST_F(TestOptLib,elim_two_cast)239 TEST_F(TestOptLib, elim_two_cast) {
240   FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_cast", "before");
241   FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_cast", "after");
242 
243   auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_});
244   ASSERT_TRUE(CheckOpt(before, after, patterns));
245 }
246 
TEST_F(TestOptLib,test_elim_transpose)247 TEST_F(TestOptLib, test_elim_transpose) {
248   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before");
249   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after");
250 
251   auto patterns = std::vector<SubstitutionPtr>({irpass.transpose_eliminate_});
252   ASSERT_TRUE(CheckOpt(before, after, patterns));
253 }
254 
TEST_F(TestOptLib,test_elim_depend_value)255 TEST_F(TestOptLib, test_elim_depend_value) {
256   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before");
257   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after");
258 
259   auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_});
260   ASSERT_TRUE(CheckOpt(before, after, patterns));
261 }
262 
TEST_F(TestOptLib,test_elim_tile_multiply_one)263 TEST_F(TestOptLib, test_elim_tile_multiply_one) {
264   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before");
265   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after");
266 
267   auto patterns = std::vector<SubstitutionPtr>({irpass.tile_eliminate_});
268   ASSERT_TRUE(CheckOpt(before, after, patterns, true));
269 }
270 
TEST_F(TestOptLib,test_elim_reduce_mean_shape_one)271 TEST_F(TestOptLib, test_elim_reduce_mean_shape_one) {
272   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_reduce_mean_shape_one", "before");
273   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_reduce_mean_shape_one", "after");
274 
275   // construct such case that input x shape is (1), keepdims is true
276   auto inputs = before->output()->cast<CNodePtr>()->inputs();
277   if (inputs.size() > 2) {
278     auto x_node = inputs[1];
279     std::vector<int64_t> shp = {1};
280     tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
281     auto x_abstract = x_tensor->ToAbstract();
282     x_node->set_abstract(x_abstract);
283 
284     auto reduce_node = inputs[0];
285     auto reduce = reduce_node->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
286     reduce->set_attr("keep_dims", std::make_shared<BoolImm>(true));
287   }
288 
289   auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_});
290   ASSERT_TRUE(CheckOpt(before, after, patterns));
291 }
292 
TEST_F(TestOptLib,test_elim_all_shape_one)293 TEST_F(TestOptLib, test_elim_all_shape_one) {
294   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_all_shape_one", "before");
295   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_all_shape_one", "after");
296 
297   // construct such case that input x shape is (1) keep_dims is true
298   auto inputs = before->output()->cast<CNodePtr>()->inputs();
299   if (inputs.size() > 2) {
300     auto x_node = inputs[1];
301     std::vector<int64_t> shp = {1};
302     tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
303     auto x_abstract = x_tensor->ToAbstract();
304     x_node->set_abstract(x_abstract);
305 
306     auto reduce_node = inputs[0];
307     auto reduce = reduce_node->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
308     reduce->set_attr("keep_dims", std::make_shared<BoolImm>(true));
309   }
310   auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_});
311   ASSERT_TRUE(CheckOpt(before, after, patterns));
312 }
313 
TEST_F(TestOptLib,test_elim_sum_shape_one)314 TEST_F(TestOptLib, test_elim_sum_shape_one) {
315   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_sum_shape_one", "before");
316   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_sum_shape_one", "after");
317 
318   // construct such case that input x shape is (1) keepdims is true
319   auto inputs = before->output()->cast<CNodePtr>()->inputs();
320   if (inputs.size() > 2) {
321     auto x_node = inputs[1];
322     std::vector<int64_t> shp = {1};
323     tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
324     auto x_abstract = x_tensor->ToAbstract();
325     x_node->set_abstract(x_abstract);
326 
327     auto reduce_node = inputs[0];
328     auto reduce = reduce_node->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
329     reduce->set_attr("keep_dims", std::make_shared<BoolImm>(true));
330   }
331   auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_});
332   ASSERT_TRUE(CheckOpt(before, after, patterns));
333 }
334 
TEST_F(TestOptLib,test_tuple_getitem)335 TEST_F(TestOptLib, test_tuple_getitem) {
336   FuncGraphPtr make_get_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "make_get_0");
337   FuncGraphPtr make_get_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "make_get_1");
338   FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0");
339   FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1");
340 
341   FuncGraphPtr make_get_const = std::make_shared<FuncGraph>();
342   auto value_node_1 = NewValueNode(static_cast<int64_t>(1));
343   auto value_node_2 = NewValueNode(static_cast<int64_t>(2));
344   std::vector<int64_t> vec{1, 2};
345   auto value_node_tuple = NewValueNode(MakeValue(vec));
346   std::vector<AnfNodePtr> node_list{NewValueNode(prim::kPrimTupleGetItem), value_node_tuple, value_node_1};
347   auto get_item = make_get_const->NewCNode(node_list);
348   make_get_const->set_output(get_item);
349 
350   FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
351   after_2->set_output(value_node_2);
352 
353   auto patterns = std::vector<SubstitutionPtr>(
354     {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
355      irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
356      irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
357   ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
358   ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
359   ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
360 }
361 
TEST_F(TestOptLib,test_tuple_setitem)362 TEST_F(TestOptLib, test_tuple_setitem) {
363   FuncGraphPtr before_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "before_0");
364   FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "before_1");
365   FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
366   FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");
367 
368   auto patterns = std::vector<SubstitutionPtr>(
369     {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
370      irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
371      irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
372 
373   ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
374   ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
375 }
376 
TEST_F(TestOptLib,test_tuple_get_set_item)377 TEST_F(TestOptLib, test_tuple_get_set_item) {
378   FuncGraphPtr before_0 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
379   FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
380   FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
381   FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
382 
383   auto patterns = std::vector<SubstitutionPtr>(
384     {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
385      irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
386      irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
387 
388   ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
389   ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
390 }
391 
TEST_F(TestOptLib,test_partial)392 TEST_F(TestOptLib, test_partial) {
393   FuncGraphPtr before = getPyFun.CallAndParseRet("test_partial", "before");
394   FuncGraphPtr after = getPyFun.CallAndParseRet("test_partial", "after");
395 
396   auto patterns = std::vector<SubstitutionPtr>({irpass.partial_eliminate_});
397 
398   ASSERT_TRUE(CheckOpt(before, after, patterns));
399 }
400 
TEST_F(TestOptLib,test_replace_applicator)401 TEST_F(TestOptLib, test_replace_applicator) {
402   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_replace_applicator", "before1");
403   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_replace_applicator", "before2");
404   FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_replace_applicator", "before3");
405   FuncGraphPtr after = getPyFun.CallAndParseRet("test_replace_applicator", "after");
406 
407   auto patterns = std::vector<SubstitutionPtr>({irpass.replace_applicator_});
408 
409   ASSERT_TRUE(CheckOpt(before1, after, patterns));
410   ASSERT_TRUE(CheckOpt(before2, after, patterns));
411   ASSERT_TRUE(CheckOpt(before3, before3, patterns));
412 }
413 
TEST_F(TestOptLib,test_specialize_on_graph_arguments)414 TEST_F(TestOptLib, test_specialize_on_graph_arguments) {
415   FuncGraphPtr before = getPyFun.CallAndParseRet("test_specialize_on_graph_arguments", "before");
416   FuncGraphPtr after = getPyFun.CallAndParseRet("test_specialize_on_graph_arguments", "after");
417 
418   auto patterns = std::vector<SubstitutionPtr>({irpass.specialize_transform_});
419 
420   ASSERT_TRUE(CheckOpt(before, after, patterns));
421 }
422 
TEST_F(TestOptLib,test_incorporate_getitem)423 TEST_F(TestOptLib, test_incorporate_getitem) {
424   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "before1");
425   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "before2");
426   FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after1");
427   FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after2");
428 
429   auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_});
430 
431   ASSERT_TRUE(CheckOpt(before1, after1, patterns));
432   ASSERT_TRUE(CheckOpt(before2, after2, patterns));
433 }
434 
TEST_F(TestOptLib,test_incorporate_getitem_through_switch)435 TEST_F(TestOptLib, test_incorporate_getitem_through_switch) {
436   FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "before");
437   FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "after");
438 
439   auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_});
440   ASSERT_TRUE(CheckOpt(before, after, patterns));
441 }
442 
TEST_F(TestOptLib,test_incorporate_call)443 TEST_F(TestOptLib, test_incorporate_call) {
444   FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_call", "before");
445   FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_call", "after");
446 
447   auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_call_});
448   ASSERT_TRUE(CheckOpt(before, after, patterns));
449 }
450 
TEST_F(TestOptLib,test_incorporate_call_through_switch)451 TEST_F(TestOptLib, test_incorporate_call_through_switch) {
452   FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_call_through_switch", "before");
453   FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_call_through_switch", "after");
454   auto patterns = std::vector<SubstitutionPtr>({
455     irpass.incorporate_call_switch_,
456     irpass.incorporate_call_,
457     irpass.arithmetic_simplify_,
458   });
459   ASSERT_TRUE(CheckOpt(before, after, patterns));
460 }
461 
TEST_F(TestOptLib,test_float_tuple_getitem_through_switch)462 TEST_F(TestOptLib, test_float_tuple_getitem_through_switch) {
463   FuncGraphPtr before = getPyFun.CallAndParseRet("test_float_tuple_getitem_through_switch", "before");
464   FuncGraphPtr after = getPyFun.CallAndParseRet("test_float_tuple_getitem_through_switch", "after");
465 
466   auto patterns = std::vector<SubstitutionPtr>({irpass.float_tuple_getitem_switch_});
467   ASSERT_TRUE(CheckOpt(before, after, patterns));
468 }
469 
TEST_F(TestOptLib,test_merge_addn)470 TEST_F(TestOptLib, test_merge_addn) {
471   FuncGraphPtr before = getPyFun.CallAndParseRet("test_merge_addn", "before");
472   FuncGraphPtr after = getPyFun.CallAndParseRet("test_merge_addn", "after");
473 
474   auto patterns = std::vector<SubstitutionPtr>({irpass.merge_addn_});
475   ASSERT_TRUE(CheckOpt(before, after, patterns));
476 }
477 
TEST_F(TestOptLib,test_filter_addn_zero)478 TEST_F(TestOptLib, test_filter_addn_zero) {
479   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_addn_zero", "before_1");
480   FuncGraphPtr after = getPyFun.CallAndParseRet("test_addn_zero", "after");
481   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_addn_zero", "before_2");
482   FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_addn_zero", "before_3");
483   FuncGraphPtr before4 = getPyFun.CallAndParseRet("test_addn_zero", "before_4");
484   auto patterns = std::vector<SubstitutionPtr>({irpass.addn_zero_filter_});
485   ASSERT_TRUE(CheckOpt(before1, after, patterns));
486   ASSERT_TRUE(CheckOpt(before2, after, patterns));
487   ASSERT_TRUE(CheckOpt(before3, after, patterns));
488   ASSERT_TRUE(CheckOpt(before4, before4, patterns));
489 }
490 
TEST_F(TestOptLib,test_minmax_grad)491 TEST_F(TestOptLib, test_minmax_grad) {
492   FuncGraphPtr before11 = getPyFun.CallAndParseRet("test_minmax_grad", "before_11");
493   FuncGraphPtr before12 = getPyFun.CallAndParseRet("test_minmax_grad", "before_12");
494   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_minmax_grad", "before_2");
495   FuncGraphPtr before31 = getPyFun.CallAndParseRet("test_minmax_grad", "before_31");
496   FuncGraphPtr before32 = getPyFun.CallAndParseRet("test_minmax_grad", "before_32");
497   FuncGraphPtr before4 = getPyFun.CallAndParseRet("test_minmax_grad", "before_4");
498   auto patterns = std::vector<SubstitutionPtr>({irpass.minmaximum_grad_});
499   ASSERT_TRUE(CheckOpt(before11, before11, patterns));
500   ASSERT_TRUE(CheckOpt(before12, before12, patterns));
501   ASSERT_TRUE(CheckOpt(before2, before2, patterns));
502   ASSERT_TRUE(CheckOpt(before31, before31, patterns));
503   ASSERT_TRUE(CheckOpt(before32, before32, patterns));
504   ASSERT_TRUE(CheckOpt(before4, before4, patterns));
505 }
506 
TEST_F(TestOptLib,test_reducesum_one)507 TEST_F(TestOptLib, test_reducesum_one) {
508   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_reducesum_one", "before_1");
509   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_reducesum_one", "before_2");
510   FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_reducesum_one", "before_3");
511   FuncGraphPtr before4 = getPyFun.CallAndParseRet("test_reducesum_one", "before_4");
512   FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_reducesum_one", "after_1");
513   FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_reducesum_one", "after_2");
514   FuncGraphPtr after3 = getPyFun.CallAndParseRet("test_reducesum_one", "after_3");
515   auto patterns = std::vector<SubstitutionPtr>({irpass.reduce_eliminate_});
516 
517   std::vector<int64_t> shp = {3, 2, 2, 1};
518   tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
519   auto x_abstract = x_tensor->ToAbstract();
520 
521   std::vector<int64_t> shp2 = {3, 2, 1, 1};
522   tensor::TensorPtr x_tensor2 = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp2);
523   auto x_abstract2 = x_tensor2->ToAbstract();
524 
525   auto inputs = before1->output()->cast<CNodePtr>()->inputs();
526   if (inputs.size() > 1) {
527     auto x_node = inputs[1];
528     x_node->set_abstract(x_abstract);
529   }
530   ASSERT_TRUE(CheckOpt(before1, after1, patterns));
531 
532   auto inputs2 = before2->output()->cast<CNodePtr>()->inputs();
533   if (inputs2.size() > 1) {
534     auto x_node2 = inputs2[1];
535     x_node2->set_abstract(x_abstract2);
536   }
537   ASSERT_TRUE(CheckOpt(before2, after1, patterns));
538 
539   auto inputs3 = before2->output()->cast<CNodePtr>()->inputs();
540   if (inputs3.size() > 1) {
541     auto x_node3 = inputs3[1];
542     x_node3->set_abstract(x_abstract);
543   }
544   ASSERT_TRUE(CheckOpt(before2, before2, patterns));
545 
546   auto inputs4 = before3->output()->cast<CNodePtr>()->inputs();
547   if (inputs4.size() > 1) {
548     auto x_node4 = inputs4[1];
549     x_node4->set_abstract(x_abstract);
550   }
551   ASSERT_TRUE(CheckOpt(before3, after2, patterns));
552 
553   auto inputs5 = before4->output()->cast<CNodePtr>()->inputs();
554   if (inputs5.size() > 1) {
555     auto x_node5 = inputs5[1];
556     x_node5->set_abstract(x_abstract2);
557   }
558   ASSERT_TRUE(CheckOpt(before4, after3, patterns));
559 }
560 
561 #ifndef ENABLE_SECURITY
TEST_F(TestOptLib,test_print_tuple_wrapper)562 TEST_F(TestOptLib, test_print_tuple_wrapper) {
563   FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "before1");
564   FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "before2");
565   FuncGraphPtr before3 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "before3");
566   FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "after1");
567   FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_print_tuple_wrapper", "after2");
568   auto patterns = std::vector<SubstitutionPtr>({irpass.print_tuple_wrapper_});
569   ASSERT_TRUE(CheckOpt(before1, after1, patterns));
570   ASSERT_TRUE(CheckOpt(before2, after2, patterns));
571   ASSERT_TRUE(CheckOpt(before3, before3, patterns));
572 }
573 #endif
574 
TEST_F(TestOptLib,test_constant_duplicate_mul)575 TEST_F(TestOptLib, test_constant_duplicate_mul) {
576   FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell");
577   FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr");
578   FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl");
579   FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr");
580   FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after");
581   auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
582   ASSERT_TRUE(CheckOpt(beforell, after, patterns));
583   ASSERT_TRUE(CheckOpt(beforelr, after, patterns));
584   ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
585   ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
586 }
587 
TEST_F(TestOptLib,test_adjust_allreduce_mul_add)588 TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
589   FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
590   FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
591   FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
592   FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
593   FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
594   FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
595   FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
596   FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
597   auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_});
598   ASSERT_TRUE(CheckOpt(beforell, after1, patterns, true));
599   ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
600   ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
601   ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
602 }
603 
TEST_F(TestOptLib,test_row_tensor)604 TEST_F(TestOptLib, test_row_tensor) {
605   FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "before_get_indices");
606   FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "after_get_indices");
607   FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_row_tensor", "before_get_values");
608   FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_row_tensor", "after_get_values");
609   FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "before_get_dense_shape");
610   FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "after_get_dense_shape");
611   auto patterns = std::vector<SubstitutionPtr>({irpass.row_tensor_eliminate_});
612   ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
613   ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
614   ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
615 }
616 
TEST_F(TestOptLib,test_sparse_tensor)617 TEST_F(TestOptLib, test_sparse_tensor) {
618   FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices");
619   FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices");
620   FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values");
621   FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values");
622   FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape");
623   FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape");
624   auto patterns = std::vector<SubstitutionPtr>({irpass.sparse_tensor_eliminate_});
625   ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
626   ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
627   ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
628 }
629 }  // namespace opt
630 }  // namespace mindspore
631