• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_UTILS_H_
17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_UTILS_H_
18 
19 #include <vector>
20 #include <string>
21 #include <set>
22 #include "mindspore/core/symbolic_shape/symbol.h"
23 #include "mindspore/core/abstract/abstract_value.h"
24 
25 namespace mindspore {
26 namespace symshape {
27 /// \brief Build constant symbolic value.
28 MS_CORE_API SymbolPtr ConstValueToSymbol(const ValuePtr &v, bool to_scalar = false);
29 
30 /// \brief Build symbolic value.
31 /// If the abstract's value is ValueAny, the variable value list is generated according to the shape.
32 MS_CORE_API SymbolPtr BuildSymbolicValue(const AbstractBasePtr &abstract);
33 
34 // symbol to ShapeVector
35 MS_CORE_API ShapeVector ToShape(const Symbol *symbol);
ToShape(const SymbolPtr & symbol)36 inline ShapeVector ToShape(const SymbolPtr &symbol) { return ToShape(symbol.get()); }
37 
38 MS_CORE_API SymbolPtr ShapeVector2Symbol(const ShapeVector &shape, const OpPtr &op = nullptr);
39 
40 MS_CORE_API SymbolPtr IntValues2Symbol(const std::vector<int64_t> &shape, const OpPtr &op = nullptr);
41 
42 // get int value from symbol
43 MS_CORE_API int64_t AsInt(const Symbol *s);
AsInt(const SymbolPtr & s)44 inline int64_t AsInt(const SymbolPtr &s) { return AsInt(s.get()); }
45 
46 // get bool value from symbol
AsBool(const Symbol * s)47 inline bool AsBool(const Symbol *s) { return s->as<BoolSymbol>()->value(); }
AsBool(const SymbolPtr & s)48 inline bool AsBool(const SymbolPtr &s) { return AsBool(s.get()); }
49 
NormAxis(int64_t axis,size_t rank)50 inline int64_t NormAxis(int64_t axis, size_t rank) { return axis >= 0 ? axis : axis + static_cast<int64_t>(rank); }
51 MS_CORE_API std::set<int64_t> NormAxis(const ListSymbol *axis, size_t rank);
52 
53 MS_CORE_API std::string SymbolListToStr(const SymbolPtrList &slist, const std::string &pre, const std::string &post,
54                                         bool raw_str = false);
55 
56 MS_CORE_API BaseShapePtr QueryShape(const AbstractBasePtr &abs);
57 MS_CORE_API ValuePtr QueryValue(const AbstractBasePtr &abs);
58 }  // namespace symshape
59 }  // namespace mindspore
60 #endif  // MINDSPORE_CORE_SYMBOLIC_SHAPE_UTILS_H_
61