1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/model/op_node.h"
17
18 #include "backend/optimizer/graph_kernel/model/node.h"
19
20 namespace mindspore {
21 namespace opt {
22 namespace graphkernel {
GetListInt(const ValuePtr & attr_value)23 std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
24 bool is_int64 = true;
25 auto get_int_value = [&is_int64](const ValuePtr &value) -> int64_t {
26 if (value->isa<Int64Imm>()) {
27 return GetValue<int64_t>(value);
28 }
29 is_int64 = false;
30 return static_cast<int64_t>(GetValue<int>(value));
31 };
32 std::vector<int64_t> list_int;
33 const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
34 (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
35 if (!is_int64) {
36 MS_LOG(WARNING) << "Vector type should be 'int64_t' but got 'int'";
37 }
38 return list_int;
39 }
40
Check(const NodePtrList & inputs,const DAttrs & attrs)41 void PrimOp::Check(const NodePtrList &inputs, const DAttrs &attrs) {
42 CheckShape(inputs, attrs);
43 CheckType(inputs, attrs);
44 CheckFormat(inputs, attrs);
45 }
46
47 // check all type to be identical
CheckType(const NodePtrList & inputs,const DAttrs & attrs)48 void PrimOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
49 TypeId tid = inputs[0]->type;
50 for (size_t i = 1; i < inputs.size(); i++) {
51 if (inputs[i]->type != tid) {
52 MS_LOG(EXCEPTION) << "Incompatible dtype between input " << 0 << "and" << i;
53 }
54 }
55 }
56
57 // check all formats are compatible, only DefaultForant is compatible with others
CheckFormat(const NodePtrList & inputs,const DAttrs & attrs)58 void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) {
59 DFormat res = inputs[0]->format;
60 size_t i = 0;
61 for (size_t j = 1; j < inputs.size(); j++) {
62 if (inputs[j]->format != res) {
63 if (inputs[j]->format != kOpFormat_DEFAULT && res != kOpFormat_DEFAULT) {
64 MS_LOG(EXCEPTION) << "Incompatible format between input " << i << "and" << (j + 1);
65 }
66 if (res == kOpFormat_DEFAULT) {
67 res = inputs[j]->format;
68 i = j + 1;
69 }
70 }
71 }
72 }
73
Infer(const NodePtrList & inputs,const DAttrs & attrs)74 NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
75 Check(inputs, attrs);
76 NodeBase nodebase;
77 nodebase.shape = InferShape(inputs, attrs);
78 nodebase.type = InferType(inputs, attrs);
79 nodebase.format = InferFormat(inputs, attrs);
80 return nodebase;
81 }
82
Dump(std::ostringstream & os) const83 void PrimOp::Dump(std::ostringstream &os) const {
84 DumpTensor(os);
85 os << " = " << this->op_ << "(";
86 for (size_t i = 0; i < inputs_.size(); i++) {
87 inputs_[i]->DumpTensor(os);
88 if (i != inputs_.size() - 1) os << ", ";
89 }
90 os << ")";
91 std::ostringstream attr_os;
92 bool has_attr = false;
93 std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"};
94 for (auto attr : attrs_) {
95 if (attr.second != nullptr && black_list.count(attr.first) == 0) {
96 if (has_attr) {
97 attr_os << ", ";
98 } else {
99 has_attr = true;
100 }
101 attr_os << attr.first << ": " << attr.second->ToString();
102 }
103 }
104 if (has_attr) {
105 os << " // attr {" << attr_os.str() << "}";
106 }
107 }
108
109 template <typename TM, typename TD>
CalcByOperator(const NodePtrList & inputs,const std::string & op,TypeId tid)110 tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &op, TypeId tid) {
111 std::vector<TM> inputs_tm;
112 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_tm), [](const NodePtr &i) {
113 return *static_cast<TM *>(std::static_pointer_cast<graphkernel::ConstTensorNode>(i)->data()->data_c());
114 });
115
116 std::unordered_map<std::string, std::function<TM(const std::vector<TM> &)>> func_map = {
117 {"Add", [](const std::vector<TM> &n) { return n[0] + n[1]; }},
118 {"Sub", [](const std::vector<TM> &n) { return n[0] - n[1]; }},
119 {"Mul", [](const std::vector<TM> &n) { return n[0] * n[1]; }},
120 {"RealDiv", [](const std::vector<TM> &n) { return n[0] / n[1]; }},
121 {"Neg", [](const std::vector<TM> &n) { return TM(0) - n[0]; }},
122 {"Reciprocal", [](const std::vector<TM> &n) { return TM(1) / n[0]; }},
123 {"Log", [](const std::vector<TM> &n) { return log(n[0]); }},
124 {"Exp", [](const std::vector<TM> &n) { return exp(n[0]); }},
125 {"Abs", [](const std::vector<TM> &n) { return n[0] < TM(0) ? (TM(0) - n[0]) : n[0]; }},
126 {"Sqrt", [](const std::vector<TM> &n) { return sqrt(n[0]); }},
127 {"Rsqrt", [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); }},
128 };
129 if (func_map.find(op) == func_map.end()) return nullptr;
130 return std::make_shared<tensor::Tensor>(static_cast<TD>(func_map[op](inputs_tm)), TypeIdToType(tid));
131 }
132
InferValue(const NodePtrList & inputs,const DAttrs & attrs,const std::string & op)133 NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op) {
134 for (auto i : inputs) {
135 if (i->NodeType() != NType::Value) return nullptr;
136 }
137 TypeId output_type = this->type;
138 tensor::TensorPtr res = nullptr;
139 switch (output_type) {
140 case TypeId::kNumberTypeUInt8: {
141 res = CalcByOperator<uint8_t, int64_t>(inputs, op, output_type);
142 break;
143 }
144 case TypeId::kNumberTypeInt8: {
145 res = CalcByOperator<int8_t, int64_t>(inputs, op, output_type);
146 break;
147 }
148 case TypeId::kNumberTypeInt16: {
149 res = CalcByOperator<int16_t, int64_t>(inputs, op, output_type);
150 break;
151 }
152 case TypeId::kNumberTypeInt32: {
153 res = CalcByOperator<int32_t, int64_t>(inputs, op, output_type);
154 break;
155 }
156 case TypeId::kNumberTypeInt64: {
157 res = CalcByOperator<int64_t, int64_t>(inputs, op, output_type);
158 break;
159 }
160 case TypeId::kNumberTypeUInt16: {
161 res = CalcByOperator<uint16_t, int64_t>(inputs, op, output_type);
162 break;
163 }
164 case TypeId::kNumberTypeUInt32: {
165 res = CalcByOperator<uint32_t, int64_t>(inputs, op, output_type);
166 break;
167 }
168 case TypeId::kNumberTypeUInt64: {
169 res = CalcByOperator<uint64_t, int64_t>(inputs, op, output_type);
170 break;
171 }
172 case TypeId::kNumberTypeFloat16: {
173 res = CalcByOperator<float16, double>(inputs, op, output_type);
174 break;
175 }
176 case TypeId::kNumberTypeFloat32: {
177 res = CalcByOperator<float, double>(inputs, op, output_type);
178 break;
179 }
180 case TypeId::kNumberTypeFloat64: {
181 res = CalcByOperator<double, double>(inputs, op, output_type);
182 break;
183 }
184 default:
185 return nullptr;
186 }
187 return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
188 }
189
190 // default format shape to fractal_Nz format shape
ToNz(const DShape & default_shape)191 DShape ToNz(const DShape &default_shape) {
192 constexpr size_t nz_size = 2;
193 auto len = default_shape.size();
194 DShape leading_shape;
195 DShape tail_shape;
196 if (default_shape.size() > nz_size) {
197 (void)leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - SizeToLong(nz_size));
198 }
199 if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) {
200 // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16)
201 if (default_shape.back() % 16 != 0) {
202 MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back();
203 }
204 tail_shape = {default_shape.back() / 16, 1, 1, 16};
205 } else if (default_shape.size() >= nz_size || default_shape[1] == 1) {
206 // (N, 32, 1) -> (N, 1, 2, 16, 1)
207 if (default_shape[len - nz_size] % 16 != 0) {
208 MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size];
209 }
210 tail_shape = {1, default_shape[0] / 16, 16, 1};
211 } else {
212 // (N, 32, 48) -> (N, 3, 2, 16, 16)
213 if (default_shape.back() % 16 != 0 || default_shape[len - nz_size] % 16 != 0) {
214 MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got "
215 << default_shape.back() << " " << default_shape[len - nz_size];
216 }
217 tail_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
218 }
219 (void)leading_shape.insert(leading_shape.end(), tail_shape.begin(), tail_shape.end());
220 return leading_shape;
221 }
222
BroadcastShape(const NodePtrList & inputs,bool to_nz=false)223 DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
224 std::vector<std::vector<int64_t>> shapes;
225 for (auto &input : inputs) {
226 if (to_nz && input->format != kOpFormat_FRAC_NZ) {
227 shapes.emplace_back(ToNz(input->shape));
228 } else {
229 shapes.emplace_back(input->shape);
230 }
231 }
232 auto max_dim_input =
233 std::max_element(shapes.begin(), shapes.end(),
234 [](const std::vector<int64_t> &a, const std::vector<int64_t> &b) { return a.size() < b.size(); });
235 auto max_dim = max_dim_input->size();
236 std::vector<std::vector<int64_t>> align_shapes;
237 for (auto &s : shapes) {
238 std::vector<int64_t> cur(max_dim - s.size(), 1);
239 cur.insert(cur.end(), s.begin(), s.end());
240 (void)align_shapes.emplace_back(cur);
241 }
242 std::vector<int64_t> output_shape(max_dim, 1);
243 for (size_t i = 0; i < max_dim; i++) {
244 for (auto &align_shape : align_shapes) {
245 if (align_shape[i] > 1) {
246 if (output_shape[i] == 1) {
247 output_shape[i] = align_shape[i];
248 }
249 if (output_shape[i] != align_shape[i]) {
250 MS_LOG(EXCEPTION) << "Shape broadcast failed. " << output_shape[i] << " vs " << align_shape[i];
251 }
252 }
253 }
254 }
255 return output_shape;
256 }
257
InferShape(const NodePtrList & inputs,const DAttrs &)258 DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &) {
259 if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) {
260 return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC || input->format == kOpFormat_NCHW;
261 })) {
262 return BroadcastShape(inputs, false);
263 }
264 if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) {
265 return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC ||
266 input->format == kOpFormat_NCHW || input->format == kOpFormat_FRAC_NZ;
267 })) {
268 return BroadcastShape(inputs, true);
269 }
270 MS_LOG(EXCEPTION) << "Unsupported format.";
271 }
272
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)273 DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
274 auto it = std::find_if(inputs.begin(), inputs.end(), [](const NodePtr &i) { return i->format != kOpFormat_DEFAULT; });
275 return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format;
276 }
277
Infer(const NodePtrList & inputs,const DAttrs & attrs)278 NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
279 auto nodebase = PrimOp::Infer(inputs, attrs);
280 auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
281 for (auto &ref : inputs) {
282 if (ref->shape.size() != this->shape.size()) return true;
283 for (size_t i = 0; i < this->shape.size(); ++i) {
284 if (ref->shape[i] != this->shape[i]) return true;
285 }
286 }
287 return false;
288 };
289 compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
290 return nodebase;
291 }
292
InferType(const NodePtrList & inputs,const DAttrs & attrs)293 TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
294 CHECK_ATTR(attrs, "dst_type");
295 auto dst_type = attrs.find("dst_type")->second;
296 if (dst_type->isa<Type>()) {
297 return dst_type->cast<TypePtr>()->type_id();
298 }
299 return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
300 }
301
CheckType(const NodePtrList & inputs,const DAttrs &)302 void SelectOp::CheckType(const NodePtrList &inputs, const DAttrs &) {
303 if (inputs[0]->type != TypeId::kNumberTypeBool) {
304 MS_LOG(EXCEPTION) << "Select's input[0] should be bool type";
305 }
306 if (inputs[1]->type != inputs[2]->type) {
307 MS_LOG(EXCEPTION) << "Select's input[1] and input[2]'s type doesn't match";
308 }
309 }
310
InferShape(const NodePtrList & inputs,const DAttrs & attrs)311 DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
312 CHECK_ATTR(attrs, "shape");
313 auto new_shape = GetListInt(attrs.find("shape")->second);
314 auto origin_shape = inputs[0]->shape;
315 auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>());
316 auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>());
317 for (size_t i = 0; i < new_shape.size(); i++) {
318 if (new_shape[i] == -1) {
319 new_shape[i] = origin_product / new_product * (-1);
320 return new_shape;
321 }
322 }
323 if (origin_product != new_product) {
324 MS_LOG(EXCEPTION) << "The shape product before and after reshaping should be equal";
325 }
326 return new_shape;
327 }
328
InferShape(const NodePtrList & inputs,const DAttrs & attrs)329 DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
330 CHECK_ATTR(attrs, "shape");
331 return GetListInt(attrs.find("shape")->second);
332 }
333
334 // check rudece axis in range [-size,size)
Check(const NodePtrList & inputs,const DAttrs & attrs)335 void ReduceOp::Check(const NodePtrList &inputs, const DAttrs &attrs) {
336 PrimOp::Check(inputs, attrs);
337 CHECK_ATTR(attrs, "axis");
338 auto axis = GetListInt(attrs.find("axis")->second);
339 int64_t size = static_cast<int64_t>(inputs[0]->shape.size());
340 auto it = std::find_if(axis.begin(), axis.end(), [&size](const int64_t &i) { return (i >= size || i < (-size)); });
341 if (it != axis.end()) {
342 MS_LOG(EXCEPTION) << "reduce_axis should be in range [" << (-size) << "," << size << ")"
343 << ",but got " << (*it);
344 }
345 }
346
InferShape(const NodePtrList & inputs,const DAttrs & attrs)347 DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
348 CHECK_ATTR(attrs, "axis");
349 CHECK_ATTR(attrs, "keep_dims");
350 auto axis = GetListInt(attrs.find("axis")->second);
351 auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
352 if (keepdims) {
353 DShape new_shape = inputs[0]->shape;
354 for (auto x : axis) {
355 new_shape[LongToSize(x)] = 1;
356 }
357 return new_shape;
358 }
359 DShape new_shape;
360 const auto &input_shape = inputs[0]->shape;
361 for (size_t i = 0; i < input_shape.size(); i++) {
362 if (std::find(axis.begin(), axis.end(), i) == axis.end()) {
363 new_shape.emplace_back(input_shape[i]);
364 }
365 }
366 if (new_shape.empty()) {
367 new_shape.emplace_back(1);
368 }
369 return new_shape;
370 }
371
InferShape(const NodePtrList & inputs,const DAttrs & attrs)372 DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
373 auto check_nd = [](const std::vector<int64_t> &shape, size_t n) {
374 if (shape.size() != n) {
375 MS_LOG(EXCEPTION) << "input dimension should be " << n << ", but got " << shape.size();
376 }
377 };
378 auto shape0 = inputs[0]->shape;
379 auto shape1 = inputs[1]->shape;
380 check_nd(shape0, 4);
381 check_nd(shape1, 4);
382 CHECK_ATTR(attrs, "format");
383 if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC &&
384 GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) {
385 MS_LOG(EXCEPTION) << "check NHWC format failed";
386 }
387 auto n = shape0[0];
388 auto h = shape0[1];
389 auto w = shape0[2];
390 auto out_channel = shape1[0];
391 CHECK_ATTR(attrs, "pad_list");
392 CHECK_ATTR(attrs, "pad_mode");
393 CHECK_ATTR(attrs, "kernel_size");
394 CHECK_ATTR(attrs, "stride");
395 CHECK_ATTR(attrs, "dilation");
396 auto pad_list = GetListInt(attrs.find("pad_list")->second);
397 auto pad_mode = GetValue<std::string>(attrs.find("pad_mode")->second);
398 auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
399 auto stride = GetListInt(attrs.find("stride")->second);
400 auto dilation = GetListInt(attrs.find("dilation")->second);
401 check_nd(pad_list, 4);
402 check_nd(kernel_size, 2);
403 check_nd(stride, 4);
404 check_nd(dilation, 4);
405 bool has_pad = false;
406 if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
407 has_pad = true;
408 } else {
409 if (pad_mode == "VALID" || pad_mode == "valid") {
410 if (std::any_of(pad_list.begin(), pad_list.end(), [](int i) { return i == 0; })) {
411 has_pad = true;
412 }
413 }
414 }
415 if (!has_pad) {
416 pad_list = {0, 0, 0, 0};
417 }
418 auto k_h = (kernel_size[0] - 1) * dilation[2] + 1;
419 auto k_w = (kernel_size[1] - 1) * dilation[3] + 1;
420 auto out_h = (h + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1;
421 auto out_w = (w + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1;
422 std::vector<int64_t> output = {n, out_h, out_w, out_channel};
423 return output;
424 }
425
InferType(const NodePtrList & inputs,const DAttrs & attrs)426 TypeId Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
427 if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
428 auto dst_type = attrs.find("dst_type")->second;
429 if (dst_type->isa<Type>()) {
430 return dst_type->cast<TypePtr>()->type_id();
431 }
432 return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
433 }
434
InferShape(const NodePtrList & inputs,const DAttrs & attrs)435 DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
436 CHECK_ATTR(attrs, "perm");
437 auto perm = GetListInt(attrs.find("perm")->second);
438 auto &old_shape = inputs[0]->shape;
439 DShape new_shape;
440 if (perm.size() != old_shape.size()) {
441 MS_LOG(EXCEPTION) << "perm.size() != old_shape.size(). " << perm.size() << " vs " << old_shape.size();
442 }
443 std::transform(perm.begin(), perm.end(), std::back_inserter(new_shape),
444 [&old_shape](int64_t p) { return old_shape[LongToSize(p)]; });
445 return new_shape;
446 }
447
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)448 DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
449 if (inputs[0]->shape.size() != 4) return kOpFormat_DEFAULT;
450 CHECK_ATTR(attrs, "perm");
451 auto perm = GetListInt(attrs.find("perm")->second);
452 const auto &ori_format = inputs[0]->format;
453 if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) {
454 std::vector<int64_t> nchw2nhwc = {0, 2, 3, 1};
455 if (perm == nchw2nhwc) return kOpFormat_NHWC;
456 } else if (ori_format == kOpFormat_NHWC) {
457 std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
458 if (perm == nhwc2nchw) return kOpFormat_DEFAULT;
459 }
460 return kOpFormat_DEFAULT;
461 }
462
InferShape(const NodePtrList & inputs,const DAttrs & attrs)463 DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
464 std::vector<int64_t> shape0 = inputs[0]->shape;
465 std::vector<int64_t> shape1 = inputs[1]->shape;
466 if (shape0.size() != 2 || shape1.size() != 2) {
467 MS_LOG(EXCEPTION) << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
468 }
469 CHECK_ATTR(attrs, "transpose_a");
470 CHECK_ATTR(attrs, "transpose_b");
471 auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second);
472 auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second);
473 int64_t m = transpose_a ? shape0[1] : shape0[0];
474 int64_t k1 = transpose_a ? shape0[0] : shape0[1];
475 int64_t k2 = transpose_b ? shape1[1] : shape1[0];
476 int64_t n = transpose_b ? shape1[0] : shape1[1];
477 if (k1 != k2) {
478 MS_LOG(EXCEPTION) << "MatMul's inputs have different k value " << k1 << " vs " << k2;
479 }
480 std::vector<int64_t> output = {m, n};
481 return output;
482 }
483
InferType(const NodePtrList & inputs,const DAttrs & attrs)484 TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
485 CHECK_ATTR(attrs, "dst_type");
486 if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
487 auto dst_type = attrs.find("dst_type")->second;
488 if (dst_type->isa<Type>()) {
489 return dst_type->cast<TypePtr>()->type_id();
490 }
491 return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
492 }
493
InferShape(const NodePtrList & inputs,const DAttrs & attrs)494 DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
495 std::vector<int64_t> shape0 = inputs[0]->shape;
496 size_t n = shape0.size();
497 CHECK_ATTR(attrs, "head");
498 CHECK_ATTR(attrs, "tail");
499 std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
500 std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
501 if (pad_before.size() != n || pad_after.size() != n) {
502 MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs "
503 << pad_after.size();
504 }
505 std::vector<int64_t> output;
506 for (size_t i = 0; i < n; i++) {
507 output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]);
508 }
509 return output;
510 }
511
InferShape(const NodePtrList & inputs,const DAttrs & attrs)512 DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
513 std::vector<int64_t> shape0 = inputs[0]->shape;
514 size_t n = shape0.size();
515 CHECK_ATTR(attrs, "tail");
516 std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
517 if (unpad_after.size() != n) {
518 MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
519 }
520 std::vector<int64_t> output;
521 for (size_t i = 0; i < n; i++) {
522 output.emplace_back(shape0[i] - unpad_after[i]);
523 }
524 return output;
525 }
526
CheckType(const NodePtrList & inputs,const DAttrs & attrs)527 void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
528 if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
529 MS_LOG(EXCEPTION) << "Complex's input[0] should be float32";
530 }
531 if (inputs[0]->type != inputs[1]->type) {
532 MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
533 }
534 }
535
InferShape(const NodePtrList &,const DAttrs & attrs)536 DShape StandardNormalOp::InferShape(const NodePtrList &, const DAttrs &attrs) {
537 CHECK_ATTR(attrs, "shape");
538 return GetListInt(attrs.find("shape")->second);
539 }
540 } // namespace graphkernel
541 } // namespace opt
542 } // namespace mindspore
543