1 /**
2 * Copyright 2023-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "mindspore/core/symbolic_shape/utils.h"
17 #include <algorithm>
18 #include <utility>
19 #include <memory>
20 #include "ir/kernel_tensor_value.h"
21 #include "mindspore/core/utils/check_convert_utils.h"
22 #include "mindspore/core/ops/op_utils.h"
23 #include "mindspore/core/symbolic_shape/int_symbol.h"
24
25 namespace mindspore {
26 namespace symshape {
27 namespace {
GenValueByTensorShape(const ShapeVector & shape,const TypePtr & type_ptr)28 SymbolPtr GenValueByTensorShape(const ShapeVector &shape, const TypePtr &type_ptr) {
29 if (IsDynamic(shape)) {
30 return ListSymbol::Make();
31 }
32 if (shape.size() > 1) {
33 MS_LOG(WARNING) << "Symbolic value only support 0-D or 1-D value, but got the shape: " << shape;
34 return ListSymbol::Make();
35 }
36 if (shape.size() == 0) {
37 if (type_ptr->generic_type_id() == kNumberTypeBool) {
38 return BoolSymbol::Make();
39 }
40 if (type_ptr->generic_type_id() == kNumberTypeFloat) {
41 return FloatSymbol::Make();
42 }
43 return IntSymbol::Make();
44 }
45 SymbolPtrList list(LongToSize(shape[0]));
46 if (type_ptr->generic_type_id() == kNumberTypeBool) {
47 std::generate(list.begin(), list.end(), []() { return BoolSymbol::Make(); });
48 } else if (type_ptr->generic_type_id() == kNumberTypeFloat) {
49 std::generate(list.begin(), list.end(), []() { return FloatSymbol::Make(); });
50 } else {
51 std::generate(list.begin(), list.end(), []() { return IntSymbol::Make(); });
52 }
53 return ListSymbol::Make(std::move(list));
54 }
55
GenValueByShape(const BaseShapePtr & baseshape,const TypePtr & type_ptr)56 SymbolPtr GenValueByShape(const BaseShapePtr &baseshape, const TypePtr &type_ptr) {
57 if (baseshape->isa<abstract::NoShape>()) {
58 return GenValueByTensorShape({}, type_ptr);
59 }
60 if (baseshape->isa<abstract::TensorShape>()) {
61 auto tensor_type = type_ptr->cast<TensorTypePtr>();
62 MS_EXCEPTION_IF_NULL(tensor_type);
63 return GenValueByTensorShape(baseshape->cast<abstract::TensorShapePtr>()->shape(), tensor_type->element());
64 }
65 if (baseshape->isa<abstract::DynamicSequenceShape>()) {
66 return ListSymbol::Make();
67 }
68 auto seq_shape = baseshape->cast<abstract::SequenceShapePtr>();
69 MS_EXCEPTION_IF_NULL(seq_shape);
70 SymbolPtrList result(seq_shape->size());
71 auto seq_type = type_ptr->cast<TuplePtr>();
72 MS_EXCEPTION_IF_NULL(seq_type);
73 if (seq_shape->size() != seq_type->size()) {
74 MS_LOG(INTERNAL_EXCEPTION) << "The size of seq_shape and seq_type should equal, but got " << seq_shape->size()
75 << " vs " << seq_type->size();
76 }
77 (void)std::transform(
78 seq_shape->shape().begin(), seq_shape->shape().end(), seq_type->elements().begin(), result.begin(),
79 [](const BaseShapePtr &shape_elm, const TypePtr &type_elm) { return GenValueByShape(shape_elm, type_elm); });
80 return ListSymbol::Make(std::move(result));
81 }
82 } // namespace
83
KernelTensorValueToSymbol(const ValuePtr & v,bool to_scalar)84 SymbolPtr KernelTensorValueToSymbol(const ValuePtr &v, bool to_scalar) {
85 auto type_ptr = v->type();
86 if (type_ptr == nullptr) {
87 MS_LOG(WARNING) << "type of KernelTensorPtr is null! trying getting Tuple Int";
88 auto value = CheckAndConvertUtils::CheckTupleInt(v->ToString(), v, "ConstSymbolicValue");
89 return IntValues2Symbol(value);
90 }
91 if (type_ptr->type_id() == kNumberTypeBool) {
92 return BoolSymbol::Make(ops::GetScalarValue<bool>(v).value());
93 }
94 if (type_ptr->type_id() == kObjectTypeString) {
95 return StrSymbol::Make(ops::GetScalarValue<std::string>(v).value());
96 }
97 if (type_ptr->type_id() == kNumberTypeInt64) {
98 return IntSymbol::Make(ops::GetScalarValue<int64_t>(v).value());
99 }
100 if (type_ptr->type_id() == kNumberTypeInt32) {
101 return IntSymbol::Make(static_cast<int64_t>(ops::GetScalarValue<int32_t>(v).value()));
102 }
103 auto value_opt = ops::GetArrayValue<int64_t>(v);
104 if (value_opt.has_value()) {
105 auto vec = value_opt.value().ToVector();
106 if (to_scalar && !vec.empty()) {
107 return IntSymbol::Make(vec[0]);
108 }
109 return IntValues2Symbol(vec);
110 }
111 MS_LOG(INTERNAL_EXCEPTION) << "Unsupported KernelTensorValue to Symbol: " << type_ptr->ToString();
112 }
113
ConstValueToSymbol(const ValuePtr & v,bool to_scalar)114 SymbolPtr ConstValueToSymbol(const ValuePtr &v, bool to_scalar) {
115 if (v->isa<KernelTensorValue>()) {
116 return KernelTensorValueToSymbol(v, to_scalar);
117 }
118 if (v->isa<ValueSequence>()) {
119 auto seq = v->cast_ptr<ValueSequence>();
120 MS_EXCEPTION_IF_NULL(seq);
121 SymbolPtrList result(seq->size());
122 (void)std::transform(seq->value().begin(), seq->value().end(), result.begin(),
123 [to_scalar](const ValuePtr &v) { return ConstValueToSymbol(v, to_scalar); });
124 return ListSymbol::Make(std::move(result));
125 }
126 if (v->isa<tensor::Tensor>()) {
127 auto tensor_value = CheckAndConvertUtils::CheckTensorIntValue(v->ToString(), v, "ConstSymbolicValue");
128 auto tensor = v->cast_ptr<tensor::Tensor>();
129 return tensor->shape().empty() ? IntSymbol::Make(tensor_value[0]) : IntValues2Symbol(tensor_value);
130 }
131 if (v->isa<IntegerImm>()) {
132 return IntSymbol::Make(v->isa<Int64Imm>() ? GetValue<int64_t>(v) : static_cast<int64_t>(GetValue<int32_t>(v)));
133 }
134 if (v->isa<BoolImm>()) {
135 return BoolSymbol::Make(GetValue<bool>(v));
136 }
137 if (v->isa<FloatImm>()) {
138 return FloatSymbol::Make(v->isa<FP64Imm>() ? GetValue<double>(v) : static_cast<double>(GetValue<float>(v)));
139 }
140 if (v->isa<StringImm>()) {
141 return StrSymbol::Make(GetValue<std::string>(v));
142 }
143 MS_LOG(EXCEPTION)
144 << "Value should be one of {ValueSequence, Tensor, IntegerImm, BoolImm, FloatImm, StringImm}, but got "
145 << v->ToString();
146 return nullptr;
147 }
148
BuildSymbolicValue(const AbstractBasePtr & abstract)149 SymbolPtr BuildSymbolicValue(const AbstractBasePtr &abstract) {
150 auto value_ptr = abstract->GetValue();
151 if (value_ptr->isa<ValueAny>()) {
152 return GenValueByShape(abstract->GetShape(), abstract->GetType());
153 }
154 auto shape = abstract->GetShape();
155 if (shape->isa<abstract::TensorShape>() && shape->cast_ptr<abstract::TensorShape>()->shape().empty()) {
156 return ConstValueToSymbol(value_ptr, true);
157 }
158 return ConstValueToSymbol(value_ptr, false);
159 }
160
ToShape(const Symbol * symbol)161 ShapeVector ToShape(const Symbol *symbol) {
162 if (!symbol->HasData()) {
163 return {abstract::Shape::kShapeRankAny};
164 }
165 auto *list = symbol->as<ListSymbol>();
166 MS_EXCEPTION_IF_NULL(list);
167 ShapeVector shape(list->size());
168 (void)std::transform(list->symbols().cbegin(), list->symbols().cend(), shape.begin(), [](const SymbolPtr &s) {
169 auto int_smbl = s->as<IntSymbol>();
170 MS_EXCEPTION_IF_NULL(int_smbl);
171 if (!int_smbl->HasData()) {
172 return abstract::Shape::kShapeDimAny;
173 }
174 return int_smbl->value();
175 });
176 return shape;
177 }
178
ShapeVector2Symbol(const ShapeVector & shape,const OpPtr & op)179 SymbolPtr ShapeVector2Symbol(const ShapeVector &shape, const OpPtr &op) {
180 if (IsDynamicRank(shape)) {
181 return ListSymbol::Make(op);
182 }
183 SymbolPtrList result(shape.size());
184 (void)std::transform(shape.begin(), shape.end(), result.begin(), [op](int64_t s) {
185 if (s == abstract::Shape::kShapeDimAny) {
186 return IntSymbol::Make(op);
187 } else {
188 return IntSymbol::Make(s, op);
189 }
190 });
191 return ListSymbol::Make(std::move(result), op);
192 }
193
IntValues2Symbol(const std::vector<int64_t> & shape,const OpPtr & op)194 SymbolPtr IntValues2Symbol(const std::vector<int64_t> &shape, const OpPtr &op) {
195 SymbolPtrList result(shape.size());
196 (void)std::transform(shape.begin(), shape.end(), result.begin(), [op](int64_t s) { return IntSymbol::Make(s, op); });
197 return ListSymbol::Make(std::move(result), op);
198 }
199
AsInt(const Symbol * s)200 int64_t AsInt(const Symbol *s) { return s->as<IntSymbol>()->value(); }
201
NormAxis(const ListSymbol * axis,size_t rank)202 std::set<int64_t> NormAxis(const ListSymbol *axis, size_t rank) {
203 std::set<int64_t> result;
204 for (auto &item : axis->symbols()) {
205 result.insert(NormAxis(AsInt(item), rank));
206 }
207 return result;
208 }
209
SymbolListToStr(const SymbolPtrList & slist,const std::string & pre,const std::string & post,bool raw_str)210 std::string SymbolListToStr(const SymbolPtrList &slist, const std::string &pre, const std::string &post, bool raw_str) {
211 std::ostringstream oss;
212 oss << pre;
213 bool first = true;
214 for (auto &s : slist) {
215 if (first) {
216 first = false;
217 } else {
218 oss << ", ";
219 }
220 oss << (raw_str ? s->ToRawString() : s->ToString());
221 }
222 oss << post;
223 return oss.str();
224 }
225
QueryShape(const AbstractBasePtr & abs)226 BaseShapePtr QueryShape(const AbstractBasePtr &abs) {
227 MS_EXCEPTION_IF_NULL(abs);
228 auto symbolic_shape = abs->GetSymbolicShape();
229 if (symbolic_shape == nullptr) {
230 return nullptr;
231 }
232 auto digital_shape = abs->GetShape();
233 MS_EXCEPTION_IF_NULL(digital_shape);
234 if (!symbolic_shape->HasData()) {
235 return digital_shape;
236 }
237 if (digital_shape->isa<abstract::NoShape>()) {
238 return digital_shape;
239 }
240 if (digital_shape->isa<abstract::TensorShape>()) {
241 return std::make_shared<abstract::TensorShape>(ToShape(symbolic_shape));
242 }
243 abstract::BaseShapePtrList shape_arr;
244 shape_arr.reserve(symbolic_shape->size());
245 (void)std::transform(
246 symbolic_shape->symbols().begin(), symbolic_shape->symbols().end(), std::back_inserter(shape_arr),
247 [](const SymbolPtr &s) { return std::make_shared<abstract::TensorShape>(ToShape(s->as<ListSymbol>())); });
248 return std::make_shared<abstract::TupleShape>(std::move(shape_arr));
249 }
250
QueryValue(const AbstractBasePtr & abs)251 ValuePtr QueryValue(const AbstractBasePtr &abs) {
252 MS_EXCEPTION_IF_NULL(abs);
253 auto symbolic_value = abs->GetSymbolicValue();
254 if (symbolic_value == nullptr) {
255 auto value = abs->GetValue();
256 return value != nullptr ? value : kValueAny;
257 }
258 return symbolic_value->ToValue();
259 }
260 } // namespace symshape
261 } // namespace mindspore
262