1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/meta_tensor.h"
18
19 #include <functional>
20 #include <numeric>
21 #include <vector>
22 #include <sstream>
23 #include <string>
24
25 namespace mindspore {
26 namespace tensor {
27 // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor()28 MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
29
MetaTensor(const TypeId data_type,const ShapeVector & shape)30 MetaTensor::MetaTensor(const TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {}
31
MetaTensor(const TypePtr & type_ptr,const ShapeVector & shape)32 MetaTensor::MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape) {
33 TypeId data_type = TypeId::kTypeUnknown;
34 if (type_ptr != nullptr) {
35 data_type = type_ptr->type_id();
36 }
37 data_type_ = data_type;
38 shape_ = shape;
39 }
40
MetaTensor(const MetaTensor & meta_tensor)41 MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
42 : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
43
operator =(const MetaTensor & meta_tensor)44 MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
45 if (&meta_tensor == this) {
46 return *this;
47 }
48
49 data_type_ = meta_tensor.data_type();
50 shape_ = meta_tensor.shape();
51 device_info_ = meta_tensor.device_info();
52
53 return *this;
54 }
55
operator ==(const MetaTensor & meta_tensor) const56 bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
57 return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
58 }
59
60 // Get the size of a given dimension by its index number.
61 // The given index number should be in [0, shape_.size()).
62 // param index Dimension index number.
63 // return The size of the dimension if succeed, or -1 if failed.
DimensionSize(const size_t index) const64 int64_t MetaTensor::DimensionSize(const size_t index) const {
65 int64_t dim_size = -1;
66 if (index < shape_.size()) {
67 dim_size = shape_[index];
68 } else {
69 MS_LOG(ERROR) << "Dimension index is wrong: " << index;
70 }
71 return dim_size;
72 }
73
ElementsNum() const74 int MetaTensor::ElementsNum() const { return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>()); }
75
SetDtype(const TypePtr type_ptr)76 TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
77 if (type_ptr == nullptr) {
78 MS_LOG(ERROR) << "Dtype to be set is nullptr.";
79 return nullptr;
80 }
81 (void)set_data_type(type_ptr->type_id());
82 return type_ptr;
83 }
84
SetDeviceInfo(const std::string & format,const TypePtr & data_type,const std::string & host_format)85 void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type, const std::string &host_format) {
86 DeviceInfo info(format, data_type, host_format);
87 set_device_info(info);
88 }
89
ToString() const90 std::string MetaTensor::ToString() const {
91 std::ostringstream buf;
92 buf << "MetaTensor(shape=[" << shape() << "]";
93 if (is_parameter_) {
94 buf << ", name=" << param_info_->name();
95 }
96 buf << ")";
97 return buf.str();
98 }
99
DumpText() const100 std::string MetaTensor::DumpText() const {
101 std::ostringstream oss;
102 oss << type_name() << "(" << SizeToInt(data_type_) << ")[";
103 for (size_t i = 0; i < shape_.size(); ++i) {
104 oss << (i > 0 ? ", " : "") << shape_[i];
105 }
106 oss << "]";
107 return oss.str();
108 }
109 } // namespace tensor
110 } // namespace mindspore
111