1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/dtype/tensor_type.h"
18 #include "utils/log_adapter.h"
19 #include "utils/ms_utils.h"
20
21 namespace mindspore {
DeepCopy() const22 TypePtr UndeterminedType::DeepCopy() const {
23 if (IsGeneric() || element_type_ == nullptr) {
24 return std::make_shared<UndeterminedType>();
25 }
26 return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
27 }
28
ToReprString() const29 std::string UndeterminedType::ToReprString() const {
30 if (element_type_ == nullptr) {
31 return "Undetermined";
32 }
33 return "Undetermined[" + element_type_->ToReprString() + "]";
34 }
35
ToString() const36 std::string UndeterminedType::ToString() const {
37 if (element_type_ == nullptr) {
38 return "Undetermined";
39 }
40 return "Undetermined[" + element_type_->ToString() + "]";
41 }
42
DumpText() const43 std::string UndeterminedType::DumpText() const {
44 if (element_type_ == nullptr) {
45 return "Undetermined";
46 }
47 return "Undetermined[" + element_type_->DumpText() + "]";
48 }
49
operator ==(const Type & other) const50 bool UndeterminedType::operator==(const Type &other) const {
51 if (!IsSameObjectType(*this, other)) {
52 return false;
53 }
54 const auto &other_type = static_cast<const UndeterminedType &>(other);
55 return common::IsEqual(element_type_, other_type.element_type_);
56 }
57
hash() const58 size_t UndeterminedType::hash() const {
59 size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
60 if (element_type_ != nullptr) {
61 hash_value = hash_combine(hash_value, element_type_->hash());
62 }
63 return hash_value;
64 }
65
DeepCopy() const66 TypePtr TensorType::DeepCopy() const {
67 if (element_type_ == nullptr) {
68 return std::make_shared<TensorType>();
69 }
70 if (IsGeneric()) {
71 return std::make_shared<TensorType>();
72 }
73 return std::make_shared<TensorType>(element_type_->DeepCopy());
74 }
75
ToReprString() const76 std::string TensorType::ToReprString() const {
77 if (element_type_ == nullptr) {
78 return "tensor";
79 }
80 return "tensor[" + element_type_->ToReprString() + "]";
81 }
82
ToString() const83 std::string TensorType::ToString() const {
84 if (element_type_ == nullptr) {
85 return "Tensor";
86 }
87 return "Tensor[" + element_type_->ToString() + "]";
88 }
89
DumpText() const90 std::string TensorType::DumpText() const {
91 if (element_type_ == nullptr) {
92 return "Tensor";
93 }
94 return "Tensor(" + element_type_->DumpText() + ")";
95 }
96
operator ==(const Type & other) const97 bool TensorType::operator==(const Type &other) const {
98 if (this == &other) {
99 return true;
100 }
101 if (!IsSameObjectType(*this, other)) {
102 return false;
103 }
104 return *this == static_cast<const TensorType &>(other);
105 }
106
operator ==(const TensorType & other) const107 bool TensorType::operator==(const TensorType &other) const {
108 if (other.isa<AnyType>()) {
109 return false;
110 }
111 return common::IsEqual(element_type_, other.element_type_);
112 }
113
hash() const114 size_t TensorType::hash() const {
115 size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
116 if (element_type_ != nullptr) {
117 hash_value = hash_combine(hash_value, element_type_->hash());
118 }
119 return hash_value;
120 }
121
ToString() const122 std::string AnyType::ToString() const {
123 if (element() == nullptr) {
124 return "Any(Tensor)";
125 }
126 return "Any(Tensor)[" + element()->ToString() + "]";
127 }
128
DumpText() const129 std::string AnyType::DumpText() const {
130 if (element() == nullptr) {
131 return "Any(Tensor)";
132 }
133 return "Any(Tensor)(" + element()->DumpText() + ")";
134 }
135
operator ==(const Type & other) const136 bool AnyType::operator==(const Type &other) const {
137 if (this == &other) {
138 return true;
139 }
140 if (!other.isa<AnyType>()) {
141 return false;
142 }
143 return *this == static_cast<const AnyType &>(other);
144 }
145
operator ==(const AnyType & other) const146 bool AnyType::operator==(const AnyType &other) const {
147 if (this == &other) {
148 return true;
149 }
150 return common::IsEqual(element(), other.element());
151 }
152
ToString() const153 std::string NegligibleType::ToString() const {
154 if (element() == nullptr) {
155 return "Negligible(Tensor)";
156 }
157 return "Negligible(Tensor)[" + element()->ToString() + "]";
158 }
159
DumpText() const160 std::string NegligibleType::DumpText() const {
161 if (element() == nullptr) {
162 return "Negligible(Tensor)";
163 }
164 return "Negligible(Tensor)(" + element()->DumpText() + ")";
165 }
166
ElementsDtypeStr(const StringType str_type) const167 std::string SparseTensorType::ElementsDtypeStr(const StringType str_type) const {
168 std::ostringstream oss;
169 for (const TypePtr &elem : elements_) {
170 if (str_type == kToString) {
171 oss << elem->ToString();
172 } else if (str_type == kDumpText) {
173 oss << elem->DumpText();
174 } else if (str_type == kReprString) {
175 oss << elem->ToReprString();
176 }
177 oss << ",";
178 }
179 return oss.str();
180 }
181
ToString() const182 std::string SparseTensorType::ToString() const {
183 if (elements_.empty()) {
184 return GetSparseTensorTypeName();
185 }
186 return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kToString) + "]";
187 }
188
DumpText() const189 std::string SparseTensorType::DumpText() const {
190 if (elements_.empty()) {
191 return GetSparseTensorTypeName();
192 }
193 return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kDumpText) + "]";
194 }
195
ToReprString() const196 std::string SparseTensorType::ToReprString() const {
197 if (elements_.empty()) {
198 return GetSparseTensorTypeName();
199 }
200 return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kReprString) + "]";
201 }
202
ElementsClone() const203 const TypePtrList SparseTensorType::ElementsClone() const {
204 TypePtrList elems;
205 (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elems), [](const TypePtr &ele) {
206 MS_EXCEPTION_IF_NULL(ele);
207 return ele->DeepCopy();
208 });
209 return elems;
210 }
211
DeepCopy() const212 TypePtr SparseTensorType::DeepCopy() const {
213 if (IsGeneric()) {
214 return std::make_shared<SparseTensorType>();
215 }
216 return std::make_shared<SparseTensorType>(ElementsClone());
217 }
218
operator [](std::size_t dim) const219 const TypePtr SparseTensorType::operator[](std::size_t dim) const {
220 if (dim >= size()) {
221 MS_LOG(EXCEPTION) << "Index " << dim << " is out range of the SparseTensorType size " << size() << ".";
222 }
223 return elements_[dim];
224 }
225
operator ==(const Type & other) const226 bool SparseTensorType::operator==(const Type &other) const {
227 if (!IsSameObjectType(*this, other)) {
228 return false;
229 }
230 const auto &other_type = static_cast<const SparseTensorType &>(other);
231 return TypeListEqual()(elements_, other_type.elements_);
232 }
233
hash() const234 size_t SparseTensorType::hash() const {
235 size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
236 return hash_combine(hash_value, TypeListHasher()(elements_));
237 }
238
DeepCopy() const239 TypePtr RowTensorType::DeepCopy() const {
240 MS_EXCEPTION_IF_NULL(element_type_);
241 if (IsGeneric()) {
242 return std::make_shared<RowTensorType>();
243 }
244 return std::make_shared<RowTensorType>(element_type_->DeepCopy());
245 }
246
ToReprString() const247 std::string RowTensorType::ToReprString() const {
248 if (element_type_ == nullptr) {
249 return "RowTensor";
250 }
251 return "RowTensor[" + element_type_->ToReprString() + "]";
252 }
253
ToString() const254 std::string RowTensorType::ToString() const {
255 if (element_type_ == nullptr) {
256 return "RowTensor";
257 }
258 return "RowTensor[" + element_type_->ToString() + "]";
259 }
260
DumpText() const261 std::string RowTensorType::DumpText() const {
262 if (element_type_ == nullptr) {
263 return "RowTensor";
264 }
265 return "RowTensor[" + element_type_->DumpText() + "]";
266 }
267
operator ==(const Type & other) const268 bool RowTensorType::operator==(const Type &other) const {
269 if (!IsSameObjectType(*this, other)) {
270 return false;
271 }
272 const auto &other_type = static_cast<const RowTensorType &>(other);
273 return common::IsEqual(element_type_, other_type.element_type_);
274 }
275
hash() const276 size_t RowTensorType::hash() const {
277 size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
278 if (element_type_ != nullptr) {
279 hash_value = hash_combine(hash_value, element_type_->hash());
280 }
281 return hash_value;
282 }
283
DeepCopy() const284 TypePtr COOTensorType::DeepCopy() const {
285 if (IsGeneric()) {
286 return std::make_shared<COOTensorType>();
287 }
288 return std::make_shared<COOTensorType>(ElementsClone());
289 }
290
DeepCopy() const291 TypePtr CSRTensorType::DeepCopy() const {
292 if (IsGeneric()) {
293 return std::make_shared<CSRTensorType>();
294 }
295 return std::make_shared<CSRTensorType>(ElementsClone());
296 }
297
DeepCopy() const298 TypePtr MapTensorType::DeepCopy() const {
299 if (IsGeneric()) {
300 return std::make_shared<MapTensorType>();
301 }
302 MS_EXCEPTION_IF_NULL(key_dtype_);
303 MS_EXCEPTION_IF_NULL(value_dtype_);
304 return std::make_shared<MapTensorType>(key_dtype_->DeepCopy(), value_dtype_->DeepCopy());
305 }
306
ToString() const307 std::string MapTensorType::ToString() const {
308 if (IsGeneric()) {
309 return "MapTensor";
310 }
311 MS_EXCEPTION_IF_NULL(key_dtype_);
312 MS_EXCEPTION_IF_NULL(value_dtype_);
313 return "MapTensor[" + key_dtype_->ToString() + ", " + value_dtype_->ToString() + "]";
314 }
315
ToReprString() const316 std::string MapTensorType::ToReprString() const {
317 if (IsGeneric()) {
318 return "MapTensor";
319 }
320 MS_EXCEPTION_IF_NULL(key_dtype_);
321 MS_EXCEPTION_IF_NULL(value_dtype_);
322 return "MapTensor[" + key_dtype_->ToReprString() + ", " + value_dtype_->ToReprString() + "]";
323 }
324
DumpText() const325 std::string MapTensorType::DumpText() const {
326 if (IsGeneric()) {
327 return "MapTensor";
328 }
329 MS_EXCEPTION_IF_NULL(key_dtype_);
330 MS_EXCEPTION_IF_NULL(value_dtype_);
331 return "MapTensor[" + key_dtype_->DumpText() + ", " + value_dtype_->DumpText() + "]";
332 }
333
operator ==(const Type & other) const334 bool MapTensorType::operator==(const Type &other) const {
335 if (!IsSameObjectType(*this, other)) {
336 return false;
337 }
338 const auto &other_type = static_cast<const MapTensorType &>(other);
339 return common::IsEqual(key_dtype_, other_type.key_dtype_) && common::IsEqual(value_dtype_, other_type.value_dtype_);
340 }
341
hash() const342 size_t MapTensorType::hash() const {
343 size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
344 if (!IsGeneric()) {
345 hash_value = hash_combine(hash_value, (key_dtype_ == nullptr ? 0 : key_dtype_->hash()));
346 hash_value = hash_combine(hash_value, (value_dtype_ == nullptr) ? 0 : value_dtype_->hash());
347 }
348 return hash_value;
349 }
350 } // namespace mindspore
351