• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/gpu/replace_addn_fusion.h"
17 #include "backend/session/anf_runtime_algorithm.h"
18 #include "ir/primitive.h"
19 #include "utils/utils.h"
20 #include "backend/optimizer/common/helper.h"
21 
22 namespace mindspore {
23 namespace opt {
DefinePattern() const24 const BaseRef ReplaceAddNFusion::DefinePattern() const {
25   VectorRef addn = VectorRef({prim::kPrimAddN, A, B});
26   return addn;
27 }
28 
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const29 const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
30   MS_EXCEPTION_IF_NULL(graph);
31   MS_EXCEPTION_IF_NULL(node);
32   auto A = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
33   auto B = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
34   MS_EXCEPTION_IF_NULL(A);
35   MS_EXCEPTION_IF_NULL(B);
36   int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
37   if (num_input == kAddNInputNum) {
38     auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name());
39     MS_EXCEPTION_IF_NULL(prim);
40     std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
41     auto add_new = graph->NewCNode(inputs);
42     MS_EXCEPTION_IF_NULL(add_new);
43     std::vector<TypeId> outputs_type;
44     std::vector<std::vector<size_t>> outputs_shape;
45     outputs_type.push_back(AnfAlgo::GetOutputInferDataType(A, 0));
46     outputs_shape.push_back(AnfAlgo::GetOutputInferShape(A, 0));
47     AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, add_new.get());
48     auto manager = graph->manager();
49     MS_EXCEPTION_IF_NULL(manager);
50     manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(add_new));
51     return add_new;
52   } else {
53     return nullptr;
54   }
55 }
56 }  // namespace opt
57 }  // namespace mindspore
58