• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/dtype/tensor_type.h"
18 #include "utils/log_adapter.h"
19 #include "utils/ms_utils.h"
20 
21 namespace mindspore {
DeepCopy() const22 TypePtr UndeterminedType::DeepCopy() const {
23   if (IsGeneric() || element_type_ == nullptr) {
24     return std::make_shared<UndeterminedType>();
25   }
26   return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
27 }
28 
ToReprString() const29 std::string UndeterminedType::ToReprString() const {
30   if (element_type_ == nullptr) {
31     return "Undetermined";
32   }
33   return "Undetermined[" + element_type_->ToReprString() + "]";
34 }
35 
ToString() const36 std::string UndeterminedType::ToString() const {
37   if (element_type_ == nullptr) {
38     return "Undetermined";
39   }
40   return "Undetermined[" + element_type_->ToString() + "]";
41 }
42 
DumpText() const43 std::string UndeterminedType::DumpText() const {
44   if (element_type_ == nullptr) {
45     return "Undetermined";
46   }
47   return "Undetermined[" + element_type_->DumpText() + "]";
48 }
49 
operator ==(const Type & other) const50 bool UndeterminedType::operator==(const Type &other) const {
51   if (!IsSameObjectType(*this, other)) {
52     return false;
53   }
54   const auto &other_type = static_cast<const UndeterminedType &>(other);
55   return common::IsEqual(element_type_, other_type.element_type_);
56 }
57 
hash() const58 size_t UndeterminedType::hash() const {
59   size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
60   if (element_type_ != nullptr) {
61     hash_value = hash_combine(hash_value, element_type_->hash());
62   }
63   return hash_value;
64 }
65 
DeepCopy() const66 TypePtr TensorType::DeepCopy() const {
67   if (element_type_ == nullptr) {
68     return std::make_shared<TensorType>();
69   }
70   if (IsGeneric()) {
71     return std::make_shared<TensorType>();
72   }
73   return std::make_shared<TensorType>(element_type_->DeepCopy());
74 }
75 
ToReprString() const76 std::string TensorType::ToReprString() const {
77   if (element_type_ == nullptr) {
78     return "tensor";
79   }
80   return "tensor[" + element_type_->ToReprString() + "]";
81 }
82 
ToString() const83 std::string TensorType::ToString() const {
84   if (element_type_ == nullptr) {
85     return "Tensor";
86   }
87   return "Tensor[" + element_type_->ToString() + "]";
88 }
89 
DumpText() const90 std::string TensorType::DumpText() const {
91   if (element_type_ == nullptr) {
92     return "Tensor";
93   }
94   return "Tensor(" + element_type_->DumpText() + ")";
95 }
96 
operator ==(const Type & other) const97 bool TensorType::operator==(const Type &other) const {
98   if (this == &other) {
99     return true;
100   }
101   if (!IsSameObjectType(*this, other)) {
102     return false;
103   }
104   return *this == static_cast<const TensorType &>(other);
105 }
106 
operator ==(const TensorType & other) const107 bool TensorType::operator==(const TensorType &other) const {
108   if (other.isa<AnyType>()) {
109     return false;
110   }
111   return common::IsEqual(element_type_, other.element_type_);
112 }
113 
hash() const114 size_t TensorType::hash() const {
115   size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
116   if (element_type_ != nullptr) {
117     hash_value = hash_combine(hash_value, element_type_->hash());
118   }
119   return hash_value;
120 }
121 
ToString() const122 std::string AnyType::ToString() const {
123   if (element() == nullptr) {
124     return "Any(Tensor)";
125   }
126   return "Any(Tensor)[" + element()->ToString() + "]";
127 }
128 
DumpText() const129 std::string AnyType::DumpText() const {
130   if (element() == nullptr) {
131     return "Any(Tensor)";
132   }
133   return "Any(Tensor)(" + element()->DumpText() + ")";
134 }
135 
operator ==(const Type & other) const136 bool AnyType::operator==(const Type &other) const {
137   if (this == &other) {
138     return true;
139   }
140   if (!other.isa<AnyType>()) {
141     return false;
142   }
143   return *this == static_cast<const AnyType &>(other);
144 }
145 
operator ==(const AnyType & other) const146 bool AnyType::operator==(const AnyType &other) const {
147   if (this == &other) {
148     return true;
149   }
150   return common::IsEqual(element(), other.element());
151 }
152 
ToString() const153 std::string NegligibleType::ToString() const {
154   if (element() == nullptr) {
155     return "Negligible(Tensor)";
156   }
157   return "Negligible(Tensor)[" + element()->ToString() + "]";
158 }
159 
DumpText() const160 std::string NegligibleType::DumpText() const {
161   if (element() == nullptr) {
162     return "Negligible(Tensor)";
163   }
164   return "Negligible(Tensor)(" + element()->DumpText() + ")";
165 }
166 
ElementsDtypeStr(const StringType str_type) const167 std::string SparseTensorType::ElementsDtypeStr(const StringType str_type) const {
168   std::ostringstream oss;
169   for (const TypePtr &elem : elements_) {
170     if (str_type == kToString) {
171       oss << elem->ToString();
172     } else if (str_type == kDumpText) {
173       oss << elem->DumpText();
174     } else if (str_type == kReprString) {
175       oss << elem->ToReprString();
176     }
177     oss << ",";
178   }
179   return oss.str();
180 }
181 
ToString() const182 std::string SparseTensorType::ToString() const {
183   if (elements_.empty()) {
184     return GetSparseTensorTypeName();
185   }
186   return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kToString) + "]";
187 }
188 
DumpText() const189 std::string SparseTensorType::DumpText() const {
190   if (elements_.empty()) {
191     return GetSparseTensorTypeName();
192   }
193   return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kDumpText) + "]";
194 }
195 
ToReprString() const196 std::string SparseTensorType::ToReprString() const {
197   if (elements_.empty()) {
198     return GetSparseTensorTypeName();
199   }
200   return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kReprString) + "]";
201 }
202 
ElementsClone() const203 const TypePtrList SparseTensorType::ElementsClone() const {
204   TypePtrList elems;
205   (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elems), [](const TypePtr &ele) {
206     MS_EXCEPTION_IF_NULL(ele);
207     return ele->DeepCopy();
208   });
209   return elems;
210 }
211 
DeepCopy() const212 TypePtr SparseTensorType::DeepCopy() const {
213   if (IsGeneric()) {
214     return std::make_shared<SparseTensorType>();
215   }
216   return std::make_shared<SparseTensorType>(ElementsClone());
217 }
218 
operator [](std::size_t dim) const219 const TypePtr SparseTensorType::operator[](std::size_t dim) const {
220   if (dim >= size()) {
221     MS_LOG(EXCEPTION) << "Index " << dim << " is out range of the SparseTensorType size " << size() << ".";
222   }
223   return elements_[dim];
224 }
225 
operator ==(const Type & other) const226 bool SparseTensorType::operator==(const Type &other) const {
227   if (!IsSameObjectType(*this, other)) {
228     return false;
229   }
230   const auto &other_type = static_cast<const SparseTensorType &>(other);
231   return TypeListEqual()(elements_, other_type.elements_);
232 }
233 
hash() const234 size_t SparseTensorType::hash() const {
235   size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
236   return hash_combine(hash_value, TypeListHasher()(elements_));
237 }
238 
DeepCopy() const239 TypePtr RowTensorType::DeepCopy() const {
240   MS_EXCEPTION_IF_NULL(element_type_);
241   if (IsGeneric()) {
242     return std::make_shared<RowTensorType>();
243   }
244   return std::make_shared<RowTensorType>(element_type_->DeepCopy());
245 }
246 
ToReprString() const247 std::string RowTensorType::ToReprString() const {
248   if (element_type_ == nullptr) {
249     return "RowTensor";
250   }
251   return "RowTensor[" + element_type_->ToReprString() + "]";
252 }
253 
ToString() const254 std::string RowTensorType::ToString() const {
255   if (element_type_ == nullptr) {
256     return "RowTensor";
257   }
258   return "RowTensor[" + element_type_->ToString() + "]";
259 }
260 
DumpText() const261 std::string RowTensorType::DumpText() const {
262   if (element_type_ == nullptr) {
263     return "RowTensor";
264   }
265   return "RowTensor[" + element_type_->DumpText() + "]";
266 }
267 
operator ==(const Type & other) const268 bool RowTensorType::operator==(const Type &other) const {
269   if (!IsSameObjectType(*this, other)) {
270     return false;
271   }
272   const auto &other_type = static_cast<const RowTensorType &>(other);
273   return common::IsEqual(element_type_, other_type.element_type_);
274 }
275 
hash() const276 size_t RowTensorType::hash() const {
277   size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
278   if (element_type_ != nullptr) {
279     hash_value = hash_combine(hash_value, element_type_->hash());
280   }
281   return hash_value;
282 }
283 
DeepCopy() const284 TypePtr COOTensorType::DeepCopy() const {
285   if (IsGeneric()) {
286     return std::make_shared<COOTensorType>();
287   }
288   return std::make_shared<COOTensorType>(ElementsClone());
289 }
290 
DeepCopy() const291 TypePtr CSRTensorType::DeepCopy() const {
292   if (IsGeneric()) {
293     return std::make_shared<CSRTensorType>();
294   }
295   return std::make_shared<CSRTensorType>(ElementsClone());
296 }
297 
DeepCopy() const298 TypePtr MapTensorType::DeepCopy() const {
299   if (IsGeneric()) {
300     return std::make_shared<MapTensorType>();
301   }
302   MS_EXCEPTION_IF_NULL(key_dtype_);
303   MS_EXCEPTION_IF_NULL(value_dtype_);
304   return std::make_shared<MapTensorType>(key_dtype_->DeepCopy(), value_dtype_->DeepCopy());
305 }
306 
ToString() const307 std::string MapTensorType::ToString() const {
308   if (IsGeneric()) {
309     return "MapTensor";
310   }
311   MS_EXCEPTION_IF_NULL(key_dtype_);
312   MS_EXCEPTION_IF_NULL(value_dtype_);
313   return "MapTensor[" + key_dtype_->ToString() + ", " + value_dtype_->ToString() + "]";
314 }
315 
ToReprString() const316 std::string MapTensorType::ToReprString() const {
317   if (IsGeneric()) {
318     return "MapTensor";
319   }
320   MS_EXCEPTION_IF_NULL(key_dtype_);
321   MS_EXCEPTION_IF_NULL(value_dtype_);
322   return "MapTensor[" + key_dtype_->ToReprString() + ", " + value_dtype_->ToReprString() + "]";
323 }
324 
DumpText() const325 std::string MapTensorType::DumpText() const {
326   if (IsGeneric()) {
327     return "MapTensor";
328   }
329   MS_EXCEPTION_IF_NULL(key_dtype_);
330   MS_EXCEPTION_IF_NULL(value_dtype_);
331   return "MapTensor[" + key_dtype_->DumpText() + ", " + value_dtype_->DumpText() + "]";
332 }
333 
operator ==(const Type & other) const334 bool MapTensorType::operator==(const Type &other) const {
335   if (!IsSameObjectType(*this, other)) {
336     return false;
337   }
338   const auto &other_type = static_cast<const MapTensorType &>(other);
339   return common::IsEqual(key_dtype_, other_type.key_dtype_) && common::IsEqual(value_dtype_, other_type.value_dtype_);
340 }
341 
hash() const342 size_t MapTensorType::hash() const {
343   size_t hash_value = hash_combine(static_cast<size_t>(kMetaTypeObject), static_cast<size_t>(object_type()));
344   if (!IsGeneric()) {
345     hash_value = hash_combine(hash_value, (key_dtype_ == nullptr ? 0 : key_dtype_->hash()));
346     hash_value = hash_combine(hash_value, (value_dtype_ == nullptr) ? 0 : value_dtype_->hash());
347   }
348   return hash_value;
349 }
350 }  // namespace mindspore
351