1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <sstream>
19
20 #include "ir/dtype.h"
21 #include "utils/ms_utils.h"
22 #include "base/core_ops.h"
23 #include "abstract/param_validator.h"
24 #include "abstract/infer_functions.h"
25 #include "abstract/utils.h"
26 #include "utils/ms_context.h"
27 #include "utils/symbolic.h"
28 #include "utils/shape_utils.h"
29
30 namespace {
31 constexpr auto kRankSize = "rank_size";
32 }
33
34 namespace mindspore {
35 namespace abstract {
InferImplIdentity(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)36 AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
37 const AbstractBasePtrList &args_spec_list) {
38 // An object of a subclass of AbstractBase
39 CheckArgsSize(primitive->name(), args_spec_list, 1);
40 return args_spec_list[0];
41 }
42
InferImplEnvGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)43 AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
44 const AbstractBasePtrList &args_spec_list) {
45 MS_EXCEPTION_IF_NULL(primitive);
46 // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
47 CheckArgsSize(primitive->name(), args_spec_list, 3);
48 auto key = args_spec_list[1];
49 auto dflt = args_spec_list[2];
50 TypePtr type = key->GetTypeTrack();
51 MS_EXCEPTION_IF_NULL(type);
52 if (type->type_id() != kObjectTypeSymbolicKeyType) {
53 MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
54 }
55
56 auto context = MsContext::GetInstance();
57 MS_EXCEPTION_IF_NULL(context);
58 bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
59 if (enable_sparse && dflt->isa<AbstractTensor>()) {
60 auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
61 return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
62 }
63
64 if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
65 return dflt;
66 }
67 ValuePtr key_value_ptr = key->GetValueTrack();
68 MS_EXCEPTION_IF_NULL(key_value_ptr);
69 auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
70 auto expected = key_value_track->abstract();
71 MS_EXCEPTION_IF_NULL(expected);
72 (void)expected->Join(dflt);
73 return expected;
74 }
75
InferImplEnvSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)76 AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
77 const AbstractBasePtrList &args_spec_list) {
78 // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
79 CheckArgsSize(primitive->name(), args_spec_list, 3);
80
81 auto key = args_spec_list[1];
82 ValuePtr key_value_ptr = key->GetValueTrack();
83 MS_EXCEPTION_IF_NULL(key_value_ptr);
84 auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
85 if (key_value_track == nullptr) {
86 MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
87 << key_value_ptr->ToString();
88 }
89 auto expected = key_value_track->abstract();
90 MS_EXCEPTION_IF_NULL(expected);
91 return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
92 }
93
InferImplEnvAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)94 AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
95 const AbstractBasePtrList &args_spec_list) {
96 // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
97 CheckArgsSize(primitive->name(), args_spec_list, 2);
98 return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
99 }
100
InferImplMakeRefKey(const AnalysisEnginePtr &,const PrimitivePtr & prim,const AbstractBasePtrList &)101 AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) {
102 ValuePtr name_value = prim->GetAttr("tag");
103 MS_EXCEPTION_IF_NULL(name_value);
104 auto name = name_value->cast<StringImmPtr>();
105 if (name == nullptr) {
106 MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << ".";
107 }
108 auto refkey = std::make_shared<RefKey>(name->value());
109 if (refkey == nullptr) {
110 MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared<RefKey> failed";
111 }
112 return refkey->ToAbstract();
113 }
114
InferImplMakeRef(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)115 AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &,
116 const AbstractBasePtrList &args_spec_list) {
117 // arguments: key, value, target type(None if no target type)
118 if (args_spec_list.size() != 3) {
119 MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
120 << ".";
121 }
122 auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>();
123 return std::make_shared<AbstractRef>(args_spec_list[0], tensor);
124 }
125
InferImplGetRefKey(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)126 AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
127 const AbstractBasePtrList &args_spec_list) {
128 // arguments: value
129 if (args_spec_list.size() != 1) {
130 MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << ".";
131 }
132 TypePtr type = args_spec_list[0]->GetTypeTrack();
133 if (type->type_id() != kObjectTypeRef) {
134 MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString();
135 }
136 auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>();
137 MS_EXCEPTION_IF_NULL(abs_ref);
138 return abs_ref->ref();
139 }
140
InferImplGetRefValue(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)141 AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &,
142 const AbstractBasePtrList &args_spec_list) {
143 // arguments: value
144 if (args_spec_list.size() != 1) {
145 MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
146 << ".";
147 }
148 TypePtr type = args_spec_list[0]->GetTypeTrack();
149 if (type->type_id() != kObjectTypeRef) {
150 return args_spec_list[0];
151 }
152 auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>();
153 MS_EXCEPTION_IF_NULL(abs_ref);
154 return abs_ref->ref();
155 }
156
InferImplStateSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)157 AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
158 const AbstractBasePtrList &args_spec_list) {
159 // args: Two objects of a subclass of AbstractBase, key and value.
160 CheckArgsSize(primitive->name(), args_spec_list, 2);
161
162 TypePtr type = args_spec_list[0]->GetTypeTrack();
163 MS_EXCEPTION_IF_NULL(type);
164 if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) {
165 MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString();
166 }
167 return std::make_shared<AbstractScalar>(kAnyValue, kBool);
168 }
169
InferImplDepend(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)170 AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
171 const AbstractBasePtrList &args_spec_list) {
172 CheckArgsSize(primitive->name(), args_spec_list, 2);
173
174 // If the dependent has a value, just return depended node.
175 // If depended node is not Any, the dependent maybe eliminated.
176 auto dependant_abstract = args_spec_list[1];
177 auto dependant_value = dependant_abstract->BuildValue();
178 MS_EXCEPTION_IF_NULL(dependant_value);
179 if (dependant_value != kAnyValue) {
180 return args_spec_list[0];
181 }
182
183 auto depends = args_spec_list[0]->Broaden(); // Avoid eliminating the dependent node.
184 if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
185 // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
186 if (depends->isa<AbstractScalar>()) {
187 depends->set_value(kAnyValue);
188 }
189 }
190 return depends;
191 }
192
InferImplUpdateState(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)193 AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
194 const AbstractBasePtrList &args_spec_list) {
195 if (args_spec_list.empty()) {
196 MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
197 }
198 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
199 return args_spec_list[0]->Broaden();
200 }
201
InferImplMakeRowTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)202 AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
203 const AbstractBasePtrList &args_spec_list) {
204 // Inputs: two tensors and a tuple.
205 const std::string op_name = primitive->name();
206 constexpr size_t size_expected = 3;
207 CheckArgsSize(op_name, args_spec_list, size_expected);
208 auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
209 auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
210 auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
211
212 auto indices_dtype = indices->element()->BuildType();
213 if (!indices_dtype->isa<Int>()) {
214 MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
215 }
216 auto indices_shp = indices->shape()->shape();
217 if (indices_shp.size() != 1) {
218 MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
219 << " dimension tensor";
220 }
221 auto values_shp = values->shape()->shape();
222 if (indices_shp[0] != values_shp[0]) {
223 MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
224 << values_shp[0] << ", but got " << indices_shp[0];
225 }
226
227 for (const auto &elem_type : dense_shape->ElementsType()) {
228 if (!elem_type->isa<Int>()) {
229 MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
230 }
231 }
232 auto dense_shape_value = dense_shape->BuildValue();
233 MS_EXCEPTION_IF_NULL(dense_shape_value);
234 auto dense_shape_valuetuple = dense_shape_value->cast<ValueTuplePtr>();
235 MS_EXCEPTION_IF_NULL(dense_shape_valuetuple);
236 auto shp = dense_shape_valuetuple->value();
237 ShapeVector dense_shape_vec;
238 (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
239 [](const ValuePtr &e) -> int64_t {
240 auto elem = GetValue<int64_t>(e);
241 return elem;
242 });
243 if (dense_shape_vec.size() != values_shp.size()) {
244 MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
245 << values_shp.size() << ", but got " << dense_shape_valuetuple->size();
246 }
247 for (size_t i = 0; i < dense_shape_vec.size(); i++) {
248 if (dense_shape_vec[i] < 0) {
249 MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
250 << dense_shape_vec[i];
251 }
252 // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
253 if (i != 0 && dense_shape_vec[i] != values_shp[i]) {
254 MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
255 << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
256 }
257 }
258 auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec);
259 ret->set_indices(indices);
260 ret->set_values(values);
261 ret->set_dense_shape(dense_shape);
262 return ret;
263 }
264
InferImplRowTensorGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)265 AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
266 const AbstractBasePtrList &args_spec_list) {
267 // Inputs: two tensors and a tuple.
268 const std::string op_name = primitive->name();
269 CheckArgsSize(op_name, args_spec_list, 1);
270 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
271 MS_EXCEPTION_IF_NULL(row_tensor->values());
272 return row_tensor->values();
273 }
274
InferImplRowTensorGetIndices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)275 AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
276 const AbstractBasePtrList &args_spec_list) {
277 // Inputs: two tensors and a tuple.
278 const std::string op_name = primitive->name();
279 CheckArgsSize(op_name, args_spec_list, 1);
280 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
281 MS_EXCEPTION_IF_NULL(row_tensor->indices());
282 return row_tensor->indices();
283 }
284
InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)285 AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
286 const AbstractBasePtrList &args_spec_list) {
287 // Inputs: two tensors and a tuple.
288 const std::string op_name = primitive->name();
289 CheckArgsSize(op_name, args_spec_list, 1);
290 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
291 MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
292 return row_tensor->dense_shape();
293 }
294
InferImplRowTensorAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)295 AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
296 const AbstractBasePtrList &args_spec_list) {
297 // Inputs: row tensor and tensor.
298 const std::string op_name = primitive->name();
299 constexpr size_t args_size = 2;
300 CheckArgsSize(op_name, args_spec_list, args_size);
301 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
302 auto tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
303 MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
304 MS_EXCEPTION_IF_NULL(tensor->shape());
305 return args_spec_list[0];
306 }
307
InferImplMakeSparseTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)308 AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
309 const AbstractBasePtrList &args_spec_list) {
310 // Inputs: two tensors and a tuple.
311 constexpr auto kMakeSparseInputNum = 3;
312 const std::string op_name = primitive->name();
313 CheckArgsSize(op_name, args_spec_list, kMakeSparseInputNum);
314 auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
315 auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
316 auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
317
318 auto indices_dtype = indices->element()->BuildType();
319 if (!indices_dtype->isa<Int>()) {
320 MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
321 }
322 auto indices_shp = indices->shape()->shape();
323 if (indices_shp.size() != 2) {
324 MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
325 << " dimension tensor";
326 }
327 auto values_shp = values->shape()->shape();
328 if (values_shp.size() != 1) {
329 MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
330 << " dimension tensor";
331 }
332 if (indices_shp[0] != values_shp[0]) {
333 MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
334 << values_shp[0] << ", but got " << indices_shp[0];
335 }
336
337 for (const auto &elem_type : dense_shape->ElementsType()) {
338 if (!elem_type->isa<Int>()) {
339 MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
340 }
341 }
342 auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
343 MS_EXCEPTION_IF_NULL(dense_shape_value);
344 auto shp = dense_shape_value->value();
345 ShapeVector dense_shape_vec;
346 (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
347 [](const ValuePtr &e) -> int64_t {
348 auto elem = GetValue<int64_t>(e);
349 return elem;
350 });
351 if (LongToSize(indices_shp[1]) != dense_shape_vec.size()) {
352 MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices "
353 << indices_shp[1] << ", but got " << dense_shape_vec.size();
354 }
355 for (auto dense_shape_elem : dense_shape_vec) {
356 if (dense_shape_elem < 0) {
357 MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
358 << dense_shape_value->ToString();
359 }
360 }
361 auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
362 ret->set_indices(indices);
363 ret->set_values(values);
364 ret->set_dense_shape(dense_shape);
365 return ret;
366 }
367
InferImplSparseTensorGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)368 AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
369 const AbstractBasePtrList &args_spec_list) {
370 // Inputs: two tensors and a tuple.
371 const std::string op_name = primitive->name();
372 CheckArgsSize(op_name, args_spec_list, 1);
373 auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
374 MS_EXCEPTION_IF_NULL(sparse_tensor->values());
375 return sparse_tensor->values();
376 }
377
InferImplSparseTensorGetIndices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)378 AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
379 const AbstractBasePtrList &args_spec_list) {
380 // Inputs: two tensors and a tuple.
381 const std::string op_name = primitive->name();
382 CheckArgsSize(op_name, args_spec_list, 1);
383 auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
384 MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
385 return sparse_tensor->indices();
386 }
387
InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)388 AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
389 const AbstractBasePtrList &args_spec_list) {
390 // Inputs: two tensors and a tuple.
391 const std::string op_name = primitive->name();
392 CheckArgsSize(op_name, args_spec_list, 1);
393 auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
394 MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
395 return sparse_tensor->dense_shape();
396 }
397
InferImplAllSwap(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)398 AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
399 const AbstractBasePtrList &args_spec_list) {
400 const std::string op_name = primitive->name();
401 CheckArgsSize(op_name, args_spec_list, 3);
402 auto tensor_in = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
403 MS_EXCEPTION_IF_NULL(tensor_in);
404 MS_EXCEPTION_IF_NULL(tensor_in->shape());
405 auto tensor_in_shape = tensor_in->shape()->shape();
406
407 auto send_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
408 MS_EXCEPTION_IF_NULL(send_size);
409 auto recv_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
410 MS_EXCEPTION_IF_NULL(recv_size);
411
412 // Get the content of the recv size
413 auto recv_size_value_ptr = recv_size->BuildValue();
414 MS_EXCEPTION_IF_NULL(recv_size_value_ptr);
415 auto recv_size_tensor = recv_size_value_ptr->cast<tensor::TensorPtr>();
416 MS_EXCEPTION_IF_NULL(recv_size_tensor);
417 auto data_pos = reinterpret_cast<int64_t *>(recv_size_tensor->data_c());
418 MS_EXCEPTION_IF_NULL(data_pos);
419 int64_t infer_max_size = 0;
420 for (int64_t i = 0; i < recv_size_tensor->DataSize(); ++i) {
421 infer_max_size += *(data_pos + i);
422 }
423
424 ShapeVector tensor_out_shape = {Shape::SHP_ANY, tensor_in_shape[1]};
425 ShapeVector min_shape = {1, tensor_in_shape[1]};
426
427 ShapeVector max_shape = {infer_max_size / tensor_in_shape[1], tensor_in_shape[1]};
428
429 auto tensor_out = std::make_shared<AbstractTensor>(tensor_in->element(),
430 std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
431
432 AbstractTensorPtr ret = std::make_shared<AbstractTensor>(
433 tensor_out->element(), std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
434 return ret;
435 }
436
InferImplAllReduce(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)437 AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
438 const AbstractBasePtrList &args_spec_list) {
439 const std::string op_name = primitive->name();
440 CheckArgsSize(op_name, args_spec_list, 1);
441 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
442 MS_EXCEPTION_IF_NULL(x);
443 MS_EXCEPTION_IF_NULL(x->shape());
444 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
445 }
446
InferImplBroadcast(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)447 AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
448 const AbstractBasePtrList &args_spec_list) {
449 const std::string op_name = primitive->name();
450 CheckArgsSize(op_name, args_spec_list, 1);
451 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
452 MS_EXCEPTION_IF_NULL(x);
453 MS_EXCEPTION_IF_NULL(x->shape());
454 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
455 }
456
InferImplAllGather(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)457 AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
458 const AbstractBasePtrList &args_spec_list) {
459 const std::string op_name = primitive->name();
460 CheckArgsSize(op_name, args_spec_list, 1);
461 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
462 MS_EXCEPTION_IF_NULL(x);
463 MS_EXCEPTION_IF_NULL(x->shape());
464 auto tmp_shape = x->shape()->shape();
465 if (!primitive->HasAttr(kRankSize)) {
466 MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr";
467 }
468 auto rank_size = GetValue<int>(primitive->GetAttr(kRankSize));
469 if (rank_size == 0) {
470 MS_LOG(EXCEPTION) << "rank_size is 0";
471 }
472 if (tmp_shape.empty()) {
473 MS_LOG(EXCEPTION) << "shape size is 0";
474 }
475 if (tmp_shape[0] > 0) {
476 tmp_shape[0] = tmp_shape[0] * rank_size;
477 }
478 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
479 }
480
InferImplReduceScatter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)481 AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
482 const AbstractBasePtrList &args_spec_list) {
483 const std::string op_name = primitive->name();
484 CheckArgsSize(op_name, args_spec_list, 1);
485 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
486 MS_EXCEPTION_IF_NULL(x);
487 MS_EXCEPTION_IF_NULL(x->shape());
488 auto tmp_shape = x->shape()->shape();
489 if (!primitive->HasAttr(kRankSize)) {
490 MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr";
491 }
492 auto rank_size = GetValue<int>(primitive->GetAttr(kRankSize));
493 if (tmp_shape.empty()) {
494 MS_LOG(EXCEPTION) << "shape size is 0";
495 }
496 tmp_shape[0] = LongMulWithOverflowCheck(tmp_shape[0], rank_size);
497 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
498 }
499
InferImplMemCpyAsync(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)500 AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
501 const AbstractBasePtrList &args_spec_list) {
502 const std::string op_name = primitive->name();
503 CheckArgsSize(op_name, args_spec_list, 1);
504 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
505 MS_EXCEPTION_IF_NULL(x);
506 MS_EXCEPTION_IF_NULL(x->shape());
507 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
508 }
509
InferImplCast(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)510 AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
511 const AbstractBasePtrList &args_spec_list) {
512 const std::string op_name = primitive->name();
513 // GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize.
514 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
515 MS_EXCEPTION_IF_NULL(input_x);
516 auto attr = primitive->GetAttr("dst_type");
517 if (attr == nullptr) {
518 auto type_abs = CheckArg<AbstractType>(op_name, args_spec_list, 1);
519 attr = type_abs->BuildValue();
520 MS_EXCEPTION_IF_NULL(attr);
521 primitive->set_attr("dst_type", attr);
522 }
523 auto input_type = attr->cast<TypePtr>();
524 auto ret = std::make_shared<AbstractTensor>(input_type, input_x->shape());
525 return ret;
526 }
527
InferImplExpandDims(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)528 AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
529 const AbstractBasePtrList &args_spec_list) {
530 const std::string op_name = primitive->name();
531 CheckArgsSize(op_name, args_spec_list, 1);
532 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
533 MS_EXCEPTION_IF_NULL(x);
534 MS_EXCEPTION_IF_NULL(x->shape());
535
536 std::vector<int64_t> shape;
537 std::vector<int64_t> x_shape = x->shape()->shape();
538 (void)shape.insert(shape.end(), x_shape.begin(), x_shape.end());
539 auto axis = primitive->GetAttr("axis");
540 auto value = GetValue<int64_t>(axis);
541 if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
542 MS_LOG(EXCEPTION) << " axis value should be in range [-input_x.dim-1,input_x.dim], but axis value is" << value
543 << " and input_x.dim is" << x_shape.size();
544 }
545 if (value < 0) {
546 value = value + SizeToInt(x_shape.size()) + 1;
547 }
548 (void)shape.insert(shape.begin() + value, 1);
549
550 auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
551 return ret;
552 }
553
InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)554 AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
555 const AbstractBasePtrList &args_spec_list) {
556 const std::string &op_name = primitive->name();
557 CheckArgsSize(op_name, args_spec_list, 1);
558 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
559
560 ShapeVector input_shape = input->shape()->shape();
561 int32_t input_rank = SizeToInt(input_shape.size());
562 ShapeVector inferred_shape(input_rank, Shape::SHP_ANY);
563 ShapeVector min_shape(input_rank, 1);
564 ShapeVector max_shape = input_shape;
565
566 ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
567 return std::make_shared<AbstractTensor>(input->element(), shape);
568 }
569
InferImplLoad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)570 AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
571 const AbstractBasePtrList &args_spec_list) {
572 // Inputs: Ref/Tensor, universal
573 CheckArgsSize(primitive->name(), args_spec_list, 2);
574 auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
575 if (ref_abs != nullptr) {
576 // Return tensor value if input is Ref.
577 return ref_abs->CloneAsTensor();
578 }
579 return args_spec_list[0]->Broaden();
580 }
581
InferImplTransData(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)582 AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
583 const AbstractBasePtrList &args_spec_list) {
584 // An object of a subclass of AbstractBase
585 CheckArgsSize(primitive->name(), args_spec_list, 1);
586 auto output = args_spec_list[0];
587 MS_EXCEPTION_IF_NULL(output);
588 return output;
589 }
590 } // namespace abstract
591 } // namespace mindspore
592