1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <functional>
19 #include <iterator>
20 #include <numeric>
21 #include "abstract/infer_functions.h"
22 #include "abstract/utils.h"
23 #include "abstract/param_validator.h"
24 #include "utils/shape_utils.h"
25 #include "ops/op_utils.h"
26
27 namespace mindspore {
28 namespace abstract {
InferImplScalarToArray(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)29 AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
30 const AbstractBasePtrList &args_spec_list) {
31 // Inputs: a scalar.
32 const std::string op_name = primitive->name();
33 CheckArgsSize(op_name, args_spec_list, 1);
34 AbstractScalarPtr arg = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
35 return std::make_shared<AbstractTensor>(arg, std::make_shared<Shape>());
36 }
37
InferImplArrayToScalar(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)38 AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
39 const AbstractBasePtrList &args_spec_list) {
40 // Inputs: a tensor with 0 shape.
41 const std::string op_name = primitive->name();
42 CheckArgsSize(op_name, args_spec_list, 1);
43 auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
44 auto a_shp = arg->shape();
45 MS_EXCEPTION_IF_NULL(a_shp);
46 if (!a_shp->shape().empty()) {
47 MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape.";
48 }
49 return arg->element();
50 }
51
InferImplBroadCastShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)52 AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
53 const AbstractBasePtrList &args_spec_list) {
54 // Inputs: two tuples.
55 const std::string op_name = primitive->name();
56 constexpr size_t args_size = 2;
57 CheckArgsSize(op_name, args_spec_list, args_size);
58 auto xs = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
59 auto ys = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
60 auto x_value = xs->BuildValue();
61 MS_EXCEPTION_IF_NULL(x_value);
62 auto value_tuple_x = x_value->cast<ValueTuplePtr>();
63 MS_EXCEPTION_IF_NULL(value_tuple_x);
64 auto shp_tuple_x = value_tuple_x->value();
65 ShapeVector shp_x;
66 (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x),
67 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
68 auto tupe_value_y = ys->BuildValue();
69 MS_EXCEPTION_IF_NULL(tupe_value_y);
70 auto value_tuple_y = tupe_value_y->cast<ValueTuplePtr>();
71 MS_EXCEPTION_IF_NULL(value_tuple_y);
72 auto shp_tuple_y = value_tuple_y->value();
73 ShapeVector shp_y;
74 (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y),
75 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
76
77 ShapeVector res = BroadcastShape(shp_x, shp_y);
78 MS_EXCEPTION_IF_NULL(args_spec_list[1]);
79 if (res.empty()) {
80 MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
81 << args_spec_list[1]->ToString();
82 }
83
84 AbstractBasePtrList elems;
85 (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
86 return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
87 });
88
89 return std::make_shared<AbstractTuple>(elems);
90 }
91
InferImplStack(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)92 AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
93 const AbstractBasePtrList &args_spec_list) {
94 // Inputs: a tuple of tensor.
95 const std::string op_name = primitive->name();
96 CheckArgsSize(op_name, args_spec_list, 1);
97 auto arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
98 if (arg->elements().empty()) {
99 MS_LOG(EXCEPTION) << "Arg elements is empty.";
100 }
101
102 size_t tuple_len = arg->elements().size();
103 AbstractTensorPtr tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
104 auto shape = tensor_base->shape();
105 MS_EXCEPTION_IF_NULL(shape);
106 int64_t rank_base = SizeToLong(shape->shape().size());
107
108 ValuePtr axis = primitive->GetAttr("axis");
109 // Axis value should be in [-(rank_base + 1), rank_base).
110 int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
111 // If axis is negative, add offset(rank_base + 1) to turn it to positive.
112 axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base + 1));
113
114 for (size_t i = 1; i < tuple_len; ++i) {
115 AbstractTensorPtr tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
116 (void)CheckDtypeSame(op_name, tensor_base, tensor);
117 (void)CheckShapeSame(op_name, tensor_base, tensor);
118 }
119 auto element = tensor_base->element();
120 MS_EXCEPTION_IF_NULL(element);
121 primitive->set_attr("N", MakeValue(SizeToLong(tuple_len)));
122 primitive->set_attr("T", element->BuildType());
123
124 AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
125 MS_EXCEPTION_IF_NULL(ret);
126 auto ret_shape_ptr = ret->shape();
127 MS_EXCEPTION_IF_NULL(ret_shape_ptr);
128 auto ret_shape = ret_shape_ptr->shape();
129 (void)ret_shape.insert(ret_shape.begin() + axis_value, SizeToLong(tuple_len));
130 ret->set_shape(std::make_shared<Shape>(ret_shape));
131 return ret;
132 }
133
InferImplUnique(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)134 AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
135 const AbstractBasePtrList &args_spec_list) {
136 // inputs: a 1-d Tensor
137 const std::string op_name = primitive->name();
138 CheckArgsSize(op_name, args_spec_list, 1);
139 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
140
141 auto shape = input->shape();
142 MS_EXCEPTION_IF_NULL(shape);
143 if (shape->shape().size() != 1) {
144 MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
145 }
146 ShapeVector ids_shape = {Shape::SHP_ANY};
147 ShapeVector min_shape = {1};
148 ShapeVector max_shape = shape->max_shape();
149 if (max_shape.empty()) {
150 max_shape = shape->shape();
151 }
152
153 auto ids =
154 std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
155 // Currently we choose the same data type as input for the idx.
156 TypePtr ids_idx_type = kInt32;
157 MS_EXCEPTION_IF_NULL(input->element());
158 MS_EXCEPTION_IF_NULL(input->element()->GetTypeTrack());
159 if (input->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
160 ids_idx_type = kInt64;
161 }
162 ShapeVector idx_shape = shape->shape();
163 ShapeVector idx_min_shape = shape->min_shape();
164 if (idx_min_shape.empty()) {
165 idx_min_shape = shape->shape();
166 }
167 ShapeVector idx_max_shape = shape->max_shape();
168 if (idx_max_shape.empty()) {
169 idx_max_shape = shape->shape();
170 }
171
172 auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
173 ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
174 // outputs: ids, ids_idx
175 AbstractBasePtrList elements = {ids, ids_idx};
176 return std::make_shared<AbstractTuple>(elements);
177 }
178
InferImplPadAndShift(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)179 AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
180 const AbstractBasePtrList &args_spec_list) {
181 // inputs: a 1-d Tensor
182 const std::string op_name = primitive->name();
183 const size_t size_expected = 3;
184 CheckArgsSize(op_name, args_spec_list, size_expected);
185 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
186 MS_EXCEPTION_IF_NULL(input);
187 auto shape = input->shape();
188 MS_EXCEPTION_IF_NULL(shape);
189 if (shape->shape().size() != 1) {
190 MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
191 }
192 ShapeVector ids_shape = {Shape::SHP_ANY};
193 ShapeVector min_shape = {1};
194 ShapeVector max_shape = shape->max_shape();
195 if (max_shape.empty()) {
196 max_shape = shape->shape();
197 }
198 return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
199 }
200
InferImplUniqueGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)201 AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
202 const AbstractBasePtrList &args_spec_list) {
203 // inputs: a 1-d Tensor
204 const std::string op_name = primitive->name();
205 const size_t size_expected = 2;
206 CheckArgsSize(op_name, args_spec_list, size_expected);
207 AbstractTuplePtr dout = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
208 CheckArgsSize(op_name + " dout", dout->elements(), size_expected);
209 auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0);
210 auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1);
211 auto ids_shape = ids->shape();
212 auto ids_idx_shape = ids_idx->shape();
213 MS_EXCEPTION_IF_NULL(ids_shape);
214 MS_EXCEPTION_IF_NULL(ids_idx_shape);
215 if (ids->shape()->shape().size() != 1) {
216 MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1.";
217 }
218 if (ids_idx->shape()->shape().size() != 1) {
219 MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1.";
220 }
221
222 // outputs: dx
223 return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape());
224 }
225
InferImplUnsortedSegmentSum(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)226 AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
227 const AbstractBasePtrList &args_spec_list) {
228 const std::string op_name = primitive->name();
229 constexpr size_t args_size = 3;
230 CheckArgsSize(op_name, args_spec_list, args_size);
231 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
232 MS_EXCEPTION_IF_NULL(x);
233 MS_EXCEPTION_IF_NULL(x->shape());
234 auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
235 MS_EXCEPTION_IF_NULL(segment_ids);
236 MS_EXCEPTION_IF_NULL(segment_ids->shape());
237 auto segment_ids_shape = segment_ids->shape()->shape();
238 (void)CheckTensorDType(x, {kFloat16, kFloat32, kFloat64, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s");
239 (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentSum should be %s");
240 bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape
241 bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
242 bool op_is_dynamic = x_is_dyn || ids_is_dyn;
243 auto x_shape = x->shape()->shape();
244 ShapeVector shape;
245 int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
246 if (num_segments_value <= 0) {
247 MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentSum";
248 }
249 shape.emplace_back(num_segments_value);
250 shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
251 if (!op_is_dynamic) { // not dynamic
252 for (size_t i = 0; i < segment_ids_shape.size(); i++) {
253 if (x_shape[i] != segment_ids_shape[i]) {
254 MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values";
255 }
256 }
257 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
258 }
259 ShapeVector min_shape;
260 ShapeVector max_shape;
261 min_shape.emplace_back(num_segments_value);
262 max_shape.emplace_back(num_segments_value);
263 bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
264 bool ids_any_shape =
265 std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
266 if (!x_any_shape && !ids_any_shape) { // only validate when shapes fully known
267 for (size_t i = 0; i < segment_ids_shape.size(); i++) {
268 if (x_shape[i] != segment_ids_shape[i]) {
269 MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values";
270 }
271 }
272 }
273 ShapeVector x_shape_min;
274 ShapeVector x_shape_max;
275 x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
276 x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
277 min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
278 max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
279 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
280 }
281
InferImplUnsortedSegmentMax(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)282 AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
283 const AbstractBasePtrList &args_spec_list) {
284 const std::string op_name = primitive->name();
285 const size_t size_expected = 3;
286 CheckArgsSize(op_name, args_spec_list, size_expected);
287 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
288 MS_EXCEPTION_IF_NULL(x->shape());
289 auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
290 MS_EXCEPTION_IF_NULL(segment_ids);
291 MS_EXCEPTION_IF_NULL(segment_ids->shape());
292 auto segment_ids_shape = segment_ids->shape()->shape();
293 (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s");
294 (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
295 bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic
296 bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
297 bool op_is_dynamic = x_is_dyn || ids_is_dyn;
298 auto x_shape = x->shape()->shape();
299 ShapeVector shape;
300 int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
301 if (num_segments_value <= 0) {
302 MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax";
303 }
304 shape.emplace_back(num_segments_value);
305 shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
306 if (!op_is_dynamic) { // not dynamic
307 if (x_shape[0] != segment_ids_shape[0]) {
308 MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
309 }
310 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
311 }
312 ShapeVector min_shape;
313 ShapeVector max_shape;
314 min_shape.emplace_back(num_segments_value);
315 max_shape.emplace_back(num_segments_value);
316 bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
317 bool ids_any_shape =
318 std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
319 if (!x_any_shape && !ids_any_shape) {
320 if (x_shape[0] != segment_ids_shape[0]) {
321 MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
322 }
323 }
324 ShapeVector x_shape_min;
325 ShapeVector x_shape_max;
326 x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
327 x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
328 min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
329 max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
330 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
331 }
332
InferImplUnsortedSegmentMin(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)333 AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
334 const AbstractBasePtrList &args_spec_list) {
335 const std::string op_name = primitive->name();
336 const size_t size_expected = 3;
337 CheckArgsSize(op_name, args_spec_list, size_expected);
338 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
339 MS_EXCEPTION_IF_NULL(x);
340 MS_EXCEPTION_IF_NULL(x->shape());
341 auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
342 MS_EXCEPTION_IF_NULL(segment_ids);
343 MS_EXCEPTION_IF_NULL(segment_ids->shape());
344 auto segment_ids_shape = segment_ids->shape()->shape();
345 (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMin should be %s");
346 (void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMin should be %s");
347 bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape
348 bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
349 bool op_is_dynamic = x_is_dyn || ids_is_dyn;
350 auto x_shape = x->shape()->shape();
351 ShapeVector shape;
352 int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
353 if (num_segments_value <= 0) {
354 MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMin";
355 }
356 shape.emplace_back(num_segments_value);
357 shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
358 if (!op_is_dynamic) { // not dynamic
359 if (x_shape[0] != segment_ids_shape[0]) {
360 MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
361 }
362 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
363 }
364 ShapeVector min_shape;
365 ShapeVector max_shape;
366 min_shape.emplace_back(num_segments_value);
367 max_shape.emplace_back(num_segments_value);
368 bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
369 bool ids_any_shape =
370 std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
371 if (!x_any_shape && !ids_any_shape) { // only validate when shapes fully known
372 if (x_shape[0] != segment_ids_shape[0]) {
373 MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
374 }
375 }
376 ShapeVector x_shape_min;
377 ShapeVector x_shape_max;
378 x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
379 x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
380 min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
381 max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
382 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
383 }
384
InferImplScatterAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)385 AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
386 const AbstractBasePtrList &args_spec_list) {
387 constexpr auto kScatterAddInputNum = 3;
388 const std::string op_name = primitive->name();
389 CheckRequiredArgsSize(op_name, args_spec_list, kScatterAddInputNum);
390 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
391 MS_EXCEPTION_IF_NULL(x);
392 MS_EXCEPTION_IF_NULL(x->shape());
393 ShapeVector shape = x->shape()->shape();
394 ShapeVector min_shape = x->shape()->min_shape();
395 ShapeVector max_shape = x->shape()->max_shape();
396 CheckMinMaxShape(shape, &min_shape, &max_shape);
397 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
398 }
399
InferImplScatterSub(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)400 AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
401 const AbstractBasePtrList &args_spec_list) {
402 constexpr auto kScatterSubInputNum = 3;
403 const std::string op_name = primitive->name();
404 CheckRequiredArgsSize(op_name, args_spec_list, kScatterSubInputNum);
405 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
406 MS_EXCEPTION_IF_NULL(x);
407 MS_EXCEPTION_IF_NULL(x->shape());
408 ShapeVector shape = x->shape()->shape();
409 ShapeVector min_shape = x->shape()->min_shape();
410 ShapeVector max_shape = x->shape()->max_shape();
411 CheckMinMaxShape(shape, &min_shape, &max_shape);
412 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
413 }
414
InferImplScatterUpdate(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)415 AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
416 const AbstractBasePtrList &args_spec_list) {
417 const std::string op_name = primitive->name();
418 CheckRequiredArgsSize(op_name, args_spec_list, 3);
419 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
420 MS_EXCEPTION_IF_NULL(x);
421 MS_EXCEPTION_IF_NULL(x->shape());
422 ShapeVector shape = x->shape()->shape();
423 ShapeVector min_shape = x->shape()->min_shape();
424 ShapeVector max_shape = x->shape()->max_shape();
425 CheckMinMaxShape(shape, &min_shape, &max_shape);
426 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
427 }
428
InferImplMapCacheIdx(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)429 AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
430 const AbstractBasePtrList &args_spec_list) {
431 const std::string op_name = primitive->name();
432 const size_t size_expected = 5;
433 CheckArgsSize(op_name, args_spec_list, size_expected);
434 auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
435 MS_EXCEPTION_IF_NULL(hash_map->shape());
436
437 auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
438 auto indices_shp = indices->shape();
439 MS_EXCEPTION_IF_NULL(indices_shp);
440
441 ShapeVector shape;
442 ShapeVector min_shape;
443 ShapeVector max_shape;
444 if (!indices_shp->max_shape().empty()) {
445 max_shape = indices_shp->max_shape();
446 } else {
447 max_shape = indices_shp->shape();
448 }
449 for (size_t i = 0; i < max_shape.size(); i++) {
450 shape.emplace_back(Shape::SHP_ANY);
451 min_shape.emplace_back(1);
452 }
453
454 auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
455 auto old_emb_idx =
456 std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
457 auto miss_emb_idx =
458 std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
459 auto swap_emb_idx =
460 std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
461
462 AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
463 return std::make_shared<AbstractTuple>(elements);
464 }
465
InferImplCacheSwapTable(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)466 AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
467 const AbstractBasePtrList &args_spec_list) {
468 const std::string op_name = primitive->name();
469 const size_t size_expected = 3;
470 CheckArgsSize(op_name, args_spec_list, size_expected);
471 auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
472 auto cache_table_shp = cache_table->shape();
473 MS_EXCEPTION_IF_NULL(cache_table_shp);
474
475 auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
476 auto swap_cache_idx_shp = swap_cache_idx->shape();
477 MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);
478
479 auto cache_table_shape = cache_table_shp->shape();
480 auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
481 ShapeVector shape;
482 shape.emplace_back(swap_cache_idx_shape[0]);
483 shape.emplace_back(cache_table_shape[1]);
484 auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape();
485 ShapeVector max_shape;
486 ShapeVector min_shape;
487 if (!swap_cache_idx_max_shape.empty()) {
488 max_shape.emplace_back(swap_cache_idx_max_shape[0]);
489 max_shape.emplace_back(cache_table_shape[1]);
490 } else {
491 max_shape = shape;
492 }
493 for (size_t i = 0; i < max_shape.size(); ++i) {
494 min_shape.emplace_back(1);
495 }
496
497 AbstractTensorPtr ret =
498 std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
499 return ret;
500 }
501
InferImplUpdateCache(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)502 AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
503 const AbstractBasePtrList &args_spec_list) {
504 const std::string op_name = primitive->name();
505 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
506
507 ShapeVector shape;
508 shape.emplace_back(1);
509
510 AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
511 return ret;
512 }
513
InferImplSubAndFilter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)514 AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
515 const AbstractBasePtrList &args_spec_list) {
516 const std::string op_name = primitive->name();
517 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
518 auto input_x_shp = input_x->shape();
519 MS_EXCEPTION_IF_NULL(input_x_shp);
520
521 ShapeVector shape;
522 ShapeVector min_shape;
523 ShapeVector max_shape;
524 if (!input_x_shp->max_shape().empty()) {
525 max_shape = input_x_shp->max_shape();
526 } else {
527 max_shape = input_x_shp->shape();
528 }
529 for (size_t i = 0; i < max_shape.size(); i++) {
530 shape.emplace_back(Shape::SHP_ANY);
531 min_shape.emplace_back(1);
532 }
533 auto filter_res =
534 std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
535 auto filter_idx =
536 std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
537 AbstractBasePtrList elements = {filter_res, filter_idx};
538 return std::make_shared<AbstractTuple>(elements);
539 }
540
InferImplDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)541 AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
542 const AbstractBasePtrList &args_spec_list) {
543 const std::string op_name = primitive->name();
544 const size_t size_expected = 2;
545 CheckArgsSize(op_name, args_spec_list, size_expected);
546 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
547 auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
548 MS_EXCEPTION_IF_NULL(x);
549 MS_EXCEPTION_IF_NULL(x->shape());
550 MS_EXCEPTION_IF_NULL(y);
551 MS_EXCEPTION_IF_NULL(y->shape());
552 ShapeVector x_shape = x->shape()->shape();
553 ShapeVector y_shape = y->shape()->shape();
554 ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
555 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
556 }
557
InferImplRealDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)558 AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
559 const AbstractBasePtrList &args_spec_list) {
560 const std::string op_name = primitive->name();
561 const size_t size_expected = 2;
562 CheckArgsSize(op_name, args_spec_list, size_expected);
563 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
564 auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
565 MS_EXCEPTION_IF_NULL(x);
566 MS_EXCEPTION_IF_NULL(x->shape());
567 MS_EXCEPTION_IF_NULL(y);
568 MS_EXCEPTION_IF_NULL(y->shape());
569 ShapeVector x_shape = x->shape()->shape();
570 ShapeVector y_shape = y->shape()->shape();
571 ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
572 if (out_shape.empty()) {
573 MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
574 << args_spec_list[1]->ToString();
575 }
576 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
577 }
578
InferImplGatherV2(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)579 AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
580 const AbstractBasePtrList &args_spec_list) {
581 const std::string &op_name = primitive->name();
582 constexpr size_t args_size = 3;
583 CheckArgsSize(op_name, args_spec_list, args_size);
584 AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
585 AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
586 bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
587 bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
588 int64_t axis_val = 0;
589 // 3rd input is a Tensor when GatherV2 is a dynamic shape operator
590 constexpr size_t aixs_index = 2;
591 if (args_spec_list[aixs_index]->isa<AbstractTensor>()) {
592 auto axis = args_spec_list[aixs_index]->cast<AbstractTensorPtr>();
593 MS_EXCEPTION_IF_NULL(axis);
594 auto axis_value_ptr = axis->BuildValue();
595 MS_EXCEPTION_IF_NULL(axis_value_ptr);
596 auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
597 MS_EXCEPTION_IF_NULL(axis_tensor);
598 axis_val = *static_cast<int64_t *>(axis_tensor->data_c());
599 } else if (args_spec_list[aixs_index]->isa<AbstractScalar>()) {
600 auto axis = args_spec_list[aixs_index]->cast<AbstractScalarPtr>();
601 axis_val = GetValue<int64_t>(axis->BuildValue());
602 } else {
603 MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name();
604 }
605 auto params_shp = params->shape()->shape();
606 auto indices_shp = indices->shape()->shape();
607 auto params_rank = static_cast<int64_t>(params_shp.size());
608 // either inputs or both can be dynamic and computation requires min/max shapes for both
609 ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape();
610 ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape();
611 ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape();
612 ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape();
613 // check axis_val within interval: [-params_rank, params_rank)
614 if (-params_rank > axis_val || axis_val >= params_rank) {
615 MS_LOG(EXCEPTION) << "For Gather - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) "
616 << "Got " << axis_val << ".";
617 }
618 if (axis_val < 0) {
619 axis_val += params_rank;
620 }
621 auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector {
622 ShapeVector out_vec;
623 std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
624 copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec));
625 copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec));
626 return out_vec;
627 };
628 ShapeVector out_shape = calc_shape(indices_shp, params_shp);
629 if (ind_dyn || param_dyn) {
630 ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
631 ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
632 return std::make_shared<AbstractTensor>(params->element(),
633 std::make_shared<Shape>(out_shape, min_shape, max_shape));
634 }
635 return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
636 }
637
InferImplDynamicAssign(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)638 AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
639 const AbstractBasePtrList &args_spec_list) {
640 // Inputs: a tensor
641 const size_t size_expected = 2;
642 CheckArgsSize(primitive->name(), args_spec_list, size_expected);
643
644 MS_LOG(INFO) << "InferImplDynamicAssign " << args_spec_list[0];
645 auto type = args_spec_list[0]->BuildType();
646 MS_EXCEPTION_IF_NULL(type);
647 if (type->type_id() == kObjectTypeRefKey) {
648 return args_spec_list[1]->Broaden();
649 } else {
650 auto x = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0);
651 auto y = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 1);
652 MS_EXCEPTION_IF_NULL(x);
653 MS_EXCEPTION_IF_NULL(y);
654 auto y_shape = y->shape();
655 MS_EXCEPTION_IF_NULL(y_shape);
656 if (!y_shape->max_shape().empty()) {
657 x->set_shape(y->shape());
658 }
659 return args_spec_list[0];
660 }
661 }
662
InferImplEmbeddingLookup(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)663 AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
664 const AbstractBasePtrList &args_spec_list) {
665 const std::string op_name = primitive->name();
666 auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
667 auto params_shp = params->shape();
668 MS_EXCEPTION_IF_NULL(params_shp);
669 auto params_shape = params_shp->shape();
670 auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
671 auto indices_shp = indices->shape();
672 MS_EXCEPTION_IF_NULL(indices_shp);
673 auto indices_shape = indices_shp->shape();
674 auto indices_max_shape = indices_shp->max_shape();
675 auto indices_min_shape = indices_shp->min_shape();
676 ShapeVector shape;
677 ShapeVector max_shape;
678 ShapeVector min_shape;
679 shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
680 shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
681 if (!indices_max_shape.empty()) {
682 max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
683 max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end());
684 } else {
685 max_shape = shape;
686 }
687 if (!indices_min_shape.empty()) {
688 min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
689 min_shape.insert(min_shape.end(), params_shape.begin() + 1, params_shape.end());
690 } else {
691 min_shape = shape;
692 }
693
694 AbstractTensorPtr ret =
695 std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
696 return ret;
697 }
698
InferImplDynamicShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)699 AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
700 const AbstractBasePtrList &args_spec_list) {
701 const std::string &op_name = primitive->name();
702 CheckArgsSize(op_name, args_spec_list, 1);
703 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
704 MS_EXCEPTION_IF_NULL(input->shape());
705 auto shape = input->shape()->shape();
706 bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
707 ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
708 if (has_dyn_shape) {
709 auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(64));
710 auto min_value = MakeValue(input->shape()->min_shape());
711 auto max_value = MakeValue(input->shape()->max_shape());
712 auto tensor = std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
713 tensor->set_value_range(min_value, max_value);
714 return tensor;
715 }
716 auto shp_buf_size = sizeof(int64_t) * shape.size();
717 auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);
718
719 return tensor->ToAbstract();
720 }
721
InferImplTranspose(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)722 AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
723 const AbstractBasePtrList &args_spec_list) {
724 const std::string &op_name = primitive->name();
725 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
726 auto input_shp = input->shape()->shape();
727 ValuePtr perm = primitive->GetAttr("perm");
728 MS_EXCEPTION_IF_NULL(perm);
729 auto perm_val = perm->cast<ValueTuplePtr>();
730 MS_EXCEPTION_IF_NULL(perm_val);
731 auto perm_val_data = perm_val->value();
732 ShapeVector perm_vec;
733 (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
734 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
735 ShapeVector result_shp;
736 ShapeVector max_shp;
737 ShapeVector min_shp;
738 ShapeVector x_max_shp = input->shape()->max_shape();
739 ShapeVector x_min_shp = input->shape()->min_shape();
740 CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp);
741 for (size_t i = 0; i < perm_vec.size(); i++) {
742 auto idx = static_cast<size_t>(perm_vec[i]);
743 result_shp.push_back(input_shp[idx]);
744 max_shp.push_back(x_max_shp[idx]);
745 min_shp.push_back(x_min_shp[idx]);
746 }
747 return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
748 }
749
InferImplReshape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)750 AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
751 const AbstractBasePtrList &args_spec_list) {
752 const std::string op_name = primitive->name();
753 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
754 MS_EXCEPTION_IF_NULL(x);
755 MS_EXCEPTION_IF_NULL(x->shape());
756 ShapeVector shape;
757 ShapeVector x_shape = x->shape()->shape();
758 ShapeVector x_max_shape = x->shape()->max_shape();
759 ShapeVector x_min_shape = x->shape()->min_shape();
760 if (x_max_shape.empty()) {
761 x_max_shape = x_shape;
762 }
763 if (x_min_shape.empty()) {
764 x_min_shape = x_shape;
765 }
766 ValuePtr sh = primitive->GetAttr("shape");
767 MS_EXCEPTION_IF_NULL(sh);
768 auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
769 MS_EXCEPTION_IF_NULL(reshape_value_tuple);
770 auto reshape_tuple = reshape_value_tuple->value();
771
772 (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
773 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
774
775 auto max_shape = shape;
776 auto min_shape = shape;
777 int64_t x_num = 1;
778 int64_t x_min_num = 1;
779 int64_t x_max_num = 1;
780 for (int64_t value : x_shape) {
781 x_num = LongMulWithOverflowCheck(value, x_num);
782 }
783 for (int64_t value : x_min_shape) {
784 x_min_num = LongMulWithOverflowCheck(value, x_min_num);
785 }
786 for (int64_t value : x_max_shape) {
787 x_max_num = LongMulWithOverflowCheck(value, x_max_num);
788 }
789
790 auto it_first = find(shape.begin(), shape.end(), -1);
791 if (it_first != shape.end()) {
792 auto it_second = find(it_first + 1, shape.end(), -1);
793 if (it_second != shape.end()) {
794 MS_LOG(EXCEPTION) << "At most one component of input shape can be -1";
795 }
796 auto index = LongToSize(std::distance(shape.begin(), it_first));
797 int64_t infer_value = x_num;
798 int64_t infer_min_value = x_min_num;
799 int64_t infer_max_value = x_max_num;
800 for (size_t i = 0; i < shape.size(); ++i) {
801 int64_t value = shape[i];
802 if (value != -1 && value != 0) {
803 infer_value = infer_value / value;
804 infer_min_value = infer_min_value / value;
805 infer_max_value = infer_max_value / value;
806 }
807 }
808 shape[index] = infer_value;
809 min_shape[index] = infer_min_value;
810 max_shape[index] = infer_max_value;
811 }
812
813 AbstractTensorPtr ret =
814 std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
815 return ret;
816 }
817
InferImplMapUniform(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)818 AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
819 const AbstractBasePtrList &args_spec_list) {
820 // Inputs: one tensor.
821 const std::string op_name = primitive->name();
822 const size_t size_expected = 3;
823 CheckArgsSize(op_name, args_spec_list, size_expected);
824 return args_spec_list[0]->Broaden();
825 }
826
InferImplSplit(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)827 AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
828 const AbstractBasePtrList &args_spec_list) {
829 const std::string op_name = primitive->name();
830 CheckArgsSize(op_name, args_spec_list, 1);
831 AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
832 ShapeVector x_shape = input_x->shape()->shape();
833 ShapeVector x_shape_min = input_x->shape()->min_shape();
834 if (x_shape_min.empty()) {
835 x_shape_min = x_shape;
836 }
837 ShapeVector x_shape_max = input_x->shape()->max_shape();
838 if (x_shape_max.empty()) {
839 x_shape_max = x_shape;
840 }
841 int64_t rank = SizeToLong(x_shape.size());
842
843 ValuePtr axis = primitive->GetAttr("axis");
844 int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
845 uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank)));
846 int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
847 if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
848 MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
849 << " must be divisible by output_num = " << output_num_value;
850 }
851
852 ShapeVector output_shape = x_shape;
853 if (output_shape[axis_value_pos] != Shape::SHP_ANY) {
854 output_shape[axis_value_pos] = static_cast<int>(x_shape[axis_value_pos] / output_num_value);
855 }
856 ShapeVector output_shape_min = x_shape_min;
857 output_shape_min[axis_value_pos] = static_cast<int>(x_shape_min[axis_value_pos] / output_num_value);
858 ShapeVector output_shape_max = x_shape_max;
859 output_shape_max[axis_value_pos] = static_cast<int>(x_shape_max[axis_value_pos] / output_num_value);
860
861 AbstractBasePtrList output_list;
862 for (int64_t i = 0; i < output_num_value; ++i) {
863 auto output = input_x->Broaden();
864 output->set_shape(std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max));
865 output_list.push_back(output);
866 }
867 return std::make_shared<AbstractTuple>(output_list);
868 }
869
InferImplSequenceMask(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)870 AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
871 const AbstractBasePtrList &args_spec_list) {
872 const std::string &op_name = primitive->name();
873 const size_t size_expected = 2;
874 CheckArgsSize(op_name, args_spec_list, size_expected);
875
876 AbstractTensorPtr lengths = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
877 (void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s");
878
879 int64_t maxlen_value = 0;
880
881 if (args_spec_list[1]->isa<AbstractScalar>()) {
882 AbstractScalarPtr maxlen = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
883 (void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s");
884
885 TypePtr maxlen_type = nullptr;
886 maxlen_type = maxlen->GetTypeTrack();
887 MS_EXCEPTION_IF_NULL(maxlen_type);
888
889 if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) {
890 maxlen_value = static_cast<int64_t>(GetValue<int32_t>(maxlen->BuildValue()));
891 } else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) {
892 maxlen_value = GetValue<int64_t>(maxlen->BuildValue());
893 }
894 } else if (args_spec_list[1]->isa<AbstractTensor>()) {
895 auto maxlen_tensor_ptr = args_spec_list[1]->cast<AbstractTensorPtr>();
896 MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr);
897 auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue();
898 MS_EXCEPTION_IF_NULL(maxlen_value_ptr);
899 auto maxlen_tensor = maxlen_value_ptr->cast<tensor::TensorPtr>();
900 MS_EXCEPTION_IF_NULL(maxlen_tensor);
901 maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c());
902 }
903
904 if (maxlen_value <= 0) {
905 MS_LOG(EXCEPTION) << "maxlen must be positive, but got: " << maxlen_value;
906 }
907
908 ShapeVector lengths_shape = lengths->shape()->shape();
909 ShapeVector lengths_shape_min = lengths->shape()->min_shape();
910 if (lengths_shape_min.empty()) {
911 lengths_shape_min = lengths_shape;
912 }
913 ShapeVector lengths_shape_max = lengths->shape()->max_shape();
914 if (lengths_shape_max.empty()) {
915 lengths_shape_max = lengths_shape;
916 }
917
918 lengths_shape.push_back(maxlen_value);
919 lengths_shape_min.push_back(maxlen_value);
920 lengths_shape_max.push_back(maxlen_value);
921
922 ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max);
923 return std::make_shared<AbstractTensor>(kBool, output_shape);
924 }
925
InferImplConcat(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)926 AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
927 const AbstractBasePtrList &args_spec_list) {
928 MS_EXCEPTION_IF_NULL(primitive);
929 const std::string op_name = primitive->name();
930 if (args_spec_list.empty()) {
931 MS_LOG(EXCEPTION) << "args_spec_list is empty.";
932 }
933
934 AbstractTuplePtr arg = nullptr;
935 AbstractTensorPtr tensor_base = nullptr;
936 size_t tuple_len = 0;
937 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
938 if (args_spec_list[0]->isa<AbstractTuple>()) {
939 CheckArgsSize(op_name, args_spec_list, 1);
940 arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
941 tuple_len = arg->elements().size();
942 tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
943 } else if (args_spec_list[0]->isa<AbstractTensor>()) {
944 tuple_len = args_spec_list.size();
945 tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
946 }
947
948 MS_EXCEPTION_IF_NULL(tensor_base);
949 ShapeVector shape_base = tensor_base->shape()->shape();
950 int64_t rank_base = SizeToLong(shape_base.size());
951 ShapeVector min_shape_base = tensor_base->shape()->min_shape();
952 ShapeVector max_shape_base = tensor_base->shape()->max_shape();
953 CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base);
954
955 primitive->set_attr("T", tensor_base->element()->BuildType());
956 primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len)));
957
958 ValuePtr axis = primitive->GetAttr("axis");
959 // Axis value should be in [-(rank_base + 1), rank_base).
960 int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
961 // If axis is negative, add offset(rank_base) to turn it to positive.
962 axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base));
963
964 int64_t all_shp = shape_base[axis_value];
965 int64_t min_all_shp = min_shape_base[axis_value];
966 int64_t max_all_shp = max_shape_base[axis_value];
967 for (size_t i = 1; i < tuple_len; ++i) {
968 AbstractTensorPtr tensor = nullptr;
969 if (args_spec_list[0]->isa<AbstractTuple>()) {
970 tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
971 } else if (args_spec_list[0]->isa<AbstractTensor>()) {
972 tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
973 }
974 ShapeVector shape_tensor = tensor->shape()->shape();
975 int64_t rank_tensor = SizeToLong(shape_tensor.size());
976 ShapeVector min_shape_tensor = tensor->shape()->min_shape();
977 ShapeVector max_shape_tensor = tensor->shape()->max_shape();
978 CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor);
979 (void)CheckDtypeSame(op_name, tensor_base, tensor);
980 if (rank_tensor != rank_base) {
981 MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank";
982 }
983 for (int j = 0; j < rank_base; ++j) {
984 if (j != axis_value && shape_tensor[j] != shape_base[j]) {
985 MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Size";
986 }
987 }
988 if (all_shp == -1 || shape_base[axis_value] == -1) {
989 all_shp = -1;
990 } else {
991 all_shp += shape_tensor[axis_value];
992 }
993 min_all_shp += min_shape_tensor[axis_value];
994 max_all_shp += max_shape_tensor[axis_value];
995 }
996
997 AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
998 MS_EXCEPTION_IF_NULL(ret);
999 auto shape = ret->shape()->shape();
1000 auto min_shape = ret->shape()->min_shape();
1001 auto max_shape = ret->shape()->max_shape();
1002 CheckMinMaxShape(shape, &min_shape, &max_shape);
1003 shape[axis_value] = all_shp;
1004 min_shape[axis_value] = min_all_shp;
1005 max_shape[axis_value] = max_all_shp;
1006 ret->set_shape(std::make_shared<Shape>(shape, min_shape, max_shape));
1007 return ret;
1008 }
1009
InferImplRange(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1010 AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1011 const AbstractBasePtrList &args_spec_list) {
1012 const std::string &op_name = primitive->name();
1013 if (args_spec_list.size() == 1) {
1014 return args_spec_list[0]->Broaden();
1015 }
1016 constexpr size_t args_size = 3;
1017 constexpr size_t range_start_index = 0;
1018 constexpr size_t range_end_index = 1;
1019 constexpr size_t range_delta_index = 2;
1020 CheckArgsSize(op_name, args_spec_list, args_size);
1021 AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, range_start_index);
1022 AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, range_end_index);
1023 AbstractTensorPtr range_delta = CheckArg<AbstractTensor>(op_name, args_spec_list, range_delta_index);
1024
1025 TypePtrList supported_types = {kInt64, kInt32, kFloat32, kFloat64};
1026 TypePtr range_start_type = CheckTensorDType(range_start, supported_types, "range_start input of Range should be %s");
1027 TypePtr range_end_type = CheckTensorDType(range_end, supported_types, "range_start input of Range should be %s");
1028 TypePtr range_delta_type = CheckTensorDType(range_delta, supported_types, "range_start input of Range should be %s");
1029 // check all 3 inputs are same type
1030 if (!IsIdentidityOrSubclass(range_start_type, range_end_type) ||
1031 !IsIdentidityOrSubclass(range_end_type, range_delta_type)) {
1032 MS_LOG(EXCEPTION) << "All inputs must have same type, but got: " << args_spec_list[range_start_index]->type_name()
1033 << ", " << args_spec_list[range_end_index]->type_name() << ", and "
1034 << args_spec_list[range_delta_index]->type_name();
1035 }
1036
1037 int64_t max_output_length = -1;
1038 ValuePtr max_output_length_ptr = primitive->GetAttr("maxlen");
1039 max_output_length = GetValue<int64_t>(max_output_length_ptr);
1040 ShapeVector output_shape = {Shape::SHP_ANY};
1041 ShapeVector min_shape = {1};
1042 ShapeVector max_shape = {max_output_length};
1043 ShapePtr shape = std::make_shared<Shape>(output_shape, min_shape, max_shape);
1044
1045 return std::make_shared<AbstractTensor>(range_start_type, shape);
1046 }
1047
InferImplArgMaxWithValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1048 AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1049 const AbstractBasePtrList &args_spec_list) {
1050 const std::string op_name = primitive->name();
1051 CheckArgsSize(op_name, args_spec_list, 1);
1052 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1053 MS_EXCEPTION_IF_NULL(x);
1054 MS_EXCEPTION_IF_NULL(x->shape());
1055 // check keep_dims
1056 ValuePtr keep_dims = primitive->GetAttr("keep_dims");
1057 MS_EXCEPTION_IF_NULL(keep_dims);
1058 if (!keep_dims->isa<BoolImm>()) {
1059 MS_LOG(EXCEPTION) << "keep_dims should be Bool.";
1060 }
1061 bool keep_dims_value = GetValue<bool>(keep_dims);
1062 // check axis
1063 ValuePtr axis = primitive->GetAttr("axis");
1064 MS_EXCEPTION_IF_NULL(axis);
1065 if (!axis->isa<Int32Imm>() && !axis->isa<Int64Imm>()) {
1066 MS_LOG(EXCEPTION) << "axis should be Int.";
1067 }
1068 // check axis convert negative to positive value
1069 auto check_axis = [](int64_t &axis, const size_t dim) -> void {
1070 auto dim_ = static_cast<int64_t>(dim);
1071 if (axis < -dim_ || axis >= dim_) {
1072 MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << ".";
1073 }
1074 if (axis >= -dim_ && axis < 0) {
1075 axis += dim_;
1076 }
1077 return;
1078 };
1079 // main calculate shape func
1080 auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void {
1081 (void)shape.insert(shape.end(), x_shape.begin(), x_shape.end());
1082 auto axis_value = GetValue<int64_t>(axis);
1083 check_axis(axis_value, x_shape.size());
1084 if (keep_dims_value) {
1085 shape[axis_value] = 1;
1086 } else {
1087 (void)shape.erase(std::begin(shape) + axis_value);
1088 }
1089 };
1090 ShapeVector shape = {};
1091 ShapeVector min_shape = {};
1092 ShapeVector max_shape = {};
1093 ShapeVector x_shape = x->shape()->shape();
1094 ShapeVector x_min_shape = x->shape()->min_shape();
1095 ShapeVector x_max_shape = x->shape()->max_shape();
1096 CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
1097 cal_shape(shape, x_shape);
1098 cal_shape(min_shape, x_min_shape);
1099 cal_shape(max_shape, x_max_shape);
1100 TypePtr idx_type = kInt32;
1101 auto index = std::make_shared<AbstractTensor>(idx_type, std::make_shared<Shape>(shape, min_shape, max_shape));
1102 auto value = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
1103 AbstractBasePtrList result = {index, value};
1104 return std::make_shared<AbstractTuple>(result);
1105 }
1106
InferImplSort(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1107 AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1108 const AbstractBasePtrList &args_spec_list) {
1109 const std::string &op_name = primitive->name();
1110 CheckArgsSize(op_name, args_spec_list, 1);
1111 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1112
1113 TypePtrList supported_types = {kFloat16, kFloat32};
1114 (void)CheckTensorDType(input, supported_types, "input for Sort should be %s");
1115
1116 ValuePtr axis_ptr = primitive->GetAttr("axis");
1117 int64_t axis = GetValue<int64_t>(axis_ptr);
1118 int64_t input_rank = input->shape()->shape().size();
1119 if (input_rank == 0) {
1120 MS_LOG(EXCEPTION) << "input must be a Tensor with dimension > 0.";
1121 }
1122
1123 if (!(axis >= -input_rank && axis < input_rank)) {
1124 MS_LOG(EXCEPTION) << "axis is not in the valid range [" << -input_rank << ", " << input_rank << ").";
1125 }
1126
1127 auto sorted_values = std::make_shared<AbstractTensor>(input->element(), input->shape());
1128 TypePtr idx_type = kInt32;
1129 auto indices = std::make_shared<AbstractTensor>(idx_type, input->shape());
1130 AbstractBasePtrList result = {sorted_values, indices};
1131 return std::make_shared<AbstractTuple>(result);
1132 }
1133
InferImplMaskedSelect(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1134 AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1135 const AbstractBasePtrList &args_spec_list) {
1136 const std::string op_name = primitive->name();
1137 const size_t size_expected = 2;
1138 CheckArgsSize(op_name, args_spec_list, size_expected);
1139 AbstractTensorPtr x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1140 AbstractTensorPtr mask = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
1141
1142 auto x_shape = x->shape();
1143 auto mask_shape = mask->shape();
1144 auto broadcast_shape = BroadcastShape(x_shape->shape(), mask_shape->shape());
1145 ShapeVector y_shape = {Shape::SHP_ANY};
1146 ShapeVector min_shape = {1};
1147 int64_t max_size = std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<int64_t>());
1148 ShapeVector max_shape = {max_size};
1149 if (max_shape.empty()) {
1150 max_shape = x_shape->shape();
1151 }
1152 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(y_shape, min_shape, max_shape));
1153 }
1154
InferImplDynamicStitch(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1155 AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1156 const AbstractBasePtrList &args_spec_list) {
1157 MS_EXCEPTION_IF_NULL(primitive);
1158 auto prim_name = primitive->name();
1159 constexpr int64_t args_size = 2;
1160 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kEqual, args_size,
1161 prim_name);
1162 for (const auto &item : args_spec_list) {
1163 MS_EXCEPTION_IF_NULL(item);
1164 }
1165
1166 // input0: indices
1167 auto input_tuple = args_spec_list[0]->cast<abstract::AbstractSequeuePtr>();
1168 MS_EXCEPTION_IF_NULL(input_tuple);
1169 auto indices = input_tuple->elements();
1170 auto indices0 = indices[0]->cast<abstract::AbstractTensorPtr>();
1171 MS_EXCEPTION_IF_NULL(indices0);
1172 auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape];
1173
1174 // input1: data
1175 auto input_tuple_1 = args_spec_list[1]->cast<abstract::AbstractSequeuePtr>();
1176 MS_EXCEPTION_IF_NULL(input_tuple_1);
1177 auto data = input_tuple_1->elements();
1178 auto data0 = data[0]->cast<abstract::AbstractTensorPtr>();
1179 MS_EXCEPTION_IF_NULL(data0);
1180 auto data0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data0->BuildShape())[kShape];
1181 if (indices.size() != data.size()) {
1182 MS_LOG(EXCEPTION) << "The number of input[0] must be the same as input[0]!";
1183 }
1184
1185 int64_t indices_total_size = 0;
1186 std::map<std::string, TypePtr> types;
1187 (void)types.emplace("data0", data0->BuildType());
1188 for (size_t i = 1; i < data.size(); ++i) {
1189 auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape];
1190 auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape];
1191 if (indicesi_shape.size() > datai_shape.size()) {
1192 MS_LOG(EXCEPTION) << "The rank of indices[i] must be <= rank of data[i]!";
1193 }
1194 indices_total_size += SizeToLong(indicesi_shape.size());
1195 }
1196 std::set<TypePtr> valid_types = ops::common_valid_types;
1197 auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
1198
1199 ShapeVector out_shape = {abstract::Shape::SHP_ANY};
1200 for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) {
1201 out_shape.push_back(data0_shape[i]);
1202 }
1203 const int64_t EXPAND_MAX = 10;
1204 ShapeVector min_shape = out_shape;
1205 ShapeVector max_shape = out_shape;
1206 min_shape[0] = 1;
1207 max_shape[0] = indices_total_size * EXPAND_MAX;
1208 return std::make_shared<AbstractTensor>(infer_type,
1209 std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
1210 }
1211
InferImplTensorCopySlices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1212 AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1213 const AbstractBasePtrList &args_spec_list) {
1214 auto &op_name = primitive->name();
1215 constexpr auto kTensorCopySlicesInputNum = 5;
1216 CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum);
1217 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1218 return std::make_shared<AbstractTensor>(input->element(), input->shape());
1219 }
1220 } // namespace abstract
1221 } // namespace mindspore
1222