• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_UTILS_H_
20 #define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_UTILS_H_
21 
22 #include <vector>
23 #include <string>
24 #include <memory>
25 #include <functional>
26 
27 #include "include/api/types.h"
28 #include "ir/tensor.h"
29 #include "include/backend/device_address.h"
30 #include "common/utils.h"
31 #include "common/mutable_tensor_impl.h"
32 #include "mindspore/core/ir/tensor.h"
33 #include "kernel/kernel.h"
34 #include "src/tensor.h"
35 #include "infer/tensor.h"
36 #ifdef ENABLE_CLOUD_INFERENCE
37 #include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
38 #endif
39 namespace mindspore {
40 class TensorRefData : public tensor::TensorData {
41  public:
42   TensorRefData(void *data, size_t elem_count, size_t data_size, size_t ndim,
43                 const std::function<void(uint8_t *)> &deleter = nullptr);
44   ~TensorRefData();
45 
46   ssize_t size() const override;
47   ssize_t itemsize() const override;
48   ssize_t nbytes() const override;
49   ssize_t ndim() const override;
50   void *data() override;
51   const void *const_data() const override;
is_sub_data()52   bool is_sub_data() const override { return false; }
has_sub_data()53   bool has_sub_data() const override { return false; }
54   std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override;
55 
56  private:
57   void *data_ = nullptr;
58   size_t elem_count_ = 0;
59   size_t data_size_ = 0;
60   size_t ndim_ = 0;
61   std::function<void(uint8_t *)> deleter_ = nullptr;
62 };
63 
64 constexpr auto kLiteDeviceName = "LiteDevice";
65 
66 class LiteDeviceAddress : public device::DeviceAddress {
67  public:
LiteDeviceAddress(void * ptr,size_t size)68   LiteDeviceAddress(void *ptr, size_t size) : device::DeviceAddress(ptr, size) {}
SetData(void * data)69   void SetData(void *data) { set_ptr(data); }
70 
SyncDeviceToHost(const ShapeVector & shape,size_t size,TypeId type,void * host_ptr)71   bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override {
72     return false;
73   }
SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const void * host_ptr,const std::string & format)74   bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
75                         const std::string &format) const override {
76     return false;
77   }
SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const void * host_ptr)78   bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override {
79     return SyncHostToDevice(shape, size, type, host_ptr, "DefaultFormat");
80   }
ClearDeviceMemory()81   void ClearDeviceMemory() override {}
82 };
83 
84 class TensorTensorImpl : public MutableTensorImpl {
85  public:
TensorTensorImpl(const tensor::Tensor & tensor)86   explicit TensorTensorImpl(const tensor::Tensor &tensor) : tensor_(std::make_shared<tensor::Tensor>(tensor)) {}
TensorTensorImpl(const std::shared_ptr<tensor::Tensor> & tensor)87   explicit TensorTensorImpl(const std::shared_ptr<tensor::Tensor> &tensor) : tensor_(tensor) {}
88 
SetData(void *,bool)89   void SetData(void *, bool) override { MS_LOG_EXCEPTION << "Cannot set data for TensorTensorImpl"; }
90 
Data()91   std::shared_ptr<const void> Data() const override {
92     MS_EXCEPTION_IF_NULL(tensor_);
93     return std::shared_ptr<const void>(tensor_->data_c(), [](const void *) {});
94   }
95 
SetDeviceId(int device_id)96   void SetDeviceId(int device_id) override {
97     MS_EXCEPTION_IF_NULL(tensor_);
98     device_id_ = device_id;
99   }
100 
SetDevice(const std::string & device)101   void SetDevice(const std::string &device) override {
102     MS_EXCEPTION_IF_NULL(tensor_);
103     device_ = device;
104   }
105 
GetDeviceId()106   int GetDeviceId() const override {
107     MS_EXCEPTION_IF_NULL(tensor_);
108     return device_id_;
109   }
110 
GetDevice()111   std::string GetDevice() const override {
112     MS_EXCEPTION_IF_NULL(tensor_);
113     return device_;
114   }
115 
MutableData()116   void *MutableData() override {
117     MS_EXCEPTION_IF_NULL(tensor_);
118     return tensor_->data_c();
119   }
120 
SetDeviceData(void * data)121   void SetDeviceData(void *data) override {
122     MS_EXCEPTION_IF_NULL(tensor_);
123     auto old_device_data = GetDeviceData();
124     MS_LOG(ERROR) << "set device data in tensor utils.";
125 #ifdef ENABLE_CLOUD_INFERENCE
126     if (old_device_data != nullptr && device_own_data_) {
127       kernel::AscendAllocatorPlugin::GetInstance().Free(old_device_data, GetDeviceId());
128     }
129 #endif
130     auto data_size = DataSize();
131     auto device_address = std::make_shared<LiteDeviceAddress>(data, data_size);
132     tensor_->set_device_address(device_address);
133     device_own_data_ = false;
134   }
GetDeviceData()135   void *GetDeviceData() override {
136     MS_EXCEPTION_IF_NULL(tensor_);
137     auto device_address = tensor_->device_address();
138     if (device_address == nullptr) {
139       return nullptr;
140     }
141     return device_address->GetMutablePtr();
142   }
143 
IsDevice()144   bool IsDevice() const override {
145     MS_EXCEPTION_IF_NULL(tensor_);
146     return tensor_->device_address() != nullptr;
147   }
148 
IsConst()149   bool IsConst() const override { return false; }
150 
SetShape(const std::vector<int64_t> &)151   void SetShape(const std::vector<int64_t> &) override { MS_LOG_EXCEPTION << "Cannot set shape for TensorTensorImpl"; }
SetDataType(mindspore::DataType)152   void SetDataType(mindspore::DataType) override { MS_LOG_EXCEPTION << "Cannot set data type for TensorTensorImpl"; }
SetName(const std::string & name)153   void SetName(const std::string &name) override {
154     MS_EXCEPTION_IF_NULL(tensor_);
155     tensor_->set_name(name);
156   }
157 
158   mindspore::Format Format() const override;
159 
160   void SetFormat(mindspore::Format format) override;
161 
Name()162   const std::string &Name() const override {
163     MS_EXCEPTION_IF_NULL(tensor_);
164     return tensor_->name();
165   }
DataType()166   enum DataType DataType() const override {
167     MS_EXCEPTION_IF_NULL(tensor_);
168     return static_cast<enum DataType>(tensor_->data_type());
169   }
Shape()170   const std::vector<int64_t> &Shape() const override {
171     MS_EXCEPTION_IF_NULL(tensor_);
172     return tensor_->shape();
173   }
174 
SetAllocator(const std::shared_ptr<Allocator> & allocator)175   void SetAllocator(const std::shared_ptr<Allocator> &allocator) override {
176     MS_EXCEPTION_IF_NULL(tensor_);
177     tensor_->set_user_data("allocator", allocator);
178   }
GetAllocator()179   std::shared_ptr<Allocator> GetAllocator() const override {
180     MS_EXCEPTION_IF_NULL(tensor_);
181     return tensor_->user_data<Allocator>("allocator");
182   }
183 
GetQuantParams()184   std::vector<QuantParam> GetQuantParams() const override {
185     MS_EXCEPTION_IF_NULL(tensor_);
186     auto data = tensor_->user_data<std::vector<QuantParam>>("quant_param");
187     return data ? *data : std::vector<QuantParam>();
188   }
189 
SetQuantParams(const std::vector<QuantParam> & quant_param)190   void SetQuantParams(const std::vector<QuantParam> &quant_param) override {
191     MS_EXCEPTION_IF_NULL(tensor_);
192     tensor_->set_user_data("quant_param", std::make_shared<std::vector<QuantParam>>(quant_param));
193   }
194 
DataSize()195   size_t DataSize() const override {
196     auto elem_num = ElementNum();
197     if (elem_num <= 0) {
198       return 0;
199     }
200     return LongToSize(elem_num) * lite::DataTypeSize(static_cast<enum TypeId>(DataType()));
201   }
202 
Clone()203   std::shared_ptr<Impl> Clone() const override { return std::make_shared<TensorTensorImpl>(tensor_); }
204 
205  private:
206   std::shared_ptr<tensor::Tensor> tensor_ = nullptr;
207   std::string device_ = "";
208   int device_id_ = -1;
209   bool device_own_data_ = true;
210 };
211 
212 class TensorUtils {
213  public:
214   // MSTensor <-> TensorPtr
215   static std::vector<mindspore::tensor::TensorPtr> MSTensorToTensorPtr(const std::vector<MSTensor> &ms_tensors);
216   static std::vector<MSTensor> TensorPtrToMSTensor(std::vector<mindspore::tensor::TensorPtr> tensor_ptrs,
217                                                    const std::vector<std::string> &tensor_names);
218 
219   static std::vector<mindspore::tensor::Tensor> MSTensorToTensor(const std::vector<MSTensor> &ms_tensors);
220   static std::vector<MSTensor> TensorToMSTensor(std::vector<mindspore::tensor::Tensor> tensors,
221                                                 const std::vector<std::string> &tensor_names);
222 
223   // TensorPtr <-> Tensor
224   static std::vector<mindspore::tensor::TensorPtr> TensorToTensorPtr(
225     const std::vector<mindspore::tensor::Tensor> &tensors);
226   static std::vector<mindspore::tensor::Tensor> TensorPtrToTensor(
227     const std::vector<mindspore::tensor::TensorPtr> &tensor_ptrs);
228 };
229 
230 class CloudTensorUtils {
231  public:
232   /* lite tensor ---> Address */
233   static kernel::AddressPtr LiteTensorToAddressPtr(const lite::Tensor *lite_tensor);
234   static std::vector<mindspore::kernel::AddressPtr> LiteTensorToAddressPtrVec(
235     const std::vector<lite::Tensor *> &lite_tensors);
236 
237   /* lite tensor ---> kernel tensor */
238   static kernel::KernelTensor *LiteTensorToKernelTensorPtr(const lite::Tensor *lite_tensor);
239   static std::vector<kernel::KernelTensor *> LiteTensorToKernelTensorPtrVec(
240     const std::vector<lite::Tensor *> &lite_tensors);
241 };
242 
243 class AbstractTensorUtils {
244  public:
245   static std::vector<std::vector<int64_t>> GetTensorListShapes(const std::vector<infer::abstract::Tensor *> &tensors);
246   static bool SetTensorListShapse(const std::vector<infer::abstract::Tensor *> &tensors,
247                                   const std::vector<std::vector<int64_t>> &shapes);
248 };
249 }  // namespace mindspore
250 
251 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_UTILS_H_
252