• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_CXX_API_TENSOR_TENSOR_IMPL_H_
18 #define MINDSPORE_LITE_SRC_CXX_API_TENSOR_TENSOR_IMPL_H_
19 
20 #include <cstddef>
21 #include <numeric>
22 #include <memory>
23 #include <algorithm>
24 #include <string>
25 #include <vector>
26 #include <functional>
27 #include "include/api/types.h"
28 #include "include/api/status.h"
29 #include "include/errorcode.h"
30 #include "include/lite_utils.h"
31 #include "include/ms_tensor.h"
32 #include "src/tensor.h"
33 #include "src/common/log_adapter.h"
34 
35 namespace mindspore {
36 using mindspore::lite::RET_OK;
37 
38 class MSTensor::Impl {
39  public:
Impl()40   Impl() {}
41 
~Impl()42   virtual ~Impl() {
43     if (lite_tensor_ == nullptr) {
44       return;
45     }
46     if (!from_session_) {
47       if (!own_data_) {
48         lite_tensor_->set_data(nullptr);
49       }
50       delete lite_tensor_;
51       lite_tensor_ = nullptr;
52     }
53   }
54 
Impl(tensor::MSTensor * tensor)55   explicit Impl(tensor::MSTensor *tensor) : lite_tensor_(tensor), from_session_(true) {
56     if (tensor != nullptr) {
57       tensor_name_ = tensor->tensor_name();
58     }
59   }
60 
61   static std::shared_ptr<Impl> MS_API CreateTensorImpl(const std::string &name, enum DataType type,
62                                                        const std::vector<int64_t> &shape, const void *data,
63                                                        size_t data_len);
64 
65 #ifndef STRING_KERNEL_CLIP
66   static std::shared_ptr<Impl> MS_API StringsToTensorImpl(const std::string &name, const std::vector<std::string> &str);
67 
68   static std::vector<std::string> MS_API TensorImplToStrings(const std::shared_ptr<Impl> &impl);
69 #endif
70 
Name()71   virtual const std::string &Name() const {
72     static std::string empty = "";
73     if (lite_tensor_ == nullptr) {
74       MS_LOG(ERROR) << "Invalid tensor.";
75       return empty;
76     }
77     return tensor_name_;
78   }
79 
SetName(const std::string & name)80   void SetName(const std::string &name) {
81     if (lite_tensor_ == nullptr) {
82       MS_LOG(ERROR) << "Invalid tensor.";
83       return;
84     }
85     lite_tensor_->set_tensor_name(name);
86     tensor_name_ = name;
87   }
88 
DataType()89   virtual enum DataType DataType() const {
90     if (lite_tensor_ == nullptr) {
91       MS_LOG(ERROR) << "Invalid tensor.";
92       return DataType::kTypeUnknown;
93     }
94     return static_cast<enum DataType>(lite_tensor_->data_type());
95   }
96 
SetDataType(enum DataType data_type)97   void SetDataType(enum DataType data_type) {
98     if (lite_tensor_ == nullptr) {
99       MS_LOG(ERROR) << "Invalid tensor.";
100       return;
101     }
102     lite_tensor_->set_data_type(static_cast<enum TypeId>(data_type));
103   }
104 
ElementNum()105   int64_t ElementNum() const {
106     if (lite_tensor_ == nullptr) {
107       MS_LOG(ERROR) << "Invalid tensor.";
108       return -1;
109     }
110     return static_cast<int64_t>(lite_tensor_->ElementsNum());
111   }
112 
Shape()113   virtual const std::vector<int64_t> &Shape() const {
114     static std::vector<int64_t> empty;
115     if (lite_tensor_ == nullptr) {
116       MS_LOG(ERROR) << "Invalid tensor.";
117       return empty;
118     }
119     auto shape = lite_tensor_->shape();
120     lite_shape_.resize(shape.size());
121     std::transform(shape.begin(), shape.end(), lite_shape_.begin(), [](int c) { return static_cast<int64_t>(c); });
122     return lite_shape_;
123   }
124 
Clone()125   virtual std::shared_ptr<Impl> Clone() const { return nullptr; }
126 
SetShape(const std::vector<int64_t> & shape)127   void SetShape(const std::vector<int64_t> &shape) {
128     if (lite_tensor_ == nullptr) {
129       MS_LOG(ERROR) << "Invalid tensor.";
130       return;
131     }
132     std::vector<int> tensor_shape;
133     tensor_shape.resize(shape.size());
134     std::transform(shape.begin(), shape.end(), tensor_shape.begin(), [](int64_t c) { return static_cast<int>(c); });
135     lite_tensor_->set_shape(tensor_shape);
136   }
137 
allocator()138   std::shared_ptr<Allocator> allocator() const {
139     if (lite_tensor_ == nullptr) {
140       MS_LOG(ERROR) << "Invalid tensor.";
141       return nullptr;
142     }
143     return lite_tensor_->allocator();
144   }
145 
SetAllocator(std::shared_ptr<Allocator> allocator)146   void SetAllocator(std::shared_ptr<Allocator> allocator) {
147     if (lite_tensor_ == nullptr) {
148       MS_LOG(ERROR) << "Invalid tensor.";
149       return;
150     }
151     lite_tensor_->set_allocator(allocator);
152   }
153 
format()154   mindspore::Format format() {
155     if (lite_tensor_ == nullptr) {
156       MS_LOG(ERROR) << "Invalid tensor.";
157       return mindspore::Format::NHWC;
158     }
159     return lite_tensor_->format();
160   }
161 
SetFormat(mindspore::Format format)162   void SetFormat(mindspore::Format format) {
163     if (lite_tensor_ == nullptr) {
164       MS_LOG(ERROR) << "Invalid tensor.";
165       return;
166     }
167     lite_tensor_->set_format(format);
168   }
169 
Data()170   virtual std::shared_ptr<const void> Data() const {
171     if (lite_tensor_ == nullptr) {
172       MS_LOG(ERROR) << "Invalid tensor.";
173       return nullptr;
174     }
175     return std::shared_ptr<const void>(lite_tensor_->data(), [](const void *) {});
176   }
177 
MutableData()178   virtual void *MutableData() {
179     if (lite_tensor_ == nullptr) {
180       MS_LOG(ERROR) << "Invalid tensor.";
181       return nullptr;
182     }
183     return lite_tensor_->MutableData();
184   }
IsConst()185   virtual bool IsConst() const {
186     if (lite_tensor_ == nullptr) {
187       MS_LOG(ERROR) << "Invalid tensor.";
188       return false;
189     }
190     return lite_tensor_->IsConst();
191   }
192 
DataSize()193   virtual size_t DataSize() const {
194     if (lite_tensor_ == nullptr) {
195       MS_LOG(ERROR) << "Invalid tensor.";
196       return 0;
197     }
198     return lite_tensor_->Size();
199   }
200 
SetData(void * data)201   void SetData(void *data) {
202     if (lite_tensor_ == nullptr) {
203       MS_LOG(ERROR) << "Invalid tensor.";
204       return;
205     }
206     lite_tensor_->set_data(data);
207   }
208 
QuantParams()209   virtual std::vector<QuantParam> QuantParams() const {
210     if (lite_tensor_ == nullptr) {
211       MS_LOG(ERROR) << "Invalid tensor.";
212       return std::vector<QuantParam>{};
213     }
214     auto lite_quant_params = lite_tensor_->quant_params();
215     std::vector<QuantParam> quant_params;
216     for (size_t i = 0; i < lite_quant_params.size(); i++) {
217       QuantParam param{};
218       param.bit_num = lite_quant_params[i].bitNum;
219       param.scale = lite_quant_params[i].scale;
220       param.zero_point = lite_quant_params[i].zeroPoint;
221       quant_params.push_back(param);
222     }
223     return quant_params;
224   }
225 
SetQuantParams(std::vector<QuantParam> quant_params)226   void SetQuantParams(std::vector<QuantParam> quant_params) {
227     if (lite_tensor_ == nullptr) {
228       MS_LOG(ERROR) << "Invalid tensor.";
229       return;
230     }
231     std::vector<lite::LiteQuantParam> lite_quant_params;
232     for (size_t i = 0; i < quant_params.size(); i++) {
233       lite::LiteQuantParam lite_param{};
234       lite_param.bitNum = quant_params[i].bit_num;
235       lite_param.scale = quant_params[i].scale;
236       lite_param.zeroPoint = quant_params[i].zero_point;
237       lite_quant_params.push_back(lite_param);
238     }
239     lite_tensor_->set_quant_params(lite_quant_params);
240   }
241 
IsDevice()242   virtual bool IsDevice() const { return false; }
243 
lite_tensor()244   tensor::MSTensor *lite_tensor() const { return lite_tensor_; }
245 
set_lite_tensor(tensor::MSTensor * tensor)246   Status set_lite_tensor(tensor::MSTensor *tensor) {
247     if (tensor == nullptr) {
248       MS_LOG(ERROR) << "Tensor to set is null.";
249       return kLiteNullptr;
250     }
251     lite_tensor_ = tensor;
252     return kSuccess;
253   }
254 
set_own_data(bool own_data)255   void set_own_data(bool own_data) { own_data_ = own_data; }
256 
set_from_session(bool from_session)257   void set_from_session(bool from_session) { from_session_ = from_session; }
258 
259  private:
260   tensor::MSTensor *lite_tensor_ = nullptr;
261   std::string tensor_name_ = "";
262   mutable std::vector<int64_t> lite_shape_;
263   bool own_data_ = false;
264   bool from_session_ = false;
265 };
266 }  // namespace mindspore
267 
268 #endif  // MINDSPORE_LITE_SRC_CXX_API_TENSOR_TENSOR_IMPL_H_
269