• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/meta_tensor.h"
18 #include <numeric>
19 #include <functional>
20 
21 namespace mindspore {
22 namespace tensor {
23 // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor()24 MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
25 
MetaTensor(TypeId data_type,const ShapeVector & shape)26 MetaTensor::MetaTensor(TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {}
27 
MetaTensor(const TypePtr & type_ptr,const ShapeVector & shape)28 MetaTensor::MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape)
29     : data_type_(type_ptr != nullptr ? type_ptr->type_id() : TypeId::kTypeUnknown), shape_(shape) {}
30 
MetaTensor(const MetaTensor & meta_tensor)31 MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
32     : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
33 
operator =(const MetaTensor & meta_tensor)34 MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
35   if (&meta_tensor == this) {
36     return *this;
37   }
38 
39   data_type_ = meta_tensor.data_type();
40   shape_ = meta_tensor.shape();
41 
42   return *this;
43 }
44 
operator ==(const MetaTensor & meta_tensor) const45 bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
46   return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
47 }
48 
49 // Get the size of a given dimension by its index number.
50 // The given index number should be in [0, shape_.size()).
51 // param index Dimension index number.
52 // return The size of the dimension if succeed, or -1 if failed.
DimensionSize(size_t index) const53 int64_t MetaTensor::DimensionSize(size_t index) const {
54   int64_t dim_size = -1;
55   if (index < shape_.size()) {
56     dim_size = shape_[index];
57   } else {
58     MS_LOG(ERROR) << "Dimension index is wrong: " << index;
59   }
60   return dim_size;
61 }
62 
ElementsNum() const63 int64_t MetaTensor::ElementsNum() const {
64   return std::accumulate(shape_.begin(), shape_.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
65 }
66 
SetDtype(const TypePtr type_ptr)67 TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
68   if (type_ptr == nullptr) {
69     MS_LOG(ERROR) << "Dtype to be set is nullptr.";
70     return nullptr;
71   }
72   (void)set_data_type(type_ptr->type_id());
73   return type_ptr;
74 }
75 
ToString() const76 std::string MetaTensor::ToString() const {
77   std::ostringstream buf;
78   buf << "MetaTensor(shape=[" << shape() << "]";
79   if (is_parameter_) {
80     buf << ", name=" << param_info_->name();
81   }
82   buf << ")";
83   return buf.str();
84 }
85 
MetaSparseTensor()86 MetaSparseTensor::MetaSparseTensor() : data_type_(TypeId::kTypeUnknown) {}
87 
MetaSparseTensor(TypeId data_type,const ShapeVector & shape)88 MetaSparseTensor::MetaSparseTensor(TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {}
89 
MetaSparseTensor(const MetaSparseTensor & meta_sparse_tensor)90 MetaSparseTensor::MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor)
91     : Value(meta_sparse_tensor), data_type_(meta_sparse_tensor.data_type()), shape_(meta_sparse_tensor.shape()) {}
92 
operator =(const MetaSparseTensor & meta_sparse_tensor)93 MetaSparseTensor &MetaSparseTensor::operator=(const MetaSparseTensor &meta_sparse_tensor) {
94   if (this == &meta_sparse_tensor) {
95     return *this;
96   }
97   Value::operator=(meta_sparse_tensor);
98   data_type_ = meta_sparse_tensor.data_type();
99   shape_ = meta_sparse_tensor.shape();
100   return *this;
101 }
102 
Dtype() const103 TypePtr MetaSparseTensor::Dtype() const { return TypeIdToType(data_type_); }
104 }  // namespace tensor
105 }  // namespace mindspore
106