• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <iterator>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "abstract/abstract_value.h"
26 #include "abstract/ops/infer_functions.h"
27 #include "abstract/param_validator.h"
28 #include "abstract/utils.h"
29 #include "utils/shape_utils.h"
30 #include "abstract/dshape.h"
31 #include "base/base.h"
32 #include "ir/anf.h"
33 #include "ir/dtype.h"
34 #include "ir/dtype/number.h"
35 #include "ir/dtype/type.h"
36 #include "ir/primitive.h"
37 #include "ir/scalar.h"
38 #include "ir/tensor.h"
39 #include "ir/value.h"
40 #include "mindapi/base/shape_vector.h"
41 #include "mindapi/base/type_id.h"
42 #include "utils/convert_utils_base.h"
43 #include "utils/log_adapter.h"
44 #include "utils/check_convert_utils.h"
45 
46 namespace mindspore {
47 namespace abstract {
InferImplScalarToArray(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)48 AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49                                        const AbstractBasePtrList &args_abs_list) {
50   // Inputs: a scalar.
51   const std::string op_name = primitive->name();
52   CheckArgsSize(op_name, args_abs_list, 1);
53   AbstractScalarPtr arg = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
54   return std::make_shared<AbstractTensor>(arg, std::make_shared<Shape>());
55 }
56 
InferImplArrayToScalar(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)57 AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
58                                        const AbstractBasePtrList &args_abs_list) {
59   // Inputs: a tensor with 0 shape.
60   const std::string op_name = primitive->name();
61   CheckArgsSize(op_name, args_abs_list, 1);
62   auto arg = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
63   auto a_shp = arg->shape();
64   MS_EXCEPTION_IF_NULL(a_shp);
65   if (!a_shp->shape().empty()) {
66     MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape.";
67   }
68   return arg->element();
69 }
70 
InferImplBroadcastShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)71 AbstractBasePtr InferImplBroadcastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72                                         const AbstractBasePtrList &args_abs_list) {
73   // Inputs: two tuples.
74   const std::string op_name = primitive->name();
75   constexpr size_t args_size = 2;
76   CheckArgsSize(op_name, args_abs_list, args_size);
77   auto xs = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
78   auto ys = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
79   auto x_value = xs->BuildValue();
80   MS_EXCEPTION_IF_NULL(x_value);
81   auto value_tuple_x = x_value->cast<ValueTuplePtr>();
82   MS_EXCEPTION_IF_NULL(value_tuple_x);
83   auto shp_tuple_x = value_tuple_x->value();
84   ShapeVector shp_x;
85   (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x),
86                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
87   auto tupe_value_y = ys->BuildValue();
88   MS_EXCEPTION_IF_NULL(tupe_value_y);
89   auto value_tuple_y = tupe_value_y->cast<ValueTuplePtr>();
90   MS_EXCEPTION_IF_NULL(value_tuple_y);
91   auto shp_tuple_y = value_tuple_y->value();
92   ShapeVector shp_y;
93   (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y),
94                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
95 
96   ShapeVector res = BroadcastShape(shp_x, shp_y);
97   MS_EXCEPTION_IF_NULL(args_abs_list[1]);
98   if (res.empty()) {
99     MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_abs_list[0]->ToString() << "," << args_abs_list[1]->ToString();
100   }
101 
102   AbstractBasePtrList elems;
103   (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
104     return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
105   });
106   return std::make_shared<AbstractTuple>(elems);
107 }
108 
InferImplMapCacheIdx(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)109 AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
110                                      const AbstractBasePtrList &args_abs_list) {
111   const std::string op_name = primitive->name();
112   const size_t size_expected = 5;
113   CheckArgsSize(op_name, args_abs_list, size_expected);
114   auto hash_map = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
115   MS_EXCEPTION_IF_NULL(hash_map->shape());
116 
117   auto indices = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
118   auto indices_shp = indices->shape();
119   MS_EXCEPTION_IF_NULL(indices_shp);
120 
121   ShapeVector shape(indices_shp->shape().size(), -1);
122 
123   auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
124   auto old_emb_idx = std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape));
125   auto miss_emb_idx = std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape));
126   auto swap_emb_idx = std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape));
127 
128   AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
129   return std::make_shared<AbstractTuple>(elements);
130 }
131 
InferImplCacheSwapTable(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)132 AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
133                                         const AbstractBasePtrList &args_abs_list) {
134   const std::string op_name = primitive->name();
135   const size_t size_expected = 3;
136   CheckArgsSize(op_name, args_abs_list, size_expected);
137   auto cache_table = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
138   auto cache_table_shp = cache_table->shape();
139   MS_EXCEPTION_IF_NULL(cache_table_shp);
140 
141   auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
142   auto swap_cache_idx_shp = swap_cache_idx->shape();
143   MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);
144 
145   auto cache_table_shape = cache_table_shp->shape();
146   auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
147   ShapeVector shape;
148   shape.emplace_back(swap_cache_idx_shape[0]);
149   shape.emplace_back(cache_table_shape[1]);
150 
151   AbstractTensorPtr ret = std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape));
152   return ret;
153 }
154 
InferImplSubAndFilter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)155 AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
156                                       const AbstractBasePtrList &args_abs_list) {
157   const std::string op_name = primitive->name();
158   auto input_x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
159   auto input_x_shp = input_x->shape();
160   MS_EXCEPTION_IF_NULL(input_x_shp);
161 
162   ShapeVector shape(input_x_shp->shape().size(), -1);
163 
164   auto filter_res = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
165   auto filter_idx = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
166   AbstractBasePtrList elements = {filter_res, filter_idx};
167   return std::make_shared<AbstractTuple>(elements);
168 }
169 
InferImplDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)170 AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
171                              const AbstractBasePtrList &args_abs_list) {
172   const std::string op_name = primitive->name();
173   const size_t size_expected = 2;
174   CheckArgsSize(op_name, args_abs_list, size_expected);
175   auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
176   auto y = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
177   MS_EXCEPTION_IF_NULL(x);
178   MS_EXCEPTION_IF_NULL(x->shape());
179   MS_EXCEPTION_IF_NULL(y);
180   MS_EXCEPTION_IF_NULL(y->shape());
181   ShapeVector x_shape = x->shape()->shape();
182   ShapeVector y_shape = y->shape()->shape();
183   ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
184   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
185 }
186 
InferImplRealInnerDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)187 AbstractBasePtr InferImplRealInnerDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
188                                       const AbstractBasePtrList &args_abs_list) {
189   const std::string op_name = primitive->name();
190   const size_t size_expected = 2;
191   CheckArgsSize(op_name, args_abs_list, size_expected);
192   auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
193   auto y = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
194   MS_EXCEPTION_IF_NULL(x);
195   MS_EXCEPTION_IF_NULL(x->shape());
196   MS_EXCEPTION_IF_NULL(y);
197   MS_EXCEPTION_IF_NULL(y->shape());
198   ShapeVector x_shape = x->shape()->shape();
199   ShapeVector y_shape = y->shape()->shape();
200   ShapeVector out_shape = BroadcastShape(x_shape, y_shape);
201   if (out_shape.empty()) {
202     MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_abs_list[0]->ToString() << "," << args_abs_list[1]->ToString();
203   }
204   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(out_shape));
205 }
206 
InferImplTranspose(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)207 AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
208                                    const AbstractBasePtrList &args_abs_list) {
209   const std::string &op_name = primitive->name();
210   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
211   auto input_shp = input->shape()->shape();
212   ValuePtr perm = primitive->GetAttr("perm");
213   MS_EXCEPTION_IF_NULL(perm);
214   auto perm_val = perm->cast<ValueTuplePtr>();
215   MS_EXCEPTION_IF_NULL(perm_val);
216   auto perm_val_data = perm_val->value();
217   ShapeVector perm_vec;
218   (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
219                        [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
220   ShapeVector result_shp;
221   for (size_t i = 0; i < perm_vec.size(); i++) {
222     auto idx = static_cast<size_t>(perm_vec[i]);
223     result_shp.push_back(input_shp[idx]);
224   }
225   return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
226 }
227 
InferImplMapUniform(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)228 AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
229                                     const AbstractBasePtrList &args_abs_list) {
230   // Inputs: one tensor.
231   const std::string op_name = primitive->name();
232   const size_t size_expected = 3;
233   CheckArgsSize(op_name, args_abs_list, size_expected);
234   return args_abs_list[0]->Broaden();
235 }
236 
InferImplSequenceMask(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)237 AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
238                                       const AbstractBasePtrList &args_abs_list) {
239   const std::string &op_name = primitive->name();
240   const size_t size_expected = 2;
241   CheckArgsSize(op_name, args_abs_list, size_expected);
242 
243   AbstractTensorPtr lengths = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
244   (void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s");
245 
246   int64_t maxlen_value = 0;
247 
248   if (args_abs_list[1]->isa<AbstractScalar>()) {
249     AbstractScalarPtr maxlen = CheckArg<AbstractScalar>(op_name, args_abs_list, 1);
250     (void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s");
251 
252     TypePtr maxlen_type = nullptr;
253     maxlen_type = maxlen->GetTypeTrack();
254     MS_EXCEPTION_IF_NULL(maxlen_type);
255 
256     if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) {
257       maxlen_value = static_cast<int64_t>(GetValue<int32_t>(maxlen->BuildValue()));
258     } else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) {
259       maxlen_value = GetValue<int64_t>(maxlen->BuildValue());
260     }
261   } else if (args_abs_list[1]->isa<AbstractTensor>()) {
262     auto maxlen_tensor_ptr = args_abs_list[1]->cast<AbstractTensorPtr>();
263     MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr);
264     auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue();
265     MS_EXCEPTION_IF_NULL(maxlen_value_ptr);
266     auto maxlen_tensor = maxlen_value_ptr->cast<tensor::TensorPtr>();
267     MS_EXCEPTION_IF_NULL(maxlen_tensor);
268     maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c());
269   }
270 
271   if (maxlen_value <= 0) {
272     MS_LOG(EXCEPTION) << "maxlen must be positive, but got: " << maxlen_value;
273   }
274 
275   ShapeVector lengths_shape = lengths->shape()->shape();
276   lengths_shape.push_back(maxlen_value);
277   ShapePtr output_shape = std::make_shared<Shape>(lengths_shape);
278   return std::make_shared<AbstractTensor>(kBool, output_shape);
279 }
280 
281 // Helper struct for FlattenConcat infer.
282 struct ChunkInfo {
283   size_t bytes{0};  // number of bytes.
284   size_t size{0};   // number of elements.
285 };
286 
287 using ChunkMap = std::map<TypeId, std::vector<ChunkInfo>>;
288 
289 // Group inputs by data type and fusion size.
GroupingAbstractTensors(const AbstractBasePtrList & elements,size_t fusion_size,const std::string & prim_name)290 static ChunkMap GroupingAbstractTensors(const AbstractBasePtrList &elements, size_t fusion_size,
291                                         const std::string &prim_name) {
292   ChunkMap chunk_map;
293   for (auto &element : elements) {
294     auto abs_tensor = dyn_cast<abstract::AbstractTensor>(element);
295     if (abs_tensor == nullptr) {
296       MS_LOG(EXCEPTION) << "The input element for '" << prim_name << "' should be Tensor, but got "
297                         << element->type_name() << ".";
298     }
299     // Calculate data size (number of elements) by shape.
300     auto base_shape = abs_tensor->GetShape();
301     MS_EXCEPTION_IF_NULL(base_shape);
302     auto shape = base_shape->cast<ShapePtr>();
303     if (shape == nullptr) {
304       MS_LOG(EXCEPTION) << "The input tensors for '" << prim_name << "' should have shape, but got "
305                         << base_shape->ToString() << ".";
306     }
307     auto data_size = SizeOf(shape->shape());
308     if (data_size == 0) {
309       MS_LOG(EXCEPTION) << "The input tensors for '" << prim_name << "'should have static shape, but got "
310                         << shape->ToString() << ".";
311     }
312     // Find data type from the AbstractTensor.
313     const auto &element_abs = abs_tensor->element();
314     MS_EXCEPTION_IF_NULL(element_abs);
315     auto dtype = element_abs->BuildType();
316     MS_EXCEPTION_IF_NULL(dtype);
317     const auto type_id = dtype->type_id();
318     const auto data_bytes = data_size * abstract::TypeIdSize(type_id);
319     if (fusion_size != 0 && fusion_size < data_bytes) {
320       MS_LOG(EXCEPTION) << "Fusion size " << fusion_size << " is too small for a tensor size " << data_bytes << ".";
321     }
322     // Group them by data type and fusion size.
323     auto &chunks = chunk_map[type_id];
324     if (chunks.empty()) {
325       (void)chunks.emplace_back();
326     }
327     if (fusion_size != 0 && chunks.back().bytes + data_bytes > fusion_size) {
328       (void)chunks.emplace_back();
329     }
330     auto &chunk = chunks.back();
331     chunk.bytes += data_bytes;
332     chunk.size += data_size;
333   }
334   return chunk_map;
335 }
336 
InferImplFlattenConcat(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)337 AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
338                                        const AbstractBasePtrList &args_abs_list) {
339   CheckArgsSize(primitive->name(), args_abs_list, 1);
340   auto seq = dyn_cast<abstract::AbstractSequence>(args_abs_list[0]);
341   if (seq == nullptr) {
342     MS_LOG(EXCEPTION) << "The input for '" << primitive->name() << "' should be tuple or list, but got "
343                       << args_abs_list[0]->type_name();
344   }
345   // Get fusion size from primitive attribute.
346   const auto fusion_size_attr = primitive->GetAttr("fusion_size");
347   const size_t fusion_size = static_cast<size_t>(fusion_size_attr != nullptr ? GetValue<int64_t>(fusion_size_attr) : 0);
348   // Group inputs by data type and fusion size.
349   auto chunk_map = GroupingAbstractTensors(seq->elements(), fusion_size, primitive->name());
350   // Make result AbstractTuple according to the grouping result.
351   AbstractBasePtrList tuple_element;
352   for (auto &entry : chunk_map) {
353     auto dtype = TypeIdToType(entry.first);
354     for (auto &chunk : entry.second) {
355       ShapeVector shape_vec{static_cast<int64_t>(chunk.size)};
356       auto abs = std::make_shared<abstract::AbstractTensor>(dtype, shape_vec);
357       (void)tuple_element.emplace_back(abs);
358     }
359   }
360   return std::make_shared<abstract::AbstractTuple>(std::move(tuple_element));
361 }
362 }  // namespace abstract
363 }  // namespace mindspore
364