• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "mindspore/core/symbolic_shape/utils.h"
17 #include <algorithm>
18 #include <utility>
19 #include <memory>
20 #include "ir/kernel_tensor_value.h"
21 #include "mindspore/core/utils/check_convert_utils.h"
22 #include "mindspore/core/ops/op_utils.h"
23 #include "mindspore/core/symbolic_shape/int_symbol.h"
24 
25 namespace mindspore {
26 namespace symshape {
27 namespace {
GenValueByTensorShape(const ShapeVector & shape,const TypePtr & type_ptr)28 SymbolPtr GenValueByTensorShape(const ShapeVector &shape, const TypePtr &type_ptr) {
29   if (IsDynamic(shape)) {
30     return ListSymbol::Make();
31   }
32   if (shape.size() > 1) {
33     MS_LOG(WARNING) << "Symbolic value only support 0-D or 1-D value, but got the shape: " << shape;
34     return ListSymbol::Make();
35   }
36   if (shape.size() == 0) {
37     if (type_ptr->generic_type_id() == kNumberTypeBool) {
38       return BoolSymbol::Make();
39     }
40     if (type_ptr->generic_type_id() == kNumberTypeFloat) {
41       return FloatSymbol::Make();
42     }
43     return IntSymbol::Make();
44   }
45   SymbolPtrList list(LongToSize(shape[0]));
46   if (type_ptr->generic_type_id() == kNumberTypeBool) {
47     std::generate(list.begin(), list.end(), []() { return BoolSymbol::Make(); });
48   } else if (type_ptr->generic_type_id() == kNumberTypeFloat) {
49     std::generate(list.begin(), list.end(), []() { return FloatSymbol::Make(); });
50   } else {
51     std::generate(list.begin(), list.end(), []() { return IntSymbol::Make(); });
52   }
53   return ListSymbol::Make(std::move(list));
54 }
55 
GenValueByShape(const BaseShapePtr & baseshape,const TypePtr & type_ptr)56 SymbolPtr GenValueByShape(const BaseShapePtr &baseshape, const TypePtr &type_ptr) {
57   if (baseshape->isa<abstract::NoShape>()) {
58     return GenValueByTensorShape({}, type_ptr);
59   }
60   if (baseshape->isa<abstract::TensorShape>()) {
61     auto tensor_type = type_ptr->cast<TensorTypePtr>();
62     MS_EXCEPTION_IF_NULL(tensor_type);
63     return GenValueByTensorShape(baseshape->cast<abstract::TensorShapePtr>()->shape(), tensor_type->element());
64   }
65   if (baseshape->isa<abstract::DynamicSequenceShape>()) {
66     return ListSymbol::Make();
67   }
68   auto seq_shape = baseshape->cast<abstract::SequenceShapePtr>();
69   MS_EXCEPTION_IF_NULL(seq_shape);
70   SymbolPtrList result(seq_shape->size());
71   auto seq_type = type_ptr->cast<TuplePtr>();
72   MS_EXCEPTION_IF_NULL(seq_type);
73   if (seq_shape->size() != seq_type->size()) {
74     MS_LOG(INTERNAL_EXCEPTION) << "The size of seq_shape and seq_type should equal, but got " << seq_shape->size()
75                                << " vs " << seq_type->size();
76   }
77   (void)std::transform(
78     seq_shape->shape().begin(), seq_shape->shape().end(), seq_type->elements().begin(), result.begin(),
79     [](const BaseShapePtr &shape_elm, const TypePtr &type_elm) { return GenValueByShape(shape_elm, type_elm); });
80   return ListSymbol::Make(std::move(result));
81 }
82 }  // namespace
83 
KernelTensorValueToSymbol(const ValuePtr & v,bool to_scalar)84 SymbolPtr KernelTensorValueToSymbol(const ValuePtr &v, bool to_scalar) {
85   auto type_ptr = v->type();
86   if (type_ptr == nullptr) {
87     MS_LOG(WARNING) << "type of KernelTensorPtr is null! trying getting Tuple Int";
88     auto value = CheckAndConvertUtils::CheckTupleInt(v->ToString(), v, "ConstSymbolicValue");
89     return IntValues2Symbol(value);
90   }
91   if (type_ptr->type_id() == kNumberTypeBool) {
92     return BoolSymbol::Make(ops::GetScalarValue<bool>(v).value());
93   }
94   if (type_ptr->type_id() == kObjectTypeString) {
95     return StrSymbol::Make(ops::GetScalarValue<std::string>(v).value());
96   }
97   if (type_ptr->type_id() == kNumberTypeInt64) {
98     return IntSymbol::Make(ops::GetScalarValue<int64_t>(v).value());
99   }
100   if (type_ptr->type_id() == kNumberTypeInt32) {
101     return IntSymbol::Make(static_cast<int64_t>(ops::GetScalarValue<int32_t>(v).value()));
102   }
103   auto value_opt = ops::GetArrayValue<int64_t>(v);
104   if (value_opt.has_value()) {
105     auto vec = value_opt.value().ToVector();
106     if (to_scalar && !vec.empty()) {
107       return IntSymbol::Make(vec[0]);
108     }
109     return IntValues2Symbol(vec);
110   }
111   MS_LOG(INTERNAL_EXCEPTION) << "Unsupported KernelTensorValue to Symbol: " << type_ptr->ToString();
112 }
113 
ConstValueToSymbol(const ValuePtr & v,bool to_scalar)114 SymbolPtr ConstValueToSymbol(const ValuePtr &v, bool to_scalar) {
115   if (v->isa<KernelTensorValue>()) {
116     return KernelTensorValueToSymbol(v, to_scalar);
117   }
118   if (v->isa<ValueSequence>()) {
119     auto seq = v->cast_ptr<ValueSequence>();
120     MS_EXCEPTION_IF_NULL(seq);
121     SymbolPtrList result(seq->size());
122     (void)std::transform(seq->value().begin(), seq->value().end(), result.begin(),
123                          [to_scalar](const ValuePtr &v) { return ConstValueToSymbol(v, to_scalar); });
124     return ListSymbol::Make(std::move(result));
125   }
126   if (v->isa<tensor::Tensor>()) {
127     auto tensor_value = CheckAndConvertUtils::CheckTensorIntValue(v->ToString(), v, "ConstSymbolicValue");
128     auto tensor = v->cast_ptr<tensor::Tensor>();
129     return tensor->shape().empty() ? IntSymbol::Make(tensor_value[0]) : IntValues2Symbol(tensor_value);
130   }
131   if (v->isa<IntegerImm>()) {
132     return IntSymbol::Make(v->isa<Int64Imm>() ? GetValue<int64_t>(v) : static_cast<int64_t>(GetValue<int32_t>(v)));
133   }
134   if (v->isa<BoolImm>()) {
135     return BoolSymbol::Make(GetValue<bool>(v));
136   }
137   if (v->isa<FloatImm>()) {
138     return FloatSymbol::Make(v->isa<FP64Imm>() ? GetValue<double>(v) : static_cast<double>(GetValue<float>(v)));
139   }
140   if (v->isa<StringImm>()) {
141     return StrSymbol::Make(GetValue<std::string>(v));
142   }
143   MS_LOG(EXCEPTION)
144     << "Value should be one of {ValueSequence, Tensor, IntegerImm, BoolImm, FloatImm, StringImm}, but got "
145     << v->ToString();
146   return nullptr;
147 }
148 
BuildSymbolicValue(const AbstractBasePtr & abstract)149 SymbolPtr BuildSymbolicValue(const AbstractBasePtr &abstract) {
150   auto value_ptr = abstract->GetValue();
151   if (value_ptr->isa<ValueAny>()) {
152     return GenValueByShape(abstract->GetShape(), abstract->GetType());
153   }
154   auto shape = abstract->GetShape();
155   if (shape->isa<abstract::TensorShape>() && shape->cast_ptr<abstract::TensorShape>()->shape().empty()) {
156     return ConstValueToSymbol(value_ptr, true);
157   }
158   return ConstValueToSymbol(value_ptr, false);
159 }
160 
ToShape(const Symbol * symbol)161 ShapeVector ToShape(const Symbol *symbol) {
162   if (!symbol->HasData()) {
163     return {abstract::Shape::kShapeRankAny};
164   }
165   auto *list = symbol->as<ListSymbol>();
166   MS_EXCEPTION_IF_NULL(list);
167   ShapeVector shape(list->size());
168   (void)std::transform(list->symbols().cbegin(), list->symbols().cend(), shape.begin(), [](const SymbolPtr &s) {
169     auto int_smbl = s->as<IntSymbol>();
170     MS_EXCEPTION_IF_NULL(int_smbl);
171     if (!int_smbl->HasData()) {
172       return abstract::Shape::kShapeDimAny;
173     }
174     return int_smbl->value();
175   });
176   return shape;
177 }
178 
ShapeVector2Symbol(const ShapeVector & shape,const OpPtr & op)179 SymbolPtr ShapeVector2Symbol(const ShapeVector &shape, const OpPtr &op) {
180   if (IsDynamicRank(shape)) {
181     return ListSymbol::Make(op);
182   }
183   SymbolPtrList result(shape.size());
184   (void)std::transform(shape.begin(), shape.end(), result.begin(), [op](int64_t s) {
185     if (s == abstract::Shape::kShapeDimAny) {
186       return IntSymbol::Make(op);
187     } else {
188       return IntSymbol::Make(s, op);
189     }
190   });
191   return ListSymbol::Make(std::move(result), op);
192 }
193 
IntValues2Symbol(const std::vector<int64_t> & shape,const OpPtr & op)194 SymbolPtr IntValues2Symbol(const std::vector<int64_t> &shape, const OpPtr &op) {
195   SymbolPtrList result(shape.size());
196   (void)std::transform(shape.begin(), shape.end(), result.begin(), [op](int64_t s) { return IntSymbol::Make(s, op); });
197   return ListSymbol::Make(std::move(result), op);
198 }
199 
AsInt(const Symbol * s)200 int64_t AsInt(const Symbol *s) { return s->as<IntSymbol>()->value(); }
201 
NormAxis(const ListSymbol * axis,size_t rank)202 std::set<int64_t> NormAxis(const ListSymbol *axis, size_t rank) {
203   std::set<int64_t> result;
204   for (auto &item : axis->symbols()) {
205     result.insert(NormAxis(AsInt(item), rank));
206   }
207   return result;
208 }
209 
SymbolListToStr(const SymbolPtrList & slist,const std::string & pre,const std::string & post,bool raw_str)210 std::string SymbolListToStr(const SymbolPtrList &slist, const std::string &pre, const std::string &post, bool raw_str) {
211   std::ostringstream oss;
212   oss << pre;
213   bool first = true;
214   for (auto &s : slist) {
215     if (first) {
216       first = false;
217     } else {
218       oss << ", ";
219     }
220     oss << (raw_str ? s->ToRawString() : s->ToString());
221   }
222   oss << post;
223   return oss.str();
224 }
225 
QueryShape(const AbstractBasePtr & abs)226 BaseShapePtr QueryShape(const AbstractBasePtr &abs) {
227   MS_EXCEPTION_IF_NULL(abs);
228   auto symbolic_shape = abs->GetSymbolicShape();
229   if (symbolic_shape == nullptr) {
230     return nullptr;
231   }
232   auto digital_shape = abs->GetShape();
233   MS_EXCEPTION_IF_NULL(digital_shape);
234   if (!symbolic_shape->HasData()) {
235     return digital_shape;
236   }
237   if (digital_shape->isa<abstract::NoShape>()) {
238     return digital_shape;
239   }
240   if (digital_shape->isa<abstract::TensorShape>()) {
241     return std::make_shared<abstract::TensorShape>(ToShape(symbolic_shape));
242   }
243   abstract::BaseShapePtrList shape_arr;
244   shape_arr.reserve(symbolic_shape->size());
245   (void)std::transform(
246     symbolic_shape->symbols().begin(), symbolic_shape->symbols().end(), std::back_inserter(shape_arr),
247     [](const SymbolPtr &s) { return std::make_shared<abstract::TensorShape>(ToShape(s->as<ListSymbol>())); });
248   return std::make_shared<abstract::TupleShape>(std::move(shape_arr));
249 }
250 
QueryValue(const AbstractBasePtr & abs)251 ValuePtr QueryValue(const AbstractBasePtr &abs) {
252   MS_EXCEPTION_IF_NULL(abs);
253   auto symbolic_value = abs->GetSymbolicValue();
254   if (symbolic_value == nullptr) {
255     auto value = abs->GetValue();
256     return value != nullptr ? value : kValueAny;
257   }
258   return symbolic_value->ToValue();
259 }
260 }  // namespace symshape
261 }  // namespace mindspore
262