• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_TENSOR_TENSOR_IMPL_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_TENSOR_TENSOR_IMPL_H_
19 
20 #include <cstddef>
21 #include <numeric>
22 #include <memory>
23 #include <algorithm>
24 #include <string>
25 #include <vector>
26 #include <functional>
27 #include "include/api/types.h"
28 #include "include/api/status.h"
29 #include "include/errorcode.h"
30 #include "src/tensor.h"
31 #include "src/common/log_adapter.h"
32 #include "ir/api_tensor_impl.h"
33 #include "common/mutable_tensor_impl.h"
34 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
35 #include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
36 #endif
37 
38 namespace mindspore {
39 using mindspore::lite::RET_OK;
40 
41 class LiteTensorImpl : public MutableTensorImpl {
42  public:
LiteTensorImpl()43   LiteTensorImpl() {}
44 
~LiteTensorImpl()45   ~LiteTensorImpl() override {
46     if (lite_tensor_ == nullptr) {
47       return;
48     }
49 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
50     if (GetDeviceData() != nullptr && own_data_) {
51       MS_LOG(INFO) << "free device data in tensor impl.";
52       kernel::AscendAllocatorPlugin::GetInstance().Free(GetDeviceData(), GetDeviceId());
53       lite_tensor_->set_device_data(nullptr);
54     }
55 #endif
56     if (!from_session_) {
57       if (!own_data_) {
58         lite_tensor_->set_data(nullptr);
59       }
60       delete lite_tensor_;
61       lite_tensor_ = nullptr;
62     }
63   }
64 
LiteTensorImpl(lite::Tensor * tensor)65   explicit LiteTensorImpl(lite::Tensor *tensor) : lite_tensor_(tensor), from_session_(true) {
66     if (tensor != nullptr) {
67       tensor_name_ = tensor->tensor_name();
68     }
69   }
70 
71   static std::shared_ptr<LiteTensorImpl> MS_API CreateTensorImpl(const std::string &name, enum DataType type,
72                                                                  const std::vector<int64_t> &shape, const void *data,
73                                                                  size_t data_len);
74   static std::shared_ptr<LiteTensorImpl> MS_API CreateTensorImplByDeepCopy(const std::string &name, enum DataType type,
75                                                                            const std::vector<int64_t> &shape,
76                                                                            const void *data, size_t data_len);
77 
78 #ifndef STRING_KERNEL_CLIP
79   static std::shared_ptr<LiteTensorImpl> MS_API StringsToTensorImpl(const std::string &name,
80                                                                     const std::vector<std::string> &str);
81 
82   static std::vector<std::string> MS_API TensorImplToStrings(const std::shared_ptr<LiteTensorImpl> &impl);
83 #endif
84 
Name()85   const std::string &Name() const override {
86     static const std::string empty = "";
87     if (lite_tensor_ == nullptr) {
88       MS_LOG(ERROR) << "Invalid tensor.";
89       return empty;
90     }
91     return tensor_name_;
92   }
93 
SetName(const std::string & name)94   void SetName(const std::string &name) override {
95     if (lite_tensor_ == nullptr) {
96       MS_LOG(ERROR) << "Invalid tensor.";
97       return;
98     }
99     lite_tensor_->set_tensor_name(name);
100     tensor_name_ = name;
101   }
102 
DataType()103   enum DataType DataType() const override {
104     if (lite_tensor_ == nullptr) {
105       MS_LOG(ERROR) << "Invalid tensor.";
106       return DataType::kTypeUnknown;
107     }
108     return static_cast<enum DataType>(lite_tensor_->data_type());
109   }
110 
SetDataType(enum DataType data_type)111   void SetDataType(enum DataType data_type) override {
112     if (lite_tensor_ == nullptr) {
113       MS_LOG(ERROR) << "Invalid tensor.";
114       return;
115     }
116     lite_tensor_->set_data_type(static_cast<enum TypeId>(data_type));
117   }
118 
ElementNum()119   int64_t ElementNum() const override {
120     if (lite_tensor_ == nullptr) {
121       MS_LOG(ERROR) << "Invalid tensor.";
122       return -1;
123     }
124     return static_cast<int64_t>(lite_tensor_->ElementsNum());
125   }
126 
Shape()127   const std::vector<int64_t> &Shape() const override {
128     static std::vector<int64_t> empty{};
129     if (lite_tensor_ == nullptr) {
130       MS_LOG(ERROR) << "Invalid tensor.";
131       return empty;
132     }
133     auto shape = lite_tensor_->shape();
134     lite_shape_.resize(shape.size());
135     std::transform(shape.begin(), shape.end(), lite_shape_.begin(), [](int c) { return static_cast<int64_t>(c); });
136     return lite_shape_;
137   }
138 
GetDevice()139   std::string GetDevice() const override { return lite_tensor_->get_device(); }
140 
SetDevice(const std::string & device)141   void SetDevice(const std::string &device) override {
142 #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
143     void *device_data = GetDeviceData();
144     if (device_data != nullptr && own_data_) {
145       MS_LOG(INFO) << "free device data in tensor impl.";
146       kernel::AscendAllocatorPlugin::GetInstance().Free(device_data, GetDeviceId());
147     }
148 #endif
149     lite_tensor_->set_device(device);
150     own_data_ = false;
151   }
152 
GetDeviceId()153   int GetDeviceId() const override { return lite_tensor_->get_device_id(); }
154 
SetDeviceId(int device_id)155   void SetDeviceId(int device_id) override { lite_tensor_->set_device_id(device_id); }
156 
Clone()157   std::shared_ptr<mindspore::MSTensor::Impl> Clone() const override { return nullptr; }
158 
SetShape(const std::vector<int64_t> & shape)159   void SetShape(const std::vector<int64_t> &shape) override {
160     if (lite_tensor_ == nullptr) {
161       MS_LOG(ERROR) << "Invalid tensor.";
162       return;
163     }
164     std::vector<int> tensor_shape;
165     tensor_shape.resize(shape.size());
166     std::transform(shape.begin(), shape.end(), tensor_shape.begin(), [](int64_t c) { return static_cast<int>(c); });
167     lite_tensor_->set_shape(tensor_shape);
168   }
169 
GetAllocator()170   std::shared_ptr<Allocator> GetAllocator() const override {
171     if (lite_tensor_ == nullptr) {
172       MS_LOG(ERROR) << "Invalid tensor.";
173       return nullptr;
174     }
175     return lite_tensor_->allocator();
176   }
177 
SetAllocator(const std::shared_ptr<Allocator> & allocator)178   void SetAllocator(const std::shared_ptr<Allocator> &allocator) override {
179     if (lite_tensor_ == nullptr) {
180       MS_LOG(ERROR) << "Invalid tensor.";
181       return;
182     }
183     lite_tensor_->set_allocator(allocator);
184   }
185 
Format()186   mindspore::Format Format() const override {
187     if (lite_tensor_ == nullptr) {
188       MS_LOG(ERROR) << "Invalid tensor.";
189       return mindspore::Format::NHWC;
190     }
191     return lite_tensor_->format();
192   }
193 
SetFormat(const mindspore::Format format)194   void SetFormat(const mindspore::Format format) override {
195     if (lite_tensor_ == nullptr) {
196       MS_LOG(ERROR) << "Invalid tensor.";
197       return;
198     }
199     lite_tensor_->set_format(format);
200   }
201 
Data()202   std::shared_ptr<const void> Data() const override {
203     if (lite_tensor_ == nullptr) {
204       MS_LOG(ERROR) << "Invalid tensor.";
205       return nullptr;
206     }
207     return std::shared_ptr<const void>(lite_tensor_->data(), [](const void *) {});
208   }
209 
MutableData()210   void *MutableData() override {
211     if (lite_tensor_ == nullptr) {
212       MS_LOG(ERROR) << "Invalid tensor.";
213       return nullptr;
214     }
215     auto ret = lite_tensor_->MutableData();
216     own_data_ = lite_tensor_->own_data();
217     return ret;
218   }
IsConst()219   bool IsConst() const override {
220     if (lite_tensor_ == nullptr) {
221       MS_LOG(ERROR) << "Invalid tensor.";
222       return false;
223     }
224     return lite_tensor_->IsConst();
225   }
226 
DataSize()227   size_t DataSize() const override {
228     if (lite_tensor_ == nullptr) {
229       MS_LOG(ERROR) << "Invalid tensor.";
230       return 0;
231     }
232     return lite_tensor_->Size();
233   }
234 
SetData(void * data,bool own_data)235   void SetData(void *data, bool own_data) override {
236     if (lite_tensor_ == nullptr) {
237       MS_LOG(ERROR) << "Invalid tensor.";
238       return;
239     }
240     lite_tensor_->set_data(data, own_data);
241   }
242 
GetQuantParams()243   std::vector<QuantParam> GetQuantParams() const override {
244     if (lite_tensor_ == nullptr) {
245       MS_LOG(ERROR) << "Invalid tensor.";
246       return std::vector<QuantParam>{};
247     }
248     auto lite_quant_params = lite_tensor_->quant_params();
249     std::vector<QuantParam> quant_params;
250     for (size_t i = 0; i < lite_quant_params.size(); i++) {
251       QuantParam param{};
252       param.bit_num = lite_quant_params[i].bitNum;
253       param.scale = lite_quant_params[i].scale;
254       param.zero_point = lite_quant_params[i].zeroPoint;
255       param.min = lite_quant_params[i].min;
256       param.max = lite_quant_params[i].max;
257       quant_params.push_back(param);
258     }
259     return quant_params;
260   }
261 
SetQuantParams(const std::vector<QuantParam> & quant_params)262   void SetQuantParams(const std::vector<QuantParam> &quant_params) override {
263     if (lite_tensor_ == nullptr) {
264       MS_LOG(ERROR) << "Invalid tensor.";
265       return;
266     }
267     std::vector<lite::LiteQuantParam> lite_quant_params;
268     for (size_t i = 0; i < quant_params.size(); i++) {
269       lite::LiteQuantParam lite_param{};
270       lite_param.bitNum = quant_params[i].bit_num;
271       lite_param.scale = quant_params[i].scale;
272       lite_param.zeroPoint = quant_params[i].zero_point;
273       lite_quant_params.push_back(lite_param);
274     }
275     lite_tensor_->set_quant_params(lite_quant_params);
276   }
277 
IsDevice()278   bool IsDevice() const override {
279     if (lite_tensor_ == nullptr) {
280       MS_LOG(ERROR) << "Invalid tensor.";
281       return false;
282     }
283     return lite_tensor_->is_device();
284   }
285 
lite_tensor()286   lite::Tensor *lite_tensor() const { return lite_tensor_; }
287 
set_lite_tensor(lite::Tensor * tensor)288   Status set_lite_tensor(lite::Tensor *tensor) {
289     if (tensor == nullptr) {
290       MS_LOG(ERROR) << "Tensor to set is null.";
291       return kLiteNullptr;
292     }
293     lite_tensor_ = tensor;
294     return kSuccess;
295   }
296 
set_own_data(bool own_data)297   void set_own_data(bool own_data) { own_data_ = own_data; }
298 
set_from_session(bool from_session)299   void set_from_session(bool from_session) { from_session_ = from_session; }
300 
301   void SetDeviceData(void *data) override;
302   void *GetDeviceData() override;
303 
304  private:
305   lite::Tensor *lite_tensor_ = nullptr;
306   std::string tensor_name_ = "";
307   mutable std::vector<int64_t> lite_shape_;
308   bool own_data_ = false;
309   bool from_session_ = false;
310 };
311 using LiteTensorImplPtr = std::shared_ptr<LiteTensorImpl>;
312 }  // namespace mindspore
313 
314 #endif  // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_TENSOR_TENSOR_IMPL_H_
315