• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_NUMPY_IMPL_H_
18 #define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_NUMPY_IMPL_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <functional>
24 #include <utility>
25 #include <set>
26 
27 #include "src/common/log_adapter.h"
28 #include "common/mutable_tensor_impl.h"
29 #include "pybind11/pybind11.h"
30 #include "pybind11/numpy.h"
31 #ifdef ENABLE_CLOUD_INFERENCE
32 #include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
33 #endif
34 
35 namespace py = pybind11;
36 namespace mindspore {
37 class TensorNumpyImpl : public MutableTensorImpl {
38  public:
TensorNumpyImpl(const std::string & name,py::buffer_info && buffer,const std::vector<int64_t> & ms_shape)39   TensorNumpyImpl(const std::string &name, py::buffer_info &&buffer, const std::vector<int64_t> &ms_shape)
40       : name_(name), buffer_(std::move(buffer)), ms_shape_(ms_shape) {}
~TensorNumpyImpl()41   ~TensorNumpyImpl() {
42     {
43       if (PyGILState_Check() == 0) {
44         py::gil_scoped_acquire acquire;
45         { buffer_ = py::buffer_info(); }
46       } else {
47         buffer_ = py::buffer_info();
48       }
49     }
50     if (device_data_ != nullptr) {
51       MS_LOG(INFO) << "free device data in tensor numpy impl.";
52       kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_);
53     }
54   }
Shape()55   const std::vector<int64_t> &Shape() const override { return ms_shape_; }
SetShape(const std::vector<int64_t> & shape)56   void SetShape(const std::vector<int64_t> &shape) override {
57     MS_LOG(WARNING) << "Cannot call SetShape for numpy tensor";
58   }
59 
DataType()60   enum DataType DataType() const override { return GetDataType(buffer_); }
SetDataType(mindspore::DataType data_type)61   void SetDataType(mindspore::DataType data_type) override {
62     MS_LOG(WARNING) << "Cannot call SetDataType for numpy tensor";
63   }
64 
SetName(const std::string & name)65   void SetName(const std::string &name) override { name_ = name; }
Name()66   const std::string &Name() const override { return name_; }
67 
Format()68   mindspore::Format Format() const override { return format_; }
SetFormat(mindspore::Format)69   void SetFormat(mindspore::Format) override { MS_LOG(ERROR) << "Cannot call SetFormat for numpy tensor"; }
70 
SetAllocator(const std::shared_ptr<Allocator> & allocator)71   void SetAllocator(const std::shared_ptr<Allocator> &allocator) override {
72     MS_LOG(ERROR) << "Cannot call SetAllocator for numpy tensor";
73   }
GetAllocator()74   std::shared_ptr<Allocator> GetAllocator() const override { return nullptr; }
75 
GetQuantParams()76   std::vector<QuantParam> GetQuantParams() const override { return {}; }
SetQuantParams(const std::vector<QuantParam> & quant_param)77   void SetQuantParams(const std::vector<QuantParam> &quant_param) override {
78     MS_LOG(ERROR) << "Cannot call SetQuantParams for numpy tensor";
79   }
80 
ElementNum()81   int64_t ElementNum() const override { return buffer_.size; }
DataSize()82   size_t DataSize() const override { return buffer_.size * buffer_.itemsize; }
83 
SetDeviceData(void * data)84   void SetDeviceData(void *data) override {
85 #ifdef ENABLE_CLOUD_INFERENCE
86     if (device_data_ != nullptr) {
87       kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_);
88     }
89     device_data_ = data;
90     return;
91 #endif
92     MS_LOG(ERROR) << "not support.";
93   }
94 
GetDeviceData()95   void *GetDeviceData() override { return device_data_; }
96 
IsConst()97   bool IsConst() const override { return false; }
SetIsConst(bool)98   void SetIsConst(bool) { MS_LOG(ERROR) << "Cannot call SetIsConst for numpy tensor"; }
99 
IsDevice()100   bool IsDevice() const override { return false; }
101 
Data()102   std::shared_ptr<const void> Data() const override {
103     auto data = static_cast<const uint8_t *>(buffer_.ptr);
104     return std::shared_ptr<const void>(data, [](const void *) {});
105   }
106 
SetData(void *,bool)107   void SetData(void *, bool) override { MS_LOG(ERROR) << "Cannot call SetData for numpy tensor"; }
108 
GetDeviceId()109   int GetDeviceId() const override { return device_id_; }
110 
SetDeviceId(int device_id)111   void SetDeviceId(int device_id) override { device_id_ = device_id; }
112 
GetDevice()113   std::string GetDevice() const override { return device_; }
114 
SetDevice(const std::string & device)115   void SetDevice(const std::string &device) override { device_ = device; }
116 
MutableData()117   void *MutableData() override { return buffer_.ptr; }
118 
Clone()119   std::shared_ptr<Impl> Clone() const override {
120     MS_LOG(ERROR) << "Cannot call Clone for numpy tensor";
121     return nullptr;
122   }
123 
GetDataType(const py::buffer_info & buf)124   static enum DataType GetDataType(const py::buffer_info &buf) {
125     std::set<char> fp_format = {'e', 'f', 'd'};
126     std::set<char> int_format = {'b', 'h', 'i', 'l', 'q'};
127     std::set<char> uint_format = {'B', 'H', 'I', 'L', 'Q'};
128     if (buf.format.size() == 1) {
129       char format = buf.format.front();
130       if (fp_format.find(format) != fp_format.end()) {
131         switch (buf.itemsize) {
132           case sizeof(uint16_t):
133             return DataType::kNumberTypeFloat16;
134           case sizeof(uint32_t):
135             return DataType::kNumberTypeFloat32;
136           case sizeof(uint64_t):
137             return DataType::kNumberTypeFloat64;
138         }
139       } else if (int_format.find(format) != int_format.end()) {
140         switch (buf.itemsize) {
141           case sizeof(int8_t):
142             return DataType::kNumberTypeInt8;
143           case sizeof(int16_t):
144             return DataType::kNumberTypeInt16;
145           case sizeof(int32_t):
146             return DataType::kNumberTypeInt32;
147           case sizeof(int64_t):
148             return DataType::kNumberTypeInt64;
149         }
150       } else if (uint_format.find(format) != uint_format.end()) {
151         switch (buf.itemsize) {
152           case sizeof(uint8_t):
153             return DataType::kNumberTypeUInt8;
154           case sizeof(uint16_t):
155             return DataType::kNumberTypeUInt16;
156           case sizeof(uint32_t):
157             return DataType::kNumberTypeUInt32;
158           case sizeof(uint64_t):
159             return DataType::kNumberTypeUInt64;
160         }
161       } else if (format == '?') {
162         return DataType::kNumberTypeBool;
163       }
164     }
165     MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
166     return DataType::kTypeUnknown;
167   }
168 
169  protected:
170   std::string name_;
171   enum Format format_ = mindspore::NCHW;
172 
173   py::buffer_info buffer_;
174   std::vector<int64_t> ms_shape_;
175   void *device_data_ = nullptr;
176   std::string device_ = "";
177   int device_id_ = -1;
178 };
179 }  // namespace mindspore
180 
181 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_NUMPY_IMPL_H_
182