• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_
18 #define MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_
19 
20 #include <vector>
21 #include <memory>
22 #include "common/common_test.h"
23 #include "ir/anf.h"
24 #include "ir/value.h"
25 #include "include/common/utils/utils.h"
26 #include "include/backend/anf_runtime_algorithm.h"
27 
28 #define private public
29 #define protected public
30 #include "include/backend/optimizer/pattern_to_pattern.h"
31 #undef private
32 #undef protected
33 
34 namespace mindspore {
35 namespace opt {
36 class CheckPattern {
37  public:
CheckPattern()38   CheckPattern()
39       : m_(std::make_shared<PatternMap>()),
40         src_pattern_(SrcPattern(m_)),
41         pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
42         primitive_vars_(std::make_shared<PrimitiveVarMap>()),
43         equiv_(std::make_shared<Equiv>()){};
build_pattern_map(const AnfNodePtr & node)44   bool build_pattern_map(const AnfNodePtr &node) {
45     VarPtr root_g = std::make_shared<Var>("RootG");
46     auto src_pattern_root = SexpToNode(src_pattern_.GetRoot(), root_g, primitive_vars_.get(), multigraph_);
47     auto primitive = GetCNodePrimitive(src_pattern_root);
48     if (IsPrimitiveCNode(node, primitive)) {
49       MS_EXCEPTION_IF_NULL(primitive_vars_);
50       MS_EXCEPTION_IF_NULL(equiv_);
51       equiv_->clear();
52       EquivPtr equiv = pattern_engine_.Match(src_pattern_root, node, *primitive_vars_, equiv_);
53       if (equiv != nullptr && !equiv->empty()) {
54         return src_pattern_.build_pattern_map(node, equiv);
55       }
56     }
57     return false;
58   }
59   PatternMapPtr m_;
60   SrcPattern src_pattern_;
61   PatternEngine pattern_engine_;
62   PrimitiveVarMapPtr primitive_vars_;
63   EquivPtr equiv_;
64   bool multigraph_ = true;
65 };
66 }  // namespace opt
67 }  // namespace mindspore
68 
69 #endif  // MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_
70