1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/common/expander/core/emitter.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <unordered_set>
22 #include <utility>
23 #include "include/common/utils/utils.h"
24 #include "ir/anf.h"
25 #include "ops/sequence_ops.h"
26 #include "ops/math_ops.h"
27 #include "ops/array_ops.h"
28 #include "ops/framework_ops.h"
29 #include "include/common/utils/convert_utils.h"
30 #include "ir/functor.h"
31 #include "ops/primitive_c.h"
32 #include "utils/anf_utils.h"
33 #include "utils/check_convert_utils.h"
34 #include "utils/ms_context.h"
35 #include "ops/op_def.h"
36 #include "ir/primitive.h"
37
38 namespace mindspore {
39 namespace expander {
40 namespace {
GetIntList(const NodePtr & node)41 std::pair<bool, std::vector<int64_t>> GetIntList(const NodePtr &node) {
42 MS_EXCEPTION_IF_NULL(node);
43 ValuePtr value_ptr = node->BuildValue();
44 if (value_ptr != nullptr) {
45 if ((value_ptr->isa<ValueSequence>() && !value_ptr->ContainsValueAny()) || value_ptr->isa<Scalar>()) {
46 return std::make_pair(true, CheckAndConvertUtils::CheckIntOrTupleInt("value", value_ptr, "GetIntList"));
47 }
48 if (value_ptr->isa<tensor::Tensor>()) {
49 auto tensor = value_ptr->cast<tensor::TensorPtr>();
50 MS_EXCEPTION_IF_NULL(tensor);
51 // In pynative mode, need data sync before get tensor value, otherwise the tensor value may be undefined.
52 tensor->data_sync();
53 return std::make_pair(true, CheckAndConvertUtils::CheckTensorIntValue("value", value_ptr, "GetIntList"));
54 }
55 }
56 return std::make_pair(false, std::vector<int64_t>{});
57 }
58
CreateZeroScalar(const TypePtr & type)59 ValuePtr CreateZeroScalar(const TypePtr &type) {
60 auto tensor = std::make_shared<tensor::Tensor>(0, type);
61 return CreateValueFromTensor(tensor);
62 }
63
CalReshapeRealDstShape(const ShapeVector & x_shape,const ShapeVector & dst_shape)64 ShapeVector CalReshapeRealDstShape(const ShapeVector &x_shape, const ShapeVector &dst_shape) {
65 if (!IsDynamicShape(dst_shape)) {
66 return dst_shape;
67 }
68
69 if (IsDynamicRank(dst_shape) || IsDynamic(x_shape)) {
70 MS_LOG(EXCEPTION) << "The source shape(" << x_shape << ") or target shape(" << dst_shape
71 << ") is invalid for Reshape const infer!";
72 }
73
74 ShapeVector res_shape(dst_shape.begin(), dst_shape.end());
75 if (std::count(dst_shape.begin(), dst_shape.end(), abstract::Shape::kShapeDimAny) != 1) {
76 MS_LOG(EXCEPTION) << "The target shape can only have one -1 for Reshape, bug got " << dst_shape;
77 }
78
79 auto total_size = std::accumulate(x_shape.cbegin(), x_shape.cend(), 1, std::multiplies<int64_t>());
80 size_t target_idx = 0;
81 int64_t dst_size = 1;
82 for (size_t i = 0; i < dst_shape.size(); ++i) {
83 if (dst_shape[i] == abstract::Shape::kShapeDimAny) {
84 target_idx = i;
85 continue;
86 }
87 dst_size *= dst_shape[i];
88 }
89 MS_EXCEPTION_IF_CHECK_FAIL(dst_size != 0, "Cannot divide zeros!");
90 res_shape[target_idx] = total_size / dst_size;
91 return res_shape;
92 }
93 } // namespace
94
Emit(const std::string & op_name,const NodePtrList & inputs,const DAttr & attrs)95 NodePtr Emitter::Emit(const std::string &op_name, const NodePtrList &inputs, const DAttr &attrs) {
96 auto prim = NewPrimitive(op_name, attrs);
97 return EmitOp(prim, inputs);
98 }
99
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)100 NodePtr Emitter::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
101 MS_EXCEPTION(NotImplementedError) << "Base Emitter not implemented EmitOp() method";
102 }
103
NewPrimitive(const std::string & op_name,const DAttr & attrs)104 PrimitivePtr Emitter::NewPrimitive(const std::string &op_name, const DAttr &attrs) {
105 PrimitivePtr prim = nullptr;
106 if (mindspore::ops::IsPrimitiveFunction(op_name)) {
107 prim = std::make_shared<Primitive>(op_name);
108 if (!attrs.empty()) {
109 prim->SetAttrs(attrs);
110 }
111 } else {
112 auto &func = Emitter::primc_func_cache()[op_name];
113 if (func == nullptr) {
114 const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
115 const auto iter = op_primc_fns.find(op_name);
116 prim = iter == op_primc_fns.end() ? std::make_shared<ops::PrimitiveC>(op_name) : (func = iter->second)();
117 } else {
118 prim = func();
119 }
120 }
121 MS_EXCEPTION_IF_NULL(prim);
122 if (!attrs.empty()) {
123 (void)prim->SetAttrs(attrs);
124 }
125 return prim;
126 }
127
EmitValue(const ValuePtr & value)128 NodePtr Emitter::EmitValue(const ValuePtr &value) {
129 MS_EXCEPTION(NotImplementedError) << "Base Emitter not implemented EmitValue() method";
130 }
131
Exp(const NodePtr & x)132 NodePtr Emitter::Exp(const NodePtr &x) {
133 return Emit(kExpOpName, {x},
134 {{"base", MakeValue<float>(-1.0)}, {"scale", MakeValue<float>(1.0)}, {"shift", MakeValue<float>(0.0)}});
135 }
136
Log(const NodePtr & x)137 NodePtr Emitter::Log(const NodePtr &x) {
138 return Emit(kLogOpName, {x},
139 {{"base", MakeValue<pyfloat>(-1.0)},
140 {"scale", MakeValue<pyfloat>(1.0)},
141 {"shift", MakeValue<pyfloat>(0.0)},
142 {"cust_aicpu", MakeValue(kLogOpName)}});
143 }
144
Cast(const NodePtr & node,const TypePtr & type)145 NodePtr Emitter::Cast(const NodePtr &node, const TypePtr &type) {
146 MS_EXCEPTION_IF_NULL(node);
147 MS_EXCEPTION_IF_NULL(type);
148 // do not emit a node when the dst type is the same as src type
149 if (node->dtype()->type_id() == type->type_id()) {
150 return node;
151 }
152 return Emit("Cast", {node, Value(static_cast<int64_t>(type->type_id()))});
153 }
154
Reshape(const NodePtr & node,const NodePtr & shape)155 NodePtr Emitter::Reshape(const NodePtr &node, const NodePtr &shape) {
156 MS_EXCEPTION_IF_NULL(node);
157 auto [success, dst_shape] = GetIntList(shape);
158 if (!success) {
159 auto tuple_shape = TensorToTuple(shape);
160 return Emit(kReshapeOpName, {node, tuple_shape});
161 }
162
163 if (node->input_type() == InputType::kConstant) {
164 // If node and shape is both known, return node itself or a new tensor with target shape.
165 auto value = node->BuildValue();
166 MS_EXCEPTION_IF_NULL(value);
167 auto tensor = value->cast<tensor::TensorPtr>();
168 if (tensor != nullptr && tensor->data().const_data() != nullptr) {
169 const auto &tensor_shape = tensor->shape_c();
170 auto update_shape = CalReshapeRealDstShape(tensor_shape, dst_shape);
171 if (tensor_shape == update_shape) {
172 return node;
173 }
174 auto type_id = tensor->data_type();
175 return this->Tensor(type_id, update_shape, tensor->data_c(), type_id);
176 }
177 }
178
179 auto node_shape = node->shape();
180 if (IsDynamicRank(node_shape)) {
181 return Emit(kReshapeOpName, {node, shape});
182 }
183 if (dst_shape.size() != node_shape.size()) {
184 return Emit(kReshapeOpName, {node, shape});
185 }
186 for (size_t i = 0; i < dst_shape.size(); ++i) {
187 if (dst_shape[i] != node_shape[i] && dst_shape[i] != -1) {
188 return Emit(kReshapeOpName, {node, shape});
189 }
190 }
191 return node;
192 }
193
MatMul(const NodePtr & a,const NodePtr & b,bool transpose_a,bool transpose_b)194 NodePtr Emitter::MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) {
195 return Emit(prim::kPrimMatMul->name(), {a, b, Value(transpose_a), Value(transpose_b)});
196 }
197
BatchMatMul(const NodePtr & a,const NodePtr & b,bool transpose_a,bool transpose_b)198 NodePtr Emitter::BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) {
199 return Emit(prim::kPrimBatchMatMul->name(), {a, b, Value(transpose_a), Value(transpose_b)});
200 }
201
MatMulExt(const NodePtr & a,const NodePtr & b)202 NodePtr Emitter::MatMulExt(const NodePtr &a, const NodePtr &b) {
203 return UnifyDtypeAndEmit(prim::kPrimMatMulExt->name(), a, b, {});
204 }
205
Transpose(const NodePtr & node,const NodePtr & perm)206 NodePtr Emitter::Transpose(const NodePtr &node, const NodePtr &perm) {
207 MS_EXCEPTION_IF_NULL(node);
208 MS_EXCEPTION_IF_NULL(perm);
209 auto [success, perm_list] = GetIntList(perm);
210 if (!success) {
211 auto tuple_perm = TensorToTuple(perm);
212 return Emit(kTransposeOpName, {node, tuple_perm});
213 }
214 // perm like [0, 1, 2, 3] does not need transpose.
215 auto n = SizeToLong(perm_list.size());
216 for (size_t i = 0; i < perm_list.size(); ++i) {
217 // perm value may be negative, e.g. [0, -3, 2, 3] is equal to [0, 1, 2, 3]
218 auto perm_i = perm_list[i] < 0 ? (perm_list[i] + n) : perm_list[i];
219 if (perm_i != static_cast<int64_t>(i)) {
220 return Emit(kTransposeOpName, {node, perm});
221 }
222 }
223 return node;
224 }
225
Tile(const NodePtr & node,const NodePtr & dims)226 NodePtr Emitter::Tile(const NodePtr &node, const NodePtr &dims) {
227 MS_EXCEPTION_IF_NULL(node);
228 MS_EXCEPTION_IF_NULL(dims);
229 auto [success, multiples_list] = GetIntList(dims);
230 if (!success) {
231 auto tuple_multiples = TensorToTuple(dims);
232 return Emit(kTileOpName, {node, tuple_multiples});
233 }
234 bool is_all_one = std::all_of(multiples_list.begin(), multiples_list.end(), [](int64_t shp) { return shp == 1; });
235 if (is_all_one && node->shape().size() >= multiples_list.size()) {
236 return node;
237 }
238 return Emit(kTileOpName, {node, dims});
239 }
240
BroadcastTo(const NodePtr & x,const NodePtr & y)241 NodePtr Emitter::BroadcastTo(const NodePtr &x, const NodePtr &y) {
242 if (IsDynamic(x->shape()) || IsDynamic(y->shape())) {
243 return Emit("BroadcastTo", {x, Shape(y)});
244 }
245
246 return x->shape() == y->shape() ? x : Emit("BroadcastTo", {x, Shape(y)});
247 }
248
ZerosLike(const NodePtr & node)249 NodePtr Emitter::ZerosLike(const NodePtr &node) {
250 MS_EXCEPTION_IF_NULL(node);
251 if (node->input_type() == InputType::kConstant) {
252 if (node->dtype()->type_id() == kMetaTypeNone) {
253 return Tensor(0);
254 }
255 auto v = node->BuildValue();
256 MS_EXCEPTION_IF_NULL(v);
257 if (v->isa<ValueSequence>()) {
258 return Emit(kSequenceZerosLikeOpName, {node});
259 } else if (v->isa<Scalar>()) {
260 return EmitValue(CreateZeroScalar(v->type()));
261 } else if (v->isa<Type>()) {
262 return Tensor(0, v->type());
263 } else if (v->isa<Monad>()) {
264 return Tensor(0);
265 }
266 }
267
268 auto abs = node->abstract();
269 MS_EXCEPTION_IF_NULL(abs);
270
271 if (abs->isa<abstract::AbstractTensor>()) {
272 return Emit(kZerosLikeOpName, {node});
273 } else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractType>() ||
274 abs->isa<abstract::AbstractNone>()) {
275 return node;
276 } else if (abs->isa<abstract::AbstractSequence>()) {
277 auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
278 if (!sequence_abs->dynamic_len() && sequence_abs->empty()) {
279 return node;
280 }
281 return Emit(kSequenceZerosLikeOpName, {node});
282 } else if (abs->isa<abstract::AbstractScalar>()) {
283 auto value = CreateZeroScalar(abs->BuildType());
284 return EmitValue(value);
285 }
286
287 MS_LOG(EXCEPTION) << "Cannot emit ZerosLike for " << node->ToString() << " with abstract " << abs;
288 }
289
Fill(double value,const ShapeVector & shape,TypeId data_type)290 NodePtr Emitter::Fill(double value, const ShapeVector &shape, TypeId data_type) {
291 size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
292 std::vector<double> data(data_num, value);
293 return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeFloat64);
294 }
295
Fill(int64_t value,const ShapeVector & shape,TypeId data_type)296 NodePtr Emitter::Fill(int64_t value, const ShapeVector &shape, TypeId data_type) {
297 size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
298 std::vector<int64_t> data(data_num, value);
299 return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeInt64);
300 }
301
ScalarToTensor(const NodePtr & node)302 NodePtr Emitter::ScalarToTensor(const NodePtr &node) {
303 MS_EXCEPTION_IF_NULL(node);
304 auto value = node->BuildValue();
305 MS_EXCEPTION_IF_NULL(value);
306 auto scalar = value->cast<ScalarPtr>();
307 MS_EXCEPTION_IF_NULL(scalar);
308 auto tensor = mindspore::ScalarToTensor(scalar);
309 return EmitValue(tensor);
310 }
311
ScalarToTensor(const NodePtr & node,const TypePtr & dtype)312 NodePtr Emitter::ScalarToTensor(const NodePtr &node, const TypePtr &dtype) {
313 MS_EXCEPTION_IF_NULL(node);
314 auto value = node->BuildValue();
315 MS_EXCEPTION_IF_NULL(value);
316 auto scalar = value->cast<ScalarPtr>();
317 if (scalar == nullptr) {
318 return Emit("ScalarToTensor", {node, Value(static_cast<int64_t>(dtype->type_id()))});
319 }
320 MS_EXCEPTION_IF_NULL(scalar);
321 auto tensor = mindspore::ScalarToTensor(scalar);
322 return EmitValue(tensor);
323 }
324
BoolNot(const NodePtr & node)325 NodePtr Emitter::BoolNot(const NodePtr &node) {
326 MS_EXCEPTION_IF_NULL(node);
327 auto abs = node->abstract();
328 MS_EXCEPTION_IF_NULL(abs);
329 NodePtr new_node{nullptr};
330 if (abs->isa<abstract::AbstractScalar>()) {
331 auto value_ptr = node->BuildValue();
332 MS_EXCEPTION_IF_NULL(value_ptr);
333 if (!(value_ptr->isa<ValueAny>() || value_ptr->isa<None>())) {
334 auto value = GetValue<bool>(value_ptr);
335 new_node = Value(static_cast<bool>(!value));
336 } else {
337 new_node = Emit("BoolNot", {node});
338 }
339 } else {
340 MS_LOG(EXCEPTION) << "BooNot only support scalar, but got " << abs->ToString();
341 }
342 return new_node;
343 }
344
NeedReduce(const ShapeVector & shape,const std::vector<int64_t> & axis,bool keep_dim,bool skip_mode) const345 std::pair<bool, ShapeVector> Emitter::NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis,
346 bool keep_dim, bool skip_mode) const {
347 if (IsDynamic(shape)) {
348 return std::make_pair(true, shape);
349 }
350 if (shape.empty() || (skip_mode && axis.empty())) {
351 return std::make_pair(false, shape);
352 }
353 auto rank = SizeToLong(shape.size());
354 std::vector<bool> axis_map;
355 if (axis.empty()) {
356 axis_map = std::vector<bool>(shape.size(), true);
357 } else {
358 axis_map = std::vector<bool>(shape.size(), false);
359 for (size_t i = 0; i < axis.size(); ++i) {
360 if (axis[i] < -rank || axis[i] >= rank) {
361 MS_EXCEPTION(ValueError) << "Reduce axis[" << i << "] is " << axis[i] << ", which is out of range [-" << rank
362 << ", " << rank << ") for shape: " << shape;
363 }
364 auto axis_i = axis[i] < 0 ? axis[i] + rank : axis[i];
365 axis_map[LongToSize(axis_i)] = true;
366 }
367 }
368 // Calc reduce output shape
369 ShapeVector out_shape;
370 bool need_reduce = false;
371 for (size_t i = 0; i < shape.size(); ++i) {
372 if (!axis_map[i]) {
373 // not reduce axis
374 out_shape.push_back(shape[i]);
375 } else {
376 // reduce axis
377 if (shape[i] != 1) {
378 need_reduce = true;
379 }
380 if (keep_dim) {
381 out_shape.push_back(1);
382 }
383 }
384 }
385 return std::make_pair(need_reduce, out_shape);
386 }
387
NeedReduce(const NodePtr & shape,const NodePtr & axis,bool keep_dim,bool skip_mode)388 std::pair<bool, NodePtr> Emitter::NeedReduce(const NodePtr &shape, const NodePtr &axis, bool keep_dim, bool skip_mode) {
389 auto [axis_success, axis_value] = GetIntList(axis);
390 if (axis_success && skip_mode && axis_value.empty()) {
391 return std::make_pair(false, shape);
392 }
393 auto [shape_success, shape_value] = GetIntList(shape);
394 if (shape_success && axis_success) {
395 auto [need_reduce, shape_vec] = NeedReduce(shape_value, axis_value, keep_dim, skip_mode);
396 return std::make_pair(need_reduce, Value(shape_vec));
397 }
398
399 auto v = Value(ShapeVector{});
400 return std::make_pair(true, v);
401 }
402
ReduceSum(const NodePtr & x,const NodePtr & axis,bool keep_dims,bool skip_mode)403 NodePtr Emitter::ReduceSum(const NodePtr &x, const NodePtr &axis, bool keep_dims, bool skip_mode) {
404 MS_EXCEPTION_IF_NULL(x);
405 MS_EXCEPTION_IF_NULL(axis);
406 auto need_reduce = NeedReduce(Shape(x), axis, keep_dims, skip_mode);
407 if (!need_reduce.first) {
408 return Reshape(x, need_reduce.second);
409 }
410 auto tuple_axis = TensorToTuple(axis);
411 return Emit(prim::kPrimReduceSum->name(), {x, tuple_axis, Value(keep_dims), Value(skip_mode)});
412 }
413
ReduceSum(const NodePtr & x,const ShapeVector & axis,bool keep_dims)414 NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_dims) {
415 MS_EXCEPTION_IF_NULL(x);
416 auto real_axis = axis;
417 #ifdef WITH_BACKEND
418 const auto &shape = x->shape();
419 if (real_axis.empty()) {
420 if (IsDynamicRank(shape)) {
421 MS_LOG(DEBUG) << "For ReduceSum, it may wrong with a empty axis for dynamic rank case.";
422 } else {
423 for (int64_t i = 0; i < SizeToLong(shape.size()); i++) {
424 real_axis.push_back(i);
425 }
426 }
427 }
428 #endif
429 return ReduceSum(x, Value<ShapeVector>(real_axis), keep_dims, false);
430 }
431
SumExt(const NodePtr & input,const NodePtr & axis,const NodePtr & keep_dims)432 NodePtr Emitter::SumExt(const NodePtr &input, const NodePtr &axis, const NodePtr &keep_dims) {
433 MS_EXCEPTION_IF_NULL(input);
434 MS_EXCEPTION_IF_NULL(axis);
435 MS_EXCEPTION_IF_NULL(keep_dims);
436 auto input_dtype_id = input->dtype()->type_id();
437 auto dtype = Value(static_cast<int64_t>(input_dtype_id));
438 MS_EXCEPTION_IF_NULL(dtype);
439
440 return Emit("SumExt", {input, axis, keep_dims, dtype});
441 }
442
Gather(const NodePtr & params,const NodePtr & indices,const NodePtr & axis,int64_t batch_dims)443 NodePtr Emitter::Gather(const NodePtr ¶ms, const NodePtr &indices, const NodePtr &axis, int64_t batch_dims) {
444 MS_EXCEPTION_IF_NULL(params);
445 MS_EXCEPTION_IF_NULL(indices);
446 MS_EXCEPTION_IF_NULL(axis);
447 return Emit(kGatherOpName, {params, indices, axis, Value(batch_dims)});
448 }
Gather(const NodePtr & params,const NodePtr & indices,int64_t axis,int64_t batch_dims)449 NodePtr Emitter::Gather(const NodePtr ¶ms, const NodePtr &indices, int64_t axis, int64_t batch_dims) {
450 return Gather(params, indices, Value(axis), batch_dims);
451 }
452
GetConstInputs(const NodePtrList & inputs,const std::vector<bool> & only_depend_shape,const ShapeValidFunc & valid_func)453 std::tuple<bool, ShapeArray, std::vector<std::vector<size_t>>> GetConstInputs(
454 const NodePtrList &inputs, const std::vector<bool> &only_depend_shape, const ShapeValidFunc &valid_func) {
455 bool all_const = true;
456 ShapeArray const_args;
457 std::vector<std::vector<size_t>> pos_idx;
458
459 for (size_t i = 0; i < inputs.size(); ++i) {
460 MS_EXCEPTION_IF_NULL(inputs[i]);
461
462 if (!only_depend_shape[i]) {
463 // input[i]'s value is used
464 auto [success, vec] = GetIntList(inputs[i]);
465 if (!success) {
466 all_const = false;
467 break;
468 }
469 pos_idx.push_back({const_args.size()});
470 const_args.push_back(vec);
471 } else {
472 // input[i]'s shape is used
473 auto abs = inputs[i]->abstract();
474 MS_EXCEPTION_IF_NULL(abs);
475
476 if (auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>(); sequence_abs != nullptr) {
477 auto begin_idx = const_args.size();
478 auto is_const = ops::TryGetShapeArg(sequence_abs, &const_args, &pos_idx);
479
480 if (is_const) {
481 for (size_t j = begin_idx; j < const_args.size(); ++j) {
482 is_const = valid_func ? valid_func(j, const_args[j]) : !IsDynamic(const_args[j]);
483 if (!is_const) {
484 break;
485 }
486 }
487 }
488
489 if (!is_const) {
490 all_const = false;
491 break;
492 }
493 } else {
494 auto input_shape = inputs[i]->shape();
495 auto input_valid = valid_func ? valid_func(i, input_shape) : !IsDynamic(input_shape);
496 if (!input_valid) {
497 all_const = false;
498 break;
499 }
500 pos_idx.push_back({const_args.size()});
501 const_args.push_back(input_shape);
502 }
503 }
504 }
505
506 return std::make_tuple(all_const, const_args, pos_idx);
507 }
508
ShapeCalc(const ShapeCalcBaseFunctorPtr & functor,const NodePtrList & inputs,const std::vector<int64_t> & value_depend,const ShapeValidFunc & valid_func)509 NodePtrList Emitter::ShapeCalc(const ShapeCalcBaseFunctorPtr &functor, const NodePtrList &inputs,
510 const std::vector<int64_t> &value_depend, const ShapeValidFunc &valid_func) {
511 std::vector<bool> only_depend_shape(inputs.size(), true);
512 for (auto idx : value_depend) {
513 only_depend_shape[LongToSize(idx)] = false;
514 }
515
516 bool all_const;
517 ShapeArray const_args;
518 std::vector<std::vector<size_t>> pos_idx;
519 // Try to get all const input shapes or values, and call the shape calc function when success.
520 std::tie(all_const, const_args, pos_idx) = GetConstInputs(inputs, only_depend_shape, valid_func);
521 NodePtrList res;
522 // all inputs are static-shape tensors,
523 if (all_const) {
524 auto out = functor->Calc(const_args, pos_idx);
525 res.reserve(out.size());
526 (void)std::transform(out.begin(), out.end(), std::back_inserter(res),
527 [this](const ShapeVector &sh) { return Value(sh); });
528 return res;
529 }
530
531 auto out = Emit(kShapeCalcOpName, inputs,
532 {{kAttrFunctor, functor},
533 {kAttrOnlyDependShape, MakeValue(only_depend_shape)},
534 {kAttrInputIsDynamicShape, MakeValue(true)}});
535 MS_EXCEPTION_IF_NULL(out);
536 auto abs = out->abstract();
537 MS_EXCEPTION_IF_NULL(abs);
538 auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
539 MS_EXCEPTION_IF_NULL(tuple_abs);
540 if (!tuple_abs->dynamic_len() && tuple_abs->size() != 0 && tuple_abs->elements()[0]->isa<abstract::AbstractTuple>()) {
541 res.reserve(tuple_abs->size());
542 for (size_t i = 0; i < tuple_abs->size(); ++i) {
543 res.push_back(TupleGetItem(out, i));
544 }
545 } else {
546 res.push_back(out);
547 }
548 return res;
549 }
550
TensorToTuple(const NodePtr & node)551 NodePtr Emitter::TensorToTuple(const NodePtr &node) {
552 MS_EXCEPTION_IF_NULL(node);
553 auto abs = node->abstract();
554 MS_EXCEPTION_IF_NULL(abs);
555 if (abs->isa<abstract::AbstractTensor>()) {
556 auto [success, value] = GetIntList(node);
557 if (success) {
558 return EmitValue(MakeValue(value));
559 }
560 return Emit(kTensorToTupleOpName, {node});
561 }
562 if (!abs->isa<abstract::AbstractTuple>()) {
563 MS_LOG(INTERNAL_EXCEPTION) << "A Tensor or tuple is expected, but got " << abs->BuildType()->ToString() << ".";
564 }
565 return node;
566 }
567
UnifyDtype2(const NodePtr & lhs,const NodePtr & rhs)568 std::tuple<NodePtr, NodePtr> Emitter::UnifyDtype2(const NodePtr &lhs, const NodePtr &rhs) {
569 auto it1 = type_vector_[lhs->dtype()->type_id()];
570 auto it2 = type_vector_[rhs->dtype()->type_id()];
571 if (!it1 || !it2 || it1 == it2) {
572 return {lhs, rhs};
573 }
574 if (it1 < it2) {
575 return {this->Cast(lhs, rhs->dtype()), rhs};
576 }
577 return {lhs, this->Cast(rhs, lhs->dtype())};
578 }
579
SparseSoftmaxCrossEntropyWithLogits(const NodePtrList & inputs,const DAttr & attrs,const NodePtr & out,const NodePtr & dout,bool is_graph_mode)580 NodePtr Emitter::SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const DAttr &attrs, const NodePtr &out,
581 const NodePtr &dout, bool is_graph_mode) {
582 auto grad = Emit("SparseSoftmaxCrossEntropyWithLogits", inputs, attrs);
583 if (is_graph_mode) {
584 grad = Depend(grad, out);
585 }
586 grad = Mul(grad, dout);
587 return grad;
588 }
589
Conditional(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)590 NodePtr Emitter::Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
591 MS_EXCEPTION(NotImplementedError) << "Base Emitter not implement Conditional() method";
592 }
593
While(const NodePtr & cond,const BlockFunc & body,const NodePtrList & init_list)594 NodePtr Emitter::While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) {
595 MS_EXCEPTION(NotImplementedError) << "Base Emitter not implement While() method";
596 }
597
IfThenElse(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)598 NodePtr CtrlFlowBlock::IfThenElse(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
599 auto tb = BuildSubgraph(true_case);
600 auto fb = BuildSubgraph(false_case);
601 auto s = emitter_->Emit("Switch", {cond, tb, fb});
602
603 auto cnode = func_graph_->FuncGraph::NewCNode({s->get()});
604 cnode->set_abstract(out_abstract_);
605 auto node = emitter_->NewIrNode(cnode->cast<AnfNodePtr>());
606 return node;
607 }
608
While(const NodePtr & cond,const BlockFunc & while_body_func,const NodePtrList & init_list)609 NodePtr CtrlFlowBlock::While(const NodePtr &cond, const BlockFunc &while_body_func, const NodePtrList &init_list) {
610 auto while_fg = std::make_shared<FuncGraph>();
611 MS_EXCEPTION_IF_NULL(while_fg);
612 auto cond_cnode = cond->get()->cast<CNodePtr>();
613 MS_EXCEPTION_IF_NULL(cond_cnode);
614
615 cond_cnode->set_func_graph(while_fg);
616 auto while_fg_emitter = CreateInnerEmitter(while_fg, std::make_shared<CppInferWithPartial>());
617 MS_EXCEPTION_IF_NULL(while_fg_emitter);
618 AnfNodePtrList main_while_fg_inputs = {NewValueNode(while_fg)};
619 std::map<AnfNodePtr, ParameterPtr> param_map;
620 auto replace_by_param = [&main_while_fg_inputs, ¶m_map, &while_fg](const AnfNodePtr &inp) {
621 auto ¶m = param_map[inp];
622 if (param == nullptr) {
623 param = while_fg->add_parameter();
624 param->set_abstract(inp->abstract());
625 (void)main_while_fg_inputs.emplace_back(inp);
626 }
627 return param;
628 };
629
630 auto empty_body_func = [&init_list](Emitter *) { return init_list; };
631 auto empty_body_fg_with_inputs = BuildSubgraphOfPartial(empty_body_func);
632 for (size_t i = 1; i < empty_body_fg_with_inputs.size(); i++) {
633 auto inp = empty_body_fg_with_inputs[i]->get();
634 empty_body_fg_with_inputs[i] = while_fg_emitter->NewIrNode(replace_by_param(inp));
635 }
636 for (size_t i = 1; i < cond_cnode->size(); i++) {
637 auto inp = cond_cnode->input(i);
638 MS_EXCEPTION_IF_NULL(inp);
639 if (!inp->isa<ValueNode>()) {
640 cond_cnode->set_input(i, replace_by_param(inp));
641 }
642 }
643
644 auto body_with_inputs = BuildSubgraphOfPartial(while_body_func);
645 auto body_fg = body_with_inputs[0]->get()->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
646 for (size_t i = 1; i < body_with_inputs.size(); i++) {
647 body_with_inputs[i] = while_fg_emitter->NewIrNode(replace_by_param(body_with_inputs[i]->get()));
648 }
649 // replace the body's output to call the outside while-fg
650 AnfNodePtrList body_while_fg_inputs{NewValueNode(while_fg)};
651 if (IsPrimitiveCNode(body_fg->output(), prim::kPrimMakeTuple)) {
652 auto mt = body_fg->output()->cast<CNodePtr>();
653 MS_EXCEPTION_IF_NULL(mt);
654 (void)body_while_fg_inputs.insert(body_while_fg_inputs.end(), mt->inputs().begin() + 1, mt->inputs().end());
655 } else {
656 body_while_fg_inputs.push_back(body_fg->output());
657 }
658 if (body_while_fg_inputs.size() - 1 != init_list.size()) {
659 MS_LOG(EXCEPTION) << "The while body's output size should be equal to init_list.size(), but got "
660 << (body_while_fg_inputs.size() - 1) << " vs " << init_list.size();
661 }
662 if (body_while_fg_inputs.size() < main_while_fg_inputs.size()) {
663 for (size_t i = body_while_fg_inputs.size(); i < main_while_fg_inputs.size(); i++) {
664 auto inp = while_fg->parameters()[i - 1];
665 auto iter = std::find_if(body_with_inputs.begin(), body_with_inputs.end(),
666 [&inp](const NodePtr &no) { return no->get() == inp; });
667 if (iter != body_with_inputs.end()) {
668 auto param_idx = iter - body_with_inputs.begin() - 1;
669 body_while_fg_inputs.push_back(body_fg->parameters()[LongToSize(param_idx)]);
670 } else {
671 body_with_inputs.push_back(while_fg_emitter->NewIrNode(inp));
672 auto p = body_fg->add_parameter();
673 p->set_abstract(inp->abstract());
674 body_while_fg_inputs.push_back(p);
675 }
676 }
677 }
678 auto body_call_fg = body_fg->NewCNode(body_while_fg_inputs);
679 body_call_fg->set_abstract(out_abstract_);
680 body_fg->set_output(body_call_fg);
681
682 auto tb = while_fg_emitter->Emit("Partial", body_with_inputs);
683 auto fb = while_fg_emitter->Emit("Partial", empty_body_fg_with_inputs);
684 auto s = while_fg_emitter->Emit("Switch", {cond, tb, fb});
685 auto cnode = while_fg->NewCNode({s->get()});
686 cnode->set_abstract(out_abstract_);
687 while_fg->set_output(cnode);
688
689 auto main_cnode = func_graph_->FuncGraph::NewCNode(main_while_fg_inputs);
690 main_cnode->set_abstract(out_abstract_);
691 return emitter_->NewIrNode(main_cnode);
692 }
693
CreateInnerEmitter(const FuncGraphPtr & fg,const ExpanderInferPtr & infer) const694 EmitterPtr CtrlFlowBlock::CreateInnerEmitter(const FuncGraphPtr &fg, const ExpanderInferPtr &infer) const {
695 return emitter_creator_ ? emitter_creator_(fg, infer) : std::make_shared<IrEmitter>(fg, infer);
696 }
697
BuildSubgraph(const BlockFunc & func)698 NodePtr CtrlFlowBlock::BuildSubgraph(const BlockFunc &func) {
699 auto fg = std::make_shared<FuncGraph>();
700 MS_EXCEPTION_IF_NULL(fg);
701 fg->set_indirect(std::make_shared<bool>(true));
702 auto e = CreateInnerEmitter(fg, emitter_->infer());
703 MS_EXCEPTION_IF_NULL(e);
704 auto outputs = func(e.get());
705 if (outputs.empty()) {
706 MS_LOG(EXCEPTION) << "The block function should not return empty list.";
707 }
708 if (output_num_ == 0) {
709 output_num_ = outputs.size();
710 } else if (output_num_ != outputs.size()) {
711 MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
712 << " vs " << outputs.size() << ".";
713 }
714 NodePtr output;
715 if (output_num_ > 1) {
716 output = e->MakeTuple(outputs);
717 SetSequenceNodeElementsUseFlags(output->get(), std::make_shared<std::vector<bool>>(output_num_, true));
718 } else {
719 output = outputs[0];
720 }
721 fg->set_output(output->get());
722 if (out_abstract_ == nullptr) {
723 out_abstract_ = output->abstract();
724 }
725 return emitter_->Value(fg);
726 }
727
BuildSubgraphOfPartial(const BlockFunc & func)728 NodePtrList CtrlFlowBlock::BuildSubgraphOfPartial(const BlockFunc &func) {
729 auto fg = std::make_shared<FuncGraph>();
730 MS_EXCEPTION_IF_NULL(fg);
731 fg->set_indirect(std::make_shared<bool>(true));
732 auto sub_emitter = CreateInnerEmitter(fg, emitter_->infer());
733 MS_EXCEPTION_IF_NULL(sub_emitter);
734 auto output = func(sub_emitter.get());
735 if (output.empty()) {
736 MS_LOG(EXCEPTION) << "The block function should not return empty list.";
737 }
738 if (output_num_ == 0) {
739 output_num_ = output.size();
740 } else if (output_num_ != output.size()) {
741 MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
742 << " vs " << output.size() << ".";
743 }
744 fg->set_output((output_num_ > 1) ? sub_emitter->MakeTuple(output)->get() : output[0]->get());
745 if (out_abstract_ == nullptr) {
746 out_abstract_ = fg->output()->abstract();
747 }
748 if (output_num_ > 1) {
749 SetSequenceNodeElementsUseFlags(fg->output(), std::make_shared<std::vector<bool>>(output_num_, true));
750 }
751
752 // replace the captured inputs to parameter
753 std::function<void(const CNodePtr &)> dfs;
754 std::unordered_set<AnfNodePtr> visited;
755 std::map<AnfNodePtr, ParameterPtr> param_map;
756 NodePtrList fg_with_inputs = {emitter_->Value(fg)};
757 dfs = [&visited, &dfs, &fg, ¶m_map, &fg_with_inputs, this](const CNodePtr &node) {
758 (void)visited.insert(node);
759 for (size_t i = 0; i < node->size(); i++) {
760 auto inp = node->input(i);
761 if (inp->func_graph() == nullptr) {
762 continue;
763 }
764 if (inp->func_graph() == fg) {
765 if (inp->isa<CNode>() && visited.count(inp) == 0) {
766 dfs(inp->cast<CNodePtr>());
767 }
768 } else {
769 auto ¶m = param_map[inp];
770 if (param == nullptr) {
771 param = fg->add_parameter();
772 param->set_abstract(inp->abstract());
773 (void)fg_with_inputs.emplace_back(emitter_->NewIrNode(inp));
774 }
775 node->set_input(i, param);
776 }
777 }
778 };
779 dfs(fg->get_return());
780 return fg_with_inputs;
781 }
782
Infer(const NodePtr & node)783 void CtrlFlowBlock::CppInferWithPartial::Infer(const NodePtr &node) {
784 if (IsPrimitiveCNode(node->get(), prim::kPrimPartial) || IsPrimitiveCNode(node->get(), prim::kPrimSwitch)) {
785 return;
786 }
787 CppInfer::Infer(node);
788 }
789
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)790 NodePtr IrEmitter::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
791 AnfNodePtrList cnode_inputs = {NewValueNode(prim)};
792 cnode_inputs.reserve(inputs.size() + 1);
793 (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
794 MS_EXCEPTION_IF_NULL(no);
795 return no->get();
796 });
797 auto cnode = func_graph_->NewCNode(cnode_inputs);
798 if (scope_ != nullptr) {
799 cnode->set_scope(scope_);
800 }
801 auto node = NewIrNode(cnode->cast<AnfNodePtr>());
802 infer_->Infer(node);
803 return node;
804 }
805
EmitValue(const ValuePtr & value)806 NodePtr IrEmitter::EmitValue(const ValuePtr &value) {
807 auto node = NewIrNode(NewValueNode(value));
808 infer_->Infer(node);
809 return node;
810 }
811
operator +(const NodePtr & lhs,const NodePtr & rhs)812 NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Add(lhs, rhs); }
operator -(const NodePtr & lhs,const NodePtr & rhs)813 NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Sub(lhs, rhs); }
operator *(const NodePtr & lhs,const NodePtr & rhs)814 NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Mul(lhs, rhs); }
operator /(const NodePtr & lhs,const NodePtr & rhs)815 NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->RealDiv(lhs, rhs); }
operator -(const NodePtr & node)816 NodePtr operator-(const NodePtr &node) { return node->emitter()->Neg(node); }
817 } // namespace expander
818 } // namespace mindspore
819