1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass/symbol_resolver.h"
18
19 #include <string>
20 #include <memory>
21 #include <vector>
22
23 namespace mindspore {
24 namespace opt {
25 namespace irpass {
26 // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
27 // {prim::kPrimGetAttr, namespace, attr}
28 // {prim::kPrimGetAttr, bool, attr}
29 // {prim::kPrimResolve, namespace, symbol}
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)30 AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
31 constexpr char PARSE_SUPER_NAME[] = "namespace";
32 constexpr size_t namespace_index = 1;
33 constexpr size_t symbol_index = 2;
34
35 PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
36 auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer]() -> AnfNodePtr {
37 auto inner = resolve_node.GetNode(node);
38 auto attr = attr_node.GetNode(node);
39 if (IsPrimitiveCNode(inner, prim::kPrimResolve)) {
40 auto resolve_cnode = inner->cast<CNodePtr>();
41 auto namespace_node = resolve_cnode->input(namespace_index);
42 auto symbol_node = resolve_cnode->input(symbol_index);
43 if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
44 return nullptr;
45 }
46 // deal with the case of getting attr from a class member
47 // and avoid the case of getting attr from self (the result of ParseSuper)
48 auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node);
49 auto sym = GetValueNode<parse::SymbolPtr>(symbol_node);
50 if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) {
51 return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr);
52 }
53 }
54 return nullptr;
55 };
56
57 auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
58 auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
59 auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
60 parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str);
61 return parse::ResolveSymbol(optimizer->manager(), ns, sym, node);
62 };
63
64 auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
65 auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
66 auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
67 auto manager = optimizer->manager();
68 return parse::ResolveSymbol(manager, ns, sym, node);
69 };
70
71 // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
72 MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda,
73 attr_node.CheckFunc(IsValueNode<StringImm>, node));
74 // {prim::kPrimGetAttr, namespace, attr}
75 MATCH_REPLACE_LAMBDA_IF(
76 node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda,
77 ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
78 // {prim::kPrimGetAttr, bool, attr}
79 MATCH_REPLACE_IF(
80 node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node,
81 bool_node.CheckFunc(IsValueNode<BoolImm>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
82 // {prim::kPrimResolve, namespace, symbol}
83 MATCH_REPLACE_LAMBDA_IF(
84 node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda,
85 ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && sym_node.CheckFunc(IsValueNode<parse::Symbol>, node));
86 return nullptr;
87 }
88 } // namespace irpass
89 } // namespace opt
90 } // namespace mindspore
91