• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_SHAPE_UTILS_INFO_H_
18 #define MINDSPORE_SHAPE_UTILS_INFO_H_
19 
20 #include <algorithm>
21 #include <string>
22 #include <vector>
23 #include "abstract/dshape.h"
24 #include "utils/log_adapter.h"
25 
26 namespace mindspore {
ShapeVectorToString(const ShapeVector & shape)27 inline std::string ShapeVectorToString(const ShapeVector &shape) {
28   std::string str_shape = "";
29   for (auto &item : shape) {
30     str_shape += std::to_string(item) + ", ";
31   }
32   str_shape = str_shape.length() >= 2 ? str_shape.substr(0, str_shape.length() - 2) : str_shape;
33   return str_shape;
34 }
35 
SizeOf(const ShapeVector & shape)36 inline size_t SizeOf(const ShapeVector &shape) {
37   size_t data_size = 1;
38   for (auto dim : shape) {
39     if (dim <= 0) {
40       // For dynamic shape which has negative dimensions, data size should be zero.
41       return 0;
42     }
43     if (SIZE_MAX / dim < data_size) {
44       MS_EXCEPTION(ValueError) << "The product value of shape (" << ShapeVectorToString(shape)
45                                << ") exceeds the maximum value of size_t: " << SIZE_MAX;
46     }
47     data_size *= static_cast<size_t>(dim);
48   }
49   return data_size;
50 }
51 
IsOneElementShape(const ShapeVector & shape)52 inline bool IsOneElementShape(const ShapeVector &shape) {
53   if (shape.empty()) {
54     return true;
55   } else if (shape.size() == 1 && shape[0] == 1) {
56     return true;
57   } else {
58     return false;
59   }
60 }
61 
IsMactchedShapeInferValue(const ShapeVector & shape1,const ShapeVector & shape2)62 inline bool IsMactchedShapeInferValue(const ShapeVector &shape1, const ShapeVector &shape2) {
63   if (IsOneElementShape(shape1) && IsOneElementShape(shape2)) {
64     return true;
65   }
66   if (shape1 == shape2) {
67     return true;
68   }
69   return false;
70 }
71 
IsDynamicRank(const ShapeVector & shape)72 inline bool IsDynamicRank(const ShapeVector &shape) {
73   for (size_t i = 0; i < shape.size(); ++i) {
74     if (shape[i] > abstract::Shape::kShapeRankAny) {
75       continue;
76     }
77 
78     if (shape.size() == abstract::Shape::kDynamicRankLen) {
79       return true;
80     } else if (i == 1) {
81       MS_LOG(DEBUG) << "Shape(" << ShapeVectorToString(shape) << ") is a valid shape for real tuple tensor.";
82       return true;
83     } else {
84       MS_EXCEPTION(ValueError) << "Shape should have only one -2 for normal tensor,or [not -2, -2] for real tuple "
85                                   "tensor, or no -2 at all, but got ("
86                                << ShapeVectorToString(shape) << ").";
87     }
88   }
89 
90   return false;
91 }
92 
IsDynamicShape(const ShapeVector & shape)93 inline bool IsDynamicShape(const ShapeVector &shape) {
94   return std::any_of(shape.cbegin(), shape.cend(),
95                      [](ShapeValueDType s) { return s == abstract::Shape::kShapeDimAny; });
96 }
97 
IsDynamic(const ShapeVector & shape)98 inline bool IsDynamic(const ShapeVector &shape) {
99   for (auto &s : shape) {
100     if (s > abstract::Shape::kShapeDimAny) {
101       continue;
102     }
103 
104     if (s < abstract::Shape::kShapeRankAny) {
105       MS_EXCEPTION(ValueError) << "Shape should not have values less than -2 but got (" << ShapeVectorToString(shape)
106                                << ").";
107     }
108 
109     return true;
110   }
111 
112   return false;
113 }
114 
IsShapeEmpty(const ShapeVector & shape)115 inline bool IsShapeEmpty(const ShapeVector &shape) {
116   constexpr size_t kOne = 1;
117   constexpr size_t kZero = 0;
118   return shape.size() == kOne && shape[0] == kZero;
119 }
120 
IsShapeNone(const ShapeVector & shape)121 inline bool IsShapeNone(const ShapeVector &shape) {
122   return std::any_of(shape.begin(), shape.end(), [](const auto &dim) { return dim == 0; });
123 }
124 
125 // use for the op with the constraint that output shape must be same as input shape
InferOutShapeSameAsInShape(const ShapeArray & input_shapes)126 inline ShapeVector InferOutShapeSameAsInShape(const ShapeArray &input_shapes) {
127   ShapeVector out_shape{};
128   for (size_t i = 0; i < input_shapes.size(); i++) {
129     auto in_shape = input_shapes[i];
130     // scalar case
131     if (in_shape.empty()) {
132       return out_shape;
133     }
134     // skip to next input shape if current shape is dynamic rank
135     if (IsDynamicRank(in_shape)) {
136       continue;
137     }
138     // initialize output shape
139     auto rank = in_shape.size();
140     if (out_shape.empty()) {
141       out_shape.resize(rank, abstract::Shape::kShapeDimAny);
142     }
143     if (out_shape.size() != rank) {
144       MS_EXCEPTION(ValueError) << "Ranks of inputs must be all same if they are not dynamic.";
145     }
146     for (size_t j = 0; j < rank; j++) {
147       if (out_shape[j] != abstract::Shape::kShapeDimAny && in_shape[j] != abstract::Shape::kShapeDimAny &&
148           out_shape[j] != in_shape[j]) {
149         MS_EXCEPTION(ValueError) << "Corresponding axis of input shapes must be same if they are not dynamic.";
150       }
151       if (out_shape[j] == abstract::Shape::kShapeDimAny && in_shape[j] != abstract::Shape::kShapeDimAny) {
152         out_shape[j] = in_shape[j];
153       }
154     }
155   }
156   // if all input shapes are dynamic rank, return dynamic rank output
157   if (out_shape.empty()) {
158     return {abstract::Shape::kShapeRankAny};
159   }
160   return out_shape;
161 }
162 
163 template <typename T>
VectorToString(const std::vector<T> & values)164 std::string VectorToString(const std::vector<T> &values) {
165   std::stringstream ss;
166   ss << "[";
167   auto size = values.size();
168   for (size_t i = 0; i < size; ++i) {
169     ss << values[i];
170     if (i != size - 1) {
171       ss << ", ";
172     }
173   }
174   ss << "]";
175   return ss.str();
176 }
177 }  // namespace mindspore
178 
179 #endif  // MINDSPORE_SHAPE_UTILS_INFO_H_
180