• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <functional>
19 #include <iterator>
20 #include <numeric>
21 #include "abstract/infer_functions.h"
22 #include "abstract/utils.h"
23 #include "abstract/param_validator.h"
24 #include "utils/shape_utils.h"
25 #include "ops/op_utils.h"
26 
27 namespace mindspore {
28 namespace abstract {
InferImplScalarToArray(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)29 AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
30                                        const AbstractBasePtrList &args_spec_list) {
31   // Inputs: a scalar.
32   const std::string op_name = primitive->name();
33   CheckArgsSize(op_name, args_spec_list, 1);
34   AbstractScalarPtr arg = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
35   return std::make_shared<AbstractTensor>(arg, std::make_shared<Shape>());
36 }
37 
InferImplArrayToScalar(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)38 AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
39                                        const AbstractBasePtrList &args_spec_list) {
40   // Inputs: a tensor with 0 shape.
41   const std::string op_name = primitive->name();
42   CheckArgsSize(op_name, args_spec_list, 1);
43   auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
44   auto a_shp = arg->shape();
45   MS_EXCEPTION_IF_NULL(a_shp);
46   if (!a_shp->shape().empty()) {
47     MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape.";
48   }
49   return arg->element();
50 }
51 
InferImplBroadCastShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)52 AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
53                                         const AbstractBasePtrList &args_spec_list) {
54   // Inputs: two tuples.
55   const std::string op_name = primitive->name();
56   constexpr size_t args_size = 2;
57   CheckArgsSize(op_name, args_spec_list, args_size);
58   auto xs = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
59   auto ys = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
60   auto x_value = xs->BuildValue();
61   MS_EXCEPTION_IF_NULL(x_value);
62   auto value_tuple_x = x_value->cast<ValueTuplePtr>();
63   MS_EXCEPTION_IF_NULL(value_tuple_x);
64   auto shp_tuple_x = value_tuple_x->value();
65   ShapeVector shp_x;
66   (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x),
67                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
68   auto tupe_value_y = ys->BuildValue();
69   MS_EXCEPTION_IF_NULL(tupe_value_y);
70   auto value_tuple_y = tupe_value_y->cast<ValueTuplePtr>();
71   MS_EXCEPTION_IF_NULL(value_tuple_y);
72   auto shp_tuple_y = value_tuple_y->value();
73   ShapeVector shp_y;
74   (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y),
75                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
76 
77   ShapeVector res = BroadcastShape(shp_x, shp_y);
78   MS_EXCEPTION_IF_NULL(args_spec_list[1]);
79   if (res.empty()) {
80     MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
81                       << args_spec_list[1]->ToString();
82   }
83 
84   AbstractBasePtrList elems;
85   (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
86     return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
87   });
88 
89   return std::make_shared<AbstractTuple>(elems);
90 }
91 
InferImplStack(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)92 AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
93                                const AbstractBasePtrList &args_spec_list) {
94   // Inputs: a tuple of tensor.
95   const std::string op_name = primitive->name();
96   CheckArgsSize(op_name, args_spec_list, 1);
97   auto arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
98   if (arg->elements().empty()) {
99     MS_LOG(EXCEPTION) << "Arg elements is empty.";
100   }
101 
102   size_t tuple_len = arg->elements().size();
103   AbstractTensorPtr tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
104   auto shape = tensor_base->shape();
105   MS_EXCEPTION_IF_NULL(shape);
106   int64_t rank_base = SizeToLong(shape->shape().size());
107 
108   ValuePtr axis = primitive->GetAttr("axis");
109   // Axis value should be in [-(rank_base + 1), rank_base).
110   int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
111   // If axis is negative, add offset(rank_base + 1) to turn it to positive.
112   axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base + 1));
113 
114   for (size_t i = 1; i < tuple_len; ++i) {
115     AbstractTensorPtr tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
116     (void)CheckDtypeSame(op_name, tensor_base, tensor);
117     (void)CheckShapeSame(op_name, tensor_base, tensor);
118   }
119   auto element = tensor_base->element();
120   MS_EXCEPTION_IF_NULL(element);
121   primitive->set_attr("N", MakeValue(SizeToLong(tuple_len)));
122   primitive->set_attr("T", element->BuildType());
123 
124   AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
125   MS_EXCEPTION_IF_NULL(ret);
126   auto ret_shape_ptr = ret->shape();
127   MS_EXCEPTION_IF_NULL(ret_shape_ptr);
128   auto ret_shape = ret_shape_ptr->shape();
129   (void)ret_shape.insert(ret_shape.begin() + axis_value, SizeToLong(tuple_len));
130   ret->set_shape(std::make_shared<Shape>(ret_shape));
131   return ret;
132 }
133 
InferImplUnique(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)134 AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
135                                 const AbstractBasePtrList &args_spec_list) {
136   // inputs: a 1-d Tensor
137   const std::string op_name = primitive->name();
138   CheckArgsSize(op_name, args_spec_list, 1);
139   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
140 
141   auto shape = input->shape();
142   MS_EXCEPTION_IF_NULL(shape);
143   if (shape->shape().size() != 1) {
144     MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
145   }
146   ShapeVector ids_shape = {Shape::SHP_ANY};
147   ShapeVector min_shape = {1};
148   ShapeVector max_shape = shape->max_shape();
149   if (max_shape.empty()) {
150     max_shape = shape->shape();
151   }
152 
153   auto ids =
154     std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
155   // Currently we choose the same data type as input for the idx.
156   TypePtr ids_idx_type = kInt32;
157   MS_EXCEPTION_IF_NULL(input->element());
158   MS_EXCEPTION_IF_NULL(input->element()->GetTypeTrack());
159   if (input->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
160     ids_idx_type = kInt64;
161   }
162   ShapeVector idx_shape = shape->shape();
163   ShapeVector idx_min_shape = shape->min_shape();
164   if (idx_min_shape.empty()) {
165     idx_min_shape = shape->shape();
166   }
167   ShapeVector idx_max_shape = shape->max_shape();
168   if (idx_max_shape.empty()) {
169     idx_max_shape = shape->shape();
170   }
171 
172   auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
173   ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
174   // outputs: ids, ids_idx
175   AbstractBasePtrList elements = {ids, ids_idx};
176   return std::make_shared<AbstractTuple>(elements);
177 }
178 
InferImplPadAndShift(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)179 AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
180                                      const AbstractBasePtrList &args_spec_list) {
181   // inputs: a 1-d Tensor
182   const std::string op_name = primitive->name();
183   const size_t size_expected = 3;
184   CheckArgsSize(op_name, args_spec_list, size_expected);
185   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
186   MS_EXCEPTION_IF_NULL(input);
187   auto shape = input->shape();
188   MS_EXCEPTION_IF_NULL(shape);
189   if (shape->shape().size() != 1) {
190     MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
191   }
192   ShapeVector ids_shape = {Shape::SHP_ANY};
193   ShapeVector min_shape = {1};
194   ShapeVector max_shape = shape->max_shape();
195   if (max_shape.empty()) {
196     max_shape = shape->shape();
197   }
198   return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
199 }
200 
InferImplUniqueGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)201 AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
202                                     const AbstractBasePtrList &args_spec_list) {
203   // inputs: a 1-d Tensor
204   const std::string op_name = primitive->name();
205   const size_t size_expected = 2;
206   CheckArgsSize(op_name, args_spec_list, size_expected);
207   AbstractTuplePtr dout = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
208   CheckArgsSize(op_name + " dout", dout->elements(), size_expected);
209   auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0);
210   auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1);
211   auto ids_shape = ids->shape();
212   auto ids_idx_shape = ids_idx->shape();
213   MS_EXCEPTION_IF_NULL(ids_shape);
214   MS_EXCEPTION_IF_NULL(ids_idx_shape);
215   if (ids->shape()->shape().size() != 1) {
216     MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1.";
217   }
218   if (ids_idx->shape()->shape().size() != 1) {
219     MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1.";
220   }
221 
222   // outputs: dx
223   return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape());
224 }
225 
InferImplUnsortedSegmentSum(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)226 AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
227                                             const AbstractBasePtrList &args_spec_list) {
228   const std::string op_name = primitive->name();
229   constexpr size_t args_size = 3;
230   CheckArgsSize(op_name, args_spec_list, args_size);
231   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
232   MS_EXCEPTION_IF_NULL(x);
233   MS_EXCEPTION_IF_NULL(x->shape());
234   auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
235   MS_EXCEPTION_IF_NULL(segment_ids);
236   MS_EXCEPTION_IF_NULL(segment_ids->shape());
237   auto segment_ids_shape = segment_ids->shape()->shape();
238   (void)CheckTensorDType(x, {kFloat16, kFloat32, kFloat64, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s");
239   (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentSum should be %s");
240   bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());  // check if dynamic shape
241   bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
242   bool op_is_dynamic = x_is_dyn || ids_is_dyn;
243   auto x_shape = x->shape()->shape();
244   ShapeVector shape;
245   int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
246   if (num_segments_value <= 0) {
247     MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentSum";
248   }
249   shape.emplace_back(num_segments_value);
250   shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
251   if (!op_is_dynamic) {  // not dynamic
252     for (size_t i = 0; i < segment_ids_shape.size(); i++) {
253       if (x_shape[i] != segment_ids_shape[i]) {
254         MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values";
255       }
256     }
257     return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
258   }
259   ShapeVector min_shape;
260   ShapeVector max_shape;
261   min_shape.emplace_back(num_segments_value);
262   max_shape.emplace_back(num_segments_value);
263   bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
264   bool ids_any_shape =
265     std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
266   if (!x_any_shape && !ids_any_shape) {  // only validate when shapes fully known
267     for (size_t i = 0; i < segment_ids_shape.size(); i++) {
268       if (x_shape[i] != segment_ids_shape[i]) {
269         MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values";
270       }
271     }
272   }
273   ShapeVector x_shape_min;
274   ShapeVector x_shape_max;
275   x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
276   x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
277   min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
278   max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
279   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
280 }
281 
InferImplUnsortedSegmentMax(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)282 AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
283                                             const AbstractBasePtrList &args_spec_list) {
284   const std::string op_name = primitive->name();
285   const size_t size_expected = 3;
286   CheckArgsSize(op_name, args_spec_list, size_expected);
287   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
288   MS_EXCEPTION_IF_NULL(x->shape());
289   auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
290   MS_EXCEPTION_IF_NULL(segment_ids);
291   MS_EXCEPTION_IF_NULL(segment_ids->shape());
292   auto segment_ids_shape = segment_ids->shape()->shape();
293   (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s");
294   (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
295   bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());  // check if dynamic
296   bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
297   bool op_is_dynamic = x_is_dyn || ids_is_dyn;
298   auto x_shape = x->shape()->shape();
299   ShapeVector shape;
300   int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
301   if (num_segments_value <= 0) {
302     MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax";
303   }
304   shape.emplace_back(num_segments_value);
305   shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
306   if (!op_is_dynamic) {  // not dynamic
307     if (x_shape[0] != segment_ids_shape[0]) {
308       MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
309     }
310     return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
311   }
312   ShapeVector min_shape;
313   ShapeVector max_shape;
314   min_shape.emplace_back(num_segments_value);
315   max_shape.emplace_back(num_segments_value);
316   bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
317   bool ids_any_shape =
318     std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
319   if (!x_any_shape && !ids_any_shape) {
320     if (x_shape[0] != segment_ids_shape[0]) {
321       MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
322     }
323   }
324   ShapeVector x_shape_min;
325   ShapeVector x_shape_max;
326   x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
327   x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
328   min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
329   max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
330   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
331 }
332 
InferImplUnsortedSegmentMin(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)333 AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
334                                             const AbstractBasePtrList &args_spec_list) {
335   const std::string op_name = primitive->name();
336   const size_t size_expected = 3;
337   CheckArgsSize(op_name, args_spec_list, size_expected);
338   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
339   MS_EXCEPTION_IF_NULL(x);
340   MS_EXCEPTION_IF_NULL(x->shape());
341   auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
342   MS_EXCEPTION_IF_NULL(segment_ids);
343   MS_EXCEPTION_IF_NULL(segment_ids->shape());
344   auto segment_ids_shape = segment_ids->shape()->shape();
345   (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMin should be %s");
346   (void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMin should be %s");
347   bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());  // check if dynamic shape
348   bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
349   bool op_is_dynamic = x_is_dyn || ids_is_dyn;
350   auto x_shape = x->shape()->shape();
351   ShapeVector shape;
352   int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
353   if (num_segments_value <= 0) {
354     MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMin";
355   }
356   shape.emplace_back(num_segments_value);
357   shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
358   if (!op_is_dynamic) {  // not dynamic
359     if (x_shape[0] != segment_ids_shape[0]) {
360       MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
361     }
362     return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
363   }
364   ShapeVector min_shape;
365   ShapeVector max_shape;
366   min_shape.emplace_back(num_segments_value);
367   max_shape.emplace_back(num_segments_value);
368   bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
369   bool ids_any_shape =
370     std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
371   if (!x_any_shape && !ids_any_shape) {  // only validate when shapes fully known
372     if (x_shape[0] != segment_ids_shape[0]) {
373       MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
374     }
375   }
376   ShapeVector x_shape_min;
377   ShapeVector x_shape_max;
378   x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
379   x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
380   min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
381   max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
382   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
383 }
384 
InferImplScatterAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)385 AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
386                                     const AbstractBasePtrList &args_spec_list) {
387   constexpr auto kScatterAddInputNum = 3;
388   const std::string op_name = primitive->name();
389   CheckRequiredArgsSize(op_name, args_spec_list, kScatterAddInputNum);
390   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
391   MS_EXCEPTION_IF_NULL(x);
392   MS_EXCEPTION_IF_NULL(x->shape());
393   ShapeVector shape = x->shape()->shape();
394   ShapeVector min_shape = x->shape()->min_shape();
395   ShapeVector max_shape = x->shape()->max_shape();
396   CheckMinMaxShape(shape, &min_shape, &max_shape);
397   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
398 }
399 
InferImplScatterSub(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)400 AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
401                                     const AbstractBasePtrList &args_spec_list) {
402   constexpr auto kScatterSubInputNum = 3;
403   const std::string op_name = primitive->name();
404   CheckRequiredArgsSize(op_name, args_spec_list, kScatterSubInputNum);
405   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
406   MS_EXCEPTION_IF_NULL(x);
407   MS_EXCEPTION_IF_NULL(x->shape());
408   ShapeVector shape = x->shape()->shape();
409   ShapeVector min_shape = x->shape()->min_shape();
410   ShapeVector max_shape = x->shape()->max_shape();
411   CheckMinMaxShape(shape, &min_shape, &max_shape);
412   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
413 }
414 
InferImplScatterUpdate(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)415 AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
416                                        const AbstractBasePtrList &args_spec_list) {
417   const std::string op_name = primitive->name();
418   CheckRequiredArgsSize(op_name, args_spec_list, 3);
419   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
420   MS_EXCEPTION_IF_NULL(x);
421   MS_EXCEPTION_IF_NULL(x->shape());
422   ShapeVector shape = x->shape()->shape();
423   ShapeVector min_shape = x->shape()->min_shape();
424   ShapeVector max_shape = x->shape()->max_shape();
425   CheckMinMaxShape(shape, &min_shape, &max_shape);
426   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
427 }
428 
InferImplMapCacheIdx(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)429 AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
430                                      const AbstractBasePtrList &args_spec_list) {
431   const std::string op_name = primitive->name();
432   const size_t size_expected = 5;
433   CheckArgsSize(op_name, args_spec_list, size_expected);
434   auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
435   MS_EXCEPTION_IF_NULL(hash_map->shape());
436 
437   auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
438   auto indices_shp = indices->shape();
439   MS_EXCEPTION_IF_NULL(indices_shp);
440 
441   ShapeVector shape;
442   ShapeVector min_shape;
443   ShapeVector max_shape;
444   if (!indices_shp->max_shape().empty()) {
445     max_shape = indices_shp->max_shape();
446   } else {
447     max_shape = indices_shp->shape();
448   }
449   for (size_t i = 0; i < max_shape.size(); i++) {
450     shape.emplace_back(Shape::SHP_ANY);
451     min_shape.emplace_back(1);
452   }
453 
454   auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
455   auto old_emb_idx =
456     std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
457   auto miss_emb_idx =
458     std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
459   auto swap_emb_idx =
460     std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
461 
462   AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
463   return std::make_shared<AbstractTuple>(elements);
464 }
465 
InferImplCacheSwapTable(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)466 AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
467                                         const AbstractBasePtrList &args_spec_list) {
468   const std::string op_name = primitive->name();
469   const size_t size_expected = 3;
470   CheckArgsSize(op_name, args_spec_list, size_expected);
471   auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
472   auto cache_table_shp = cache_table->shape();
473   MS_EXCEPTION_IF_NULL(cache_table_shp);
474 
475   auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
476   auto swap_cache_idx_shp = swap_cache_idx->shape();
477   MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);
478 
479   auto cache_table_shape = cache_table_shp->shape();
480   auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
481   ShapeVector shape;
482   shape.emplace_back(swap_cache_idx_shape[0]);
483   shape.emplace_back(cache_table_shape[1]);
484   auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape();
485   ShapeVector max_shape;
486   ShapeVector min_shape;
487   if (!swap_cache_idx_max_shape.empty()) {
488     max_shape.emplace_back(swap_cache_idx_max_shape[0]);
489     max_shape.emplace_back(cache_table_shape[1]);
490   } else {
491     max_shape = shape;
492   }
493   for (size_t i = 0; i < max_shape.size(); ++i) {
494     min_shape.emplace_back(1);
495   }
496 
497   AbstractTensorPtr ret =
498     std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
499   return ret;
500 }
501 
InferImplUpdateCache(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)502 AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
503                                      const AbstractBasePtrList &args_spec_list) {
504   const std::string op_name = primitive->name();
505   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
506 
507   ShapeVector shape;
508   shape.emplace_back(1);
509 
510   AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
511   return ret;
512 }
513 
InferImplSubAndFilter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)514 AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
515                                       const AbstractBasePtrList &args_spec_list) {
516   const std::string op_name = primitive->name();
517   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
518   auto input_x_shp = input_x->shape();
519   MS_EXCEPTION_IF_NULL(input_x_shp);
520 
521   ShapeVector shape;
522   ShapeVector min_shape;
523   ShapeVector max_shape;
524   if (!input_x_shp->max_shape().empty()) {
525     max_shape = input_x_shp->max_shape();
526   } else {
527     max_shape = input_x_shp->shape();
528   }
529   for (size_t i = 0; i < max_shape.size(); i++) {
530     shape.emplace_back(Shape::SHP_ANY);
531     min_shape.emplace_back(1);
532   }
533   auto filter_res =
534     std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
535   auto filter_idx =
536     std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
537   AbstractBasePtrList elements = {filter_res, filter_idx};
538   return std::make_shared<AbstractTuple>(elements);
539 }
540 
InferImplDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)541 AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
542                              const AbstractBasePtrList &args_spec_list) {
543   const std::string op_name = primitive->name();
544   const size_t size_expected = 2;
545   CheckArgsSize(op_name, args_spec_list, size_expected);
546   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
547   auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
548   MS_EXCEPTION_IF_NULL(x);
549   MS_EXCEPTION_IF_NULL(x->shape());
550   MS_EXCEPTION_IF_NULL(y);
551   MS_EXCEPTION_IF_NULL(y->shape());
552   ShapeVector x_shape = x->shape()->shape();
553   ShapeVector y_shape = y->shape()->shape();
554   ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
555   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
556 }
557 
InferImplRealDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)558 AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
559                                  const AbstractBasePtrList &args_spec_list) {
560   const std::string op_name = primitive->name();
561   const size_t size_expected = 2;
562   CheckArgsSize(op_name, args_spec_list, size_expected);
563   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
564   auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
565   MS_EXCEPTION_IF_NULL(x);
566   MS_EXCEPTION_IF_NULL(x->shape());
567   MS_EXCEPTION_IF_NULL(y);
568   MS_EXCEPTION_IF_NULL(y->shape());
569   ShapeVector x_shape = x->shape()->shape();
570   ShapeVector y_shape = y->shape()->shape();
571   ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
572   if (out_shape.empty()) {
573     MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
574                       << args_spec_list[1]->ToString();
575   }
576   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
577 }
578 
InferImplGatherV2(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)579 AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
580                                   const AbstractBasePtrList &args_spec_list) {
581   const std::string &op_name = primitive->name();
582   constexpr size_t args_size = 3;
583   CheckArgsSize(op_name, args_spec_list, args_size);
584   AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
585   AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
586   bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
587   bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
588   int64_t axis_val = 0;
589   // 3rd input is a Tensor when GatherV2 is a dynamic shape operator
590   constexpr size_t aixs_index = 2;
591   if (args_spec_list[aixs_index]->isa<AbstractTensor>()) {
592     auto axis = args_spec_list[aixs_index]->cast<AbstractTensorPtr>();
593     MS_EXCEPTION_IF_NULL(axis);
594     auto axis_value_ptr = axis->BuildValue();
595     MS_EXCEPTION_IF_NULL(axis_value_ptr);
596     auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
597     MS_EXCEPTION_IF_NULL(axis_tensor);
598     axis_val = *static_cast<int64_t *>(axis_tensor->data_c());
599   } else if (args_spec_list[aixs_index]->isa<AbstractScalar>()) {
600     auto axis = args_spec_list[aixs_index]->cast<AbstractScalarPtr>();
601     axis_val = GetValue<int64_t>(axis->BuildValue());
602   } else {
603     MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name();
604   }
605   auto params_shp = params->shape()->shape();
606   auto indices_shp = indices->shape()->shape();
607   auto params_rank = static_cast<int64_t>(params_shp.size());
608   // either inputs or both can be dynamic and computation requires min/max shapes for both
609   ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape();
610   ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape();
611   ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape();
612   ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape();
613   // check axis_val within interval: [-params_rank, params_rank)
614   if (-params_rank > axis_val || axis_val >= params_rank) {
615     MS_LOG(EXCEPTION) << "For Gather - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) "
616                       << "Got " << axis_val << ".";
617   }
618   if (axis_val < 0) {
619     axis_val += params_rank;
620   }
621   auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector &params_vec) -> ShapeVector {
622     ShapeVector out_vec;
623     std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
624     copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec));
625     copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec));
626     return out_vec;
627   };
628   ShapeVector out_shape = calc_shape(indices_shp, params_shp);
629   if (ind_dyn || param_dyn) {
630     ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
631     ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
632     return std::make_shared<AbstractTensor>(params->element(),
633                                             std::make_shared<Shape>(out_shape, min_shape, max_shape));
634   }
635   return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
636 }
637 
InferImplDynamicAssign(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)638 AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
639                                        const AbstractBasePtrList &args_spec_list) {
640   // Inputs: a tensor
641   const size_t size_expected = 2;
642   CheckArgsSize(primitive->name(), args_spec_list, size_expected);
643 
644   MS_LOG(INFO) << "InferImplDynamicAssign " << args_spec_list[0];
645   auto type = args_spec_list[0]->BuildType();
646   MS_EXCEPTION_IF_NULL(type);
647   if (type->type_id() == kObjectTypeRefKey) {
648     return args_spec_list[1]->Broaden();
649   } else {
650     auto x = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0);
651     auto y = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 1);
652     MS_EXCEPTION_IF_NULL(x);
653     MS_EXCEPTION_IF_NULL(y);
654     auto y_shape = y->shape();
655     MS_EXCEPTION_IF_NULL(y_shape);
656     if (!y_shape->max_shape().empty()) {
657       x->set_shape(y->shape());
658     }
659     return args_spec_list[0];
660   }
661 }
662 
InferImplEmbeddingLookup(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)663 AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
664                                          const AbstractBasePtrList &args_spec_list) {
665   const std::string op_name = primitive->name();
666   auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
667   auto params_shp = params->shape();
668   MS_EXCEPTION_IF_NULL(params_shp);
669   auto params_shape = params_shp->shape();
670   auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
671   auto indices_shp = indices->shape();
672   MS_EXCEPTION_IF_NULL(indices_shp);
673   auto indices_shape = indices_shp->shape();
674   auto indices_max_shape = indices_shp->max_shape();
675   auto indices_min_shape = indices_shp->min_shape();
676   ShapeVector shape;
677   ShapeVector max_shape;
678   ShapeVector min_shape;
679   shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
680   shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
681   if (!indices_max_shape.empty()) {
682     max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
683     max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end());
684   } else {
685     max_shape = shape;
686   }
687   if (!indices_min_shape.empty()) {
688     min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
689     min_shape.insert(min_shape.end(), params_shape.begin() + 1, params_shape.end());
690   } else {
691     min_shape = shape;
692   }
693 
694   AbstractTensorPtr ret =
695     std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
696   return ret;
697 }
698 
InferImplDynamicShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)699 AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
700                                       const AbstractBasePtrList &args_spec_list) {
701   const std::string &op_name = primitive->name();
702   CheckArgsSize(op_name, args_spec_list, 1);
703   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
704   MS_EXCEPTION_IF_NULL(input->shape());
705   auto shape = input->shape()->shape();
706   bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
707   ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
708   if (has_dyn_shape) {
709     auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(64));
710     auto min_value = MakeValue(input->shape()->min_shape());
711     auto max_value = MakeValue(input->shape()->max_shape());
712     auto tensor = std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
713     tensor->set_value_range(min_value, max_value);
714     return tensor;
715   }
716   auto shp_buf_size = sizeof(int64_t) * shape.size();
717   auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);
718 
719   return tensor->ToAbstract();
720 }
721 
InferImplTranspose(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)722 AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
723                                    const AbstractBasePtrList &args_spec_list) {
724   const std::string &op_name = primitive->name();
725   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
726   auto input_shp = input->shape()->shape();
727   ValuePtr perm = primitive->GetAttr("perm");
728   MS_EXCEPTION_IF_NULL(perm);
729   auto perm_val = perm->cast<ValueTuplePtr>();
730   MS_EXCEPTION_IF_NULL(perm_val);
731   auto perm_val_data = perm_val->value();
732   ShapeVector perm_vec;
733   (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
734                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
735   ShapeVector result_shp;
736   ShapeVector max_shp;
737   ShapeVector min_shp;
738   ShapeVector x_max_shp = input->shape()->max_shape();
739   ShapeVector x_min_shp = input->shape()->min_shape();
740   CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp);
741   for (size_t i = 0; i < perm_vec.size(); i++) {
742     auto idx = static_cast<size_t>(perm_vec[i]);
743     result_shp.push_back(input_shp[idx]);
744     max_shp.push_back(x_max_shp[idx]);
745     min_shp.push_back(x_min_shp[idx]);
746   }
747   return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
748 }
749 
InferImplReshape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)750 AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
751                                  const AbstractBasePtrList &args_spec_list) {
752   const std::string op_name = primitive->name();
753   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
754   MS_EXCEPTION_IF_NULL(x);
755   MS_EXCEPTION_IF_NULL(x->shape());
756   ShapeVector shape;
757   ShapeVector x_shape = x->shape()->shape();
758   ShapeVector x_max_shape = x->shape()->max_shape();
759   ShapeVector x_min_shape = x->shape()->min_shape();
760   if (x_max_shape.empty()) {
761     x_max_shape = x_shape;
762   }
763   if (x_min_shape.empty()) {
764     x_min_shape = x_shape;
765   }
766   ValuePtr sh = primitive->GetAttr("shape");
767   MS_EXCEPTION_IF_NULL(sh);
768   auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
769   MS_EXCEPTION_IF_NULL(reshape_value_tuple);
770   auto reshape_tuple = reshape_value_tuple->value();
771 
772   (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
773                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
774 
775   auto max_shape = shape;
776   auto min_shape = shape;
777   int64_t x_num = 1;
778   int64_t x_min_num = 1;
779   int64_t x_max_num = 1;
780   for (int64_t value : x_shape) {
781     x_num = LongMulWithOverflowCheck(value, x_num);
782   }
783   for (int64_t value : x_min_shape) {
784     x_min_num = LongMulWithOverflowCheck(value, x_min_num);
785   }
786   for (int64_t value : x_max_shape) {
787     x_max_num = LongMulWithOverflowCheck(value, x_max_num);
788   }
789 
790   auto it_first = find(shape.begin(), shape.end(), -1);
791   if (it_first != shape.end()) {
792     auto it_second = find(it_first + 1, shape.end(), -1);
793     if (it_second != shape.end()) {
794       MS_LOG(EXCEPTION) << "At most one component of input shape can be -1";
795     }
796     auto index = LongToSize(std::distance(shape.begin(), it_first));
797     int64_t infer_value = x_num;
798     int64_t infer_min_value = x_min_num;
799     int64_t infer_max_value = x_max_num;
800     for (size_t i = 0; i < shape.size(); ++i) {
801       int64_t value = shape[i];
802       if (value != -1 && value != 0) {
803         infer_value = infer_value / value;
804         infer_min_value = infer_min_value / value;
805         infer_max_value = infer_max_value / value;
806       }
807     }
808     shape[index] = infer_value;
809     min_shape[index] = infer_min_value;
810     max_shape[index] = infer_max_value;
811   }
812 
813   AbstractTensorPtr ret =
814     std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
815   return ret;
816 }
817 
InferImplMapUniform(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)818 AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
819                                     const AbstractBasePtrList &args_spec_list) {
820   // Inputs: one tensor.
821   const std::string op_name = primitive->name();
822   const size_t size_expected = 3;
823   CheckArgsSize(op_name, args_spec_list, size_expected);
824   return args_spec_list[0]->Broaden();
825 }
826 
InferImplSplit(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)827 AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
828                                const AbstractBasePtrList &args_spec_list) {
829   const std::string op_name = primitive->name();
830   CheckArgsSize(op_name, args_spec_list, 1);
831   AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
832   ShapeVector x_shape = input_x->shape()->shape();
833   ShapeVector x_shape_min = input_x->shape()->min_shape();
834   if (x_shape_min.empty()) {
835     x_shape_min = x_shape;
836   }
837   ShapeVector x_shape_max = input_x->shape()->max_shape();
838   if (x_shape_max.empty()) {
839     x_shape_max = x_shape;
840   }
841   int64_t rank = SizeToLong(x_shape.size());
842 
843   ValuePtr axis = primitive->GetAttr("axis");
844   int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
845   uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank)));
846   int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
847   if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
848     MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
849                       << " must be divisible by output_num = " << output_num_value;
850   }
851 
852   ShapeVector output_shape = x_shape;
853   if (output_shape[axis_value_pos] != Shape::SHP_ANY) {
854     output_shape[axis_value_pos] = static_cast<int>(x_shape[axis_value_pos] / output_num_value);
855   }
856   ShapeVector output_shape_min = x_shape_min;
857   output_shape_min[axis_value_pos] = static_cast<int>(x_shape_min[axis_value_pos] / output_num_value);
858   ShapeVector output_shape_max = x_shape_max;
859   output_shape_max[axis_value_pos] = static_cast<int>(x_shape_max[axis_value_pos] / output_num_value);
860 
861   AbstractBasePtrList output_list;
862   for (int64_t i = 0; i < output_num_value; ++i) {
863     auto output = input_x->Broaden();
864     output->set_shape(std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max));
865     output_list.push_back(output);
866   }
867   return std::make_shared<AbstractTuple>(output_list);
868 }
869 
InferImplSequenceMask(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)870 AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
871                                       const AbstractBasePtrList &args_spec_list) {
872   const std::string &op_name = primitive->name();
873   const size_t size_expected = 2;
874   CheckArgsSize(op_name, args_spec_list, size_expected);
875 
876   AbstractTensorPtr lengths = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
877   (void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s");
878 
879   int64_t maxlen_value = 0;
880 
881   if (args_spec_list[1]->isa<AbstractScalar>()) {
882     AbstractScalarPtr maxlen = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
883     (void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s");
884 
885     TypePtr maxlen_type = nullptr;
886     maxlen_type = maxlen->GetTypeTrack();
887     MS_EXCEPTION_IF_NULL(maxlen_type);
888 
889     if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) {
890       maxlen_value = static_cast<int64_t>(GetValue<int32_t>(maxlen->BuildValue()));
891     } else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) {
892       maxlen_value = GetValue<int64_t>(maxlen->BuildValue());
893     }
894   } else if (args_spec_list[1]->isa<AbstractTensor>()) {
895     auto maxlen_tensor_ptr = args_spec_list[1]->cast<AbstractTensorPtr>();
896     MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr);
897     auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue();
898     MS_EXCEPTION_IF_NULL(maxlen_value_ptr);
899     auto maxlen_tensor = maxlen_value_ptr->cast<tensor::TensorPtr>();
900     MS_EXCEPTION_IF_NULL(maxlen_tensor);
901     maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c());
902   }
903 
904   if (maxlen_value <= 0) {
905     MS_LOG(EXCEPTION) << "maxlen must be positive, but got: " << maxlen_value;
906   }
907 
908   ShapeVector lengths_shape = lengths->shape()->shape();
909   ShapeVector lengths_shape_min = lengths->shape()->min_shape();
910   if (lengths_shape_min.empty()) {
911     lengths_shape_min = lengths_shape;
912   }
913   ShapeVector lengths_shape_max = lengths->shape()->max_shape();
914   if (lengths_shape_max.empty()) {
915     lengths_shape_max = lengths_shape;
916   }
917 
918   lengths_shape.push_back(maxlen_value);
919   lengths_shape_min.push_back(maxlen_value);
920   lengths_shape_max.push_back(maxlen_value);
921 
922   ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max);
923   return std::make_shared<AbstractTensor>(kBool, output_shape);
924 }
925 
InferImplConcat(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)926 AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
927                                 const AbstractBasePtrList &args_spec_list) {
928   MS_EXCEPTION_IF_NULL(primitive);
929   const std::string op_name = primitive->name();
930   if (args_spec_list.empty()) {
931     MS_LOG(EXCEPTION) << "args_spec_list is empty.";
932   }
933 
934   AbstractTuplePtr arg = nullptr;
935   AbstractTensorPtr tensor_base = nullptr;
936   size_t tuple_len = 0;
937   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
938   if (args_spec_list[0]->isa<AbstractTuple>()) {
939     CheckArgsSize(op_name, args_spec_list, 1);
940     arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
941     tuple_len = arg->elements().size();
942     tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
943   } else if (args_spec_list[0]->isa<AbstractTensor>()) {
944     tuple_len = args_spec_list.size();
945     tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
946   }
947 
948   MS_EXCEPTION_IF_NULL(tensor_base);
949   ShapeVector shape_base = tensor_base->shape()->shape();
950   int64_t rank_base = SizeToLong(shape_base.size());
951   ShapeVector min_shape_base = tensor_base->shape()->min_shape();
952   ShapeVector max_shape_base = tensor_base->shape()->max_shape();
953   CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base);
954 
955   primitive->set_attr("T", tensor_base->element()->BuildType());
956   primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len)));
957 
958   ValuePtr axis = primitive->GetAttr("axis");
959   // Axis value should be in [-(rank_base + 1), rank_base).
960   int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
961   // If axis is negative, add offset(rank_base) to turn it to positive.
962   axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base));
963 
964   int64_t all_shp = shape_base[axis_value];
965   int64_t min_all_shp = min_shape_base[axis_value];
966   int64_t max_all_shp = max_shape_base[axis_value];
967   for (size_t i = 1; i < tuple_len; ++i) {
968     AbstractTensorPtr tensor = nullptr;
969     if (args_spec_list[0]->isa<AbstractTuple>()) {
970       tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
971     } else if (args_spec_list[0]->isa<AbstractTensor>()) {
972       tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
973     }
974     ShapeVector shape_tensor = tensor->shape()->shape();
975     int64_t rank_tensor = SizeToLong(shape_tensor.size());
976     ShapeVector min_shape_tensor = tensor->shape()->min_shape();
977     ShapeVector max_shape_tensor = tensor->shape()->max_shape();
978     CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor);
979     (void)CheckDtypeSame(op_name, tensor_base, tensor);
980     if (rank_tensor != rank_base) {
981       MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank";
982     }
983     for (int j = 0; j < rank_base; ++j) {
984       if (j != axis_value && shape_tensor[j] != shape_base[j]) {
985         MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Size";
986       }
987     }
988     if (all_shp == -1 || shape_base[axis_value] == -1) {
989       all_shp = -1;
990     } else {
991       all_shp += shape_tensor[axis_value];
992     }
993     min_all_shp += min_shape_tensor[axis_value];
994     max_all_shp += max_shape_tensor[axis_value];
995   }
996 
997   AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
998   MS_EXCEPTION_IF_NULL(ret);
999   auto shape = ret->shape()->shape();
1000   auto min_shape = ret->shape()->min_shape();
1001   auto max_shape = ret->shape()->max_shape();
1002   CheckMinMaxShape(shape, &min_shape, &max_shape);
1003   shape[axis_value] = all_shp;
1004   min_shape[axis_value] = min_all_shp;
1005   max_shape[axis_value] = max_all_shp;
1006   ret->set_shape(std::make_shared<Shape>(shape, min_shape, max_shape));
1007   return ret;
1008 }
1009 
InferImplRange(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1010 AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1011                                const AbstractBasePtrList &args_spec_list) {
1012   const std::string &op_name = primitive->name();
1013   if (args_spec_list.size() == 1) {
1014     return args_spec_list[0]->Broaden();
1015   }
1016   constexpr size_t args_size = 3;
1017   constexpr size_t range_start_index = 0;
1018   constexpr size_t range_end_index = 1;
1019   constexpr size_t range_delta_index = 2;
1020   CheckArgsSize(op_name, args_spec_list, args_size);
1021   AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, range_start_index);
1022   AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, range_end_index);
1023   AbstractTensorPtr range_delta = CheckArg<AbstractTensor>(op_name, args_spec_list, range_delta_index);
1024 
1025   TypePtrList supported_types = {kInt64, kInt32, kFloat32, kFloat64};
1026   TypePtr range_start_type = CheckTensorDType(range_start, supported_types, "range_start input of Range should be %s");
1027   TypePtr range_end_type = CheckTensorDType(range_end, supported_types, "range_start input of Range should be %s");
1028   TypePtr range_delta_type = CheckTensorDType(range_delta, supported_types, "range_start input of Range should be %s");
1029   // check all 3 inputs are same type
1030   if (!IsIdentidityOrSubclass(range_start_type, range_end_type) ||
1031       !IsIdentidityOrSubclass(range_end_type, range_delta_type)) {
1032     MS_LOG(EXCEPTION) << "All inputs must have same type, but got: " << args_spec_list[range_start_index]->type_name()
1033                       << ", " << args_spec_list[range_end_index]->type_name() << ", and "
1034                       << args_spec_list[range_delta_index]->type_name();
1035   }
1036 
1037   int64_t max_output_length = -1;
1038   ValuePtr max_output_length_ptr = primitive->GetAttr("maxlen");
1039   max_output_length = GetValue<int64_t>(max_output_length_ptr);
1040   ShapeVector output_shape = {Shape::SHP_ANY};
1041   ShapeVector min_shape = {1};
1042   ShapeVector max_shape = {max_output_length};
1043   ShapePtr shape = std::make_shared<Shape>(output_shape, min_shape, max_shape);
1044 
1045   return std::make_shared<AbstractTensor>(range_start_type, shape);
1046 }
1047 
InferImplArgMaxWithValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1048 AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1049                                          const AbstractBasePtrList &args_spec_list) {
1050   const std::string op_name = primitive->name();
1051   CheckArgsSize(op_name, args_spec_list, 1);
1052   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1053   MS_EXCEPTION_IF_NULL(x);
1054   MS_EXCEPTION_IF_NULL(x->shape());
1055   // check keep_dims
1056   ValuePtr keep_dims = primitive->GetAttr("keep_dims");
1057   MS_EXCEPTION_IF_NULL(keep_dims);
1058   if (!keep_dims->isa<BoolImm>()) {
1059     MS_LOG(EXCEPTION) << "keep_dims should be Bool.";
1060   }
1061   bool keep_dims_value = GetValue<bool>(keep_dims);
1062   // check axis
1063   ValuePtr axis = primitive->GetAttr("axis");
1064   MS_EXCEPTION_IF_NULL(axis);
1065   if (!axis->isa<Int32Imm>() && !axis->isa<Int64Imm>()) {
1066     MS_LOG(EXCEPTION) << "axis should be Int.";
1067   }
1068   // check axis convert negative to positive value
1069   auto check_axis = [](int64_t &axis, const size_t dim) -> void {
1070     auto dim_ = static_cast<int64_t>(dim);
1071     if (axis < -dim_ || axis >= dim_) {
1072       MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << ".";
1073     }
1074     if (axis >= -dim_ && axis < 0) {
1075       axis += dim_;
1076     }
1077     return;
1078   };
1079   // main calculate shape func
1080   auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void {
1081     (void)shape.insert(shape.end(), x_shape.begin(), x_shape.end());
1082     auto axis_value = GetValue<int64_t>(axis);
1083     check_axis(axis_value, x_shape.size());
1084     if (keep_dims_value) {
1085       shape[axis_value] = 1;
1086     } else {
1087       (void)shape.erase(std::begin(shape) + axis_value);
1088     }
1089   };
1090   ShapeVector shape = {};
1091   ShapeVector min_shape = {};
1092   ShapeVector max_shape = {};
1093   ShapeVector x_shape = x->shape()->shape();
1094   ShapeVector x_min_shape = x->shape()->min_shape();
1095   ShapeVector x_max_shape = x->shape()->max_shape();
1096   CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
1097   cal_shape(shape, x_shape);
1098   cal_shape(min_shape, x_min_shape);
1099   cal_shape(max_shape, x_max_shape);
1100   TypePtr idx_type = kInt32;
1101   auto index = std::make_shared<AbstractTensor>(idx_type, std::make_shared<Shape>(shape, min_shape, max_shape));
1102   auto value = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
1103   AbstractBasePtrList result = {index, value};
1104   return std::make_shared<AbstractTuple>(result);
1105 }
1106 
InferImplSort(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1107 AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1108                               const AbstractBasePtrList &args_spec_list) {
1109   const std::string &op_name = primitive->name();
1110   CheckArgsSize(op_name, args_spec_list, 1);
1111   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1112 
1113   TypePtrList supported_types = {kFloat16, kFloat32};
1114   (void)CheckTensorDType(input, supported_types, "input for Sort should be %s");
1115 
1116   ValuePtr axis_ptr = primitive->GetAttr("axis");
1117   int64_t axis = GetValue<int64_t>(axis_ptr);
1118   int64_t input_rank = input->shape()->shape().size();
1119   if (input_rank == 0) {
1120     MS_LOG(EXCEPTION) << "input must be a Tensor with dimension > 0.";
1121   }
1122 
1123   if (!(axis >= -input_rank && axis < input_rank)) {
1124     MS_LOG(EXCEPTION) << "axis is not in the valid range [" << -input_rank << ", " << input_rank << ").";
1125   }
1126 
1127   auto sorted_values = std::make_shared<AbstractTensor>(input->element(), input->shape());
1128   TypePtr idx_type = kInt32;
1129   auto indices = std::make_shared<AbstractTensor>(idx_type, input->shape());
1130   AbstractBasePtrList result = {sorted_values, indices};
1131   return std::make_shared<AbstractTuple>(result);
1132 }
1133 
InferImplMaskedSelect(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1134 AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1135                                       const AbstractBasePtrList &args_spec_list) {
1136   const std::string op_name = primitive->name();
1137   const size_t size_expected = 2;
1138   CheckArgsSize(op_name, args_spec_list, size_expected);
1139   AbstractTensorPtr x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1140   AbstractTensorPtr mask = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
1141 
1142   auto x_shape = x->shape();
1143   auto mask_shape = mask->shape();
1144   auto broadcast_shape = BroadcastShape(x_shape->shape(), mask_shape->shape());
1145   ShapeVector y_shape = {Shape::SHP_ANY};
1146   ShapeVector min_shape = {1};
1147   int64_t max_size = std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<int64_t>());
1148   ShapeVector max_shape = {max_size};
1149   if (max_shape.empty()) {
1150     max_shape = x_shape->shape();
1151   }
1152   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(y_shape, min_shape, max_shape));
1153 }
1154 
InferImplDynamicStitch(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1155 AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1156                                        const AbstractBasePtrList &args_spec_list) {
1157   MS_EXCEPTION_IF_NULL(primitive);
1158   auto prim_name = primitive->name();
1159   constexpr int64_t args_size = 2;
1160   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kEqual, args_size,
1161                                            prim_name);
1162   for (const auto &item : args_spec_list) {
1163     MS_EXCEPTION_IF_NULL(item);
1164   }
1165 
1166   // input0: indices
1167   auto input_tuple = args_spec_list[0]->cast<abstract::AbstractSequeuePtr>();
1168   MS_EXCEPTION_IF_NULL(input_tuple);
1169   auto indices = input_tuple->elements();
1170   auto indices0 = indices[0]->cast<abstract::AbstractTensorPtr>();
1171   MS_EXCEPTION_IF_NULL(indices0);
1172   auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape];
1173 
1174   // input1: data
1175   auto input_tuple_1 = args_spec_list[1]->cast<abstract::AbstractSequeuePtr>();
1176   MS_EXCEPTION_IF_NULL(input_tuple_1);
1177   auto data = input_tuple_1->elements();
1178   auto data0 = data[0]->cast<abstract::AbstractTensorPtr>();
1179   MS_EXCEPTION_IF_NULL(data0);
1180   auto data0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data0->BuildShape())[kShape];
1181   if (indices.size() != data.size()) {
1182     MS_LOG(EXCEPTION) << "The number of input[0] must be the same as input[0]!";
1183   }
1184 
1185   int64_t indices_total_size = 0;
1186   std::map<std::string, TypePtr> types;
1187   (void)types.emplace("data0", data0->BuildType());
1188   for (size_t i = 1; i < data.size(); ++i) {
1189     auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape];
1190     auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape];
1191     if (indicesi_shape.size() > datai_shape.size()) {
1192       MS_LOG(EXCEPTION) << "The rank of indices[i] must be <= rank of data[i]!";
1193     }
1194     indices_total_size += SizeToLong(indicesi_shape.size());
1195   }
1196   std::set<TypePtr> valid_types = ops::common_valid_types;
1197   auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
1198 
1199   ShapeVector out_shape = {abstract::Shape::SHP_ANY};
1200   for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) {
1201     out_shape.push_back(data0_shape[i]);
1202   }
1203   const int64_t EXPAND_MAX = 10;
1204   ShapeVector min_shape = out_shape;
1205   ShapeVector max_shape = out_shape;
1206   min_shape[0] = 1;
1207   max_shape[0] = indices_total_size * EXPAND_MAX;
1208   return std::make_shared<AbstractTensor>(infer_type,
1209                                           std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
1210 }
1211 
InferImplTensorCopySlices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)1212 AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1213                                           const AbstractBasePtrList &args_spec_list) {
1214   auto &op_name = primitive->name();
1215   constexpr auto kTensorCopySlicesInputNum = 5;
1216   CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum);
1217   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
1218   return std::make_shared<AbstractTensor>(input->element(), input->shape());
1219 }
1220 }  // namespace abstract
1221 }  // namespace mindspore
1222