1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/meta_tensor.h"
18 #include <numeric>
19 #include <functional>
20
21 namespace mindspore {
22 namespace tensor {
23 // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor()24 MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
25
MetaTensor(TypeId data_type,const ShapeVector & shape)26 MetaTensor::MetaTensor(TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {}
27
MetaTensor(const TypePtr & type_ptr,const ShapeVector & shape)28 MetaTensor::MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape)
29 : data_type_(type_ptr != nullptr ? type_ptr->type_id() : TypeId::kTypeUnknown), shape_(shape) {}
30
MetaTensor(const MetaTensor & meta_tensor)31 MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
32 : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
33
operator =(const MetaTensor & meta_tensor)34 MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
35 if (&meta_tensor == this) {
36 return *this;
37 }
38
39 data_type_ = meta_tensor.data_type();
40 shape_ = meta_tensor.shape();
41
42 return *this;
43 }
44
operator ==(const MetaTensor & meta_tensor) const45 bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
46 return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
47 }
48
49 // Get the size of a given dimension by its index number.
50 // The given index number should be in [0, shape_.size()).
51 // param index Dimension index number.
52 // return The size of the dimension if succeed, or -1 if failed.
DimensionSize(size_t index) const53 int64_t MetaTensor::DimensionSize(size_t index) const {
54 int64_t dim_size = -1;
55 if (index < shape_.size()) {
56 dim_size = shape_[index];
57 } else {
58 MS_LOG(ERROR) << "Dimension index is wrong: " << index;
59 }
60 return dim_size;
61 }
62
ElementsNum() const63 int64_t MetaTensor::ElementsNum() const {
64 return std::accumulate(shape_.begin(), shape_.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
65 }
66
SetDtype(const TypePtr type_ptr)67 TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
68 if (type_ptr == nullptr) {
69 MS_LOG(ERROR) << "Dtype to be set is nullptr.";
70 return nullptr;
71 }
72 (void)set_data_type(type_ptr->type_id());
73 return type_ptr;
74 }
75
ToString() const76 std::string MetaTensor::ToString() const {
77 std::ostringstream buf;
78 buf << "MetaTensor(shape=[" << shape() << "]";
79 if (is_parameter_) {
80 buf << ", name=" << param_info_->name();
81 }
82 buf << ")";
83 return buf.str();
84 }
85
MetaSparseTensor()86 MetaSparseTensor::MetaSparseTensor() : data_type_(TypeId::kTypeUnknown) {}
87
MetaSparseTensor(TypeId data_type,const ShapeVector & shape)88 MetaSparseTensor::MetaSparseTensor(TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {}
89
MetaSparseTensor(const MetaSparseTensor & meta_sparse_tensor)90 MetaSparseTensor::MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor)
91 : Value(meta_sparse_tensor), data_type_(meta_sparse_tensor.data_type()), shape_(meta_sparse_tensor.shape()) {}
92
operator =(const MetaSparseTensor & meta_sparse_tensor)93 MetaSparseTensor &MetaSparseTensor::operator=(const MetaSparseTensor &meta_sparse_tensor) {
94 if (this == &meta_sparse_tensor) {
95 return *this;
96 }
97 Value::operator=(meta_sparse_tensor);
98 data_type_ = meta_sparse_tensor.data_type();
99 shape_ = meta_sparse_tensor.shape();
100 return *this;
101 }
102
Dtype() const103 TypePtr MetaSparseTensor::Dtype() const { return TypeIdToType(data_type_); }
104 } // namespace tensor
105 } // namespace mindspore
106