1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_TENSOR_TENSOR_IMPL_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_TENSOR_TENSOR_IMPL_H_ 19 20 #include <cstddef> 21 #include <numeric> 22 #include <memory> 23 #include <algorithm> 24 #include <string> 25 #include <vector> 26 #include <functional> 27 #include "include/api/types.h" 28 #include "include/api/status.h" 29 #include "include/errorcode.h" 30 #include "src/tensor.h" 31 #include "src/common/log_adapter.h" 32 #include "ir/api_tensor_impl.h" 33 #include "common/mutable_tensor_impl.h" 34 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE) 35 #include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" 36 #endif 37 38 namespace mindspore { 39 using mindspore::lite::RET_OK; 40 41 class LiteTensorImpl : public MutableTensorImpl { 42 public: LiteTensorImpl()43 LiteTensorImpl() {} 44 ~LiteTensorImpl()45 ~LiteTensorImpl() override { 46 if (lite_tensor_ == nullptr) { 47 return; 48 } 49 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE) 50 if (GetDeviceData() != nullptr && own_data_) { 51 MS_LOG(INFO) << "free device data in tensor impl."; 52 kernel::AscendAllocatorPlugin::GetInstance().Free(GetDeviceData(), GetDeviceId()); 53 lite_tensor_->set_device_data(nullptr); 54 } 55 #endif 56 if (!from_session_) { 57 if (!own_data_) { 58 lite_tensor_->set_data(nullptr); 59 } 60 delete lite_tensor_; 61 lite_tensor_ = nullptr; 62 } 63 } 64 LiteTensorImpl(lite::Tensor * tensor)65 explicit LiteTensorImpl(lite::Tensor *tensor) : lite_tensor_(tensor), from_session_(true) { 66 if (tensor != nullptr) { 67 tensor_name_ = tensor->tensor_name(); 68 } 69 } 70 71 static std::shared_ptr<LiteTensorImpl> MS_API CreateTensorImpl(const std::string &name, enum DataType type, 72 const std::vector<int64_t> &shape, const void *data, 73 size_t data_len); 74 static std::shared_ptr<LiteTensorImpl> MS_API CreateTensorImplByDeepCopy(const std::string &name, enum DataType type, 75 const std::vector<int64_t> &shape, 76 const void *data, size_t data_len); 77 78 #ifndef STRING_KERNEL_CLIP 79 static std::shared_ptr<LiteTensorImpl> MS_API StringsToTensorImpl(const std::string &name, 80 const std::vector<std::string> &str); 81 82 static std::vector<std::string> MS_API TensorImplToStrings(const std::shared_ptr<LiteTensorImpl> &impl); 83 #endif 84 Name()85 const std::string &Name() const override { 86 static const std::string empty = ""; 87 if (lite_tensor_ == nullptr) { 88 MS_LOG(ERROR) << "Invalid tensor."; 89 return empty; 90 } 91 return tensor_name_; 92 } 93 SetName(const std::string & name)94 void SetName(const std::string &name) override { 95 if (lite_tensor_ == nullptr) { 96 MS_LOG(ERROR) << "Invalid tensor."; 97 return; 98 } 99 lite_tensor_->set_tensor_name(name); 100 tensor_name_ = name; 101 } 102 DataType()103 enum DataType DataType() const override { 104 if (lite_tensor_ == nullptr) { 105 MS_LOG(ERROR) << "Invalid tensor."; 106 return DataType::kTypeUnknown; 107 } 108 return static_cast<enum DataType>(lite_tensor_->data_type()); 109 } 110 SetDataType(enum DataType data_type)111 void SetDataType(enum DataType data_type) override { 112 if (lite_tensor_ == nullptr) { 113 MS_LOG(ERROR) << "Invalid tensor."; 114 return; 115 } 116 lite_tensor_->set_data_type(static_cast<enum TypeId>(data_type)); 117 } 118 ElementNum()119 int64_t ElementNum() const override { 120 if (lite_tensor_ == nullptr) { 121 MS_LOG(ERROR) << "Invalid tensor."; 122 return -1; 123 } 124 return static_cast<int64_t>(lite_tensor_->ElementsNum()); 125 } 126 Shape()127 const std::vector<int64_t> &Shape() const override { 128 static std::vector<int64_t> empty{}; 129 if (lite_tensor_ == nullptr) { 130 MS_LOG(ERROR) << "Invalid tensor."; 131 return empty; 132 } 133 auto shape = lite_tensor_->shape(); 134 lite_shape_.resize(shape.size()); 135 std::transform(shape.begin(), shape.end(), lite_shape_.begin(), [](int c) { return static_cast<int64_t>(c); }); 136 return lite_shape_; 137 } 138 GetDevice()139 std::string GetDevice() const override { return lite_tensor_->get_device(); } 140 SetDevice(const std::string & device)141 void SetDevice(const std::string &device) override { 142 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE) 143 void *device_data = GetDeviceData(); 144 if (device_data != nullptr && own_data_) { 145 MS_LOG(INFO) << "free device data in tensor impl."; 146 kernel::AscendAllocatorPlugin::GetInstance().Free(device_data, GetDeviceId()); 147 } 148 #endif 149 lite_tensor_->set_device(device); 150 own_data_ = false; 151 } 152 GetDeviceId()153 int GetDeviceId() const override { return lite_tensor_->get_device_id(); } 154 SetDeviceId(int device_id)155 void SetDeviceId(int device_id) override { lite_tensor_->set_device_id(device_id); } 156 Clone()157 std::shared_ptr<mindspore::MSTensor::Impl> Clone() const override { return nullptr; } 158 SetShape(const std::vector<int64_t> & shape)159 void SetShape(const std::vector<int64_t> &shape) override { 160 if (lite_tensor_ == nullptr) { 161 MS_LOG(ERROR) << "Invalid tensor."; 162 return; 163 } 164 std::vector<int> tensor_shape; 165 tensor_shape.resize(shape.size()); 166 std::transform(shape.begin(), shape.end(), tensor_shape.begin(), [](int64_t c) { return static_cast<int>(c); }); 167 lite_tensor_->set_shape(tensor_shape); 168 } 169 GetAllocator()170 std::shared_ptr<Allocator> GetAllocator() const override { 171 if (lite_tensor_ == nullptr) { 172 MS_LOG(ERROR) << "Invalid tensor."; 173 return nullptr; 174 } 175 return lite_tensor_->allocator(); 176 } 177 SetAllocator(const std::shared_ptr<Allocator> & allocator)178 void SetAllocator(const std::shared_ptr<Allocator> &allocator) override { 179 if (lite_tensor_ == nullptr) { 180 MS_LOG(ERROR) << "Invalid tensor."; 181 return; 182 } 183 lite_tensor_->set_allocator(allocator); 184 } 185 Format()186 mindspore::Format Format() const override { 187 if (lite_tensor_ == nullptr) { 188 MS_LOG(ERROR) << "Invalid tensor."; 189 return mindspore::Format::NHWC; 190 } 191 return lite_tensor_->format(); 192 } 193 SetFormat(const mindspore::Format format)194 void SetFormat(const mindspore::Format format) override { 195 if (lite_tensor_ == nullptr) { 196 MS_LOG(ERROR) << "Invalid tensor."; 197 return; 198 } 199 lite_tensor_->set_format(format); 200 } 201 Data()202 std::shared_ptr<const void> Data() const override { 203 if (lite_tensor_ == nullptr) { 204 MS_LOG(ERROR) << "Invalid tensor."; 205 return nullptr; 206 } 207 return std::shared_ptr<const void>(lite_tensor_->data(), [](const void *) {}); 208 } 209 MutableData()210 void *MutableData() override { 211 if (lite_tensor_ == nullptr) { 212 MS_LOG(ERROR) << "Invalid tensor."; 213 return nullptr; 214 } 215 auto ret = lite_tensor_->MutableData(); 216 own_data_ = lite_tensor_->own_data(); 217 return ret; 218 } IsConst()219 bool IsConst() const override { 220 if (lite_tensor_ == nullptr) { 221 MS_LOG(ERROR) << "Invalid tensor."; 222 return false; 223 } 224 return lite_tensor_->IsConst(); 225 } 226 DataSize()227 size_t DataSize() const override { 228 if (lite_tensor_ == nullptr) { 229 MS_LOG(ERROR) << "Invalid tensor."; 230 return 0; 231 } 232 return lite_tensor_->Size(); 233 } 234 SetData(void * data,bool own_data)235 void SetData(void *data, bool own_data) override { 236 if (lite_tensor_ == nullptr) { 237 MS_LOG(ERROR) << "Invalid tensor."; 238 return; 239 } 240 lite_tensor_->set_data(data, own_data); 241 } 242 GetQuantParams()243 std::vector<QuantParam> GetQuantParams() const override { 244 if (lite_tensor_ == nullptr) { 245 MS_LOG(ERROR) << "Invalid tensor."; 246 return std::vector<QuantParam>{}; 247 } 248 auto lite_quant_params = lite_tensor_->quant_params(); 249 std::vector<QuantParam> quant_params; 250 for (size_t i = 0; i < lite_quant_params.size(); i++) { 251 QuantParam param{}; 252 param.bit_num = lite_quant_params[i].bitNum; 253 param.scale = lite_quant_params[i].scale; 254 param.zero_point = lite_quant_params[i].zeroPoint; 255 param.min = lite_quant_params[i].min; 256 param.max = lite_quant_params[i].max; 257 quant_params.push_back(param); 258 } 259 return quant_params; 260 } 261 SetQuantParams(const std::vector<QuantParam> & quant_params)262 void SetQuantParams(const std::vector<QuantParam> &quant_params) override { 263 if (lite_tensor_ == nullptr) { 264 MS_LOG(ERROR) << "Invalid tensor."; 265 return; 266 } 267 std::vector<lite::LiteQuantParam> lite_quant_params; 268 for (size_t i = 0; i < quant_params.size(); i++) { 269 lite::LiteQuantParam lite_param{}; 270 lite_param.bitNum = quant_params[i].bit_num; 271 lite_param.scale = quant_params[i].scale; 272 lite_param.zeroPoint = quant_params[i].zero_point; 273 lite_quant_params.push_back(lite_param); 274 } 275 lite_tensor_->set_quant_params(lite_quant_params); 276 } 277 IsDevice()278 bool IsDevice() const override { 279 if (lite_tensor_ == nullptr) { 280 MS_LOG(ERROR) << "Invalid tensor."; 281 return false; 282 } 283 return lite_tensor_->is_device(); 284 } 285 lite_tensor()286 lite::Tensor *lite_tensor() const { return lite_tensor_; } 287 set_lite_tensor(lite::Tensor * tensor)288 Status set_lite_tensor(lite::Tensor *tensor) { 289 if (tensor == nullptr) { 290 MS_LOG(ERROR) << "Tensor to set is null."; 291 return kLiteNullptr; 292 } 293 lite_tensor_ = tensor; 294 return kSuccess; 295 } 296 set_own_data(bool own_data)297 void set_own_data(bool own_data) { own_data_ = own_data; } 298 set_from_session(bool from_session)299 void set_from_session(bool from_session) { from_session_ = from_session; } 300 301 void SetDeviceData(void *data) override; 302 void *GetDeviceData() override; 303 304 private: 305 lite::Tensor *lite_tensor_ = nullptr; 306 std::string tensor_name_ = ""; 307 mutable std::vector<int64_t> lite_shape_; 308 bool own_data_ = false; 309 bool from_session_ = false; 310 }; 311 using LiteTensorImplPtr = std::shared_ptr<LiteTensorImpl>; 312 } // namespace mindspore 313 314 #endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_TENSOR_TENSOR_IMPL_H_ 315