• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/expander/core/infer.h"
18 
19 #include "mindspore/core/ops/other_ops.h"
20 #include "mindspore/core/ops/nn_optimizer_ops.h"
21 #include "mindspore/core/ops/nn_ops.h"
22 #include "mindspore/core/ops/math_ops.h"
23 #include "mindspore/core/ops/image_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "ir/anf.h"
27 #include "ir/primitive.h"
28 #include "ops/base_operator.h"
29 
30 namespace mindspore {
31 namespace expander {
InferAnfnode(const AnfNodePtr & anfnode) const32 void CppInfer::InferAnfnode(const AnfNodePtr &anfnode) const {
33   if (anfnode->isa<ValueNode>()) {
34     anfnode->set_abstract(anfnode->cast<ValueNodePtr>()->value()->ToAbstract());
35     return;
36   }
37   auto cnode = anfnode->cast<CNodePtr>();
38   MS_EXCEPTION_IF_NULL(cnode);
39   auto prim = GetCNodePrimitive(cnode);
40   MS_EXCEPTION_IF_NULL(prim);
41   AbstractBasePtrList abs_list;
42   abs_list.reserve(cnode->size());
43   (void)std::transform(cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend(), std::back_inserter(abs_list),
44                        [](const AnfNodeWeakPtr &weak_node) {
45                          AnfNodePtr node = weak_node.lock();
46                          MS_EXCEPTION_IF_NULL(node);
47                          const auto &abs = node->abstract();
48                          if (abs == nullptr) {
49                            MS_EXCEPTION_IF_CHECK_FAIL(node->isa<ValueNode>(), node->ToString() + " has no abstract");
50                            return node->cast<ValueNodePtr>()->value()->ToAbstract();
51                          }
52                          return abs;
53                        });
54 
55   auto abstract_optional = abstract::InferAbstractByFuncImpl(prim, abs_list);
56   if (abstract_optional.has_value()) {
57     cnode->set_abstract(abstract_optional.value());
58     return;
59   }
60 
61   auto &infer_impl = CppInfer::infer_impl_cache()[prim];
62   if (infer_impl.Get() == nullptr) {
63     auto found = abstract::GetPrimitiveInferImpl(prim);
64     if (found.has_value() && found.value().IsImplInferShapeAndType()) {
65       infer_impl = found.value();
66     } else {
67       MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined.";
68     }
69   }
70   cnode->set_abstract(infer_impl.InferShapeAndType(nullptr, prim, abs_list));
71 }
72 
GetShape(const NodePtr & node)73 BaseShapePtr CppInfer::GetShape(const NodePtr &node) {
74   auto abs = GetAbstract(node);
75   MS_EXCEPTION_IF_NULL(abs);
76   return abs->BuildShape();
77 }
78 
GetDtype(const NodePtr & node)79 TypePtr CppInfer::GetDtype(const NodePtr &node) {
80   auto abs = GetAbstract(node);
81   MS_EXCEPTION_IF_NULL(abs);
82   return abs->BuildType();
83 }
84 }  // namespace expander
85 }  // namespace mindspore
86