• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/meta_tensor.h"
18 
19 #include <functional>
20 #include <numeric>
21 #include <vector>
22 #include <sstream>
23 #include <string>
24 
25 namespace mindspore {
26 namespace tensor {
27 // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor()28 MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
29 
MetaTensor(const TypeId data_type,const ShapeVector & shape)30 MetaTensor::MetaTensor(const TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {}
31 
MetaTensor(const TypePtr & type_ptr,const ShapeVector & shape)32 MetaTensor::MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape) {
33   TypeId data_type = TypeId::kTypeUnknown;
34   if (type_ptr != nullptr) {
35     data_type = type_ptr->type_id();
36   }
37   data_type_ = data_type;
38   shape_ = shape;
39 }
40 
MetaTensor(const MetaTensor & meta_tensor)41 MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
42     : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
43 
operator =(const MetaTensor & meta_tensor)44 MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
45   if (&meta_tensor == this) {
46     return *this;
47   }
48 
49   data_type_ = meta_tensor.data_type();
50   shape_ = meta_tensor.shape();
51   device_info_ = meta_tensor.device_info();
52 
53   return *this;
54 }
55 
operator ==(const MetaTensor & meta_tensor) const56 bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
57   return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
58 }
59 
60 // Get the size of a given dimension by its index number.
61 // The given index number should be in [0, shape_.size()).
62 // param index Dimension index number.
63 // return The size of the dimension if succeed, or -1 if failed.
DimensionSize(const size_t index) const64 int64_t MetaTensor::DimensionSize(const size_t index) const {
65   int64_t dim_size = -1;
66   if (index < shape_.size()) {
67     dim_size = shape_[index];
68   } else {
69     MS_LOG(ERROR) << "Dimension index is wrong: " << index;
70   }
71   return dim_size;
72 }
73 
ElementsNum() const74 int MetaTensor::ElementsNum() const { return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>()); }
75 
SetDtype(const TypePtr type_ptr)76 TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
77   if (type_ptr == nullptr) {
78     MS_LOG(ERROR) << "Dtype to be set is nullptr.";
79     return nullptr;
80   }
81   (void)set_data_type(type_ptr->type_id());
82   return type_ptr;
83 }
84 
SetDeviceInfo(const std::string & format,const TypePtr & data_type,const std::string & host_format)85 void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type, const std::string &host_format) {
86   DeviceInfo info(format, data_type, host_format);
87   set_device_info(info);
88 }
89 
ToString() const90 std::string MetaTensor::ToString() const {
91   std::ostringstream buf;
92   buf << "MetaTensor(shape=[" << shape() << "]";
93   if (is_parameter_) {
94     buf << ", name=" << param_info_->name();
95   }
96   buf << ")";
97   return buf.str();
98 }
99 
DumpText() const100 std::string MetaTensor::DumpText() const {
101   std::ostringstream oss;
102   oss << type_name() << "(" << SizeToInt(data_type_) << ")[";
103   for (size_t i = 0; i < shape_.size(); ++i) {
104     oss << (i > 0 ? ", " : "") << shape_[i];
105   }
106   oss << "]";
107   return oss.str();
108 }
109 }  // namespace tensor
110 }  // namespace mindspore
111