1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/model/op_node.h"
17
18 #include <cmath>
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <unordered_set>
23 #include <unordered_map>
24 #include <sstream>
25 #include <functional>
26 #include <numeric>
27 #include <utility>
28
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "utils/anf_utils.h"
31 #include "utils/hash_map.h"
32 #include "utils/check_convert_utils.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 #include "backend/common/graph_kernel/model/node.h"
35 #include "backend/operator/ops_backend_infer_function.h"
36 #include "utils/log_adapter.h"
37 #include "ops/auto_generate/gen_ops_primitive.h"
38
39 namespace mindspore::graphkernel::inner {
GetListInt(const ValuePtr & attr_value)40 std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
41 std::vector<int64_t> list_int;
42 const auto &vals = attr_value->cast<ValueSequencePtr>()->value();
43 (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int),
44 [](const ValuePtr &v) { return AnfUtils::GetIntValue(v); });
45 return list_int;
46 }
47
InferShapeWithAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & abs_list)48 BaseShapePtr InferShapeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) {
49 auto shape_optional = abstract::InferShapeByFuncImpl(prim, abs_list, true);
50 if (shape_optional.has_value()) {
51 return shape_optional.value();
52 }
53
54 auto found = abstract::GetBackendPrimitiveInferImpl(prim);
55 if (found.has_value()) {
56 auto infer = found.value();
57 if (infer.IsImplInferShapeAndType()) {
58 return infer.InferShape(prim, abs_list);
59 }
60 }
61 MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined.";
62 return nullptr;
63 }
64
InferTypeWithAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & abs_list)65 TypePtr InferTypeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) {
66 auto type_optional = abstract::InferTypeByFuncImpl(prim, abs_list, true);
67 if (type_optional.has_value()) {
68 return type_optional.value();
69 }
70
71 auto found = abstract::GetBackendPrimitiveInferImpl(prim);
72 if (found.has_value()) {
73 auto infer = found.value();
74 if (infer.IsImplInferShapeAndType()) {
75 return infer.InferType(prim, abs_list);
76 }
77 }
78 MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined.";
79 return nullptr;
80 }
81
InferValueWithAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & abs_list)82 tensor::TensorPtr InferValueWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) {
83 auto value_optional = abstract::InferValueByFuncImpl(prim, abs_list);
84 if (value_optional.has_value()) {
85 return std::static_pointer_cast<tensor::Tensor>(value_optional.value());
86 }
87
88 auto found = abstract::GetBackendPrimitiveInferImpl(prim);
89 if (found.has_value()) {
90 auto infer = found.value();
91 if (infer.IsImplInferValue()) {
92 return std::static_pointer_cast<tensor::Tensor>(infer.InferValue(prim, abs_list));
93 }
94 }
95 return nullptr;
96 }
97
GenPrimAndAbstract(const NodePtrList & inputs,const DAttrs & attrs) const98 std::pair<PrimitivePtr, AbstractBasePtrList> PrimOp::GenPrimAndAbstract(const NodePtrList &inputs,
99 const DAttrs &attrs) const {
100 auto prim = std::make_shared<Primitive>(op_);
101 MS_EXCEPTION_IF_NULL(prim);
102 (void)prim->SetAttrs(attrs);
103 AbstractBasePtrList abs_list(inputs.size());
104 (void)std::transform(inputs.cbegin(), inputs.cend(), abs_list.begin(),
105 [](const NodePtr &node) { return node->ToAbstract(); });
106 return std::make_pair(prim, abs_list);
107 }
108
InferShape(const NodePtrList & inputs,const DAttrs & attrs)109 std::vector<DShape> PrimOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
110 auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs);
111 RectifyAbstract(prim, &abs_list);
112 auto baseshape = InferShapeWithAbstract(prim, abs_list);
113 MS_EXCEPTION_IF_NULL(baseshape);
114 if (baseshape->isa<abstract::TupleShape>()) {
115 auto tuple_shape = baseshape->cast<abstract::TupleShapePtr>();
116 MS_EXCEPTION_IF_NULL(tuple_shape);
117 const auto &shape_elements = tuple_shape->shape();
118 std::vector<DShape> result(shape_elements.size());
119 (void)std::transform(shape_elements.cbegin(), shape_elements.cend(), result.begin(),
120 [](const BaseShapePtr &s) { return s->cast<abstract::ShapePtr>()->shape(); });
121 return result;
122 }
123 auto shape = baseshape->cast<abstract::ShapePtr>();
124 if (shape != nullptr) {
125 return {shape->shape()};
126 }
127 return {DShape()};
128 }
129
InferType(const NodePtrList & inputs,const DAttrs & attrs)130 std::vector<TypeId> PrimOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
131 auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs);
132 RectifyAbstract(prim, &abs_list);
133 auto type = InferTypeWithAbstract(prim, abs_list);
134 MS_EXCEPTION_IF_NULL(type);
135 auto get_type_id = [](const TypePtr &t) {
136 return t->isa<TensorType>() ? t->cast<TensorTypePtr>()->element()->type_id() : t->type_id();
137 };
138 if (type->isa<Tuple>()) {
139 auto elements = type->cast<TuplePtr>()->elements();
140 std::vector<TypeId> result(elements.size());
141 (void)std::transform(elements.cbegin(), elements.cend(), result.begin(), get_type_id);
142 return result;
143 }
144 return {get_type_id(type)};
145 }
146
Infer(const NodePtrList & inputs,const DAttrs & attrs)147 NodeBaseList PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
148 Check(inputs, attrs);
149 NodeBaseList result;
150 auto format = InferFormat(inputs, attrs);
151 auto shapes = InferShape(inputs, attrs);
152 auto types = InferType(inputs, attrs);
153 if (shapes.size() != types.size()) {
154 MS_LOG(EXCEPTION) << "The num of shapes and types should be equal. (" << shapes.size() << " vs " << types.size()
155 << ")";
156 }
157 for (size_t i = 0; i < shapes.size(); i++) {
158 (void)result.emplace_back(NodeBase{shapes[i], types[i], format});
159 }
160 return result;
161 }
162
ToString() const163 std::string PrimOp::ToString() const {
164 std::ostringstream oss;
165 oss << Node::ToString();
166 oss << " = " << this->op_ << "(";
167 for (size_t i = 0; i < inputs_.size(); i++) {
168 if (inputs_[i]->NodeType() == NType::Primitive) {
169 oss << inputs_[i]->Node::ToString();
170 } else {
171 oss << inputs_[i]->ToString();
172 }
173 if (i != inputs_.size() - 1) {
174 oss << ", ";
175 }
176 }
177 oss << ")";
178 std::ostringstream attr_oss;
179 bool has_attr = false;
180 std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"};
181 for (auto attr : attrs_) {
182 if (attr.second != nullptr && black_list.count(attr.first) == 0) {
183 if (has_attr) {
184 attr_oss << ", ";
185 } else {
186 has_attr = true;
187 }
188 attr_oss << attr.first << ": " << attr.second->ToString();
189 }
190 }
191 if (has_attr) {
192 oss << " // attr {" << attr_oss.str() << "}";
193 }
194 return oss.str();
195 }
196
197 template <typename TD, typename TE>
ChangeDataToVec(const NodePtr & n)198 std::vector<TE> ChangeDataToVec(const NodePtr &n) {
199 std::vector<TE> res;
200 TD *data = static_cast<TD *>(std::static_pointer_cast<inner::ConstTensorNode>(n)->data()->data_c());
201 for (size_t elem = 0; elem < n->tensor_size(); elem++) {
202 res.push_back(static_cast<TE>(*(data + elem)));
203 }
204 return res;
205 }
206
207 template <typename TM>
CalcByOperator(const NodePtrList & inputs,const DAttrs &) const208 tensor::TensorPtr PrimOp::CalcByOperator(const NodePtrList &inputs, const DAttrs &) const {
209 const size_t unary_input_num = 1;
210 const size_t binary_input_num = 2;
211 if (inputs.size() > 0) {
212 bool all_shape_equal =
213 std::all_of(inputs.begin(), inputs.end(), [&inputs](const NodePtr &t) { return t->shape == inputs[0]->shape; });
214 if (!all_shape_equal) {
215 return nullptr;
216 }
217 }
218 std::vector<std::vector<TM>> inputs_tm;
219 const auto &op = this->op();
220 const auto tid = this->type;
221 for (const auto &t : inputs) {
222 (void)inputs_tm.emplace_back(ChangeDataToVec<TM, TM>(t));
223 }
224 if (inputs.size() == unary_input_num) {
225 mindspore::HashMap<std::string, std::function<TM(const TM &)>> func_map = {
226 {"Abs", [](const TM &a) { return a <= TM(0) ? -a : a; }},
227 {"Exp", [](const TM &a) { return exp(a); }},
228 {"Log", [](const TM &a) { return log(a); }},
229 {"Neg", [](const TM &a) { return -a; }},
230 {"Reciprocal",
231 [](const TM &a) {
232 if (a == TM(0)) {
233 MS_LOG(EXCEPTION) << "During graph kernel constant fold for reciprocal, divisor is zero.";
234 }
235 return TM(1) / a;
236 }},
237 {"Rsqrt",
238 [](const TM &a) {
239 if (a == TM(0)) {
240 MS_LOG(EXCEPTION) << "During graph kernel constant fold for rsqrt, divisor is zero.";
241 }
242 return TM(1) / sqrt(a);
243 }},
244 {"Sqrt", [](const TM &a) { return sqrt(a); }},
245 };
246 if (func_map.find(op) == func_map.end()) {
247 return nullptr;
248 }
249 const auto &input_a = inputs_tm[0];
250 std::vector<TM> res;
251 (void)std::transform(input_a.begin(), input_a.end(), std::back_inserter(res),
252 [&func_map, &op](const TM &i) { return func_map[op](i); });
253 return std::make_shared<tensor::Tensor>(tid, this->shape, &res[0], tid);
254 } else if (inputs.size() == binary_input_num) {
255 mindspore::HashMap<std::string, std::function<TM(const TM &, const TM &)>> func_map = {
256 {"Add", [](const TM &a, const TM &b) { return a + b; }},
257 {"Sub", [](const TM &a, const TM &b) { return a - b; }},
258 {"Mul", [](const TM &a, const TM &b) { return a * b; }},
259 {"RealDiv",
260 [](const TM &a, const TM &b) {
261 if (b == TM(0)) {
262 MS_LOG(EXCEPTION) << "During graph kernel constant fold for realdiv, divisor is zero.";
263 }
264 return a / b;
265 }},
266 };
267 if (func_map.find(op) == func_map.end()) {
268 return nullptr;
269 }
270 const auto &input_a = inputs_tm[0];
271 const auto &input_b = inputs_tm[1];
272 std::vector<TM> res;
273 for (size_t i = 0; i < input_a.size(); i++) {
274 (void)res.emplace_back(func_map[op](input_a[i], input_b[i]));
275 }
276 return std::make_shared<tensor::Tensor>(tid, this->shape, &res[0], tid);
277 }
278 return nullptr;
279 }
280
InferValue(const NodePtrList & inputs,const DAttrs & attrs)281 NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
282 for (auto i : inputs) {
283 if (i->NodeType() != NType::Tensor) {
284 return nullptr;
285 }
286 }
287 TypeId output_type = this->type;
288 tensor::TensorPtr res = nullptr;
289 switch (static_cast<int>(output_type)) {
290 case TypeId::kNumberTypeUInt8: {
291 res = CalcByOperator<uint8_t>(inputs, attrs);
292 break;
293 }
294 case TypeId::kNumberTypeInt8: {
295 res = CalcByOperator<int8_t>(inputs, attrs);
296 break;
297 }
298 case TypeId::kNumberTypeInt16: {
299 res = CalcByOperator<int16_t>(inputs, attrs);
300 break;
301 }
302 case TypeId::kNumberTypeInt32: {
303 res = CalcByOperator<int32_t>(inputs, attrs);
304 break;
305 }
306 case TypeId::kNumberTypeInt64: {
307 res = CalcByOperator<int64_t>(inputs, attrs);
308 break;
309 }
310 case TypeId::kNumberTypeUInt16: {
311 res = CalcByOperator<uint16_t>(inputs, attrs);
312 break;
313 }
314 case TypeId::kNumberTypeUInt32: {
315 res = CalcByOperator<uint32_t>(inputs, attrs);
316 break;
317 }
318 case TypeId::kNumberTypeUInt64: {
319 res = CalcByOperator<uint64_t>(inputs, attrs);
320 break;
321 }
322 case TypeId::kNumberTypeFloat16: {
323 res = CalcByOperator<float16>(inputs, attrs);
324 break;
325 }
326 case TypeId::kNumberTypeFloat32: {
327 res = CalcByOperator<float>(inputs, attrs);
328 break;
329 }
330 case TypeId::kNumberTypeFloat64: {
331 res = CalcByOperator<double>(inputs, attrs);
332 break;
333 }
334 case TypeId::kNumberTypeBFloat16: {
335 res = CalcByOperator<bfloat16>(inputs, attrs);
336 break;
337 }
338 default:
339 return nullptr;
340 }
341 if (res == nullptr) {
342 auto [prim, inputs_abstract] = GenPrimAndAbstract(inputs, attrs);
343 RectifyAbstract(prim, &inputs_abstract);
344 res = InferValueWithAbstract(prim, inputs_abstract);
345 }
346 return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
347 }
348
InferValue(const NodePtrList & inputs,const DAttrs &)349 NodePtr ReshapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) {
350 if (inputs[0]->NodeType() != NType::Tensor) {
351 return nullptr;
352 }
353 void *tensor_data = inputs[0]->As<inner::ConstTensorNode>()->data()->data_c();
354 tensor::TensorPtr result_tensor = std::make_shared<tensor::Tensor>(this->type, this->shape, tensor_data, this->type);
355 return std::make_shared<ConstTensorNode>(result_tensor);
356 }
357
358 // default format shape to fractal_Nz format shape
ToNz(const DShape & default_shape)359 DShape ToNz(const DShape &default_shape) {
360 constexpr size_t nz_size = 2;
361 constexpr auto align16 = 16;
362 auto len = default_shape.size();
363 DShape leading_shape;
364 DShape tail_shape;
365 if (default_shape.size() == 1 && default_shape[0] == 1) {
366 // # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
367 return default_shape;
368 }
369 if (default_shape.size() > nz_size) {
370 (void)leading_shape.insert(leading_shape.cend(), default_shape.cbegin(),
371 default_shape.cend() - SizeToLong(nz_size));
372 }
373 if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) {
374 // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16)
375 if (default_shape.back() % align16 != 0) {
376 MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back();
377 }
378 tail_shape = {default_shape.back() / align16, 1, 1, align16};
379 } else if (default_shape.size() >= nz_size || default_shape[1] == 1) {
380 // (N, 32, 1) -> (N, 1, 2, 16, 1)
381 if (default_shape[len - nz_size] % align16 != 0) {
382 MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size];
383 }
384 tail_shape = {1, default_shape[0] / align16, align16, 1};
385 } else {
386 // (N, 32, 48) -> (N, 3, 2, 16, 16)
387 if (default_shape.back() % align16 != 0 || default_shape[len - nz_size] % align16 != 0) {
388 MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got "
389 << default_shape.back() << " " << default_shape[len - nz_size];
390 }
391 tail_shape = {default_shape[1] / align16, default_shape[0] / align16, align16, align16};
392 }
393 (void)leading_shape.insert(leading_shape.cend(), tail_shape.cbegin(), tail_shape.cend());
394 return leading_shape;
395 }
396
BroadcastShape(const NodePtrList & inputs,bool to_nz=false)397 DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
398 std::vector<std::vector<int64_t>> shapes;
399 for (auto &input : inputs) {
400 if (to_nz && input->format != kOpFormat_FRAC_NZ) {
401 (void)shapes.emplace_back(ToNz(input->shape));
402 } else {
403 (void)shapes.emplace_back(input->shape);
404 }
405 }
406 auto max_dim_input =
407 std::max_element(shapes.begin(), shapes.end(),
408 [](const std::vector<int64_t> &a, const std::vector<int64_t> &b) { return a.size() < b.size(); });
409 auto max_dim = max_dim_input->size();
410 std::vector<std::vector<int64_t>> align_shapes;
411 for (auto &s : shapes) {
412 std::vector<int64_t> cur(max_dim - s.size(), 1);
413 (void)cur.insert(cur.cend(), s.cbegin(), s.cend());
414 (void)align_shapes.emplace_back(cur);
415 }
416 std::vector<int64_t> output_shape(max_dim, 1);
417 for (size_t i = 0; i < max_dim; i++) {
418 for (auto &align_shape : align_shapes) {
419 if (align_shape[i] > 1) {
420 if (output_shape[i] == 1) {
421 output_shape[i] = align_shape[i];
422 }
423 if (output_shape[i] != align_shape[i]) {
424 MS_LOG(EXCEPTION) << "Shape broadcast failed: " << output_shape[i] << " vs " << align_shape[i];
425 }
426 }
427 }
428 }
429 return output_shape;
430 }
431
InferShape(const NodePtrList & inputs,const DAttrs & attrs)432 std::vector<DShape> ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
433 if (std::any_of(inputs.begin(), inputs.end(),
434 [](const NodePtr &input) { return input->format == kOpFormat_FRAC_NZ; })) {
435 return {BroadcastShape(inputs, true)};
436 }
437 return PrimOp::InferShape(inputs, attrs);
438 }
439
InferFormat(const NodePtrList & inputs,const DAttrs &)440 DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &) {
441 if (inputs.empty()) {
442 return kOpFormat_DEFAULT;
443 }
444 auto first_format = inputs[0]->format;
445 for (const auto &inp : inputs) {
446 auto cur_format = inp->format;
447 if (cur_format.find("FRACTAL") != std::string::npos) {
448 // special format
449 return cur_format;
450 }
451 if (cur_format != kOpFormat_DEFAULT && inp->tensor_size() != 1) {
452 return cur_format;
453 }
454 }
455 return first_format;
456 }
457
InferShape(const NodePtrList & inputs,const DAttrs & attrs)458 std::vector<DShape> ArgReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
459 CHECK_ATTR(attrs, "axis");
460 auto axis = GetListInt(attrs.find("axis")->second);
461 const auto &input_shape = inputs[0]->shape;
462 int64_t size = SizeToLong(input_shape.size());
463 std::vector<int64_t> real_axis;
464 (void)std::transform(axis.begin(), axis.end(), std::back_inserter(real_axis),
465 [&size](const int64_t &x) { return x < 0 ? (x + size) : x; });
466
467 DShape new_shape;
468 for (size_t i = 0; i < input_shape.size(); i++) {
469 if (std::find(real_axis.begin(), real_axis.end(), SizeToLong(i)) == real_axis.end()) {
470 (void)new_shape.emplace_back(input_shape[i]);
471 }
472 }
473 if (new_shape.empty()) {
474 (void)new_shape.emplace_back(1);
475 }
476 return {new_shape};
477 }
478
InferType(const NodePtrList &,const DAttrs & attrs)479 std::vector<TypeId> ArgReduceOp::InferType(const NodePtrList &, const DAttrs &attrs) {
480 CHECK_ATTR(attrs, "output_type");
481 return {attrs.find("output_type")->second->cast<TypePtr>()->type_id()};
482 }
483
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)484 DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
485 if (attrs.count(kAttrDstFormat) != 0) {
486 return GetValue<std::string>(attrs.find(kAttrDstFormat)->second);
487 }
488 // only support NCHW/NHWC now
489 constexpr size_t kRank4 = 4;
490 if (inputs[0]->shape.size() != kRank4) {
491 return kOpFormat_DEFAULT;
492 }
493 auto perm_node = inputs[1];
494 auto perm_tensor = perm_node->As<inner::ConstTensorNode>()->data();
495 auto perm = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm_tensor, "Transpose");
496 const auto &ori_format = inputs[0]->format;
497 if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) {
498 std::vector<int64_t> nchw2nhwc = {0, 2, 3, 1};
499 if (perm == nchw2nhwc) {
500 return kOpFormat_NHWC;
501 }
502 } else if (ori_format == kOpFormat_NHWC) {
503 std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
504 if (perm == nhwc2nchw) {
505 return kOpFormat_NCHW;
506 }
507 }
508 return kOpFormat_DEFAULT;
509 }
510
InferValue(const NodePtrList & inputs,const DAttrs & attrs)511 NodePtr ConstantOfShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
512 for (auto i : inputs) {
513 if (i->NodeType() != NType::Tensor) {
514 return nullptr;
515 }
516 }
517 const auto &value = GetValue<std::vector<float>>(attrs.find("value")->second);
518 std::vector<float> res;
519 size_t elem_num = LongToSize(std::accumulate(this->shape.begin(), this->shape.end(), 1, std::multiplies<int64_t>()));
520 if (value.size() == 1) {
521 res = std::vector<float>(elem_num, value[0]);
522 } else if (value.size() == elem_num) {
523 res = value;
524 } else {
525 return nullptr;
526 }
527 auto tensor = std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], kNumberTypeFloat32);
528 return std::make_shared<ConstTensorNode>(tensor);
529 }
530
InferShape(const NodePtrList & inputs,const DAttrs & attrs)531 std::vector<DShape> ConstantOfShapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
532 const auto &value = attrs.find("shape")->second;
533 std::vector<int64_t> res;
534 if (value->isa<ValueSequence>()) {
535 res = GetValue<std::vector<int64_t>>(value);
536 return {res};
537 } else if (value->isa<tensor::Tensor>()) {
538 auto tvalue = value->cast<tensor::TensorPtr>();
539 if (tvalue->data_type_c() == static_cast<int>(TypeId::kNumberTypeInt32)) {
540 int *data = static_cast<int *>(tvalue->data_c());
541 for (size_t elem = 0; elem < tvalue->DataSize(); elem++) {
542 res.push_back(IntToLong(*(data + elem)));
543 }
544 return {res};
545 } else if (tvalue->data_type_c() == static_cast<int>(TypeId::kNumberTypeInt64)) {
546 int64_t *data = static_cast<int64_t *>(tvalue->data_c());
547 res = std::vector<int64_t>(data, data + tvalue->DataSize());
548 return {res};
549 }
550 }
551 return PrimOp::InferShape(inputs, attrs);
552 }
553
InferValue(const NodePtrList & inputs,const DAttrs &)554 NodePtr ShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) {
555 auto tensor = std::make_shared<tensor::Tensor>(this->type, this->shape, inputs[0]->shape.data(), kNumberTypeInt64);
556 return std::make_shared<ConstTensorNode>(tensor);
557 }
558
InferShape(const NodePtrList & inputs,const DAttrs & attrs)559 std::vector<DShape> PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
560 std::vector<int64_t> shape0 = inputs[0]->shape;
561 size_t n = shape0.size();
562 CHECK_ATTR(attrs, "head");
563 CHECK_ATTR(attrs, "tail");
564 std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
565 std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
566 if (pad_before.size() != n || pad_after.size() != n) {
567 MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs "
568 << pad_after.size();
569 }
570 std::vector<int64_t> output;
571 for (size_t i = 0; i < n; i++) {
572 (void)output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]);
573 }
574 return {output};
575 }
576
InferShape(const NodePtrList & inputs,const DAttrs & attrs)577 std::vector<DShape> UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
578 std::vector<int64_t> shape0 = inputs[0]->shape;
579 size_t n = shape0.size();
580 CHECK_ATTR(attrs, "tail");
581 std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
582 if (unpad_after.size() != n) {
583 MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
584 }
585 std::vector<int64_t> output;
586 for (size_t i = 0; i < n; i++) {
587 (void)output.emplace_back(shape0[i] - unpad_after[i]);
588 }
589 return {output};
590 }
591
HadPad(const ShapeVector & pad_list,const std::string & pad_mode)592 bool Conv2dOp::HadPad(const ShapeVector &pad_list, const std::string &pad_mode) {
593 constexpr size_t kTop = 0;
594 constexpr size_t kBottom = 1;
595 constexpr size_t kLeft = 2;
596 constexpr size_t kRight = 3;
597
598 if (pad_list[kTop] != pad_list[kBottom] || pad_list[kLeft] != pad_list[kRight]) {
599 return true;
600 }
601 if (pad_mode != "VALID" && pad_mode != "valid") {
602 return std::any_of(pad_list.begin(), pad_list.end(), [](auto a) { return a != 0; });
603 }
604 return false;
605 }
606
InferShape(const NodePtrList & inputs,const DAttrs & attrs)607 std::vector<DShape> Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
608 // get the output shape when format is NHWC/NCHW
609 if (inputs[0]->shape.size() == kDim4) {
610 CHECK_ATTR(attrs, "format");
611 if (inputs[0]->format == kOpFormat_NHWC || inputs[1]->format == kOpFormat_NHWC ||
612 GetValue<std::string>(attrs.find("format")->second) == kOpFormat_NHWC) {
613 CHECK_ATTR(attrs, "pad_mode");
614 CHECK_ATTR(attrs, "pad_list");
615 CHECK_ATTR(attrs, "kernel_size");
616 CHECK_ATTR(attrs, "stride");
617 CHECK_ATTR(attrs, "dilation");
618
619 auto x_shape = inputs[0]->shape;
620 auto w_shape = inputs[1]->shape;
621 auto pad_mode = GetValue<std::string>(attrs.find("pad_mode")->second);
622 auto pad_list = GetListInt(attrs.find("pad_list")->second);
623 auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
624 auto stride = GetListInt(attrs.find("stride")->second);
625 auto dilation = GetListInt(attrs.find("dilation")->second);
626 constexpr size_t kPadSize = 4;
627 constexpr size_t kKernelSize = 2;
628 constexpr size_t kStrideSize = 4;
629 constexpr size_t kDilationSize = 4;
630 if (x_shape.size() != kDim4 || w_shape.size() != kDim4 || pad_list.size() != kPadSize ||
631 kernel_size.size() != kKernelSize || stride.size() != kStrideSize || dilation.size() != kDilationSize) {
632 MS_LOG(EXCEPTION) << "For 'Conv2D', got sizes of x_shape, w_shape, pad_list, kernel_size, stride and dilation: "
633 << x_shape.size() << ", " << w_shape.size() << ", " << pad_list.size() << ", "
634 << kernel_size.size() << ", " << stride.size() << ", " << dilation.size()
635 << ". But expect: 4, 4, 4, 2, 4, 4";
636 }
637 auto has_pad = HadPad(pad_list, pad_mode);
638 if (!has_pad) {
639 pad_list = {0, 0, 0, 0};
640 }
641
642 auto k_h = (kernel_size[0] - 1) * dilation[2] + 1;
643 auto k_w = (kernel_size[1] - 1) * dilation[3] + 1;
644 auto out_h = (x_shape[1] + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1;
645 auto out_w = (x_shape[2] + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1;
646 return {{x_shape[0], out_h, out_w, w_shape[3]}};
647 } else {
648 return OpaqueOp::InferShape(inputs, attrs);
649 }
650 }
651
652 // get the output shape when format is NCHWc
653 std::vector<int64_t> data_shape = inputs[0]->shape;
654 std::vector<int64_t> weight_shape = inputs[1]->shape;
655 auto n = data_shape[0];
656 auto i_h = data_shape[2];
657 auto i_w = data_shape[3];
658 auto c_o_o = weight_shape[0];
659 auto k_h = weight_shape[2];
660 auto k_w = weight_shape[3];
661 auto c_o_i = weight_shape[5];
662
663 CHECK_ATTR(attrs, "stride");
664 CHECK_ATTR(attrs, "dilation");
665
666 std::vector<int64_t> strides = GetListInt(attrs.find("stride")->second);
667 std::vector<int64_t> dilations = GetListInt(attrs.find("dilation")->second);
668
669 auto d_h = dilations[0];
670 auto d_w = dilations[1];
671 auto s_h = strides[0];
672 auto s_w = strides[1];
673 auto k_h_d = (k_h - 1) * d_h + 1;
674 auto k_w_d = (k_w - 1) * d_w + 1;
675 auto o_h = (i_h - k_h_d) / s_h + 1;
676 auto o_w = (i_w - k_w_d) / s_w + 1;
677
678 std::vector<int64_t> output_shape{n, c_o_o, o_h, o_w, c_o_i};
679 return {output_shape};
680 }
681
InferType(const NodePtrList & inputs,const DAttrs & attrs)682 std::vector<TypeId> Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
683 if (inputs[0]->shape.size() == kDim4) {
684 return PrimOp::InferType(inputs, attrs);
685 }
686 return {inputs[0]->type};
687 }
688
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)689 DFormat Conv2dOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
690 if (inputs[0]->shape.size() == kDim4) {
691 return PrimOp::InferFormat(inputs, attrs);
692 }
693 CHECK_ATTR(attrs, "conv_out_format");
694 return GetValue<std::string>(attrs.find("conv_out_format")->second);
695 }
696
RectifyAbstract(const PrimitivePtr &,AbstractBasePtrList * input_abstract_ptr)697 void ConcatOp::RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) {
698 AbstractBasePtrList rectifyed_abs_list;
699 (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(*input_abstract_ptr));
700 input_abstract_ptr->swap(rectifyed_abs_list);
701 }
702
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)703 void ReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
704 CHECK_ATTR(prim->attrs(), "keep_dims");
705 (void)abs_list->emplace_back(prim->GetAttr("keep_dims")->ToAbstract());
706 if (prim->name() == prim::kPrimReduceSum->name()) {
707 CHECK_ATTR(prim->attrs(), "skip_mode");
708 (void)abs_list->emplace_back(prim->GetAttr("skip_mode")->ToAbstract());
709 }
710 }
711
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)712 void OneHotOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
713 CHECK_ATTR(prim->attrs(), "axis");
714 (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract());
715 }
716
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)717 void CumSumOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
718 CHECK_ATTR(prim->attrs(), "exclusive");
719 (void)abs_list->emplace_back(prim->GetAttr("exclusive")->ToAbstract());
720 CHECK_ATTR(prim->attrs(), "reverse");
721 (void)abs_list->emplace_back(prim->GetAttr("reverse")->ToAbstract());
722 }
723
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)724 void GatherOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
725 CHECK_ATTR(prim->attrs(), "batch_dims");
726 (void)abs_list->emplace_back(prim->GetAttr("batch_dims")->ToAbstract());
727 }
728
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)729 void ArgReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
730 CHECK_ATTR(prim->attrs(), "axis");
731 (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract());
732 CHECK_ATTR(prim->attrs(), "output_type");
733 (void)abs_list->emplace_back(prim->GetAttr("output_type")->ToAbstract());
734 }
735
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)736 void PagedAttentionOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
737 constexpr size_t PA_INPUT_NUM = 5;
738 constexpr size_t PA_MASK_INPUT_NUM = 6;
739 if (abs_list->size() == PA_INPUT_NUM || abs_list->size() == PA_MASK_INPUT_NUM) {
740 CHECK_ATTR(prim->attrs(), "head_num");
741 (void)abs_list->emplace_back(prim->GetAttr("head_num")->ToAbstract());
742 CHECK_ATTR(prim->attrs(), "scale_value");
743 (void)abs_list->emplace_back(prim->GetAttr("scale_value")->ToAbstract());
744 CHECK_ATTR(prim->attrs(), "kv_head_num");
745 (void)abs_list->emplace_back(prim->GetAttr("kv_head_num")->ToAbstract());
746 }
747 }
748
CompactShape(const ShapeVector & origin,int64_t axis)749 std::vector<size_t> CompactShape(const ShapeVector &origin, int64_t axis) {
750 std::vector<size_t> new_shape;
751 size_t accu = 1;
752 for (size_t i = 0; i < origin.size(); i++) {
753 if (LongToSize(axis) == i) {
754 new_shape.push_back(accu);
755 new_shape.push_back(LongToSize(origin[i]));
756 accu = 1;
757 } else {
758 accu *= LongToSize(origin[i]);
759 }
760 }
761 new_shape.push_back(accu);
762 return new_shape;
763 }
764
765 template <typename TM>
CalcGather(const NodePtrList & inputs,const DAttrs & attrs) const766 tensor::TensorPtr GatherOp::CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const {
767 constexpr size_t param_index = 0;
768 constexpr size_t indice_index = 1;
769 constexpr size_t axis_index = 2;
770 constexpr size_t input_num = 3;
771 constexpr size_t first_dim = 0;
772 constexpr size_t second_dim = 1;
773 constexpr size_t third_dim = 2;
774 int64_t axis = 0;
775 if (attrs.count("axis") > 0) {
776 axis = GetValue<int64_t>(attrs.find("axis")->second);
777 } else if (inputs.size() == input_num) {
778 int *data_axis =
779 static_cast<int *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[axis_index])->data()->data_c());
780 axis = IntToLong(*data_axis);
781 } else {
782 return nullptr;
783 }
784 ShapeVector param_shp = inputs[param_index]->shape;
785 axis = axis < 0 ? SizeToLong(param_shp.size()) + axis : axis;
786 std::vector<size_t> indices;
787 switch (static_cast<int>(inputs[indice_index]->type)) {
788 case TypeId::kNumberTypeInt8: {
789 indices = ChangeDataToVec<int8_t, size_t>(inputs[indice_index]);
790 break;
791 }
792 case TypeId::kNumberTypeInt16: {
793 indices = ChangeDataToVec<int16_t, size_t>(inputs[indice_index]);
794 break;
795 }
796 case TypeId::kNumberTypeInt32: {
797 indices = ChangeDataToVec<int32_t, size_t>(inputs[indice_index]);
798 break;
799 }
800 case TypeId::kNumberTypeInt64: {
801 indices = ChangeDataToVec<int64_t, size_t>(inputs[indice_index]);
802 break;
803 }
804 default:
805 return nullptr;
806 }
807
808 TM *input_x =
809 static_cast<TM *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[param_index])->data()->data_c());
810 std::vector<size_t> compact_shp = CompactShape(param_shp, axis);
811 std::vector<TM> res;
812 if (compact_shp.size() == input_num) {
813 for (size_t i = 0; i < compact_shp[first_dim]; i++) {
814 for (auto j : indices) {
815 for (size_t k = 0; k < compact_shp[third_dim]; k++) {
816 (void)res.emplace_back(
817 input_x[i * compact_shp[second_dim] * compact_shp[third_dim] + j * compact_shp[third_dim] + k]);
818 }
819 }
820 }
821 return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
822 }
823 return nullptr;
824 }
825
InferValue(const NodePtrList & inputs,const DAttrs & attrs)826 NodePtr GatherOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
827 for (auto i : inputs) {
828 if (i->NodeType() != NType::Tensor) {
829 return nullptr;
830 }
831 }
832 TypeId output_type = this->type;
833 tensor::TensorPtr res = nullptr;
834 switch (static_cast<int>(output_type)) {
835 case TypeId::kNumberTypeUInt8: {
836 res = CalcGather<uint8_t>(inputs, attrs);
837 break;
838 }
839 case TypeId::kNumberTypeInt8: {
840 res = CalcGather<int8_t>(inputs, attrs);
841 break;
842 }
843 case TypeId::kNumberTypeInt16: {
844 res = CalcGather<int16_t>(inputs, attrs);
845 break;
846 }
847 case TypeId::kNumberTypeInt32: {
848 res = CalcGather<int32_t>(inputs, attrs);
849 break;
850 }
851 case TypeId::kNumberTypeInt64: {
852 res = CalcGather<int64_t>(inputs, attrs);
853 break;
854 }
855 case TypeId::kNumberTypeUInt16: {
856 res = CalcGather<uint16_t>(inputs, attrs);
857 break;
858 }
859 case TypeId::kNumberTypeUInt32: {
860 res = CalcGather<uint32_t>(inputs, attrs);
861 break;
862 }
863 case TypeId::kNumberTypeUInt64: {
864 res = CalcGather<uint64_t>(inputs, attrs);
865 break;
866 }
867 case TypeId::kNumberTypeFloat16: {
868 res = CalcGather<float16>(inputs, attrs);
869 break;
870 }
871 case TypeId::kNumberTypeFloat32: {
872 res = CalcGather<float>(inputs, attrs);
873 break;
874 }
875 case TypeId::kNumberTypeFloat64: {
876 res = CalcGather<double>(inputs, attrs);
877 break;
878 }
879 case TypeId::kNumberTypeBFloat16: {
880 res = CalcGather<bfloat16>(inputs, attrs);
881 break;
882 }
883 default:
884 return nullptr;
885 }
886 return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
887 }
888
889 template <typename TM>
CalcConcat(const NodePtrList & inputs,const DAttrs & attrs)890 tensor::TensorPtr ConcatOp::CalcConcat(const NodePtrList &inputs, const DAttrs &attrs) {
891 constexpr size_t first_dim = 0;
892 constexpr size_t second_dim = 1;
893 constexpr size_t third_dim = 2;
894 int64_t axis = 0;
895 auto axis_node = inputs.back();
896 if (axis_node->NodeType() == NType::Scalar) {
897 auto scalar_node = axis_node->As<ConstScalarNode>();
898 axis = GetValue<int64_t>(scalar_node->data());
899 } else {
900 return nullptr;
901 }
902 axis = axis < 0 ? SizeToLong(this->shape.size()) + axis : axis;
903 std::vector<std::vector<TM>> inputs_tm;
904 for (const auto &t : inputs) {
905 (void)inputs_tm.emplace_back(ChangeDataToVec<TM, TM>(t));
906 }
907 std::vector<std::vector<size_t>> all_shps;
908 (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(all_shps),
909 [&axis](const NodePtr &t) { return CompactShape(t->shape, axis); });
910 std::vector<TM> res;
911 if (all_shps.size() > 0) {
912 const size_t third_dim_size = all_shps[0][third_dim];
913 const size_t first_dim_size = all_shps[0][first_dim];
914 for (size_t i = 0; i < first_dim_size; i++) {
915 for (size_t t = 0; t < inputs_tm.size(); t++) {
916 for (size_t j = 0; j < all_shps[t][second_dim]; j++) {
917 for (size_t k = 0; k < third_dim_size; k++) {
918 (void)res.emplace_back(inputs_tm[t][i * all_shps[t][second_dim] * third_dim_size + j * third_dim_size + k]);
919 }
920 }
921 }
922 }
923 return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
924 }
925 return nullptr;
926 }
927
InferValue(const NodePtrList & inputs,const DAttrs & attrs)928 NodePtr ConcatOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
929 for (auto i : inputs) {
930 if (i->NodeType() != NType::Tensor) {
931 return nullptr;
932 }
933 }
934 TypeId output_type = this->type;
935 tensor::TensorPtr res = nullptr;
936 switch (static_cast<int>(output_type)) {
937 case TypeId::kNumberTypeUInt8: {
938 res = CalcConcat<uint8_t>(inputs, attrs);
939 break;
940 }
941 case TypeId::kNumberTypeInt8: {
942 res = CalcConcat<int8_t>(inputs, attrs);
943 break;
944 }
945 case TypeId::kNumberTypeInt16: {
946 res = CalcConcat<int16_t>(inputs, attrs);
947 break;
948 }
949 case TypeId::kNumberTypeInt32: {
950 res = CalcConcat<int32_t>(inputs, attrs);
951 break;
952 }
953 case TypeId::kNumberTypeInt64: {
954 res = CalcConcat<int64_t>(inputs, attrs);
955 break;
956 }
957 case TypeId::kNumberTypeUInt16: {
958 res = CalcConcat<uint16_t>(inputs, attrs);
959 break;
960 }
961 case TypeId::kNumberTypeUInt32: {
962 res = CalcConcat<uint32_t>(inputs, attrs);
963 break;
964 }
965 case TypeId::kNumberTypeUInt64: {
966 res = CalcConcat<uint64_t>(inputs, attrs);
967 break;
968 }
969 case TypeId::kNumberTypeFloat16: {
970 res = CalcConcat<float16>(inputs, attrs);
971 break;
972 }
973 case TypeId::kNumberTypeFloat32: {
974 res = CalcConcat<float>(inputs, attrs);
975 break;
976 }
977 case TypeId::kNumberTypeFloat64: {
978 res = CalcConcat<double>(inputs, attrs);
979 break;
980 }
981 case TypeId::kNumberTypeBFloat16: {
982 res = CalcConcat<bfloat16>(inputs, attrs);
983 break;
984 }
985 default:
986 return nullptr;
987 }
988 return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
989 }
990
InferShape(const NodePtrList & inputs,const DAttrs & attrs)991 std::vector<DShape> LayoutTransformOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
992 CHECK_ATTR(attrs, kAttrSrcFormat);
993 CHECK_ATTR(attrs, kAttrDstFormat);
994 auto src_format = GetValue<std::string>(attrs.find(kAttrSrcFormat)->second);
995 auto dst_format = GetValue<std::string>(attrs.find(kAttrDstFormat)->second);
996 std::vector<int64_t> data_shape = inputs[0]->shape;
997 if (src_format == kOpFormat_NHWC) {
998 auto n = data_shape[0];
999 auto h = data_shape[1];
1000 auto w = data_shape[2];
1001 auto c = data_shape[3];
1002 auto c_o_i = GkUtils::GetChannelInConvFormat(dst_format);
1003 if (c_o_i == 0) {
1004 c_o_i = 1;
1005 }
1006 auto c_o_o = c / c_o_i;
1007 std::vector<int64_t> output_shape{n, c_o_o, h, w, c_o_i};
1008 return {output_shape};
1009 }
1010 if (dst_format == kOpFormat_NHWC) {
1011 auto n = data_shape[0];
1012 auto c_o_o = data_shape[1];
1013 auto h = data_shape[2];
1014 auto w = data_shape[3];
1015 auto c_o_i = data_shape[4];
1016 auto c = c_o_o * c_o_i;
1017 std::vector<int64_t> output_shape{n, h, w, c};
1018 return {output_shape};
1019 }
1020 // LayoutTransform between nchwnc
1021 auto n = data_shape[0];
1022 auto c_o_o = data_shape[1];
1023 auto h = data_shape[2];
1024 auto w = data_shape[3];
1025 auto c_o_i = data_shape[4];
1026 auto c_o_i_new = GkUtils::GetChannelInConvFormat(dst_format);
1027 if (c_o_i_new == 0) {
1028 c_o_i_new = 1;
1029 }
1030 auto c_o_o_new = c_o_o * c_o_i / c_o_i_new;
1031 std::vector<int64_t> output_shape{n, c_o_o_new, h, w, c_o_i_new};
1032 return {output_shape};
1033 }
1034
InferShape(const NodePtrList & inputs,const DAttrs & attrs)1035 std::vector<DShape> Pool2DOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
1036 CHECK_ATTR(attrs, "global");
1037 std::vector<int64_t> input_shape = inputs[0]->shape;
1038 bool is_nhwc = input_shape.size() == 4;
1039 int64_t n = input_shape[0];
1040 int64_t c;
1041 int64_t h;
1042 int64_t w;
1043 if (is_nhwc) {
1044 constexpr size_t h_idx = 1;
1045 constexpr size_t w_idx = 2;
1046 constexpr size_t c_idx = 3;
1047 h = input_shape[h_idx];
1048 w = input_shape[w_idx];
1049 c = input_shape[c_idx];
1050 } else {
1051 constexpr size_t c_idx = 1;
1052 constexpr size_t h_idx = 2;
1053 constexpr size_t w_idx = 3;
1054 c = input_shape[c_idx];
1055 h = input_shape[h_idx];
1056 w = input_shape[w_idx];
1057 }
1058
1059 if (GetValue<bool>(attrs.find("global")->second)) {
1060 h = 1;
1061 w = 1;
1062 } else {
1063 CHECK_ATTR(attrs, "strides");
1064 CHECK_ATTR(attrs, "kernel_size");
1065 CHECK_ATTR(attrs, "round_mode");
1066 std::vector<int64_t> strides = GetListInt(attrs.find("strides")->second);
1067 std::vector<int64_t> kernels = GetListInt(attrs.find("kernel_size")->second);
1068 if (AnfUtils::GetIntValue(attrs.find("round_mode")->second) == 0) {
1069 // ceil mode
1070 h = ((h - kernels[0] + strides[0] - 1) / strides[0]) + 1;
1071 w = ((w - kernels[1] + strides[1] - 1) / strides[1]) + 1;
1072 } else {
1073 // round mode
1074 h = ((h - kernels[0]) / strides[0]) + 1;
1075 w = ((w - kernels[1]) / strides[1]) + 1;
1076 }
1077 }
1078 if (is_nhwc) {
1079 return {{n, h, w, c}};
1080 } else {
1081 auto ci = input_shape[4];
1082 return {{n, c, h, w, ci}};
1083 }
1084 }
1085
Check(const NodePtrList & inputs,const DAttrs &)1086 void ComplexOp::Check(const NodePtrList &inputs, const DAttrs &) {
1087 if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
1088 MS_LOG(EXCEPTION) << "Complex's input[0] should be float32, but got " << TypeIdToString(inputs[0]->type, true);
1089 }
1090 if (inputs[0]->type != inputs[1]->type) {
1091 MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch: " << TypeIdToString(inputs[0]->type, true)
1092 << " vs " << TypeIdToString(inputs[1]->type, true);
1093 }
1094 }
1095
InferShape(const NodePtrList &,const DAttrs & attrs)1096 std::vector<DShape> StandardNormalOp::InferShape(const NodePtrList &, const DAttrs &attrs) {
1097 CHECK_ATTR(attrs, "shape");
1098 return {GetListInt(attrs.find("shape")->second)};
1099 }
1100
1101 template <typename TM>
CalcStridedSliceOnnx(const NodePtrList & inputs,const DAttrs &) const1102 tensor::TensorPtr StridedSliceOnnxOp::CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const {
1103 constexpr size_t input_index = 0;
1104 constexpr size_t begin_index = 1;
1105 constexpr size_t end_index = 2;
1106 constexpr size_t axes_index = 3;
1107 constexpr size_t stride_index = 4;
1108
1109 ShapeVector input_shape = inputs[input_index]->shape;
1110 std::vector<int> begin = ChangeDataToVec<int, int>(inputs[begin_index]);
1111 std::vector<int> end = ChangeDataToVec<int, int>(inputs[end_index]);
1112 std::vector<int> axes = ChangeDataToVec<int, int>(inputs[axes_index]);
1113 std::vector<int> stride = ChangeDataToVec<int, int>(inputs[stride_index]);
1114
1115 std::unordered_map<int, std::unordered_set<size_t>> info;
1116 for (size_t i = 0; i < axes.size(); i++) {
1117 int axis = axes[i] < 0 ? axes[i] + SizeToInt(input_shape.size()) : axes[i];
1118 if (begin[i] < 0 || end[i] < 0 || stride[i] < 0) {
1119 MS_LOG(INFO) << "Only do infervalue for StridedSliceOnnx when begin, end and stride are non-negative.";
1120 return nullptr;
1121 }
1122 std::unordered_set<size_t> pos;
1123 int index = begin[i];
1124 while (index < end[i]) {
1125 (void)pos.insert(IntToSize(index));
1126 index += stride[i];
1127 }
1128 (void)info.emplace(axis, pos);
1129 }
1130
1131 TM *input_x =
1132 static_cast<TM *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[input_index])->data()->data_c());
1133
1134 std::vector<TM> res;
1135
1136 std::function<void(size_t, size_t)> func;
1137 func = [&func, &input_x, &res, &info, &input_shape](size_t dim, size_t offset) {
1138 if ((dim + 1) == input_shape.size()) {
1139 for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) {
1140 if (info.count(SizeToInt(dim)) > 0) {
1141 if (info[SizeToInt(dim)].count(i) > 0) {
1142 (void)res.emplace_back(input_x[offset + i]);
1143 }
1144 } else {
1145 (void)res.emplace_back(input_x[offset + i]);
1146 }
1147 }
1148 } else if ((dim + 1) < input_shape.size()) {
1149 size_t accu = 1;
1150 for (size_t j = dim + 1; j < input_shape.size(); j++) {
1151 accu *= LongToSize(input_shape[j]);
1152 }
1153 for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) {
1154 if (info.count(SizeToInt(dim)) > 0) {
1155 if (info[SizeToInt(dim)].count(i) > 0) {
1156 func(dim + 1, offset + i * accu);
1157 }
1158 } else {
1159 func(dim + 1, offset + i * accu);
1160 }
1161 }
1162 }
1163 return;
1164 };
1165 func(0, 0);
1166 return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
1167 }
1168
InferValue(const NodePtrList & inputs,const DAttrs & attrs)1169 NodePtr StridedSliceOnnxOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
1170 for (auto i : inputs) {
1171 if (i->NodeType() != NType::Tensor) {
1172 return nullptr;
1173 }
1174 }
1175 TypeId output_type = this->type;
1176 tensor::TensorPtr res = nullptr;
1177 switch (static_cast<int>(output_type)) {
1178 case TypeId::kNumberTypeUInt8: {
1179 res = CalcStridedSliceOnnx<uint8_t>(inputs, attrs);
1180 break;
1181 }
1182 case TypeId::kNumberTypeInt8: {
1183 res = CalcStridedSliceOnnx<int8_t>(inputs, attrs);
1184 break;
1185 }
1186 case TypeId::kNumberTypeInt16: {
1187 res = CalcStridedSliceOnnx<int16_t>(inputs, attrs);
1188 break;
1189 }
1190 case TypeId::kNumberTypeInt32: {
1191 res = CalcStridedSliceOnnx<int32_t>(inputs, attrs);
1192 break;
1193 }
1194 case TypeId::kNumberTypeInt64: {
1195 res = CalcStridedSliceOnnx<int64_t>(inputs, attrs);
1196 break;
1197 }
1198 case TypeId::kNumberTypeUInt16: {
1199 res = CalcStridedSliceOnnx<uint16_t>(inputs, attrs);
1200 break;
1201 }
1202 case TypeId::kNumberTypeUInt32: {
1203 res = CalcStridedSliceOnnx<uint32_t>(inputs, attrs);
1204 break;
1205 }
1206 case TypeId::kNumberTypeUInt64: {
1207 res = CalcStridedSliceOnnx<uint64_t>(inputs, attrs);
1208 break;
1209 }
1210 case TypeId::kNumberTypeFloat16: {
1211 res = CalcStridedSliceOnnx<float16>(inputs, attrs);
1212 break;
1213 }
1214 case TypeId::kNumberTypeFloat32: {
1215 res = CalcStridedSliceOnnx<float>(inputs, attrs);
1216 break;
1217 }
1218 case TypeId::kNumberTypeFloat64: {
1219 res = CalcStridedSliceOnnx<double>(inputs, attrs);
1220 break;
1221 }
1222 case TypeId::kNumberTypeBFloat16: {
1223 res = CalcStridedSliceOnnx<bfloat16>(inputs, attrs);
1224 break;
1225 }
1226 default:
1227 return nullptr;
1228 }
1229 return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
1230 }
1231
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)1232 void MatMulOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
1233 CHECK_ATTR(prim->attrs(), "transpose_a");
1234 (void)abs_list->emplace_back(prim->GetAttr("transpose_a")->ToAbstract());
1235 CHECK_ATTR(prim->attrs(), "transpose_b");
1236 (void)abs_list->emplace_back(prim->GetAttr("transpose_b")->ToAbstract());
1237 }
1238
InferShape(const NodePtrList & inputs,const DAttrs & attrs)1239 std::vector<DShape> MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
1240 // the prim's infer shape does not supports batch dims
1241 constexpr size_t kMatMulRank = 2;
1242 if (inputs[0]->shape.size() > kMatMulRank || inputs[1]->shape.size() > kMatMulRank) {
1243 NodePtrList new_inputs = inputs;
1244 std::vector<DShape> batches(inputs.size());
1245 auto cut_batches = [&new_inputs, &batches, kMatMulRank](size_t i) -> void {
1246 const auto &shape_i = new_inputs[i]->shape;
1247 if (shape_i.size() > kMatMulRank) {
1248 DShape real_shape(shape_i.cend() - kMatMulRank, shape_i.cend());
1249 new_inputs[i] = std::make_shared<inner::Node>(NodeBase{real_shape, new_inputs[i]->type, new_inputs[i]->format});
1250 batches[i].assign(shape_i.cbegin(), shape_i.cend() - kMatMulRank);
1251 }
1252 };
1253
1254 cut_batches(0);
1255 cut_batches(1);
1256 if (batches[0].size() != batches[1].size()) {
1257 MS_LOG(EXCEPTION) << "The Matmul's batch rank should be equal, but got " << batches[0].size() << " vs "
1258 << batches[1].size();
1259 }
1260 DShape batch;
1261 for (size_t i = 0; i < batches[0].size(); i++) {
1262 if (batches[0][i] != batches[1][i]) {
1263 if (batches[0][i] != 1 && batches[1][i] != 1) {
1264 MS_LOG(EXCEPTION) << "The Matmul's batch dim is unmatched. got " << inputs[0]->shape << " and "
1265 << inputs[1]->shape;
1266 }
1267 }
1268 batch.push_back(std::max(batches[0][i], batches[1][i]));
1269 }
1270
1271 auto out_shape = PrimOp::InferShape(new_inputs, attrs)[0];
1272 // just reuse the `batch` vector
1273 (void)batch.insert(batch.end(), out_shape.begin(), out_shape.end());
1274 return {batch};
1275 }
1276 return PrimOp::InferShape(inputs, attrs);
1277 }
1278
InferType(const NodePtrList & inputs,const DAttrs & attrs)1279 std::vector<TypeId> MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
1280 if (attrs.count("dst_type") != 0) {
1281 return {attrs.find("dst_type")->second->cast<TypePtr>()->type_id()};
1282 }
1283 if (inputs[0]->type == TypeId::kNumberTypeInt8) {
1284 return {TypeId::kNumberTypeInt32};
1285 }
1286 return {inputs[0]->type};
1287 }
1288 } // namespace mindspore::graphkernel::inner
1289