1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/gpu/replace_addn_fusion.h"
17 #include "backend/session/anf_runtime_algorithm.h"
18 #include "ir/primitive.h"
19 #include "utils/utils.h"
20 #include "backend/optimizer/common/helper.h"
21
22 namespace mindspore {
23 namespace opt {
DefinePattern() const24 const BaseRef ReplaceAddNFusion::DefinePattern() const {
25 VectorRef addn = VectorRef({prim::kPrimAddN, A, B});
26 return addn;
27 }
28
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const29 const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
30 MS_EXCEPTION_IF_NULL(graph);
31 MS_EXCEPTION_IF_NULL(node);
32 auto A = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
33 auto B = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
34 MS_EXCEPTION_IF_NULL(A);
35 MS_EXCEPTION_IF_NULL(B);
36 int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
37 if (num_input == kAddNInputNum) {
38 auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name());
39 MS_EXCEPTION_IF_NULL(prim);
40 std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
41 auto add_new = graph->NewCNode(inputs);
42 MS_EXCEPTION_IF_NULL(add_new);
43 std::vector<TypeId> outputs_type;
44 std::vector<std::vector<size_t>> outputs_shape;
45 outputs_type.push_back(AnfAlgo::GetOutputInferDataType(A, 0));
46 outputs_shape.push_back(AnfAlgo::GetOutputInferShape(A, 0));
47 AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, add_new.get());
48 auto manager = graph->manager();
49 MS_EXCEPTION_IF_NULL(manager);
50 manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(add_new));
51 return add_new;
52 } else {
53 return nullptr;
54 }
55 }
56 } // namespace opt
57 } // namespace mindspore
58