1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <iterator>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "abstract/abstract_value.h"
26 #include "abstract/ops/infer_functions.h"
27 #include "abstract/param_validator.h"
28 #include "abstract/utils.h"
29 #include "utils/shape_utils.h"
30 #include "abstract/dshape.h"
31 #include "base/base.h"
32 #include "ir/anf.h"
33 #include "ir/dtype.h"
34 #include "ir/dtype/number.h"
35 #include "ir/dtype/type.h"
36 #include "ir/primitive.h"
37 #include "ir/scalar.h"
38 #include "ir/tensor.h"
39 #include "ir/value.h"
40 #include "mindapi/base/shape_vector.h"
41 #include "mindapi/base/type_id.h"
42 #include "utils/convert_utils_base.h"
43 #include "utils/log_adapter.h"
44 #include "utils/check_convert_utils.h"
45
46 namespace mindspore {
47 namespace abstract {
InferImplScalarToArray(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)48 AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49 const AbstractBasePtrList &args_abs_list) {
50 // Inputs: a scalar.
51 const std::string op_name = primitive->name();
52 CheckArgsSize(op_name, args_abs_list, 1);
53 AbstractScalarPtr arg = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
54 return std::make_shared<AbstractTensor>(arg, std::make_shared<Shape>());
55 }
56
InferImplArrayToScalar(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)57 AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
58 const AbstractBasePtrList &args_abs_list) {
59 // Inputs: a tensor with 0 shape.
60 const std::string op_name = primitive->name();
61 CheckArgsSize(op_name, args_abs_list, 1);
62 auto arg = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
63 auto a_shp = arg->shape();
64 MS_EXCEPTION_IF_NULL(a_shp);
65 if (!a_shp->shape().empty()) {
66 MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape.";
67 }
68 return arg->element();
69 }
70
InferImplBroadcastShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)71 AbstractBasePtr InferImplBroadcastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72 const AbstractBasePtrList &args_abs_list) {
73 // Inputs: two tuples.
74 const std::string op_name = primitive->name();
75 constexpr size_t args_size = 2;
76 CheckArgsSize(op_name, args_abs_list, args_size);
77 auto xs = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
78 auto ys = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
79 auto x_value = xs->BuildValue();
80 MS_EXCEPTION_IF_NULL(x_value);
81 auto value_tuple_x = x_value->cast<ValueTuplePtr>();
82 MS_EXCEPTION_IF_NULL(value_tuple_x);
83 auto shp_tuple_x = value_tuple_x->value();
84 ShapeVector shp_x;
85 (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x),
86 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
87 auto tupe_value_y = ys->BuildValue();
88 MS_EXCEPTION_IF_NULL(tupe_value_y);
89 auto value_tuple_y = tupe_value_y->cast<ValueTuplePtr>();
90 MS_EXCEPTION_IF_NULL(value_tuple_y);
91 auto shp_tuple_y = value_tuple_y->value();
92 ShapeVector shp_y;
93 (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y),
94 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
95
96 ShapeVector res = BroadcastShape(shp_x, shp_y);
97 MS_EXCEPTION_IF_NULL(args_abs_list[1]);
98 if (res.empty()) {
99 MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_abs_list[0]->ToString() << "," << args_abs_list[1]->ToString();
100 }
101
102 AbstractBasePtrList elems;
103 (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
104 return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
105 });
106 return std::make_shared<AbstractTuple>(elems);
107 }
108
InferImplMapCacheIdx(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)109 AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
110 const AbstractBasePtrList &args_abs_list) {
111 const std::string op_name = primitive->name();
112 const size_t size_expected = 5;
113 CheckArgsSize(op_name, args_abs_list, size_expected);
114 auto hash_map = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
115 MS_EXCEPTION_IF_NULL(hash_map->shape());
116
117 auto indices = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
118 auto indices_shp = indices->shape();
119 MS_EXCEPTION_IF_NULL(indices_shp);
120
121 ShapeVector shape(indices_shp->shape().size(), -1);
122
123 auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
124 auto old_emb_idx = std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape));
125 auto miss_emb_idx = std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape));
126 auto swap_emb_idx = std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape));
127
128 AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
129 return std::make_shared<AbstractTuple>(elements);
130 }
131
InferImplCacheSwapTable(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)132 AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
133 const AbstractBasePtrList &args_abs_list) {
134 const std::string op_name = primitive->name();
135 const size_t size_expected = 3;
136 CheckArgsSize(op_name, args_abs_list, size_expected);
137 auto cache_table = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
138 auto cache_table_shp = cache_table->shape();
139 MS_EXCEPTION_IF_NULL(cache_table_shp);
140
141 auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
142 auto swap_cache_idx_shp = swap_cache_idx->shape();
143 MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);
144
145 auto cache_table_shape = cache_table_shp->shape();
146 auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
147 ShapeVector shape;
148 shape.emplace_back(swap_cache_idx_shape[0]);
149 shape.emplace_back(cache_table_shape[1]);
150
151 AbstractTensorPtr ret = std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape));
152 return ret;
153 }
154
InferImplSubAndFilter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)155 AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
156 const AbstractBasePtrList &args_abs_list) {
157 const std::string op_name = primitive->name();
158 auto input_x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
159 auto input_x_shp = input_x->shape();
160 MS_EXCEPTION_IF_NULL(input_x_shp);
161
162 ShapeVector shape(input_x_shp->shape().size(), -1);
163
164 auto filter_res = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
165 auto filter_idx = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
166 AbstractBasePtrList elements = {filter_res, filter_idx};
167 return std::make_shared<AbstractTuple>(elements);
168 }
169
InferImplDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)170 AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
171 const AbstractBasePtrList &args_abs_list) {
172 const std::string op_name = primitive->name();
173 const size_t size_expected = 2;
174 CheckArgsSize(op_name, args_abs_list, size_expected);
175 auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
176 auto y = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
177 MS_EXCEPTION_IF_NULL(x);
178 MS_EXCEPTION_IF_NULL(x->shape());
179 MS_EXCEPTION_IF_NULL(y);
180 MS_EXCEPTION_IF_NULL(y->shape());
181 ShapeVector x_shape = x->shape()->shape();
182 ShapeVector y_shape = y->shape()->shape();
183 ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
184 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
185 }
186
InferImplRealInnerDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)187 AbstractBasePtr InferImplRealInnerDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
188 const AbstractBasePtrList &args_abs_list) {
189 const std::string op_name = primitive->name();
190 const size_t size_expected = 2;
191 CheckArgsSize(op_name, args_abs_list, size_expected);
192 auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
193 auto y = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
194 MS_EXCEPTION_IF_NULL(x);
195 MS_EXCEPTION_IF_NULL(x->shape());
196 MS_EXCEPTION_IF_NULL(y);
197 MS_EXCEPTION_IF_NULL(y->shape());
198 ShapeVector x_shape = x->shape()->shape();
199 ShapeVector y_shape = y->shape()->shape();
200 ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
201 if (out_shape.empty()) {
202 MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_abs_list[0]->ToString() << "," << args_abs_list[1]->ToString();
203 }
204 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
205 }
206
InferImplTranspose(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)207 AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
208 const AbstractBasePtrList &args_abs_list) {
209 const std::string &op_name = primitive->name();
210 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
211 auto input_shp = input->shape()->shape();
212 ValuePtr perm = primitive->GetAttr("perm");
213 MS_EXCEPTION_IF_NULL(perm);
214 auto perm_val = perm->cast<ValueTuplePtr>();
215 MS_EXCEPTION_IF_NULL(perm_val);
216 auto perm_val_data = perm_val->value();
217 ShapeVector perm_vec;
218 (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
219 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
220 ShapeVector result_shp;
221 for (size_t i = 0; i < perm_vec.size(); i++) {
222 auto idx = static_cast<size_t>(perm_vec[i]);
223 result_shp.push_back(input_shp[idx]);
224 }
225 return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
226 }
227
InferImplMapUniform(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)228 AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
229 const AbstractBasePtrList &args_abs_list) {
230 // Inputs: one tensor.
231 const std::string op_name = primitive->name();
232 const size_t size_expected = 3;
233 CheckArgsSize(op_name, args_abs_list, size_expected);
234 return args_abs_list[0]->Broaden();
235 }
236
InferImplSequenceMask(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)237 AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
238 const AbstractBasePtrList &args_abs_list) {
239 const std::string &op_name = primitive->name();
240 const size_t size_expected = 2;
241 CheckArgsSize(op_name, args_abs_list, size_expected);
242
243 AbstractTensorPtr lengths = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
244 (void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s");
245
246 int64_t maxlen_value = 0;
247
248 if (args_abs_list[1]->isa<AbstractScalar>()) {
249 AbstractScalarPtr maxlen = CheckArg<AbstractScalar>(op_name, args_abs_list, 1);
250 (void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s");
251
252 TypePtr maxlen_type = nullptr;
253 maxlen_type = maxlen->GetTypeTrack();
254 MS_EXCEPTION_IF_NULL(maxlen_type);
255
256 if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) {
257 maxlen_value = static_cast<int64_t>(GetValue<int32_t>(maxlen->BuildValue()));
258 } else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) {
259 maxlen_value = GetValue<int64_t>(maxlen->BuildValue());
260 }
261 } else if (args_abs_list[1]->isa<AbstractTensor>()) {
262 auto maxlen_tensor_ptr = args_abs_list[1]->cast<AbstractTensorPtr>();
263 MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr);
264 auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue();
265 MS_EXCEPTION_IF_NULL(maxlen_value_ptr);
266 auto maxlen_tensor = maxlen_value_ptr->cast<tensor::TensorPtr>();
267 MS_EXCEPTION_IF_NULL(maxlen_tensor);
268 maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c());
269 }
270
271 if (maxlen_value <= 0) {
272 MS_LOG(EXCEPTION) << "maxlen must be positive, but got: " << maxlen_value;
273 }
274
275 ShapeVector lengths_shape = lengths->shape()->shape();
276 lengths_shape.push_back(maxlen_value);
277 ShapePtr output_shape = std::make_shared<Shape>(lengths_shape);
278 return std::make_shared<AbstractTensor>(kBool, output_shape);
279 }
280
281 // Helper struct for FlattenConcat infer.
282 struct ChunkInfo {
283 size_t bytes{0}; // number of bytes.
284 size_t size{0}; // number of elements.
285 };
286
287 using ChunkMap = std::map<TypeId, std::vector<ChunkInfo>>;
288
289 // Group inputs by data type and fusion size.
GroupingAbstractTensors(const AbstractBasePtrList & elements,size_t fusion_size,const std::string & prim_name)290 static ChunkMap GroupingAbstractTensors(const AbstractBasePtrList &elements, size_t fusion_size,
291 const std::string &prim_name) {
292 ChunkMap chunk_map;
293 for (auto &element : elements) {
294 auto abs_tensor = dyn_cast<abstract::AbstractTensor>(element);
295 if (abs_tensor == nullptr) {
296 MS_LOG(EXCEPTION) << "The input element for '" << prim_name << "' should be Tensor, but got "
297 << element->type_name() << ".";
298 }
299 // Calculate data size (number of elements) by shape.
300 auto base_shape = abs_tensor->GetShape();
301 MS_EXCEPTION_IF_NULL(base_shape);
302 auto shape = base_shape->cast<ShapePtr>();
303 if (shape == nullptr) {
304 MS_LOG(EXCEPTION) << "The input tensors for '" << prim_name << "' should have shape, but got "
305 << base_shape->ToString() << ".";
306 }
307 auto data_size = SizeOf(shape->shape());
308 if (data_size == 0) {
309 MS_LOG(EXCEPTION) << "The input tensors for '" << prim_name << "'should have static shape, but got "
310 << shape->ToString() << ".";
311 }
312 // Find data type from the AbstractTensor.
313 const auto &element_abs = abs_tensor->element();
314 MS_EXCEPTION_IF_NULL(element_abs);
315 auto dtype = element_abs->BuildType();
316 MS_EXCEPTION_IF_NULL(dtype);
317 const auto type_id = dtype->type_id();
318 const auto data_bytes = data_size * abstract::TypeIdSize(type_id);
319 if (fusion_size != 0 && fusion_size < data_bytes) {
320 MS_LOG(EXCEPTION) << "Fusion size " << fusion_size << " is too small for a tensor size " << data_bytes << ".";
321 }
322 // Group them by data type and fusion size.
323 auto &chunks = chunk_map[type_id];
324 if (chunks.empty()) {
325 (void)chunks.emplace_back();
326 }
327 if (fusion_size != 0 && chunks.back().bytes + data_bytes > fusion_size) {
328 (void)chunks.emplace_back();
329 }
330 auto &chunk = chunks.back();
331 chunk.bytes += data_bytes;
332 chunk.size += data_size;
333 }
334 return chunk_map;
335 }
336
InferImplFlattenConcat(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)337 AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
338 const AbstractBasePtrList &args_abs_list) {
339 CheckArgsSize(primitive->name(), args_abs_list, 1);
340 auto seq = dyn_cast<abstract::AbstractSequence>(args_abs_list[0]);
341 if (seq == nullptr) {
342 MS_LOG(EXCEPTION) << "The input for '" << primitive->name() << "' should be tuple or list, but got "
343 << args_abs_list[0]->type_name();
344 }
345 // Get fusion size from primitive attribute.
346 const auto fusion_size_attr = primitive->GetAttr("fusion_size");
347 const size_t fusion_size = static_cast<size_t>(fusion_size_attr != nullptr ? GetValue<int64_t>(fusion_size_attr) : 0);
348 // Group inputs by data type and fusion size.
349 auto chunk_map = GroupingAbstractTensors(seq->elements(), fusion_size, primitive->name());
350 // Make result AbstractTuple according to the grouping result.
351 AbstractBasePtrList tuple_element;
352 for (auto &entry : chunk_map) {
353 auto dtype = TypeIdToType(entry.first);
354 for (auto &chunk : entry.second) {
355 ShapeVector shape_vec{static_cast<int64_t>(chunk.size)};
356 auto abs = std::make_shared<abstract::AbstractTensor>(dtype, shape_vec);
357 (void)tuple_element.emplace_back(abs);
358 }
359 }
360 return std::make_shared<abstract::AbstractTuple>(std::move(tuple_element));
361 }
362 } // namespace abstract
363 } // namespace mindspore
364