• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/model/op_node.h"
17 
18 #include "backend/optimizer/graph_kernel/model/node.h"
19 
20 namespace mindspore {
21 namespace opt {
22 namespace graphkernel {
GetListInt(const ValuePtr & attr_value)23 std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
24   bool is_int64 = true;
25   auto get_int_value = [&is_int64](const ValuePtr &value) -> int64_t {
26     if (value->isa<Int64Imm>()) {
27       return GetValue<int64_t>(value);
28     }
29     is_int64 = false;
30     return static_cast<int64_t>(GetValue<int>(value));
31   };
32   std::vector<int64_t> list_int;
33   const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
34   (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
35   if (!is_int64) {
36     MS_LOG(WARNING) << "Vector type should be 'int64_t' but got 'int'";
37   }
38   return list_int;
39 }
40 
Check(const NodePtrList & inputs,const DAttrs & attrs)41 void PrimOp::Check(const NodePtrList &inputs, const DAttrs &attrs) {
42   CheckShape(inputs, attrs);
43   CheckType(inputs, attrs);
44   CheckFormat(inputs, attrs);
45 }
46 
47 // check all type to be identical
CheckType(const NodePtrList & inputs,const DAttrs & attrs)48 void PrimOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
49   TypeId tid = inputs[0]->type;
50   for (size_t i = 1; i < inputs.size(); i++) {
51     if (inputs[i]->type != tid) {
52       MS_LOG(EXCEPTION) << "Incompatible dtype between input " << 0 << "and" << i;
53     }
54   }
55 }
56 
57 // check all formats are compatible, only DefaultForant is compatible with others
CheckFormat(const NodePtrList & inputs,const DAttrs & attrs)58 void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) {
59   DFormat res = inputs[0]->format;
60   size_t i = 0;
61   for (size_t j = 1; j < inputs.size(); j++) {
62     if (inputs[j]->format != res) {
63       if (inputs[j]->format != kOpFormat_DEFAULT && res != kOpFormat_DEFAULT) {
64         MS_LOG(EXCEPTION) << "Incompatible format between input " << i << "and" << (j + 1);
65       }
66       if (res == kOpFormat_DEFAULT) {
67         res = inputs[j]->format;
68         i = j + 1;
69       }
70     }
71   }
72 }
73 
Infer(const NodePtrList & inputs,const DAttrs & attrs)74 NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
75   Check(inputs, attrs);
76   NodeBase nodebase;
77   nodebase.shape = InferShape(inputs, attrs);
78   nodebase.type = InferType(inputs, attrs);
79   nodebase.format = InferFormat(inputs, attrs);
80   return nodebase;
81 }
82 
Dump(std::ostringstream & os) const83 void PrimOp::Dump(std::ostringstream &os) const {
84   DumpTensor(os);
85   os << " = " << this->op_ << "(";
86   for (size_t i = 0; i < inputs_.size(); i++) {
87     inputs_[i]->DumpTensor(os);
88     if (i != inputs_.size() - 1) os << ", ";
89   }
90   os << ")";
91   std::ostringstream attr_os;
92   bool has_attr = false;
93   std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"};
94   for (auto attr : attrs_) {
95     if (attr.second != nullptr && black_list.count(attr.first) == 0) {
96       if (has_attr) {
97         attr_os << ", ";
98       } else {
99         has_attr = true;
100       }
101       attr_os << attr.first << ": " << attr.second->ToString();
102     }
103   }
104   if (has_attr) {
105     os << "  // attr {" << attr_os.str() << "}";
106   }
107 }
108 
109 template <typename TM, typename TD>
CalcByOperator(const NodePtrList & inputs,const std::string & op,TypeId tid)110 tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &op, TypeId tid) {
111   std::vector<TM> inputs_tm;
112   std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_tm), [](const NodePtr &i) {
113     return *static_cast<TM *>(std::static_pointer_cast<graphkernel::ConstTensorNode>(i)->data()->data_c());
114   });
115 
116   std::unordered_map<std::string, std::function<TM(const std::vector<TM> &)>> func_map = {
117     {"Add", [](const std::vector<TM> &n) { return n[0] + n[1]; }},
118     {"Sub", [](const std::vector<TM> &n) { return n[0] - n[1]; }},
119     {"Mul", [](const std::vector<TM> &n) { return n[0] * n[1]; }},
120     {"RealDiv", [](const std::vector<TM> &n) { return n[0] / n[1]; }},
121     {"Neg", [](const std::vector<TM> &n) { return TM(0) - n[0]; }},
122     {"Reciprocal", [](const std::vector<TM> &n) { return TM(1) / n[0]; }},
123     {"Log", [](const std::vector<TM> &n) { return log(n[0]); }},
124     {"Exp", [](const std::vector<TM> &n) { return exp(n[0]); }},
125     {"Abs", [](const std::vector<TM> &n) { return n[0] < TM(0) ? (TM(0) - n[0]) : n[0]; }},
126     {"Sqrt", [](const std::vector<TM> &n) { return sqrt(n[0]); }},
127     {"Rsqrt", [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); }},
128   };
129   if (func_map.find(op) == func_map.end()) return nullptr;
130   return std::make_shared<tensor::Tensor>(static_cast<TD>(func_map[op](inputs_tm)), TypeIdToType(tid));
131 }
132 
InferValue(const NodePtrList & inputs,const DAttrs & attrs,const std::string & op)133 NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op) {
134   for (auto i : inputs) {
135     if (i->NodeType() != NType::Value) return nullptr;
136   }
137   TypeId output_type = this->type;
138   tensor::TensorPtr res = nullptr;
139   switch (output_type) {
140     case TypeId::kNumberTypeUInt8: {
141       res = CalcByOperator<uint8_t, int64_t>(inputs, op, output_type);
142       break;
143     }
144     case TypeId::kNumberTypeInt8: {
145       res = CalcByOperator<int8_t, int64_t>(inputs, op, output_type);
146       break;
147     }
148     case TypeId::kNumberTypeInt16: {
149       res = CalcByOperator<int16_t, int64_t>(inputs, op, output_type);
150       break;
151     }
152     case TypeId::kNumberTypeInt32: {
153       res = CalcByOperator<int32_t, int64_t>(inputs, op, output_type);
154       break;
155     }
156     case TypeId::kNumberTypeInt64: {
157       res = CalcByOperator<int64_t, int64_t>(inputs, op, output_type);
158       break;
159     }
160     case TypeId::kNumberTypeUInt16: {
161       res = CalcByOperator<uint16_t, int64_t>(inputs, op, output_type);
162       break;
163     }
164     case TypeId::kNumberTypeUInt32: {
165       res = CalcByOperator<uint32_t, int64_t>(inputs, op, output_type);
166       break;
167     }
168     case TypeId::kNumberTypeUInt64: {
169       res = CalcByOperator<uint64_t, int64_t>(inputs, op, output_type);
170       break;
171     }
172     case TypeId::kNumberTypeFloat16: {
173       res = CalcByOperator<float16, double>(inputs, op, output_type);
174       break;
175     }
176     case TypeId::kNumberTypeFloat32: {
177       res = CalcByOperator<float, double>(inputs, op, output_type);
178       break;
179     }
180     case TypeId::kNumberTypeFloat64: {
181       res = CalcByOperator<double, double>(inputs, op, output_type);
182       break;
183     }
184     default:
185       return nullptr;
186   }
187   return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
188 }
189 
190 // default format shape to fractal_Nz format shape
ToNz(const DShape & default_shape)191 DShape ToNz(const DShape &default_shape) {
192   constexpr size_t nz_size = 2;
193   auto len = default_shape.size();
194   DShape leading_shape;
195   DShape tail_shape;
196   if (default_shape.size() > nz_size) {
197     (void)leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - SizeToLong(nz_size));
198   }
199   if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) {
200     // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16)
201     if (default_shape.back() % 16 != 0) {
202       MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back();
203     }
204     tail_shape = {default_shape.back() / 16, 1, 1, 16};
205   } else if (default_shape.size() >= nz_size || default_shape[1] == 1) {
206     // (N, 32, 1) -> (N, 1, 2, 16, 1)
207     if (default_shape[len - nz_size] % 16 != 0) {
208       MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size];
209     }
210     tail_shape = {1, default_shape[0] / 16, 16, 1};
211   } else {
212     // (N, 32, 48) -> (N, 3, 2, 16, 16)
213     if (default_shape.back() % 16 != 0 || default_shape[len - nz_size] % 16 != 0) {
214       MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got "
215                         << default_shape.back() << " " << default_shape[len - nz_size];
216     }
217     tail_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
218   }
219   (void)leading_shape.insert(leading_shape.end(), tail_shape.begin(), tail_shape.end());
220   return leading_shape;
221 }
222 
BroadcastShape(const NodePtrList & inputs,bool to_nz=false)223 DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
224   std::vector<std::vector<int64_t>> shapes;
225   for (auto &input : inputs) {
226     if (to_nz && input->format != kOpFormat_FRAC_NZ) {
227       shapes.emplace_back(ToNz(input->shape));
228     } else {
229       shapes.emplace_back(input->shape);
230     }
231   }
232   auto max_dim_input =
233     std::max_element(shapes.begin(), shapes.end(),
234                      [](const std::vector<int64_t> &a, const std::vector<int64_t> &b) { return a.size() < b.size(); });
235   auto max_dim = max_dim_input->size();
236   std::vector<std::vector<int64_t>> align_shapes;
237   for (auto &s : shapes) {
238     std::vector<int64_t> cur(max_dim - s.size(), 1);
239     cur.insert(cur.end(), s.begin(), s.end());
240     (void)align_shapes.emplace_back(cur);
241   }
242   std::vector<int64_t> output_shape(max_dim, 1);
243   for (size_t i = 0; i < max_dim; i++) {
244     for (auto &align_shape : align_shapes) {
245       if (align_shape[i] > 1) {
246         if (output_shape[i] == 1) {
247           output_shape[i] = align_shape[i];
248         }
249         if (output_shape[i] != align_shape[i]) {
250           MS_LOG(EXCEPTION) << "Shape broadcast failed. " << output_shape[i] << " vs " << align_shape[i];
251         }
252       }
253     }
254   }
255   return output_shape;
256 }
257 
InferShape(const NodePtrList & inputs,const DAttrs &)258 DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &) {
259   if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) {
260         return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC || input->format == kOpFormat_NCHW;
261       })) {
262     return BroadcastShape(inputs, false);
263   }
264   if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) {
265         return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC ||
266                input->format == kOpFormat_NCHW || input->format == kOpFormat_FRAC_NZ;
267       })) {
268     return BroadcastShape(inputs, true);
269   }
270   MS_LOG(EXCEPTION) << "Unsupported format.";
271 }
272 
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)273 DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
274   auto it = std::find_if(inputs.begin(), inputs.end(), [](const NodePtr &i) { return i->format != kOpFormat_DEFAULT; });
275   return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format;
276 }
277 
Infer(const NodePtrList & inputs,const DAttrs & attrs)278 NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
279   auto nodebase = PrimOp::Infer(inputs, attrs);
280   auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
281     for (auto &ref : inputs) {
282       if (ref->shape.size() != this->shape.size()) return true;
283       for (size_t i = 0; i < this->shape.size(); ++i) {
284         if (ref->shape[i] != this->shape[i]) return true;
285       }
286     }
287     return false;
288   };
289   compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
290   return nodebase;
291 }
292 
InferType(const NodePtrList & inputs,const DAttrs & attrs)293 TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
294   CHECK_ATTR(attrs, "dst_type");
295   auto dst_type = attrs.find("dst_type")->second;
296   if (dst_type->isa<Type>()) {
297     return dst_type->cast<TypePtr>()->type_id();
298   }
299   return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
300 }
301 
CheckType(const NodePtrList & inputs,const DAttrs &)302 void SelectOp::CheckType(const NodePtrList &inputs, const DAttrs &) {
303   if (inputs[0]->type != TypeId::kNumberTypeBool) {
304     MS_LOG(EXCEPTION) << "Select's input[0] should be bool type";
305   }
306   if (inputs[1]->type != inputs[2]->type) {
307     MS_LOG(EXCEPTION) << "Select's input[1] and input[2]'s type doesn't match";
308   }
309 }
310 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)311 DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
312   CHECK_ATTR(attrs, "shape");
313   auto new_shape = GetListInt(attrs.find("shape")->second);
314   auto origin_shape = inputs[0]->shape;
315   auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>());
316   auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>());
317   for (size_t i = 0; i < new_shape.size(); i++) {
318     if (new_shape[i] == -1) {
319       new_shape[i] = origin_product / new_product * (-1);
320       return new_shape;
321     }
322   }
323   if (origin_product != new_product) {
324     MS_LOG(EXCEPTION) << "The shape product before and after reshaping should be equal";
325   }
326   return new_shape;
327 }
328 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)329 DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
330   CHECK_ATTR(attrs, "shape");
331   return GetListInt(attrs.find("shape")->second);
332 }
333 
334 // check rudece axis in range [-size,size)
Check(const NodePtrList & inputs,const DAttrs & attrs)335 void ReduceOp::Check(const NodePtrList &inputs, const DAttrs &attrs) {
336   PrimOp::Check(inputs, attrs);
337   CHECK_ATTR(attrs, "axis");
338   auto axis = GetListInt(attrs.find("axis")->second);
339   int64_t size = static_cast<int64_t>(inputs[0]->shape.size());
340   auto it = std::find_if(axis.begin(), axis.end(), [&size](const int64_t &i) { return (i >= size || i < (-size)); });
341   if (it != axis.end()) {
342     MS_LOG(EXCEPTION) << "reduce_axis should be in range [" << (-size) << "," << size << ")"
343                       << ",but got " << (*it);
344   }
345 }
346 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)347 DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
348   CHECK_ATTR(attrs, "axis");
349   CHECK_ATTR(attrs, "keep_dims");
350   auto axis = GetListInt(attrs.find("axis")->second);
351   auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
352   if (keepdims) {
353     DShape new_shape = inputs[0]->shape;
354     for (auto x : axis) {
355       new_shape[LongToSize(x)] = 1;
356     }
357     return new_shape;
358   }
359   DShape new_shape;
360   const auto &input_shape = inputs[0]->shape;
361   for (size_t i = 0; i < input_shape.size(); i++) {
362     if (std::find(axis.begin(), axis.end(), i) == axis.end()) {
363       new_shape.emplace_back(input_shape[i]);
364     }
365   }
366   if (new_shape.empty()) {
367     new_shape.emplace_back(1);
368   }
369   return new_shape;
370 }
371 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)372 DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
373   auto check_nd = [](const std::vector<int64_t> &shape, size_t n) {
374     if (shape.size() != n) {
375       MS_LOG(EXCEPTION) << "input dimension should be " << n << ", but got  " << shape.size();
376     }
377   };
378   auto shape0 = inputs[0]->shape;
379   auto shape1 = inputs[1]->shape;
380   check_nd(shape0, 4);
381   check_nd(shape1, 4);
382   CHECK_ATTR(attrs, "format");
383   if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC &&
384       GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) {
385     MS_LOG(EXCEPTION) << "check NHWC format failed";
386   }
387   auto n = shape0[0];
388   auto h = shape0[1];
389   auto w = shape0[2];
390   auto out_channel = shape1[0];
391   CHECK_ATTR(attrs, "pad_list");
392   CHECK_ATTR(attrs, "pad_mode");
393   CHECK_ATTR(attrs, "kernel_size");
394   CHECK_ATTR(attrs, "stride");
395   CHECK_ATTR(attrs, "dilation");
396   auto pad_list = GetListInt(attrs.find("pad_list")->second);
397   auto pad_mode = GetValue<std::string>(attrs.find("pad_mode")->second);
398   auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
399   auto stride = GetListInt(attrs.find("stride")->second);
400   auto dilation = GetListInt(attrs.find("dilation")->second);
401   check_nd(pad_list, 4);
402   check_nd(kernel_size, 2);
403   check_nd(stride, 4);
404   check_nd(dilation, 4);
405   bool has_pad = false;
406   if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
407     has_pad = true;
408   } else {
409     if (pad_mode == "VALID" || pad_mode == "valid") {
410       if (std::any_of(pad_list.begin(), pad_list.end(), [](int i) { return i == 0; })) {
411         has_pad = true;
412       }
413     }
414   }
415   if (!has_pad) {
416     pad_list = {0, 0, 0, 0};
417   }
418   auto k_h = (kernel_size[0] - 1) * dilation[2] + 1;
419   auto k_w = (kernel_size[1] - 1) * dilation[3] + 1;
420   auto out_h = (h + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1;
421   auto out_w = (w + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1;
422   std::vector<int64_t> output = {n, out_h, out_w, out_channel};
423   return output;
424 }
425 
InferType(const NodePtrList & inputs,const DAttrs & attrs)426 TypeId Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
427   if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
428   auto dst_type = attrs.find("dst_type")->second;
429   if (dst_type->isa<Type>()) {
430     return dst_type->cast<TypePtr>()->type_id();
431   }
432   return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
433 }
434 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)435 DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
436   CHECK_ATTR(attrs, "perm");
437   auto perm = GetListInt(attrs.find("perm")->second);
438   auto &old_shape = inputs[0]->shape;
439   DShape new_shape;
440   if (perm.size() != old_shape.size()) {
441     MS_LOG(EXCEPTION) << "perm.size() != old_shape.size(). " << perm.size() << " vs " << old_shape.size();
442   }
443   std::transform(perm.begin(), perm.end(), std::back_inserter(new_shape),
444                  [&old_shape](int64_t p) { return old_shape[LongToSize(p)]; });
445   return new_shape;
446 }
447 
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)448 DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
449   if (inputs[0]->shape.size() != 4) return kOpFormat_DEFAULT;
450   CHECK_ATTR(attrs, "perm");
451   auto perm = GetListInt(attrs.find("perm")->second);
452   const auto &ori_format = inputs[0]->format;
453   if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) {
454     std::vector<int64_t> nchw2nhwc = {0, 2, 3, 1};
455     if (perm == nchw2nhwc) return kOpFormat_NHWC;
456   } else if (ori_format == kOpFormat_NHWC) {
457     std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
458     if (perm == nhwc2nchw) return kOpFormat_DEFAULT;
459   }
460   return kOpFormat_DEFAULT;
461 }
462 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)463 DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
464   std::vector<int64_t> shape0 = inputs[0]->shape;
465   std::vector<int64_t> shape1 = inputs[1]->shape;
466   if (shape0.size() != 2 || shape1.size() != 2) {
467     MS_LOG(EXCEPTION) << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
468   }
469   CHECK_ATTR(attrs, "transpose_a");
470   CHECK_ATTR(attrs, "transpose_b");
471   auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second);
472   auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second);
473   int64_t m = transpose_a ? shape0[1] : shape0[0];
474   int64_t k1 = transpose_a ? shape0[0] : shape0[1];
475   int64_t k2 = transpose_b ? shape1[1] : shape1[0];
476   int64_t n = transpose_b ? shape1[0] : shape1[1];
477   if (k1 != k2) {
478     MS_LOG(EXCEPTION) << "MatMul's inputs have different k value " << k1 << " vs " << k2;
479   }
480   std::vector<int64_t> output = {m, n};
481   return output;
482 }
483 
InferType(const NodePtrList & inputs,const DAttrs & attrs)484 TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
485   CHECK_ATTR(attrs, "dst_type");
486   if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
487   auto dst_type = attrs.find("dst_type")->second;
488   if (dst_type->isa<Type>()) {
489     return dst_type->cast<TypePtr>()->type_id();
490   }
491   return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
492 }
493 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)494 DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
495   std::vector<int64_t> shape0 = inputs[0]->shape;
496   size_t n = shape0.size();
497   CHECK_ATTR(attrs, "head");
498   CHECK_ATTR(attrs, "tail");
499   std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
500   std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
501   if (pad_before.size() != n || pad_after.size() != n) {
502     MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs "
503                       << pad_after.size();
504   }
505   std::vector<int64_t> output;
506   for (size_t i = 0; i < n; i++) {
507     output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]);
508   }
509   return output;
510 }
511 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)512 DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
513   std::vector<int64_t> shape0 = inputs[0]->shape;
514   size_t n = shape0.size();
515   CHECK_ATTR(attrs, "tail");
516   std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
517   if (unpad_after.size() != n) {
518     MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
519   }
520   std::vector<int64_t> output;
521   for (size_t i = 0; i < n; i++) {
522     output.emplace_back(shape0[i] - unpad_after[i]);
523   }
524   return output;
525 }
526 
CheckType(const NodePtrList & inputs,const DAttrs & attrs)527 void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
528   if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
529     MS_LOG(EXCEPTION) << "Complex's input[0] should be float32";
530   }
531   if (inputs[0]->type != inputs[1]->type) {
532     MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
533   }
534 }
535 
InferShape(const NodePtrList &,const DAttrs & attrs)536 DShape StandardNormalOp::InferShape(const NodePtrList &, const DAttrs &attrs) {
537   CHECK_ATTR(attrs, "shape");
538   return GetListInt(attrs.find("shape")->second);
539 }
540 }  // namespace graphkernel
541 }  // namespace opt
542 }  // namespace mindspore
543