• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/model/op_node.h"
17 
18 #include <cmath>
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <unordered_set>
23 #include <unordered_map>
24 #include <sstream>
25 #include <functional>
26 #include <numeric>
27 #include <utility>
28 
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "utils/anf_utils.h"
31 #include "utils/hash_map.h"
32 #include "utils/check_convert_utils.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 #include "backend/common/graph_kernel/model/node.h"
35 #include "backend/operator/ops_backend_infer_function.h"
36 #include "utils/log_adapter.h"
37 #include "ops/auto_generate/gen_ops_primitive.h"
38 
39 namespace mindspore::graphkernel::inner {
GetListInt(const ValuePtr & attr_value)40 std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
41   std::vector<int64_t> list_int;
42   const auto &vals = attr_value->cast<ValueSequencePtr>()->value();
43   (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int),
44                        [](const ValuePtr &v) { return AnfUtils::GetIntValue(v); });
45   return list_int;
46 }
47 
InferShapeWithAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & abs_list)48 BaseShapePtr InferShapeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) {
49   auto shape_optional = abstract::InferShapeByFuncImpl(prim, abs_list, true);
50   if (shape_optional.has_value()) {
51     return shape_optional.value();
52   }
53 
54   auto found = abstract::GetBackendPrimitiveInferImpl(prim);
55   if (found.has_value()) {
56     auto infer = found.value();
57     if (infer.IsImplInferShapeAndType()) {
58       return infer.InferShape(prim, abs_list);
59     }
60   }
61   MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined.";
62   return nullptr;
63 }
64 
InferTypeWithAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & abs_list)65 TypePtr InferTypeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) {
66   auto type_optional = abstract::InferTypeByFuncImpl(prim, abs_list, true);
67   if (type_optional.has_value()) {
68     return type_optional.value();
69   }
70 
71   auto found = abstract::GetBackendPrimitiveInferImpl(prim);
72   if (found.has_value()) {
73     auto infer = found.value();
74     if (infer.IsImplInferShapeAndType()) {
75       return infer.InferType(prim, abs_list);
76     }
77   }
78   MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined.";
79   return nullptr;
80 }
81 
InferValueWithAbstract(const PrimitivePtr & prim,const AbstractBasePtrList & abs_list)82 tensor::TensorPtr InferValueWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) {
83   auto value_optional = abstract::InferValueByFuncImpl(prim, abs_list);
84   if (value_optional.has_value()) {
85     return std::static_pointer_cast<tensor::Tensor>(value_optional.value());
86   }
87 
88   auto found = abstract::GetBackendPrimitiveInferImpl(prim);
89   if (found.has_value()) {
90     auto infer = found.value();
91     if (infer.IsImplInferValue()) {
92       return std::static_pointer_cast<tensor::Tensor>(infer.InferValue(prim, abs_list));
93     }
94   }
95   return nullptr;
96 }
97 
GenPrimAndAbstract(const NodePtrList & inputs,const DAttrs & attrs) const98 std::pair<PrimitivePtr, AbstractBasePtrList> PrimOp::GenPrimAndAbstract(const NodePtrList &inputs,
99                                                                         const DAttrs &attrs) const {
100   auto prim = std::make_shared<Primitive>(op_);
101   MS_EXCEPTION_IF_NULL(prim);
102   (void)prim->SetAttrs(attrs);
103   AbstractBasePtrList abs_list(inputs.size());
104   (void)std::transform(inputs.cbegin(), inputs.cend(), abs_list.begin(),
105                        [](const NodePtr &node) { return node->ToAbstract(); });
106   return std::make_pair(prim, abs_list);
107 }
108 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)109 std::vector<DShape> PrimOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
110   auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs);
111   RectifyAbstract(prim, &abs_list);
112   auto baseshape = InferShapeWithAbstract(prim, abs_list);
113   MS_EXCEPTION_IF_NULL(baseshape);
114   if (baseshape->isa<abstract::TupleShape>()) {
115     auto tuple_shape = baseshape->cast<abstract::TupleShapePtr>();
116     MS_EXCEPTION_IF_NULL(tuple_shape);
117     const auto &shape_elements = tuple_shape->shape();
118     std::vector<DShape> result(shape_elements.size());
119     (void)std::transform(shape_elements.cbegin(), shape_elements.cend(), result.begin(),
120                          [](const BaseShapePtr &s) { return s->cast<abstract::ShapePtr>()->shape(); });
121     return result;
122   }
123   auto shape = baseshape->cast<abstract::ShapePtr>();
124   if (shape != nullptr) {
125     return {shape->shape()};
126   }
127   return {DShape()};
128 }
129 
InferType(const NodePtrList & inputs,const DAttrs & attrs)130 std::vector<TypeId> PrimOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
131   auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs);
132   RectifyAbstract(prim, &abs_list);
133   auto type = InferTypeWithAbstract(prim, abs_list);
134   MS_EXCEPTION_IF_NULL(type);
135   auto get_type_id = [](const TypePtr &t) {
136     return t->isa<TensorType>() ? t->cast<TensorTypePtr>()->element()->type_id() : t->type_id();
137   };
138   if (type->isa<Tuple>()) {
139     auto elements = type->cast<TuplePtr>()->elements();
140     std::vector<TypeId> result(elements.size());
141     (void)std::transform(elements.cbegin(), elements.cend(), result.begin(), get_type_id);
142     return result;
143   }
144   return {get_type_id(type)};
145 }
146 
Infer(const NodePtrList & inputs,const DAttrs & attrs)147 NodeBaseList PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
148   Check(inputs, attrs);
149   NodeBaseList result;
150   auto format = InferFormat(inputs, attrs);
151   auto shapes = InferShape(inputs, attrs);
152   auto types = InferType(inputs, attrs);
153   if (shapes.size() != types.size()) {
154     MS_LOG(EXCEPTION) << "The num of shapes and types should be equal. (" << shapes.size() << " vs " << types.size()
155                       << ")";
156   }
157   for (size_t i = 0; i < shapes.size(); i++) {
158     (void)result.emplace_back(NodeBase{shapes[i], types[i], format});
159   }
160   return result;
161 }
162 
ToString() const163 std::string PrimOp::ToString() const {
164   std::ostringstream oss;
165   oss << Node::ToString();
166   oss << " = " << this->op_ << "(";
167   for (size_t i = 0; i < inputs_.size(); i++) {
168     if (inputs_[i]->NodeType() == NType::Primitive) {
169       oss << inputs_[i]->Node::ToString();
170     } else {
171       oss << inputs_[i]->ToString();
172     }
173     if (i != inputs_.size() - 1) {
174       oss << ", ";
175     }
176   }
177   oss << ")";
178   std::ostringstream attr_oss;
179   bool has_attr = false;
180   std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"};
181   for (auto attr : attrs_) {
182     if (attr.second != nullptr && black_list.count(attr.first) == 0) {
183       if (has_attr) {
184         attr_oss << ", ";
185       } else {
186         has_attr = true;
187       }
188       attr_oss << attr.first << ": " << attr.second->ToString();
189     }
190   }
191   if (has_attr) {
192     oss << "  // attr {" << attr_oss.str() << "}";
193   }
194   return oss.str();
195 }
196 
197 template <typename TD, typename TE>
ChangeDataToVec(const NodePtr & n)198 std::vector<TE> ChangeDataToVec(const NodePtr &n) {
199   std::vector<TE> res;
200   TD *data = static_cast<TD *>(std::static_pointer_cast<inner::ConstTensorNode>(n)->data()->data_c());
201   for (size_t elem = 0; elem < n->tensor_size(); elem++) {
202     res.push_back(static_cast<TE>(*(data + elem)));
203   }
204   return res;
205 }
206 
207 template <typename TM>
CalcByOperator(const NodePtrList & inputs,const DAttrs &) const208 tensor::TensorPtr PrimOp::CalcByOperator(const NodePtrList &inputs, const DAttrs &) const {
209   const size_t unary_input_num = 1;
210   const size_t binary_input_num = 2;
211   if (inputs.size() > 0) {
212     bool all_shape_equal =
213       std::all_of(inputs.begin(), inputs.end(), [&inputs](const NodePtr &t) { return t->shape == inputs[0]->shape; });
214     if (!all_shape_equal) {
215       return nullptr;
216     }
217   }
218   std::vector<std::vector<TM>> inputs_tm;
219   const auto &op = this->op();
220   const auto tid = this->type;
221   for (const auto &t : inputs) {
222     (void)inputs_tm.emplace_back(ChangeDataToVec<TM, TM>(t));
223   }
224   if (inputs.size() == unary_input_num) {
225     mindspore::HashMap<std::string, std::function<TM(const TM &)>> func_map = {
226       {"Abs", [](const TM &a) { return a <= TM(0) ? -a : a; }},
227       {"Exp", [](const TM &a) { return exp(a); }},
228       {"Log", [](const TM &a) { return log(a); }},
229       {"Neg", [](const TM &a) { return -a; }},
230       {"Reciprocal",
231        [](const TM &a) {
232          if (a == TM(0)) {
233            MS_LOG(EXCEPTION) << "During graph kernel constant fold for reciprocal, divisor is zero.";
234          }
235          return TM(1) / a;
236        }},
237       {"Rsqrt",
238        [](const TM &a) {
239          if (a == TM(0)) {
240            MS_LOG(EXCEPTION) << "During graph kernel constant fold for rsqrt, divisor is zero.";
241          }
242          return TM(1) / sqrt(a);
243        }},
244       {"Sqrt", [](const TM &a) { return sqrt(a); }},
245     };
246     if (func_map.find(op) == func_map.end()) {
247       return nullptr;
248     }
249     const auto &input_a = inputs_tm[0];
250     std::vector<TM> res;
251     (void)std::transform(input_a.begin(), input_a.end(), std::back_inserter(res),
252                          [&func_map, &op](const TM &i) { return func_map[op](i); });
253     return std::make_shared<tensor::Tensor>(tid, this->shape, &res[0], tid);
254   } else if (inputs.size() == binary_input_num) {
255     mindspore::HashMap<std::string, std::function<TM(const TM &, const TM &)>> func_map = {
256       {"Add", [](const TM &a, const TM &b) { return a + b; }},
257       {"Sub", [](const TM &a, const TM &b) { return a - b; }},
258       {"Mul", [](const TM &a, const TM &b) { return a * b; }},
259       {"RealDiv",
260        [](const TM &a, const TM &b) {
261          if (b == TM(0)) {
262            MS_LOG(EXCEPTION) << "During graph kernel constant fold for realdiv, divisor is zero.";
263          }
264          return a / b;
265        }},
266     };
267     if (func_map.find(op) == func_map.end()) {
268       return nullptr;
269     }
270     const auto &input_a = inputs_tm[0];
271     const auto &input_b = inputs_tm[1];
272     std::vector<TM> res;
273     for (size_t i = 0; i < input_a.size(); i++) {
274       (void)res.emplace_back(func_map[op](input_a[i], input_b[i]));
275     }
276     return std::make_shared<tensor::Tensor>(tid, this->shape, &res[0], tid);
277   }
278   return nullptr;
279 }
280 
InferValue(const NodePtrList & inputs,const DAttrs & attrs)281 NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
282   for (auto i : inputs) {
283     if (i->NodeType() != NType::Tensor) {
284       return nullptr;
285     }
286   }
287   TypeId output_type = this->type;
288   tensor::TensorPtr res = nullptr;
289   switch (static_cast<int>(output_type)) {
290     case TypeId::kNumberTypeUInt8: {
291       res = CalcByOperator<uint8_t>(inputs, attrs);
292       break;
293     }
294     case TypeId::kNumberTypeInt8: {
295       res = CalcByOperator<int8_t>(inputs, attrs);
296       break;
297     }
298     case TypeId::kNumberTypeInt16: {
299       res = CalcByOperator<int16_t>(inputs, attrs);
300       break;
301     }
302     case TypeId::kNumberTypeInt32: {
303       res = CalcByOperator<int32_t>(inputs, attrs);
304       break;
305     }
306     case TypeId::kNumberTypeInt64: {
307       res = CalcByOperator<int64_t>(inputs, attrs);
308       break;
309     }
310     case TypeId::kNumberTypeUInt16: {
311       res = CalcByOperator<uint16_t>(inputs, attrs);
312       break;
313     }
314     case TypeId::kNumberTypeUInt32: {
315       res = CalcByOperator<uint32_t>(inputs, attrs);
316       break;
317     }
318     case TypeId::kNumberTypeUInt64: {
319       res = CalcByOperator<uint64_t>(inputs, attrs);
320       break;
321     }
322     case TypeId::kNumberTypeFloat16: {
323       res = CalcByOperator<float16>(inputs, attrs);
324       break;
325     }
326     case TypeId::kNumberTypeFloat32: {
327       res = CalcByOperator<float>(inputs, attrs);
328       break;
329     }
330     case TypeId::kNumberTypeFloat64: {
331       res = CalcByOperator<double>(inputs, attrs);
332       break;
333     }
334     case TypeId::kNumberTypeBFloat16: {
335       res = CalcByOperator<bfloat16>(inputs, attrs);
336       break;
337     }
338     default:
339       return nullptr;
340   }
341   if (res == nullptr) {
342     auto [prim, inputs_abstract] = GenPrimAndAbstract(inputs, attrs);
343     RectifyAbstract(prim, &inputs_abstract);
344     res = InferValueWithAbstract(prim, inputs_abstract);
345   }
346   return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
347 }
348 
InferValue(const NodePtrList & inputs,const DAttrs &)349 NodePtr ReshapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) {
350   if (inputs[0]->NodeType() != NType::Tensor) {
351     return nullptr;
352   }
353   void *tensor_data = inputs[0]->As<inner::ConstTensorNode>()->data()->data_c();
354   tensor::TensorPtr result_tensor = std::make_shared<tensor::Tensor>(this->type, this->shape, tensor_data, this->type);
355   return std::make_shared<ConstTensorNode>(result_tensor);
356 }
357 
358 // default format shape to fractal_Nz format shape
ToNz(const DShape & default_shape)359 DShape ToNz(const DShape &default_shape) {
360   constexpr size_t nz_size = 2;
361   constexpr auto align16 = 16;
362   auto len = default_shape.size();
363   DShape leading_shape;
364   DShape tail_shape;
365   if (default_shape.size() == 1 && default_shape[0] == 1) {
366     // # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
367     return default_shape;
368   }
369   if (default_shape.size() > nz_size) {
370     (void)leading_shape.insert(leading_shape.cend(), default_shape.cbegin(),
371                                default_shape.cend() - SizeToLong(nz_size));
372   }
373   if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) {
374     // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16)
375     if (default_shape.back() % align16 != 0) {
376       MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back();
377     }
378     tail_shape = {default_shape.back() / align16, 1, 1, align16};
379   } else if (default_shape.size() >= nz_size || default_shape[1] == 1) {
380     // (N, 32, 1) -> (N, 1, 2, 16, 1)
381     if (default_shape[len - nz_size] % align16 != 0) {
382       MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size];
383     }
384     tail_shape = {1, default_shape[0] / align16, align16, 1};
385   } else {
386     // (N, 32, 48) -> (N, 3, 2, 16, 16)
387     if (default_shape.back() % align16 != 0 || default_shape[len - nz_size] % align16 != 0) {
388       MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got "
389                         << default_shape.back() << " " << default_shape[len - nz_size];
390     }
391     tail_shape = {default_shape[1] / align16, default_shape[0] / align16, align16, align16};
392   }
393   (void)leading_shape.insert(leading_shape.cend(), tail_shape.cbegin(), tail_shape.cend());
394   return leading_shape;
395 }
396 
BroadcastShape(const NodePtrList & inputs,bool to_nz=false)397 DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
398   std::vector<std::vector<int64_t>> shapes;
399   for (auto &input : inputs) {
400     if (to_nz && input->format != kOpFormat_FRAC_NZ) {
401       (void)shapes.emplace_back(ToNz(input->shape));
402     } else {
403       (void)shapes.emplace_back(input->shape);
404     }
405   }
406   auto max_dim_input =
407     std::max_element(shapes.begin(), shapes.end(),
408                      [](const std::vector<int64_t> &a, const std::vector<int64_t> &b) { return a.size() < b.size(); });
409   auto max_dim = max_dim_input->size();
410   std::vector<std::vector<int64_t>> align_shapes;
411   for (auto &s : shapes) {
412     std::vector<int64_t> cur(max_dim - s.size(), 1);
413     (void)cur.insert(cur.cend(), s.cbegin(), s.cend());
414     (void)align_shapes.emplace_back(cur);
415   }
416   std::vector<int64_t> output_shape(max_dim, 1);
417   for (size_t i = 0; i < max_dim; i++) {
418     for (auto &align_shape : align_shapes) {
419       if (align_shape[i] > 1) {
420         if (output_shape[i] == 1) {
421           output_shape[i] = align_shape[i];
422         }
423         if (output_shape[i] != align_shape[i]) {
424           MS_LOG(EXCEPTION) << "Shape broadcast failed: " << output_shape[i] << " vs " << align_shape[i];
425         }
426       }
427     }
428   }
429   return output_shape;
430 }
431 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)432 std::vector<DShape> ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
433   if (std::any_of(inputs.begin(), inputs.end(),
434                   [](const NodePtr &input) { return input->format == kOpFormat_FRAC_NZ; })) {
435     return {BroadcastShape(inputs, true)};
436   }
437   return PrimOp::InferShape(inputs, attrs);
438 }
439 
InferFormat(const NodePtrList & inputs,const DAttrs &)440 DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &) {
441   if (inputs.empty()) {
442     return kOpFormat_DEFAULT;
443   }
444   auto first_format = inputs[0]->format;
445   for (const auto &inp : inputs) {
446     auto cur_format = inp->format;
447     if (cur_format.find("FRACTAL") != std::string::npos) {
448       // special format
449       return cur_format;
450     }
451     if (cur_format != kOpFormat_DEFAULT && inp->tensor_size() != 1) {
452       return cur_format;
453     }
454   }
455   return first_format;
456 }
457 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)458 std::vector<DShape> ArgReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
459   CHECK_ATTR(attrs, "axis");
460   auto axis = GetListInt(attrs.find("axis")->second);
461   const auto &input_shape = inputs[0]->shape;
462   int64_t size = SizeToLong(input_shape.size());
463   std::vector<int64_t> real_axis;
464   (void)std::transform(axis.begin(), axis.end(), std::back_inserter(real_axis),
465                        [&size](const int64_t &x) { return x < 0 ? (x + size) : x; });
466 
467   DShape new_shape;
468   for (size_t i = 0; i < input_shape.size(); i++) {
469     if (std::find(real_axis.begin(), real_axis.end(), SizeToLong(i)) == real_axis.end()) {
470       (void)new_shape.emplace_back(input_shape[i]);
471     }
472   }
473   if (new_shape.empty()) {
474     (void)new_shape.emplace_back(1);
475   }
476   return {new_shape};
477 }
478 
InferType(const NodePtrList &,const DAttrs & attrs)479 std::vector<TypeId> ArgReduceOp::InferType(const NodePtrList &, const DAttrs &attrs) {
480   CHECK_ATTR(attrs, "output_type");
481   return {attrs.find("output_type")->second->cast<TypePtr>()->type_id()};
482 }
483 
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)484 DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
485   if (attrs.count(kAttrDstFormat) != 0) {
486     return GetValue<std::string>(attrs.find(kAttrDstFormat)->second);
487   }
488   // only support NCHW/NHWC now
489   constexpr size_t kRank4 = 4;
490   if (inputs[0]->shape.size() != kRank4) {
491     return kOpFormat_DEFAULT;
492   }
493   auto perm_node = inputs[1];
494   auto perm_tensor = perm_node->As<inner::ConstTensorNode>()->data();
495   auto perm = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm_tensor, "Transpose");
496   const auto &ori_format = inputs[0]->format;
497   if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) {
498     std::vector<int64_t> nchw2nhwc = {0, 2, 3, 1};
499     if (perm == nchw2nhwc) {
500       return kOpFormat_NHWC;
501     }
502   } else if (ori_format == kOpFormat_NHWC) {
503     std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
504     if (perm == nhwc2nchw) {
505       return kOpFormat_NCHW;
506     }
507   }
508   return kOpFormat_DEFAULT;
509 }
510 
InferValue(const NodePtrList & inputs,const DAttrs & attrs)511 NodePtr ConstantOfShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
512   for (auto i : inputs) {
513     if (i->NodeType() != NType::Tensor) {
514       return nullptr;
515     }
516   }
517   const auto &value = GetValue<std::vector<float>>(attrs.find("value")->second);
518   std::vector<float> res;
519   size_t elem_num = LongToSize(std::accumulate(this->shape.begin(), this->shape.end(), 1, std::multiplies<int64_t>()));
520   if (value.size() == 1) {
521     res = std::vector<float>(elem_num, value[0]);
522   } else if (value.size() == elem_num) {
523     res = value;
524   } else {
525     return nullptr;
526   }
527   auto tensor = std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], kNumberTypeFloat32);
528   return std::make_shared<ConstTensorNode>(tensor);
529 }
530 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)531 std::vector<DShape> ConstantOfShapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
532   const auto &value = attrs.find("shape")->second;
533   std::vector<int64_t> res;
534   if (value->isa<ValueSequence>()) {
535     res = GetValue<std::vector<int64_t>>(value);
536     return {res};
537   } else if (value->isa<tensor::Tensor>()) {
538     auto tvalue = value->cast<tensor::TensorPtr>();
539     if (tvalue->data_type_c() == static_cast<int>(TypeId::kNumberTypeInt32)) {
540       int *data = static_cast<int *>(tvalue->data_c());
541       for (size_t elem = 0; elem < tvalue->DataSize(); elem++) {
542         res.push_back(IntToLong(*(data + elem)));
543       }
544       return {res};
545     } else if (tvalue->data_type_c() == static_cast<int>(TypeId::kNumberTypeInt64)) {
546       int64_t *data = static_cast<int64_t *>(tvalue->data_c());
547       res = std::vector<int64_t>(data, data + tvalue->DataSize());
548       return {res};
549     }
550   }
551   return PrimOp::InferShape(inputs, attrs);
552 }
553 
InferValue(const NodePtrList & inputs,const DAttrs &)554 NodePtr ShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) {
555   auto tensor = std::make_shared<tensor::Tensor>(this->type, this->shape, inputs[0]->shape.data(), kNumberTypeInt64);
556   return std::make_shared<ConstTensorNode>(tensor);
557 }
558 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)559 std::vector<DShape> PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
560   std::vector<int64_t> shape0 = inputs[0]->shape;
561   size_t n = shape0.size();
562   CHECK_ATTR(attrs, "head");
563   CHECK_ATTR(attrs, "tail");
564   std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
565   std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
566   if (pad_before.size() != n || pad_after.size() != n) {
567     MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs "
568                       << pad_after.size();
569   }
570   std::vector<int64_t> output;
571   for (size_t i = 0; i < n; i++) {
572     (void)output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]);
573   }
574   return {output};
575 }
576 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)577 std::vector<DShape> UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
578   std::vector<int64_t> shape0 = inputs[0]->shape;
579   size_t n = shape0.size();
580   CHECK_ATTR(attrs, "tail");
581   std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
582   if (unpad_after.size() != n) {
583     MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
584   }
585   std::vector<int64_t> output;
586   for (size_t i = 0; i < n; i++) {
587     (void)output.emplace_back(shape0[i] - unpad_after[i]);
588   }
589   return {output};
590 }
591 
HadPad(const ShapeVector & pad_list,const std::string & pad_mode)592 bool Conv2dOp::HadPad(const ShapeVector &pad_list, const std::string &pad_mode) {
593   constexpr size_t kTop = 0;
594   constexpr size_t kBottom = 1;
595   constexpr size_t kLeft = 2;
596   constexpr size_t kRight = 3;
597 
598   if (pad_list[kTop] != pad_list[kBottom] || pad_list[kLeft] != pad_list[kRight]) {
599     return true;
600   }
601   if (pad_mode != "VALID" && pad_mode != "valid") {
602     return std::any_of(pad_list.begin(), pad_list.end(), [](auto a) { return a != 0; });
603   }
604   return false;
605 }
606 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)607 std::vector<DShape> Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
608   // get the output shape when format is NHWC/NCHW
609   if (inputs[0]->shape.size() == kDim4) {
610     CHECK_ATTR(attrs, "format");
611     if (inputs[0]->format == kOpFormat_NHWC || inputs[1]->format == kOpFormat_NHWC ||
612         GetValue<std::string>(attrs.find("format")->second) == kOpFormat_NHWC) {
613       CHECK_ATTR(attrs, "pad_mode");
614       CHECK_ATTR(attrs, "pad_list");
615       CHECK_ATTR(attrs, "kernel_size");
616       CHECK_ATTR(attrs, "stride");
617       CHECK_ATTR(attrs, "dilation");
618 
619       auto x_shape = inputs[0]->shape;
620       auto w_shape = inputs[1]->shape;
621       auto pad_mode = GetValue<std::string>(attrs.find("pad_mode")->second);
622       auto pad_list = GetListInt(attrs.find("pad_list")->second);
623       auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
624       auto stride = GetListInt(attrs.find("stride")->second);
625       auto dilation = GetListInt(attrs.find("dilation")->second);
626       constexpr size_t kPadSize = 4;
627       constexpr size_t kKernelSize = 2;
628       constexpr size_t kStrideSize = 4;
629       constexpr size_t kDilationSize = 4;
630       if (x_shape.size() != kDim4 || w_shape.size() != kDim4 || pad_list.size() != kPadSize ||
631           kernel_size.size() != kKernelSize || stride.size() != kStrideSize || dilation.size() != kDilationSize) {
632         MS_LOG(EXCEPTION) << "For 'Conv2D', got sizes of x_shape, w_shape, pad_list, kernel_size, stride and dilation: "
633                           << x_shape.size() << ", " << w_shape.size() << ", " << pad_list.size() << ", "
634                           << kernel_size.size() << ", " << stride.size() << ", " << dilation.size()
635                           << ". But expect: 4, 4, 4, 2, 4, 4";
636       }
637       auto has_pad = HadPad(pad_list, pad_mode);
638       if (!has_pad) {
639         pad_list = {0, 0, 0, 0};
640       }
641 
642       auto k_h = (kernel_size[0] - 1) * dilation[2] + 1;
643       auto k_w = (kernel_size[1] - 1) * dilation[3] + 1;
644       auto out_h = (x_shape[1] + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1;
645       auto out_w = (x_shape[2] + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1;
646       return {{x_shape[0], out_h, out_w, w_shape[3]}};
647     } else {
648       return OpaqueOp::InferShape(inputs, attrs);
649     }
650   }
651 
652   // get the output shape when format is NCHWc
653   std::vector<int64_t> data_shape = inputs[0]->shape;
654   std::vector<int64_t> weight_shape = inputs[1]->shape;
655   auto n = data_shape[0];
656   auto i_h = data_shape[2];
657   auto i_w = data_shape[3];
658   auto c_o_o = weight_shape[0];
659   auto k_h = weight_shape[2];
660   auto k_w = weight_shape[3];
661   auto c_o_i = weight_shape[5];
662 
663   CHECK_ATTR(attrs, "stride");
664   CHECK_ATTR(attrs, "dilation");
665 
666   std::vector<int64_t> strides = GetListInt(attrs.find("stride")->second);
667   std::vector<int64_t> dilations = GetListInt(attrs.find("dilation")->second);
668 
669   auto d_h = dilations[0];
670   auto d_w = dilations[1];
671   auto s_h = strides[0];
672   auto s_w = strides[1];
673   auto k_h_d = (k_h - 1) * d_h + 1;
674   auto k_w_d = (k_w - 1) * d_w + 1;
675   auto o_h = (i_h - k_h_d) / s_h + 1;
676   auto o_w = (i_w - k_w_d) / s_w + 1;
677 
678   std::vector<int64_t> output_shape{n, c_o_o, o_h, o_w, c_o_i};
679   return {output_shape};
680 }
681 
InferType(const NodePtrList & inputs,const DAttrs & attrs)682 std::vector<TypeId> Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
683   if (inputs[0]->shape.size() == kDim4) {
684     return PrimOp::InferType(inputs, attrs);
685   }
686   return {inputs[0]->type};
687 }
688 
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)689 DFormat Conv2dOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
690   if (inputs[0]->shape.size() == kDim4) {
691     return PrimOp::InferFormat(inputs, attrs);
692   }
693   CHECK_ATTR(attrs, "conv_out_format");
694   return GetValue<std::string>(attrs.find("conv_out_format")->second);
695 }
696 
RectifyAbstract(const PrimitivePtr &,AbstractBasePtrList * input_abstract_ptr)697 void ConcatOp::RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) {
698   AbstractBasePtrList rectifyed_abs_list;
699   (void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(*input_abstract_ptr));
700   input_abstract_ptr->swap(rectifyed_abs_list);
701 }
702 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)703 void ReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
704   CHECK_ATTR(prim->attrs(), "keep_dims");
705   (void)abs_list->emplace_back(prim->GetAttr("keep_dims")->ToAbstract());
706   if (prim->name() == prim::kPrimReduceSum->name()) {
707     CHECK_ATTR(prim->attrs(), "skip_mode");
708     (void)abs_list->emplace_back(prim->GetAttr("skip_mode")->ToAbstract());
709   }
710 }
711 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)712 void OneHotOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
713   CHECK_ATTR(prim->attrs(), "axis");
714   (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract());
715 }
716 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)717 void CumSumOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
718   CHECK_ATTR(prim->attrs(), "exclusive");
719   (void)abs_list->emplace_back(prim->GetAttr("exclusive")->ToAbstract());
720   CHECK_ATTR(prim->attrs(), "reverse");
721   (void)abs_list->emplace_back(prim->GetAttr("reverse")->ToAbstract());
722 }
723 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)724 void GatherOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
725   CHECK_ATTR(prim->attrs(), "batch_dims");
726   (void)abs_list->emplace_back(prim->GetAttr("batch_dims")->ToAbstract());
727 }
728 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)729 void ArgReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
730   CHECK_ATTR(prim->attrs(), "axis");
731   (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract());
732   CHECK_ATTR(prim->attrs(), "output_type");
733   (void)abs_list->emplace_back(prim->GetAttr("output_type")->ToAbstract());
734 }
735 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)736 void PagedAttentionOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
737   constexpr size_t PA_INPUT_NUM = 5;
738   constexpr size_t PA_MASK_INPUT_NUM = 6;
739   if (abs_list->size() == PA_INPUT_NUM || abs_list->size() == PA_MASK_INPUT_NUM) {
740     CHECK_ATTR(prim->attrs(), "head_num");
741     (void)abs_list->emplace_back(prim->GetAttr("head_num")->ToAbstract());
742     CHECK_ATTR(prim->attrs(), "scale_value");
743     (void)abs_list->emplace_back(prim->GetAttr("scale_value")->ToAbstract());
744     CHECK_ATTR(prim->attrs(), "kv_head_num");
745     (void)abs_list->emplace_back(prim->GetAttr("kv_head_num")->ToAbstract());
746   }
747 }
748 
CompactShape(const ShapeVector & origin,int64_t axis)749 std::vector<size_t> CompactShape(const ShapeVector &origin, int64_t axis) {
750   std::vector<size_t> new_shape;
751   size_t accu = 1;
752   for (size_t i = 0; i < origin.size(); i++) {
753     if (LongToSize(axis) == i) {
754       new_shape.push_back(accu);
755       new_shape.push_back(LongToSize(origin[i]));
756       accu = 1;
757     } else {
758       accu *= LongToSize(origin[i]);
759     }
760   }
761   new_shape.push_back(accu);
762   return new_shape;
763 }
764 
765 template <typename TM>
CalcGather(const NodePtrList & inputs,const DAttrs & attrs) const766 tensor::TensorPtr GatherOp::CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const {
767   constexpr size_t param_index = 0;
768   constexpr size_t indice_index = 1;
769   constexpr size_t axis_index = 2;
770   constexpr size_t input_num = 3;
771   constexpr size_t first_dim = 0;
772   constexpr size_t second_dim = 1;
773   constexpr size_t third_dim = 2;
774   int64_t axis = 0;
775   if (attrs.count("axis") > 0) {
776     axis = GetValue<int64_t>(attrs.find("axis")->second);
777   } else if (inputs.size() == input_num) {
778     int *data_axis =
779       static_cast<int *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[axis_index])->data()->data_c());
780     axis = IntToLong(*data_axis);
781   } else {
782     return nullptr;
783   }
784   ShapeVector param_shp = inputs[param_index]->shape;
785   axis = axis < 0 ? SizeToLong(param_shp.size()) + axis : axis;
786   std::vector<size_t> indices;
787   switch (static_cast<int>(inputs[indice_index]->type)) {
788     case TypeId::kNumberTypeInt8: {
789       indices = ChangeDataToVec<int8_t, size_t>(inputs[indice_index]);
790       break;
791     }
792     case TypeId::kNumberTypeInt16: {
793       indices = ChangeDataToVec<int16_t, size_t>(inputs[indice_index]);
794       break;
795     }
796     case TypeId::kNumberTypeInt32: {
797       indices = ChangeDataToVec<int32_t, size_t>(inputs[indice_index]);
798       break;
799     }
800     case TypeId::kNumberTypeInt64: {
801       indices = ChangeDataToVec<int64_t, size_t>(inputs[indice_index]);
802       break;
803     }
804     default:
805       return nullptr;
806   }
807 
808   TM *input_x =
809     static_cast<TM *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[param_index])->data()->data_c());
810   std::vector<size_t> compact_shp = CompactShape(param_shp, axis);
811   std::vector<TM> res;
812   if (compact_shp.size() == input_num) {
813     for (size_t i = 0; i < compact_shp[first_dim]; i++) {
814       for (auto j : indices) {
815         for (size_t k = 0; k < compact_shp[third_dim]; k++) {
816           (void)res.emplace_back(
817             input_x[i * compact_shp[second_dim] * compact_shp[third_dim] + j * compact_shp[third_dim] + k]);
818         }
819       }
820     }
821     return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
822   }
823   return nullptr;
824 }
825 
InferValue(const NodePtrList & inputs,const DAttrs & attrs)826 NodePtr GatherOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
827   for (auto i : inputs) {
828     if (i->NodeType() != NType::Tensor) {
829       return nullptr;
830     }
831   }
832   TypeId output_type = this->type;
833   tensor::TensorPtr res = nullptr;
834   switch (static_cast<int>(output_type)) {
835     case TypeId::kNumberTypeUInt8: {
836       res = CalcGather<uint8_t>(inputs, attrs);
837       break;
838     }
839     case TypeId::kNumberTypeInt8: {
840       res = CalcGather<int8_t>(inputs, attrs);
841       break;
842     }
843     case TypeId::kNumberTypeInt16: {
844       res = CalcGather<int16_t>(inputs, attrs);
845       break;
846     }
847     case TypeId::kNumberTypeInt32: {
848       res = CalcGather<int32_t>(inputs, attrs);
849       break;
850     }
851     case TypeId::kNumberTypeInt64: {
852       res = CalcGather<int64_t>(inputs, attrs);
853       break;
854     }
855     case TypeId::kNumberTypeUInt16: {
856       res = CalcGather<uint16_t>(inputs, attrs);
857       break;
858     }
859     case TypeId::kNumberTypeUInt32: {
860       res = CalcGather<uint32_t>(inputs, attrs);
861       break;
862     }
863     case TypeId::kNumberTypeUInt64: {
864       res = CalcGather<uint64_t>(inputs, attrs);
865       break;
866     }
867     case TypeId::kNumberTypeFloat16: {
868       res = CalcGather<float16>(inputs, attrs);
869       break;
870     }
871     case TypeId::kNumberTypeFloat32: {
872       res = CalcGather<float>(inputs, attrs);
873       break;
874     }
875     case TypeId::kNumberTypeFloat64: {
876       res = CalcGather<double>(inputs, attrs);
877       break;
878     }
879     case TypeId::kNumberTypeBFloat16: {
880       res = CalcGather<bfloat16>(inputs, attrs);
881       break;
882     }
883     default:
884       return nullptr;
885   }
886   return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
887 }
888 
889 template <typename TM>
CalcConcat(const NodePtrList & inputs,const DAttrs & attrs)890 tensor::TensorPtr ConcatOp::CalcConcat(const NodePtrList &inputs, const DAttrs &attrs) {
891   constexpr size_t first_dim = 0;
892   constexpr size_t second_dim = 1;
893   constexpr size_t third_dim = 2;
894   int64_t axis = 0;
895   auto axis_node = inputs.back();
896   if (axis_node->NodeType() == NType::Scalar) {
897     auto scalar_node = axis_node->As<ConstScalarNode>();
898     axis = GetValue<int64_t>(scalar_node->data());
899   } else {
900     return nullptr;
901   }
902   axis = axis < 0 ? SizeToLong(this->shape.size()) + axis : axis;
903   std::vector<std::vector<TM>> inputs_tm;
904   for (const auto &t : inputs) {
905     (void)inputs_tm.emplace_back(ChangeDataToVec<TM, TM>(t));
906   }
907   std::vector<std::vector<size_t>> all_shps;
908   (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(all_shps),
909                        [&axis](const NodePtr &t) { return CompactShape(t->shape, axis); });
910   std::vector<TM> res;
911   if (all_shps.size() > 0) {
912     const size_t third_dim_size = all_shps[0][third_dim];
913     const size_t first_dim_size = all_shps[0][first_dim];
914     for (size_t i = 0; i < first_dim_size; i++) {
915       for (size_t t = 0; t < inputs_tm.size(); t++) {
916         for (size_t j = 0; j < all_shps[t][second_dim]; j++) {
917           for (size_t k = 0; k < third_dim_size; k++) {
918             (void)res.emplace_back(inputs_tm[t][i * all_shps[t][second_dim] * third_dim_size + j * third_dim_size + k]);
919           }
920         }
921       }
922     }
923     return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
924   }
925   return nullptr;
926 }
927 
InferValue(const NodePtrList & inputs,const DAttrs & attrs)928 NodePtr ConcatOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
929   for (auto i : inputs) {
930     if (i->NodeType() != NType::Tensor) {
931       return nullptr;
932     }
933   }
934   TypeId output_type = this->type;
935   tensor::TensorPtr res = nullptr;
936   switch (static_cast<int>(output_type)) {
937     case TypeId::kNumberTypeUInt8: {
938       res = CalcConcat<uint8_t>(inputs, attrs);
939       break;
940     }
941     case TypeId::kNumberTypeInt8: {
942       res = CalcConcat<int8_t>(inputs, attrs);
943       break;
944     }
945     case TypeId::kNumberTypeInt16: {
946       res = CalcConcat<int16_t>(inputs, attrs);
947       break;
948     }
949     case TypeId::kNumberTypeInt32: {
950       res = CalcConcat<int32_t>(inputs, attrs);
951       break;
952     }
953     case TypeId::kNumberTypeInt64: {
954       res = CalcConcat<int64_t>(inputs, attrs);
955       break;
956     }
957     case TypeId::kNumberTypeUInt16: {
958       res = CalcConcat<uint16_t>(inputs, attrs);
959       break;
960     }
961     case TypeId::kNumberTypeUInt32: {
962       res = CalcConcat<uint32_t>(inputs, attrs);
963       break;
964     }
965     case TypeId::kNumberTypeUInt64: {
966       res = CalcConcat<uint64_t>(inputs, attrs);
967       break;
968     }
969     case TypeId::kNumberTypeFloat16: {
970       res = CalcConcat<float16>(inputs, attrs);
971       break;
972     }
973     case TypeId::kNumberTypeFloat32: {
974       res = CalcConcat<float>(inputs, attrs);
975       break;
976     }
977     case TypeId::kNumberTypeFloat64: {
978       res = CalcConcat<double>(inputs, attrs);
979       break;
980     }
981     case TypeId::kNumberTypeBFloat16: {
982       res = CalcConcat<bfloat16>(inputs, attrs);
983       break;
984     }
985     default:
986       return nullptr;
987   }
988   return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
989 }
990 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)991 std::vector<DShape> LayoutTransformOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
992   CHECK_ATTR(attrs, kAttrSrcFormat);
993   CHECK_ATTR(attrs, kAttrDstFormat);
994   auto src_format = GetValue<std::string>(attrs.find(kAttrSrcFormat)->second);
995   auto dst_format = GetValue<std::string>(attrs.find(kAttrDstFormat)->second);
996   std::vector<int64_t> data_shape = inputs[0]->shape;
997   if (src_format == kOpFormat_NHWC) {
998     auto n = data_shape[0];
999     auto h = data_shape[1];
1000     auto w = data_shape[2];
1001     auto c = data_shape[3];
1002     auto c_o_i = GkUtils::GetChannelInConvFormat(dst_format);
1003     if (c_o_i == 0) {
1004       c_o_i = 1;
1005     }
1006     auto c_o_o = c / c_o_i;
1007     std::vector<int64_t> output_shape{n, c_o_o, h, w, c_o_i};
1008     return {output_shape};
1009   }
1010   if (dst_format == kOpFormat_NHWC) {
1011     auto n = data_shape[0];
1012     auto c_o_o = data_shape[1];
1013     auto h = data_shape[2];
1014     auto w = data_shape[3];
1015     auto c_o_i = data_shape[4];
1016     auto c = c_o_o * c_o_i;
1017     std::vector<int64_t> output_shape{n, h, w, c};
1018     return {output_shape};
1019   }
1020   // LayoutTransform between nchwnc
1021   auto n = data_shape[0];
1022   auto c_o_o = data_shape[1];
1023   auto h = data_shape[2];
1024   auto w = data_shape[3];
1025   auto c_o_i = data_shape[4];
1026   auto c_o_i_new = GkUtils::GetChannelInConvFormat(dst_format);
1027   if (c_o_i_new == 0) {
1028     c_o_i_new = 1;
1029   }
1030   auto c_o_o_new = c_o_o * c_o_i / c_o_i_new;
1031   std::vector<int64_t> output_shape{n, c_o_o_new, h, w, c_o_i_new};
1032   return {output_shape};
1033 }
1034 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)1035 std::vector<DShape> Pool2DOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
1036   CHECK_ATTR(attrs, "global");
1037   std::vector<int64_t> input_shape = inputs[0]->shape;
1038   bool is_nhwc = input_shape.size() == 4;
1039   int64_t n = input_shape[0];
1040   int64_t c;
1041   int64_t h;
1042   int64_t w;
1043   if (is_nhwc) {
1044     constexpr size_t h_idx = 1;
1045     constexpr size_t w_idx = 2;
1046     constexpr size_t c_idx = 3;
1047     h = input_shape[h_idx];
1048     w = input_shape[w_idx];
1049     c = input_shape[c_idx];
1050   } else {
1051     constexpr size_t c_idx = 1;
1052     constexpr size_t h_idx = 2;
1053     constexpr size_t w_idx = 3;
1054     c = input_shape[c_idx];
1055     h = input_shape[h_idx];
1056     w = input_shape[w_idx];
1057   }
1058 
1059   if (GetValue<bool>(attrs.find("global")->second)) {
1060     h = 1;
1061     w = 1;
1062   } else {
1063     CHECK_ATTR(attrs, "strides");
1064     CHECK_ATTR(attrs, "kernel_size");
1065     CHECK_ATTR(attrs, "round_mode");
1066     std::vector<int64_t> strides = GetListInt(attrs.find("strides")->second);
1067     std::vector<int64_t> kernels = GetListInt(attrs.find("kernel_size")->second);
1068     if (AnfUtils::GetIntValue(attrs.find("round_mode")->second) == 0) {
1069       // ceil mode
1070       h = ((h - kernels[0] + strides[0] - 1) / strides[0]) + 1;
1071       w = ((w - kernels[1] + strides[1] - 1) / strides[1]) + 1;
1072     } else {
1073       // round mode
1074       h = ((h - kernels[0]) / strides[0]) + 1;
1075       w = ((w - kernels[1]) / strides[1]) + 1;
1076     }
1077   }
1078   if (is_nhwc) {
1079     return {{n, h, w, c}};
1080   } else {
1081     auto ci = input_shape[4];
1082     return {{n, c, h, w, ci}};
1083   }
1084 }
1085 
Check(const NodePtrList & inputs,const DAttrs &)1086 void ComplexOp::Check(const NodePtrList &inputs, const DAttrs &) {
1087   if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
1088     MS_LOG(EXCEPTION) << "Complex's input[0] should be float32, but got " << TypeIdToString(inputs[0]->type, true);
1089   }
1090   if (inputs[0]->type != inputs[1]->type) {
1091     MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch: " << TypeIdToString(inputs[0]->type, true)
1092                       << " vs " << TypeIdToString(inputs[1]->type, true);
1093   }
1094 }
1095 
InferShape(const NodePtrList &,const DAttrs & attrs)1096 std::vector<DShape> StandardNormalOp::InferShape(const NodePtrList &, const DAttrs &attrs) {
1097   CHECK_ATTR(attrs, "shape");
1098   return {GetListInt(attrs.find("shape")->second)};
1099 }
1100 
1101 template <typename TM>
CalcStridedSliceOnnx(const NodePtrList & inputs,const DAttrs &) const1102 tensor::TensorPtr StridedSliceOnnxOp::CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const {
1103   constexpr size_t input_index = 0;
1104   constexpr size_t begin_index = 1;
1105   constexpr size_t end_index = 2;
1106   constexpr size_t axes_index = 3;
1107   constexpr size_t stride_index = 4;
1108 
1109   ShapeVector input_shape = inputs[input_index]->shape;
1110   std::vector<int> begin = ChangeDataToVec<int, int>(inputs[begin_index]);
1111   std::vector<int> end = ChangeDataToVec<int, int>(inputs[end_index]);
1112   std::vector<int> axes = ChangeDataToVec<int, int>(inputs[axes_index]);
1113   std::vector<int> stride = ChangeDataToVec<int, int>(inputs[stride_index]);
1114 
1115   std::unordered_map<int, std::unordered_set<size_t>> info;
1116   for (size_t i = 0; i < axes.size(); i++) {
1117     int axis = axes[i] < 0 ? axes[i] + SizeToInt(input_shape.size()) : axes[i];
1118     if (begin[i] < 0 || end[i] < 0 || stride[i] < 0) {
1119       MS_LOG(INFO) << "Only do infervalue for StridedSliceOnnx when begin, end and stride are non-negative.";
1120       return nullptr;
1121     }
1122     std::unordered_set<size_t> pos;
1123     int index = begin[i];
1124     while (index < end[i]) {
1125       (void)pos.insert(IntToSize(index));
1126       index += stride[i];
1127     }
1128     (void)info.emplace(axis, pos);
1129   }
1130 
1131   TM *input_x =
1132     static_cast<TM *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[input_index])->data()->data_c());
1133 
1134   std::vector<TM> res;
1135 
1136   std::function<void(size_t, size_t)> func;
1137   func = [&func, &input_x, &res, &info, &input_shape](size_t dim, size_t offset) {
1138     if ((dim + 1) == input_shape.size()) {
1139       for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) {
1140         if (info.count(SizeToInt(dim)) > 0) {
1141           if (info[SizeToInt(dim)].count(i) > 0) {
1142             (void)res.emplace_back(input_x[offset + i]);
1143           }
1144         } else {
1145           (void)res.emplace_back(input_x[offset + i]);
1146         }
1147       }
1148     } else if ((dim + 1) < input_shape.size()) {
1149       size_t accu = 1;
1150       for (size_t j = dim + 1; j < input_shape.size(); j++) {
1151         accu *= LongToSize(input_shape[j]);
1152       }
1153       for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) {
1154         if (info.count(SizeToInt(dim)) > 0) {
1155           if (info[SizeToInt(dim)].count(i) > 0) {
1156             func(dim + 1, offset + i * accu);
1157           }
1158         } else {
1159           func(dim + 1, offset + i * accu);
1160         }
1161       }
1162     }
1163     return;
1164   };
1165   func(0, 0);
1166   return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
1167 }
1168 
InferValue(const NodePtrList & inputs,const DAttrs & attrs)1169 NodePtr StridedSliceOnnxOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
1170   for (auto i : inputs) {
1171     if (i->NodeType() != NType::Tensor) {
1172       return nullptr;
1173     }
1174   }
1175   TypeId output_type = this->type;
1176   tensor::TensorPtr res = nullptr;
1177   switch (static_cast<int>(output_type)) {
1178     case TypeId::kNumberTypeUInt8: {
1179       res = CalcStridedSliceOnnx<uint8_t>(inputs, attrs);
1180       break;
1181     }
1182     case TypeId::kNumberTypeInt8: {
1183       res = CalcStridedSliceOnnx<int8_t>(inputs, attrs);
1184       break;
1185     }
1186     case TypeId::kNumberTypeInt16: {
1187       res = CalcStridedSliceOnnx<int16_t>(inputs, attrs);
1188       break;
1189     }
1190     case TypeId::kNumberTypeInt32: {
1191       res = CalcStridedSliceOnnx<int32_t>(inputs, attrs);
1192       break;
1193     }
1194     case TypeId::kNumberTypeInt64: {
1195       res = CalcStridedSliceOnnx<int64_t>(inputs, attrs);
1196       break;
1197     }
1198     case TypeId::kNumberTypeUInt16: {
1199       res = CalcStridedSliceOnnx<uint16_t>(inputs, attrs);
1200       break;
1201     }
1202     case TypeId::kNumberTypeUInt32: {
1203       res = CalcStridedSliceOnnx<uint32_t>(inputs, attrs);
1204       break;
1205     }
1206     case TypeId::kNumberTypeUInt64: {
1207       res = CalcStridedSliceOnnx<uint64_t>(inputs, attrs);
1208       break;
1209     }
1210     case TypeId::kNumberTypeFloat16: {
1211       res = CalcStridedSliceOnnx<float16>(inputs, attrs);
1212       break;
1213     }
1214     case TypeId::kNumberTypeFloat32: {
1215       res = CalcStridedSliceOnnx<float>(inputs, attrs);
1216       break;
1217     }
1218     case TypeId::kNumberTypeFloat64: {
1219       res = CalcStridedSliceOnnx<double>(inputs, attrs);
1220       break;
1221     }
1222     case TypeId::kNumberTypeBFloat16: {
1223       res = CalcStridedSliceOnnx<bfloat16>(inputs, attrs);
1224       break;
1225     }
1226     default:
1227       return nullptr;
1228   }
1229   return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
1230 }
1231 
RectifyAbstract(const PrimitivePtr & prim,AbstractBasePtrList * abs_list)1232 void MatMulOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
1233   CHECK_ATTR(prim->attrs(), "transpose_a");
1234   (void)abs_list->emplace_back(prim->GetAttr("transpose_a")->ToAbstract());
1235   CHECK_ATTR(prim->attrs(), "transpose_b");
1236   (void)abs_list->emplace_back(prim->GetAttr("transpose_b")->ToAbstract());
1237 }
1238 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)1239 std::vector<DShape> MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
1240   // the prim's infer shape does not supports batch dims
1241   constexpr size_t kMatMulRank = 2;
1242   if (inputs[0]->shape.size() > kMatMulRank || inputs[1]->shape.size() > kMatMulRank) {
1243     NodePtrList new_inputs = inputs;
1244     std::vector<DShape> batches(inputs.size());
1245     auto cut_batches = [&new_inputs, &batches, kMatMulRank](size_t i) -> void {
1246       const auto &shape_i = new_inputs[i]->shape;
1247       if (shape_i.size() > kMatMulRank) {
1248         DShape real_shape(shape_i.cend() - kMatMulRank, shape_i.cend());
1249         new_inputs[i] = std::make_shared<inner::Node>(NodeBase{real_shape, new_inputs[i]->type, new_inputs[i]->format});
1250         batches[i].assign(shape_i.cbegin(), shape_i.cend() - kMatMulRank);
1251       }
1252     };
1253 
1254     cut_batches(0);
1255     cut_batches(1);
1256     if (batches[0].size() != batches[1].size()) {
1257       MS_LOG(EXCEPTION) << "The Matmul's batch rank should be equal, but got " << batches[0].size() << " vs "
1258                         << batches[1].size();
1259     }
1260     DShape batch;
1261     for (size_t i = 0; i < batches[0].size(); i++) {
1262       if (batches[0][i] != batches[1][i]) {
1263         if (batches[0][i] != 1 && batches[1][i] != 1) {
1264           MS_LOG(EXCEPTION) << "The Matmul's batch dim is unmatched. got " << inputs[0]->shape << " and "
1265                             << inputs[1]->shape;
1266         }
1267       }
1268       batch.push_back(std::max(batches[0][i], batches[1][i]));
1269     }
1270 
1271     auto out_shape = PrimOp::InferShape(new_inputs, attrs)[0];
1272     // just reuse the `batch` vector
1273     (void)batch.insert(batch.end(), out_shape.begin(), out_shape.end());
1274     return {batch};
1275   }
1276   return PrimOp::InferShape(inputs, attrs);
1277 }
1278 
InferType(const NodePtrList & inputs,const DAttrs & attrs)1279 std::vector<TypeId> MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
1280   if (attrs.count("dst_type") != 0) {
1281     return {attrs.find("dst_type")->second->cast<TypePtr>()->type_id()};
1282   }
1283   if (inputs[0]->type == TypeId::kNumberTypeInt8) {
1284     return {TypeId::kNumberTypeInt32};
1285   }
1286   return {inputs[0]->type};
1287 }
1288 }  // namespace mindspore::graphkernel::inner
1289