1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_CXX_API_TENSOR_TENSOR_IMPL_H_ 18 #define MINDSPORE_LITE_SRC_CXX_API_TENSOR_TENSOR_IMPL_H_ 19 20 #include <cstddef> 21 #include <numeric> 22 #include <memory> 23 #include <algorithm> 24 #include <string> 25 #include <vector> 26 #include <functional> 27 #include "include/api/types.h" 28 #include "include/api/status.h" 29 #include "include/errorcode.h" 30 #include "include/lite_utils.h" 31 #include "include/ms_tensor.h" 32 #include "src/tensor.h" 33 #include "src/common/log_adapter.h" 34 35 namespace mindspore { 36 using mindspore::lite::RET_OK; 37 38 class MSTensor::Impl { 39 public: Impl()40 Impl() {} 41 ~Impl()42 virtual ~Impl() { 43 if (lite_tensor_ == nullptr) { 44 return; 45 } 46 if (!from_session_) { 47 if (!own_data_) { 48 lite_tensor_->set_data(nullptr); 49 } 50 delete lite_tensor_; 51 lite_tensor_ = nullptr; 52 } 53 } 54 Impl(tensor::MSTensor * tensor)55 explicit Impl(tensor::MSTensor *tensor) : lite_tensor_(tensor), from_session_(true) { 56 if (tensor != nullptr) { 57 tensor_name_ = tensor->tensor_name(); 58 } 59 } 60 61 static std::shared_ptr<Impl> MS_API CreateTensorImpl(const std::string &name, enum DataType type, 62 const std::vector<int64_t> &shape, const void *data, 63 size_t data_len); 64 65 #ifndef STRING_KERNEL_CLIP 66 static std::shared_ptr<Impl> MS_API StringsToTensorImpl(const std::string &name, const std::vector<std::string> &str); 67 68 static std::vector<std::string> MS_API TensorImplToStrings(const std::shared_ptr<Impl> &impl); 69 #endif 70 Name()71 virtual const std::string &Name() const { 72 static std::string empty = ""; 73 if (lite_tensor_ == nullptr) { 74 MS_LOG(ERROR) << "Invalid tensor."; 75 return empty; 76 } 77 return tensor_name_; 78 } 79 SetName(const std::string & name)80 void SetName(const std::string &name) { 81 if (lite_tensor_ == nullptr) { 82 MS_LOG(ERROR) << "Invalid tensor."; 83 return; 84 } 85 lite_tensor_->set_tensor_name(name); 86 tensor_name_ = name; 87 } 88 DataType()89 virtual enum DataType DataType() const { 90 if (lite_tensor_ == nullptr) { 91 MS_LOG(ERROR) << "Invalid tensor."; 92 return DataType::kTypeUnknown; 93 } 94 return static_cast<enum DataType>(lite_tensor_->data_type()); 95 } 96 SetDataType(enum DataType data_type)97 void SetDataType(enum DataType data_type) { 98 if (lite_tensor_ == nullptr) { 99 MS_LOG(ERROR) << "Invalid tensor."; 100 return; 101 } 102 lite_tensor_->set_data_type(static_cast<enum TypeId>(data_type)); 103 } 104 ElementNum()105 int64_t ElementNum() const { 106 if (lite_tensor_ == nullptr) { 107 MS_LOG(ERROR) << "Invalid tensor."; 108 return -1; 109 } 110 return static_cast<int64_t>(lite_tensor_->ElementsNum()); 111 } 112 Shape()113 virtual const std::vector<int64_t> &Shape() const { 114 static std::vector<int64_t> empty; 115 if (lite_tensor_ == nullptr) { 116 MS_LOG(ERROR) << "Invalid tensor."; 117 return empty; 118 } 119 auto shape = lite_tensor_->shape(); 120 lite_shape_.resize(shape.size()); 121 std::transform(shape.begin(), shape.end(), lite_shape_.begin(), [](int c) { return static_cast<int64_t>(c); }); 122 return lite_shape_; 123 } 124 Clone()125 virtual std::shared_ptr<Impl> Clone() const { return nullptr; } 126 SetShape(const std::vector<int64_t> & shape)127 void SetShape(const std::vector<int64_t> &shape) { 128 if (lite_tensor_ == nullptr) { 129 MS_LOG(ERROR) << "Invalid tensor."; 130 return; 131 } 132 std::vector<int> tensor_shape; 133 tensor_shape.resize(shape.size()); 134 std::transform(shape.begin(), shape.end(), tensor_shape.begin(), [](int64_t c) { return static_cast<int>(c); }); 135 lite_tensor_->set_shape(tensor_shape); 136 } 137 allocator()138 std::shared_ptr<Allocator> allocator() const { 139 if (lite_tensor_ == nullptr) { 140 MS_LOG(ERROR) << "Invalid tensor."; 141 return nullptr; 142 } 143 return lite_tensor_->allocator(); 144 } 145 SetAllocator(std::shared_ptr<Allocator> allocator)146 void SetAllocator(std::shared_ptr<Allocator> allocator) { 147 if (lite_tensor_ == nullptr) { 148 MS_LOG(ERROR) << "Invalid tensor."; 149 return; 150 } 151 lite_tensor_->set_allocator(allocator); 152 } 153 format()154 mindspore::Format format() { 155 if (lite_tensor_ == nullptr) { 156 MS_LOG(ERROR) << "Invalid tensor."; 157 return mindspore::Format::NHWC; 158 } 159 return lite_tensor_->format(); 160 } 161 SetFormat(mindspore::Format format)162 void SetFormat(mindspore::Format format) { 163 if (lite_tensor_ == nullptr) { 164 MS_LOG(ERROR) << "Invalid tensor."; 165 return; 166 } 167 lite_tensor_->set_format(format); 168 } 169 Data()170 virtual std::shared_ptr<const void> Data() const { 171 if (lite_tensor_ == nullptr) { 172 MS_LOG(ERROR) << "Invalid tensor."; 173 return nullptr; 174 } 175 return std::shared_ptr<const void>(lite_tensor_->data(), [](const void *) {}); 176 } 177 MutableData()178 virtual void *MutableData() { 179 if (lite_tensor_ == nullptr) { 180 MS_LOG(ERROR) << "Invalid tensor."; 181 return nullptr; 182 } 183 return lite_tensor_->MutableData(); 184 } IsConst()185 virtual bool IsConst() const { 186 if (lite_tensor_ == nullptr) { 187 MS_LOG(ERROR) << "Invalid tensor."; 188 return false; 189 } 190 return lite_tensor_->IsConst(); 191 } 192 DataSize()193 virtual size_t DataSize() const { 194 if (lite_tensor_ == nullptr) { 195 MS_LOG(ERROR) << "Invalid tensor."; 196 return 0; 197 } 198 return lite_tensor_->Size(); 199 } 200 SetData(void * data)201 void SetData(void *data) { 202 if (lite_tensor_ == nullptr) { 203 MS_LOG(ERROR) << "Invalid tensor."; 204 return; 205 } 206 lite_tensor_->set_data(data); 207 } 208 QuantParams()209 virtual std::vector<QuantParam> QuantParams() const { 210 if (lite_tensor_ == nullptr) { 211 MS_LOG(ERROR) << "Invalid tensor."; 212 return std::vector<QuantParam>{}; 213 } 214 auto lite_quant_params = lite_tensor_->quant_params(); 215 std::vector<QuantParam> quant_params; 216 for (size_t i = 0; i < lite_quant_params.size(); i++) { 217 QuantParam param{}; 218 param.bit_num = lite_quant_params[i].bitNum; 219 param.scale = lite_quant_params[i].scale; 220 param.zero_point = lite_quant_params[i].zeroPoint; 221 quant_params.push_back(param); 222 } 223 return quant_params; 224 } 225 SetQuantParams(std::vector<QuantParam> quant_params)226 void SetQuantParams(std::vector<QuantParam> quant_params) { 227 if (lite_tensor_ == nullptr) { 228 MS_LOG(ERROR) << "Invalid tensor."; 229 return; 230 } 231 std::vector<lite::LiteQuantParam> lite_quant_params; 232 for (size_t i = 0; i < quant_params.size(); i++) { 233 lite::LiteQuantParam lite_param{}; 234 lite_param.bitNum = quant_params[i].bit_num; 235 lite_param.scale = quant_params[i].scale; 236 lite_param.zeroPoint = quant_params[i].zero_point; 237 lite_quant_params.push_back(lite_param); 238 } 239 lite_tensor_->set_quant_params(lite_quant_params); 240 } 241 IsDevice()242 virtual bool IsDevice() const { return false; } 243 lite_tensor()244 tensor::MSTensor *lite_tensor() const { return lite_tensor_; } 245 set_lite_tensor(tensor::MSTensor * tensor)246 Status set_lite_tensor(tensor::MSTensor *tensor) { 247 if (tensor == nullptr) { 248 MS_LOG(ERROR) << "Tensor to set is null."; 249 return kLiteNullptr; 250 } 251 lite_tensor_ = tensor; 252 return kSuccess; 253 } 254 set_own_data(bool own_data)255 void set_own_data(bool own_data) { own_data_ = own_data; } 256 set_from_session(bool from_session)257 void set_from_session(bool from_session) { from_session_ = from_session; } 258 259 private: 260 tensor::MSTensor *lite_tensor_ = nullptr; 261 std::string tensor_name_ = ""; 262 mutable std::vector<int64_t> lite_shape_; 263 bool own_data_ = false; 264 bool from_session_ = false; 265 }; 266 } // namespace mindspore 267 268 #endif // MINDSPORE_LITE_SRC_CXX_API_TENSOR_TENSOR_IMPL_H_ 269