• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/symbol_resolver.h"
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 
23 namespace mindspore {
24 namespace opt {
25 namespace irpass {
26 // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
27 // {prim::kPrimGetAttr, namespace, attr}
28 // {prim::kPrimGetAttr, bool, attr}
29 // {prim::kPrimResolve, namespace, symbol}
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)30 AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
31   constexpr char PARSE_SUPER_NAME[] = "namespace";
32   constexpr size_t namespace_index = 1;
33   constexpr size_t symbol_index = 2;
34 
35   PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
36   auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer]() -> AnfNodePtr {
37     auto inner = resolve_node.GetNode(node);
38     auto attr = attr_node.GetNode(node);
39     if (IsPrimitiveCNode(inner, prim::kPrimResolve)) {
40       auto resolve_cnode = inner->cast<CNodePtr>();
41       auto namespace_node = resolve_cnode->input(namespace_index);
42       auto symbol_node = resolve_cnode->input(symbol_index);
43       if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
44         return nullptr;
45       }
46       // deal with the case of getting attr from a class member
47       // and avoid the case of getting attr from self (the result of ParseSuper)
48       auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node);
49       auto sym = GetValueNode<parse::SymbolPtr>(symbol_node);
50       if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) {
51         return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr);
52       }
53     }
54     return nullptr;
55   };
56 
57   auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
58     auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
59     auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
60     parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str);
61     return parse::ResolveSymbol(optimizer->manager(), ns, sym, node);
62   };
63 
64   auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
65     auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
66     auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
67     auto manager = optimizer->manager();
68     return parse::ResolveSymbol(manager, ns, sym, node);
69   };
70 
71   // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
72   MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda,
73                           attr_node.CheckFunc(IsValueNode<StringImm>, node));
74   // {prim::kPrimGetAttr, namespace, attr}
75   MATCH_REPLACE_LAMBDA_IF(
76     node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda,
77     ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
78   // {prim::kPrimGetAttr, bool, attr}
79   MATCH_REPLACE_IF(
80     node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node,
81     bool_node.CheckFunc(IsValueNode<BoolImm>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
82   // {prim::kPrimResolve, namespace, symbol}
83   MATCH_REPLACE_LAMBDA_IF(
84     node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda,
85     ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && sym_node.CheckFunc(IsValueNode<parse::Symbol>, node));
86   return nullptr;
87 }
88 }  // namespace irpass
89 }  // namespace opt
90 }  // namespace mindspore
91