• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <sstream>
19 
20 #include "ir/dtype.h"
21 #include "utils/ms_utils.h"
22 #include "base/core_ops.h"
23 #include "abstract/param_validator.h"
24 #include "abstract/infer_functions.h"
25 #include "abstract/utils.h"
26 #include "utils/ms_context.h"
27 #include "utils/symbolic.h"
28 #include "utils/shape_utils.h"
29 
30 namespace {
31 constexpr auto kRankSize = "rank_size";
32 }
33 
34 namespace mindspore {
35 namespace abstract {
InferImplIdentity(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)36 AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
37                                   const AbstractBasePtrList &args_spec_list) {
38   // An object of a subclass of AbstractBase
39   CheckArgsSize(primitive->name(), args_spec_list, 1);
40   return args_spec_list[0];
41 }
42 
InferImplEnvGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)43 AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
44                                     const AbstractBasePtrList &args_spec_list) {
45   MS_EXCEPTION_IF_NULL(primitive);
46   // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
47   CheckArgsSize(primitive->name(), args_spec_list, 3);
48   auto key = args_spec_list[1];
49   auto dflt = args_spec_list[2];
50   TypePtr type = key->GetTypeTrack();
51   MS_EXCEPTION_IF_NULL(type);
52   if (type->type_id() != kObjectTypeSymbolicKeyType) {
53     MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
54   }
55 
56   auto context = MsContext::GetInstance();
57   MS_EXCEPTION_IF_NULL(context);
58   bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
59   if (enable_sparse && dflt->isa<AbstractTensor>()) {
60     auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
61     return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
62   }
63 
64   if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
65     return dflt;
66   }
67   ValuePtr key_value_ptr = key->GetValueTrack();
68   MS_EXCEPTION_IF_NULL(key_value_ptr);
69   auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
70   auto expected = key_value_track->abstract();
71   MS_EXCEPTION_IF_NULL(expected);
72   (void)expected->Join(dflt);
73   return expected;
74 }
75 
InferImplEnvSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)76 AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
77                                     const AbstractBasePtrList &args_spec_list) {
78   // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
79   CheckArgsSize(primitive->name(), args_spec_list, 3);
80 
81   auto key = args_spec_list[1];
82   ValuePtr key_value_ptr = key->GetValueTrack();
83   MS_EXCEPTION_IF_NULL(key_value_ptr);
84   auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
85   if (key_value_track == nullptr) {
86     MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
87                       << key_value_ptr->ToString();
88   }
89   auto expected = key_value_track->abstract();
90   MS_EXCEPTION_IF_NULL(expected);
91   return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
92 }
93 
InferImplEnvAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)94 AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
95                                 const AbstractBasePtrList &args_spec_list) {
96   // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
97   CheckArgsSize(primitive->name(), args_spec_list, 2);
98   return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
99 }
100 
InferImplMakeRefKey(const AnalysisEnginePtr &,const PrimitivePtr & prim,const AbstractBasePtrList &)101 AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) {
102   ValuePtr name_value = prim->GetAttr("tag");
103   MS_EXCEPTION_IF_NULL(name_value);
104   auto name = name_value->cast<StringImmPtr>();
105   if (name == nullptr) {
106     MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << ".";
107   }
108   auto refkey = std::make_shared<RefKey>(name->value());
109   if (refkey == nullptr) {
110     MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared<RefKey> failed";
111   }
112   return refkey->ToAbstract();
113 }
114 
InferImplMakeRef(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)115 AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &,
116                                  const AbstractBasePtrList &args_spec_list) {
117   // arguments: key, value, target type(None if no target type)
118   if (args_spec_list.size() != 3) {
119     MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
120                       << ".";
121   }
122   auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>();
123   return std::make_shared<AbstractRef>(args_spec_list[0], tensor);
124 }
125 
InferImplGetRefKey(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)126 AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
127                                    const AbstractBasePtrList &args_spec_list) {
128   // arguments: value
129   if (args_spec_list.size() != 1) {
130     MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << ".";
131   }
132   TypePtr type = args_spec_list[0]->GetTypeTrack();
133   if (type->type_id() != kObjectTypeRef) {
134     MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString();
135   }
136   auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>();
137   MS_EXCEPTION_IF_NULL(abs_ref);
138   return abs_ref->ref();
139 }
140 
InferImplGetRefValue(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)141 AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &,
142                                      const AbstractBasePtrList &args_spec_list) {
143   // arguments: value
144   if (args_spec_list.size() != 1) {
145     MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
146                       << ".";
147   }
148   TypePtr type = args_spec_list[0]->GetTypeTrack();
149   if (type->type_id() != kObjectTypeRef) {
150     return args_spec_list[0];
151   }
152   auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>();
153   MS_EXCEPTION_IF_NULL(abs_ref);
154   return abs_ref->ref();
155 }
156 
InferImplStateSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)157 AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
158                                       const AbstractBasePtrList &args_spec_list) {
159   // args: Two objects of a subclass of AbstractBase, key and value.
160   CheckArgsSize(primitive->name(), args_spec_list, 2);
161 
162   TypePtr type = args_spec_list[0]->GetTypeTrack();
163   MS_EXCEPTION_IF_NULL(type);
164   if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) {
165     MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString();
166   }
167   return std::make_shared<AbstractScalar>(kAnyValue, kBool);
168 }
169 
InferImplDepend(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)170 AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
171                                 const AbstractBasePtrList &args_spec_list) {
172   CheckArgsSize(primitive->name(), args_spec_list, 2);
173 
174   // If the dependent has a value, just return depended node.
175   // If depended node is not Any, the dependent maybe eliminated.
176   auto dependant_abstract = args_spec_list[1];
177   auto dependant_value = dependant_abstract->BuildValue();
178   MS_EXCEPTION_IF_NULL(dependant_value);
179   if (dependant_value != kAnyValue) {
180     return args_spec_list[0];
181   }
182 
183   auto depends = args_spec_list[0]->Broaden();  // Avoid eliminating the dependent node.
184   if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
185     // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
186     if (depends->isa<AbstractScalar>()) {
187       depends->set_value(kAnyValue);
188     }
189   }
190   return depends;
191 }
192 
InferImplUpdateState(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)193 AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
194                                      const AbstractBasePtrList &args_spec_list) {
195   if (args_spec_list.empty()) {
196     MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
197   }
198   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
199   return args_spec_list[0]->Broaden();
200 }
201 
InferImplMakeRowTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)202 AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
203                                        const AbstractBasePtrList &args_spec_list) {
204   // Inputs: two tensors and a tuple.
205   const std::string op_name = primitive->name();
206   constexpr size_t size_expected = 3;
207   CheckArgsSize(op_name, args_spec_list, size_expected);
208   auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
209   auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
210   auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
211 
212   auto indices_dtype = indices->element()->BuildType();
213   if (!indices_dtype->isa<Int>()) {
214     MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
215   }
216   auto indices_shp = indices->shape()->shape();
217   if (indices_shp.size() != 1) {
218     MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
219                             << " dimension tensor";
220   }
221   auto values_shp = values->shape()->shape();
222   if (indices_shp[0] != values_shp[0]) {
223     MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
224                             << values_shp[0] << ", but got " << indices_shp[0];
225   }
226 
227   for (const auto &elem_type : dense_shape->ElementsType()) {
228     if (!elem_type->isa<Int>()) {
229       MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
230     }
231   }
232   auto dense_shape_value = dense_shape->BuildValue();
233   MS_EXCEPTION_IF_NULL(dense_shape_value);
234   auto dense_shape_valuetuple = dense_shape_value->cast<ValueTuplePtr>();
235   MS_EXCEPTION_IF_NULL(dense_shape_valuetuple);
236   auto shp = dense_shape_valuetuple->value();
237   ShapeVector dense_shape_vec;
238   (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
239                        [](const ValuePtr &e) -> int64_t {
240                          auto elem = GetValue<int64_t>(e);
241                          return elem;
242                        });
243   if (dense_shape_vec.size() != values_shp.size()) {
244     MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
245                             << values_shp.size() << ", but got " << dense_shape_valuetuple->size();
246   }
247   for (size_t i = 0; i < dense_shape_vec.size(); i++) {
248     if (dense_shape_vec[i] < 0) {
249       MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
250                               << dense_shape_vec[i];
251     }
252     // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
253     if (i != 0 && dense_shape_vec[i] != values_shp[i]) {
254       MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
255                               << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
256     }
257   }
258   auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec);
259   ret->set_indices(indices);
260   ret->set_values(values);
261   ret->set_dense_shape(dense_shape);
262   return ret;
263 }
264 
InferImplRowTensorGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)265 AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
266                                             const AbstractBasePtrList &args_spec_list) {
267   // Inputs: two tensors and a tuple.
268   const std::string op_name = primitive->name();
269   CheckArgsSize(op_name, args_spec_list, 1);
270   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
271   MS_EXCEPTION_IF_NULL(row_tensor->values());
272   return row_tensor->values();
273 }
274 
InferImplRowTensorGetIndices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)275 AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
276                                              const AbstractBasePtrList &args_spec_list) {
277   // Inputs: two tensors and a tuple.
278   const std::string op_name = primitive->name();
279   CheckArgsSize(op_name, args_spec_list, 1);
280   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
281   MS_EXCEPTION_IF_NULL(row_tensor->indices());
282   return row_tensor->indices();
283 }
284 
InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)285 AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
286                                                 const AbstractBasePtrList &args_spec_list) {
287   // Inputs: two tensors and a tuple.
288   const std::string op_name = primitive->name();
289   CheckArgsSize(op_name, args_spec_list, 1);
290   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
291   MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
292   return row_tensor->dense_shape();
293 }
294 
InferImplRowTensorAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)295 AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
296                                       const AbstractBasePtrList &args_spec_list) {
297   // Inputs: row tensor and tensor.
298   const std::string op_name = primitive->name();
299   constexpr size_t args_size = 2;
300   CheckArgsSize(op_name, args_spec_list, args_size);
301   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
302   auto tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
303   MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
304   MS_EXCEPTION_IF_NULL(tensor->shape());
305   return args_spec_list[0];
306 }
307 
InferImplMakeSparseTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)308 AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
309                                           const AbstractBasePtrList &args_spec_list) {
310   // Inputs: two tensors and a tuple.
311   constexpr auto kMakeSparseInputNum = 3;
312   const std::string op_name = primitive->name();
313   CheckArgsSize(op_name, args_spec_list, kMakeSparseInputNum);
314   auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
315   auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
316   auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
317 
318   auto indices_dtype = indices->element()->BuildType();
319   if (!indices_dtype->isa<Int>()) {
320     MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
321   }
322   auto indices_shp = indices->shape()->shape();
323   if (indices_shp.size() != 2) {
324     MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
325                             << " dimension tensor";
326   }
327   auto values_shp = values->shape()->shape();
328   if (values_shp.size() != 1) {
329     MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
330                             << " dimension tensor";
331   }
332   if (indices_shp[0] != values_shp[0]) {
333     MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
334                             << values_shp[0] << ", but got " << indices_shp[0];
335   }
336 
337   for (const auto &elem_type : dense_shape->ElementsType()) {
338     if (!elem_type->isa<Int>()) {
339       MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
340     }
341   }
342   auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
343   MS_EXCEPTION_IF_NULL(dense_shape_value);
344   auto shp = dense_shape_value->value();
345   ShapeVector dense_shape_vec;
346   (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
347                        [](const ValuePtr &e) -> int64_t {
348                          auto elem = GetValue<int64_t>(e);
349                          return elem;
350                        });
351   if (LongToSize(indices_shp[1]) != dense_shape_vec.size()) {
352     MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices "
353                             << indices_shp[1] << ", but got " << dense_shape_vec.size();
354   }
355   for (auto dense_shape_elem : dense_shape_vec) {
356     if (dense_shape_elem < 0) {
357       MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
358                               << dense_shape_value->ToString();
359     }
360   }
361   auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
362   ret->set_indices(indices);
363   ret->set_values(values);
364   ret->set_dense_shape(dense_shape);
365   return ret;
366 }
367 
InferImplSparseTensorGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)368 AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
369                                                const AbstractBasePtrList &args_spec_list) {
370   // Inputs: two tensors and a tuple.
371   const std::string op_name = primitive->name();
372   CheckArgsSize(op_name, args_spec_list, 1);
373   auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
374   MS_EXCEPTION_IF_NULL(sparse_tensor->values());
375   return sparse_tensor->values();
376 }
377 
InferImplSparseTensorGetIndices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)378 AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
379                                                 const AbstractBasePtrList &args_spec_list) {
380   // Inputs: two tensors and a tuple.
381   const std::string op_name = primitive->name();
382   CheckArgsSize(op_name, args_spec_list, 1);
383   auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
384   MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
385   return sparse_tensor->indices();
386 }
387 
InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)388 AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
389                                                    const AbstractBasePtrList &args_spec_list) {
390   // Inputs: two tensors and a tuple.
391   const std::string op_name = primitive->name();
392   CheckArgsSize(op_name, args_spec_list, 1);
393   auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
394   MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
395   return sparse_tensor->dense_shape();
396 }
397 
InferImplAllSwap(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)398 AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
399                                  const AbstractBasePtrList &args_spec_list) {
400   const std::string op_name = primitive->name();
401   CheckArgsSize(op_name, args_spec_list, 3);
402   auto tensor_in = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
403   MS_EXCEPTION_IF_NULL(tensor_in);
404   MS_EXCEPTION_IF_NULL(tensor_in->shape());
405   auto tensor_in_shape = tensor_in->shape()->shape();
406 
407   auto send_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
408   MS_EXCEPTION_IF_NULL(send_size);
409   auto recv_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
410   MS_EXCEPTION_IF_NULL(recv_size);
411 
412   // Get the content of the recv size
413   auto recv_size_value_ptr = recv_size->BuildValue();
414   MS_EXCEPTION_IF_NULL(recv_size_value_ptr);
415   auto recv_size_tensor = recv_size_value_ptr->cast<tensor::TensorPtr>();
416   MS_EXCEPTION_IF_NULL(recv_size_tensor);
417   auto data_pos = reinterpret_cast<int64_t *>(recv_size_tensor->data_c());
418   MS_EXCEPTION_IF_NULL(data_pos);
419   int64_t infer_max_size = 0;
420   for (int64_t i = 0; i < recv_size_tensor->DataSize(); ++i) {
421     infer_max_size += *(data_pos + i);
422   }
423 
424   ShapeVector tensor_out_shape = {Shape::SHP_ANY, tensor_in_shape[1]};
425   ShapeVector min_shape = {1, tensor_in_shape[1]};
426 
427   ShapeVector max_shape = {infer_max_size / tensor_in_shape[1], tensor_in_shape[1]};
428 
429   auto tensor_out = std::make_shared<AbstractTensor>(tensor_in->element(),
430                                                      std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
431 
432   AbstractTensorPtr ret = std::make_shared<AbstractTensor>(
433     tensor_out->element(), std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
434   return ret;
435 }
436 
InferImplAllReduce(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)437 AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
438                                    const AbstractBasePtrList &args_spec_list) {
439   const std::string op_name = primitive->name();
440   CheckArgsSize(op_name, args_spec_list, 1);
441   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
442   MS_EXCEPTION_IF_NULL(x);
443   MS_EXCEPTION_IF_NULL(x->shape());
444   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
445 }
446 
InferImplBroadcast(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)447 AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
448                                    const AbstractBasePtrList &args_spec_list) {
449   const std::string op_name = primitive->name();
450   CheckArgsSize(op_name, args_spec_list, 1);
451   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
452   MS_EXCEPTION_IF_NULL(x);
453   MS_EXCEPTION_IF_NULL(x->shape());
454   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
455 }
456 
InferImplAllGather(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)457 AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
458                                    const AbstractBasePtrList &args_spec_list) {
459   const std::string op_name = primitive->name();
460   CheckArgsSize(op_name, args_spec_list, 1);
461   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
462   MS_EXCEPTION_IF_NULL(x);
463   MS_EXCEPTION_IF_NULL(x->shape());
464   auto tmp_shape = x->shape()->shape();
465   if (!primitive->HasAttr(kRankSize)) {
466     MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr";
467   }
468   auto rank_size = GetValue<int>(primitive->GetAttr(kRankSize));
469   if (rank_size == 0) {
470     MS_LOG(EXCEPTION) << "rank_size is 0";
471   }
472   if (tmp_shape.empty()) {
473     MS_LOG(EXCEPTION) << "shape size is 0";
474   }
475   if (tmp_shape[0] > 0) {
476     tmp_shape[0] = tmp_shape[0] * rank_size;
477   }
478   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
479 }
480 
InferImplReduceScatter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)481 AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
482                                        const AbstractBasePtrList &args_spec_list) {
483   const std::string op_name = primitive->name();
484   CheckArgsSize(op_name, args_spec_list, 1);
485   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
486   MS_EXCEPTION_IF_NULL(x);
487   MS_EXCEPTION_IF_NULL(x->shape());
488   auto tmp_shape = x->shape()->shape();
489   if (!primitive->HasAttr(kRankSize)) {
490     MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr";
491   }
492   auto rank_size = GetValue<int>(primitive->GetAttr(kRankSize));
493   if (tmp_shape.empty()) {
494     MS_LOG(EXCEPTION) << "shape size is 0";
495   }
496   tmp_shape[0] = LongMulWithOverflowCheck(tmp_shape[0], rank_size);
497   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
498 }
499 
InferImplMemCpyAsync(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)500 AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
501                                      const AbstractBasePtrList &args_spec_list) {
502   const std::string op_name = primitive->name();
503   CheckArgsSize(op_name, args_spec_list, 1);
504   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
505   MS_EXCEPTION_IF_NULL(x);
506   MS_EXCEPTION_IF_NULL(x->shape());
507   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
508 }
509 
InferImplCast(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)510 AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
511                               const AbstractBasePtrList &args_spec_list) {
512   const std::string op_name = primitive->name();
513   // GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize.
514   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
515   MS_EXCEPTION_IF_NULL(input_x);
516   auto attr = primitive->GetAttr("dst_type");
517   if (attr == nullptr) {
518     auto type_abs = CheckArg<AbstractType>(op_name, args_spec_list, 1);
519     attr = type_abs->BuildValue();
520     MS_EXCEPTION_IF_NULL(attr);
521     primitive->set_attr("dst_type", attr);
522   }
523   auto input_type = attr->cast<TypePtr>();
524   auto ret = std::make_shared<AbstractTensor>(input_type, input_x->shape());
525   return ret;
526 }
527 
InferImplExpandDims(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)528 AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
529                                     const AbstractBasePtrList &args_spec_list) {
530   const std::string op_name = primitive->name();
531   CheckArgsSize(op_name, args_spec_list, 1);
532   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
533   MS_EXCEPTION_IF_NULL(x);
534   MS_EXCEPTION_IF_NULL(x->shape());
535 
536   std::vector<int64_t> shape;
537   std::vector<int64_t> x_shape = x->shape()->shape();
538   (void)shape.insert(shape.end(), x_shape.begin(), x_shape.end());
539   auto axis = primitive->GetAttr("axis");
540   auto value = GetValue<int64_t>(axis);
541   if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
542     MS_LOG(EXCEPTION) << " axis value should be in range [-input_x.dim-1,input_x.dim], but axis value is" << value
543                       << " and input_x.dim is" << x_shape.size();
544   }
545   if (value < 0) {
546     value = value + SizeToInt(x_shape.size()) + 1;
547   }
548   (void)shape.insert(shape.begin() + value, 1);
549 
550   auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
551   return ret;
552 }
553 
InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)554 AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
555                                                   const AbstractBasePtrList &args_spec_list) {
556   const std::string &op_name = primitive->name();
557   CheckArgsSize(op_name, args_spec_list, 1);
558   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
559 
560   ShapeVector input_shape = input->shape()->shape();
561   int32_t input_rank = SizeToInt(input_shape.size());
562   ShapeVector inferred_shape(input_rank, Shape::SHP_ANY);
563   ShapeVector min_shape(input_rank, 1);
564   ShapeVector max_shape = input_shape;
565 
566   ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
567   return std::make_shared<AbstractTensor>(input->element(), shape);
568 }
569 
InferImplLoad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)570 AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
571                               const AbstractBasePtrList &args_spec_list) {
572   // Inputs: Ref/Tensor, universal
573   CheckArgsSize(primitive->name(), args_spec_list, 2);
574   auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
575   if (ref_abs != nullptr) {
576     // Return tensor value if input is Ref.
577     return ref_abs->CloneAsTensor();
578   }
579   return args_spec_list[0]->Broaden();
580 }
581 
InferImplTransData(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)582 AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
583                                    const AbstractBasePtrList &args_spec_list) {
584   // An object of a subclass of AbstractBase
585   CheckArgsSize(primitive->name(), args_spec_list, 1);
586   auto output = args_spec_list[0];
587   MS_EXCEPTION_IF_NULL(output);
588   return output;
589 }
590 }  // namespace abstract
591 }  // namespace mindspore
592