1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_parallel_utils.h"
18
19 #include <inttypes.h>
20 #include <sys/time.h>
21 #include <algorithm>
22
23 #include <map>
24 #include <set>
25 #include <string>
26 #include <unordered_map>
27 #include <utility>
28
29 #include "base/core_ops.h"
30 #include "frontend/operator/ops.h"
31 #include "frontend/optimizer/optimizer.h"
32 #include "frontend/parallel/context.h"
33 #include "frontend/parallel/device_manager.h"
34 #include "frontend/parallel/graph_util/generate_graph.h"
35 #include "frontend/parallel/graph_util/graph_info.h"
36 #include "frontend/parallel/graph_util/node_info.h"
37 #include "frontend/parallel/node_check.h"
38 #include "ir/param_info.h"
39 #include "ir/tensor.h"
40 #include "utils/trace_base.h"
41 #include "utils/comm_manager.h"
42 #include "utils/ms_context.h"
43 #include "utils/symbolic.h"
44 #include "mindspore/core/utils/parallel_node_check.h"
45
46 namespace mindspore {
47 namespace parallel {
IsSomePrimitive(const CNodePtr & cnode,const std::string & name)48 bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
49 if (!cnode) {
50 return false;
51 }
52 ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
53 MS_EXCEPTION_IF_NULL(anf_node);
54 PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
55 return (prim->name() == name);
56 }
57
IsParallelCareNode(const CNodePtr & cnode)58 bool IsParallelCareNode(const CNodePtr &cnode) {
59 MS_EXCEPTION_IF_NULL(cnode);
60 ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
61 if (prim_node == nullptr) {
62 return false;
63 }
64 PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
65 if (prim == nullptr) {
66 return false;
67 }
68 if (IsInParallelBlackList(prim)) {
69 MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
70 return false;
71 }
72 // get_next is not in the forward graph, we need mark the get_next as the forward node
73 if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
74 return true;
75 }
76 if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
77 return false;
78 }
79
80 return cnode->in_forward_flag();
81 }
82
GetValueListShape(const AnfNodePtr & node)83 Shapes GetValueListShape(const AnfNodePtr &node) {
84 Shapes shapes;
85 std::vector<ValuePtr> inputs_seq;
86 if (IsValueNode<ValueList>(node)) {
87 inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
88 } else if (IsValueNode<ValueTuple>(node)) {
89 inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
90 } else {
91 MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple";
92 }
93 for (auto &ele : inputs_seq) {
94 auto tensor = ele->cast<tensor::TensorPtr>();
95 if (tensor == nullptr) {
96 MS_LOG(WARNING) << "The value node is not a tensor";
97 break;
98 }
99 auto one_shape = tensor->shape();
100 shapes.push_back(one_shape);
101 }
102 return shapes;
103 }
104
GetNodeShape(const AnfNodePtr & node)105 Shapes GetNodeShape(const AnfNodePtr &node) {
106 MS_EXCEPTION_IF_NULL(node);
107 Shapes shapes;
108 if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
109 return GetValueListShape(node);
110 }
111 BaseShapePtr base_shape_ptr = node->Shape();
112 if (node->isa<CNode>()) {
113 auto cnode = node->cast<CNodePtr>();
114 if (IsValueNode<Primitive>(cnode->input(0))) {
115 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
116 MS_EXCEPTION_IF_NULL(prim);
117 if (prim->name() == MAKEREF) {
118 AnfNodePtr ref_node = cnode->input(1);
119 auto func_graph = cnode->func_graph();
120 MS_EXCEPTION_IF_NULL(ref_node);
121 MS_EXCEPTION_IF_NULL(func_graph);
122 return GetRefKeyNodeShape(ref_node, func_graph);
123 }
124 }
125 if (cnode->input(0)->isa<CNode>()) {
126 if (cnode->inputs().size() < 2) {
127 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
128 }
129 base_shape_ptr = cnode->input(1)->Shape();
130 }
131 }
132 if (base_shape_ptr == nullptr) {
133 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
134 << node->fullname_with_scope();
135 }
136 auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
137 if (tuple_shape_ptr != nullptr) {
138 auto tuple_shape = tuple_shape_ptr->shape();
139 for (auto &shape : tuple_shape) {
140 auto each_shape = dyn_cast<abstract::Shape>(shape);
141 MS_EXCEPTION_IF_NULL(each_shape);
142 shapes.push_back(each_shape->shape());
143 }
144 } else {
145 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
146 MS_EXCEPTION_IF_NULL(shape_ptr);
147 shapes.push_back(shape_ptr->shape());
148 }
149 return shapes;
150 }
151
CreateInstanceName(const CNodePtr & node,size_t index)152 std::string CreateInstanceName(const CNodePtr &node, size_t index) {
153 MS_EXCEPTION_IF_NULL(node);
154 if (!IsValueNode<Primitive>(node->input(0))) {
155 MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
156 }
157 std::string name_base = node->fullname_with_scope();
158 std::string name = name_base + "_" + std::to_string(index);
159 std::string instance_name = HashInstanceName(name);
160 return instance_name;
161 }
162
SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input)163 void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
164 if (new_node_input.empty()) {
165 return;
166 }
167
168 auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
169 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
170 MS_EXCEPTION_IF_NULL(prim);
171
172 auto attrs = prim->attrs();
173 auto iter = attrs.find(GROUP);
174 if (iter != attrs.end()) {
175 auto value = iter->second;
176 MS_EXCEPTION_IF_NULL(value);
177 if (value->isa<StringImm>()) {
178 std::string hash_name = value->cast<StringImmPtr>()->value();
179 MS_EXCEPTION_IF_NULL(g_device_manager);
180 std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
181 (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
182 }
183 }
184 }
185
ReplaceOpInput(const Operator & replace_op,const std::string & instance_name,const CNodePtr & node)186 std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
187 const CNodePtr &node) {
188 OperatorArgs arg_replace_op = replace_op.second;
189 ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
190 if (pyop_instance == nullptr) {
191 MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
192 }
193 OperatorParams params = arg_replace_op.second;
194 if (node->inputs().size() < 2) {
195 // GetNext operator dose not has input
196 if (node->inputs().size() == 1) {
197 return {NewValueNode(pyop_instance)};
198 }
199 MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
200 }
201 std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
202
203 if (replace_op.first == EMBEDDING_LOOKUP) {
204 replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
205 }
206
207 if (!params.empty()) {
208 Param param_first = *(params.begin());
209 int64_t first_position = param_first.second;
210 if (first_position == 1) {
211 replace_input.pop_back();
212 }
213 for (auto ¶m : params) {
214 AnfNodePtr val = NewValueNode(param.first.second);
215 if (val == nullptr) {
216 MS_LOG(EXCEPTION) << "Failure:val is nullptr";
217 }
218 int64_t position = param.second;
219 (void)replace_input.insert(replace_input.begin() + position, val);
220 }
221 } else if (replace_op.first == SYNC_BATCH_NORM) {
222 for (size_t i = 2; i < node->inputs().size(); ++i) {
223 replace_input.push_back(node->input(i));
224 }
225 }
226 SetCommunicationOpGroupLabel(replace_input);
227 return replace_input;
228 }
229 } // namespace parallel
230 } // namespace mindspore
231