• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pybind_api/ir/tensor_py.h"
18 
19 #include <utility>
20 
21 #include "include/common/pybind_api/api_register.h"
22 #include "abstract/abstract_value.h"
23 #include "utils/cache_embedding_hashmap_struct.h"
24 #include "include/common/utils/python_adapter.h"
25 #include "mindspore/ccsrc/include/backend/distributed/embedding_cache/embedding_cache_utils.h"
26 #include "pybind_api/ir/tensor_index_py.h"
27 #include "pybind_api/ir/hook_py.h"
28 #include "include/common/profiler.h"
29 #include "runtime/hardware/device_context_manager.h"
30 #include "runtime/pynative/op_executor.h"
31 #include "include/backend/mbuf_device_address.h"
32 
33 namespace mindspore {
34 namespace tensor {
35 namespace {
36 struct TensorToNumpyRegister {
TensorToNumpyRegistermindspore::tensor::__anonb5d003050111::TensorToNumpyRegister37   TensorToNumpyRegister() { python_adapter::PyAdapterCallback::SetTensorToNumpyHandler(tensor::TensorPy::AsNumpy); }
38 } callback_register;
39 
GetDeviceCtx(const std::string & to)40 device::DeviceContext *GetDeviceCtx(const std::string &to) {
41   const auto &device = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
42   if (to != "CPU" && to != device) {
43     MS_LOG(EXCEPTION) << "The value of 'to' should be same with device, bug got to:" << to << ", device: " << device;
44   }
45   auto device_ctx = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
46     {device, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
47   MS_EXCEPTION_IF_NULL(device_ctx);
48 
49   device_ctx->Initialize();
50   return device_ctx;
51 }
52 }  // namespace
53 constexpr ssize_t kPyBufItemSize1 = 1;
54 constexpr ssize_t kPyBufItemSize2 = 2;
55 constexpr ssize_t kPyBufItemSize4 = 4;
56 constexpr ssize_t kPyBufItemSize8 = 8;
57 
GetDataType(const py::buffer_info & buf)58 static TypeId GetDataType(const py::buffer_info &buf) {
59   if (buf.format.size() == 1) {
60     switch (buf.format.front()) {
61       case 'e':
62       case 'f':
63       case 'd':
64         switch (buf.itemsize) {
65           case kPyBufItemSize2:
66             return TypeId::kNumberTypeFloat16;
67           case kPyBufItemSize4:
68             return TypeId::kNumberTypeFloat32;
69           case kPyBufItemSize8:
70             return TypeId::kNumberTypeFloat64;
71         }
72         break;
73       case 'b':
74       case 'h':
75       case 'i':
76       case 'l':
77       case 'q':
78         switch (buf.itemsize) {
79           case kPyBufItemSize1:
80             return TypeId::kNumberTypeInt8;
81           case kPyBufItemSize2:
82             return TypeId::kNumberTypeInt16;
83           case kPyBufItemSize4:
84             return TypeId::kNumberTypeInt32;
85           case kPyBufItemSize8:
86             return TypeId::kNumberTypeInt64;
87           default:
88             break;
89         }
90         break;
91       case 'B':
92       case 'H':
93       case 'I':
94       case 'L':
95       case 'Q':
96         switch (buf.itemsize) {
97           case kPyBufItemSize1:
98             return TypeId::kNumberTypeUInt8;
99           case kPyBufItemSize2:
100             return TypeId::kNumberTypeUInt16;
101           case kPyBufItemSize4:
102             return TypeId::kNumberTypeUInt32;
103           case kPyBufItemSize8:
104             return TypeId::kNumberTypeUInt64;
105           default:
106             break;
107         }
108         break;
109       case '?':
110         return TypeId::kNumberTypeBool;
111       case 'T':
112         return TypeId::kNumberTypeBFloat16;
113       default:
114         break;
115     }
116   } else if (buf.format.size() >= 2) {
117     // Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.
118     if (buf.format.back() == 'w' || buf.format.back() == 's') {
119       return TypeId::kObjectTypeString;
120     } else if (buf.format == "Zf") {
121       return TypeId::kNumberTypeComplex64;
122     } else if (buf.format == "Zd") {
123       return TypeId::kNumberTypeComplex128;
124     }
125   }
126   MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize;
127   return TypeId::kTypeUnknown;
128 }
129 
GetPyTypeFormat(TypeId data_type)130 static std::string GetPyTypeFormat(TypeId data_type) {
131   switch (data_type) {
132     case TypeId::kNumberTypeFloat16:
133       return "e";
134     case TypeId::kNumberTypeBFloat16:
135       return "T";
136     case TypeId::kNumberTypeFloat32:
137       return py::format_descriptor<float>::format();
138     case TypeId::kNumberTypeFloat64:
139       return py::format_descriptor<double>::format();
140     case TypeId::kNumberTypeUInt8:
141       return py::format_descriptor<uint8_t>::format();
142     case TypeId::kNumberTypeUInt16:
143       return py::format_descriptor<uint16_t>::format();
144     case TypeId::kNumberTypeUInt32:
145       return py::format_descriptor<uint32_t>::format();
146     case TypeId::kNumberTypeUInt64:
147       return py::format_descriptor<uint64_t>::format();
148     case TypeId::kNumberTypeInt4:
149     case TypeId::kNumberTypeInt8:
150       return py::format_descriptor<int8_t>::format();
151     case TypeId::kNumberTypeInt16:
152       return py::format_descriptor<int16_t>::format();
153     case TypeId::kNumberTypeInt:
154     case TypeId::kNumberTypeInt32:
155       return py::format_descriptor<int32_t>::format();
156     case TypeId::kNumberTypeInt64:
157       return py::format_descriptor<int64_t>::format();
158     case TypeId::kNumberTypeBool:
159       return py::format_descriptor<bool>::format();
160     case TypeId::kObjectTypeString:
161       return py::format_descriptor<uint8_t>::format();
162     case TypeId::kNumberTypeComplex64:
163       return py::format_descriptor<std::complex<float>>::format();
164     case TypeId::kNumberTypeComplex128:
165       return py::format_descriptor<std::complex<double>>::format();
166     case TypeId::kMetaTypeType:
167     case TypeId::kMetaTypeEllipsis:
168     default:
169       MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
170       return "";
171   }
172 }
173 
IsCContiguous(const py::array & input)174 static bool IsCContiguous(const py::array &input) {
175   auto flags = static_cast<unsigned int>(input.flags());
176   return (flags & static_cast<unsigned int>(pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_)) != 0;
177 }
178 
179 // TensorDataNumpy implements TensorData using numpy array.
180 class TensorDataNumpy : public TensorData {
181  public:
TensorDataNumpy(py::buffer_info && buffer)182   explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::make_unique<py::buffer_info>(std::move(buffer))) {}
183 
~TensorDataNumpy()184   ~TensorDataNumpy() override {
185     py::gil_scoped_acquire acquire;
186     buffer_.reset();
187   }
188 
189   /// Total number of elements.
size() const190   ssize_t size() const override { return buffer()->size; }
191 
192   /// Byte size of a single element.
itemsize() const193   ssize_t itemsize() const override { return buffer()->itemsize; }
194 
195   /// Total number of bytes.
nbytes() const196   ssize_t nbytes() const override { return buffer()->itemsize * buffer()->size; }
197 
198   /// Number of dimensions.
ndim() const199   ssize_t ndim() const override { return buffer()->ndim; }
200 
201   /// Data pointer.
data()202   void *data() override { return buffer_data(); }
203 
const_data() const204   const void *const_data() const override { return buffer()->ptr; }
205 
is_sub_data() const206   bool is_sub_data() const override { return false; }
207 
has_sub_data() const208   bool has_sub_data() const override { return false; }
209 
is_from_numpy() const210   bool is_from_numpy() const override { return true; }
211 
shape() const212   const std::vector<ssize_t> &shape() const { return buffer()->shape; }
213 
214   /// To string.
ToString(const TypeId,const ShapeVector &,bool use_comma) const215   std::string ToString(const TypeId, const ShapeVector &, bool use_comma) const override {
216     py::gil_scoped_acquire gil_acquire;
217     if (use_comma) {
218       // Call python np.array2string(data_, separator=', ') to convert string with comma.
219       py::dict kwargs;
220       kwargs["separator"] = ", ";
221       auto np = py::module::import("numpy");
222       auto array2string = np.attr("array2string");
223       return py::str(array2string(py_array(), **kwargs));
224     }
225     // without comma.
226     return py::str(py_array());
227   }
228 
229   /// py::array object. by default, use py::str() as the dummy owner to prevent data copy.
py_array(const py::handle & owner=py::str ()) const230   py::array py_array(const py::handle &owner = py::str()) const {
231     py::gil_scoped_acquire acquire;
232     py::dtype np_dtype =
233       (buffer()->format == "T") ? py::detail::npy_format_descriptor<bfloat16>::dtype() : py::dtype(*buffer());
234     return py::array(np_dtype, buffer()->shape, buffer()->strides, buffer()->ptr, owner);
235   }
236 
237  private:
buffer_data() const238   void *buffer_data() const { return buffer_->ptr; }
buffer() const239   std::unique_ptr<py::buffer_info> const &buffer() const {
240     MS_EXCEPTION_IF_NULL(buffer_);
241     return buffer_;
242   }
243 
244   // The internal buffer.
245   std::unique_ptr<py::buffer_info> buffer_;
246 };
247 
248 // This class is uesd to get huge tensor data from persistent storage. Tensor data can be got by slice.
249 // It used at extend embedding to persistent storage.
250 class PersistentTensorDataNumpy : public TensorDataNumpy {
251  public:
PersistentTensorDataNumpy(py::buffer_info && buffer,int slice_num)252   explicit PersistentTensorDataNumpy(py::buffer_info &&buffer, int slice_num)
253       : TensorDataNumpy(std::move(buffer)), slice_num_(slice_num) {}
254 
255   ~PersistentTensorDataNumpy() override = default;
256 
257   // Fill data with a special slice tensor data. It will read data from persistent storage.
FillSliceData(const int32_t param_key,const int slice_index)258   void FillSliceData(const int32_t param_key, const int slice_index) {
259     if (slice_index >= slice_num_) {
260       MS_LOG(ERROR) << "Slice index is out of range, index: " << slice_index;
261       return;
262     }
263     auto emb_store = embedding_storage_manager.Get(param_key);
264     MS_EXCEPTION_IF_NULL(emb_store);
265 
266     size_t first_dim = (size_t)SliceDataShape()[0];
267     size_t start_key = slice_index * first_dim;
268     std::vector<int> keys(first_dim);
269     std::iota(keys.begin(), keys.end(), start_key);
270     if (!emb_store->Get({keys.data(), first_dim * sizeof(int)}, {this->data(), LongToSize(this->nbytes())})) {
271       MS_LOG(EXCEPTION) << "Failed to get data from embedding store!";
272     }
273   }
274 
SliceDataShape() const275   const std::vector<ssize_t> &SliceDataShape() const { return this->shape(); }
276 
277   // Get total silce num of tensor data.
slice_num() const278   int slice_num() const { return slice_num_; }
279 
is_persistent_data() const280   bool is_persistent_data() const override { return true; }
281 
282  private:
283   int slice_num_{1};
284 };
285 
GetPyBufferFromPyArray(const py::array & input)286 py::buffer_info TensorPy::GetPyBufferFromPyArray(const py::array &input) {
287   py::buffer_info buf;
288   auto descr = py::detail::array_descriptor_proxy(py::detail::array_proxy(input.ptr())->descr);
289   // For bfloat16, modify descr->type_num to support acquiring buffer_info from numpy.
290   if (descr->type == 'T') {
291     // convert descr->type_num from T(NPY_BFLOAT16) to H(NPY_USHORT)
292     const int NPY_USHORT = 4;
293     int orig_type_num = descr->type_num;
294     descr->type_num = NPY_USHORT;
295     // acquire buffer_info with type of NPY_USHORT
296     buf = input.request();
297     // convert buffer_info.format from H(NPY_USHORT) to T(NPY_BFLOAT16)
298     buf.format = "T";
299     // change back descr->type_num
300     descr->type_num = orig_type_num;
301   } else {
302     buf = input.request();
303   }
304   return buf;
305 }
306 
MakeTensor(const py::array & input,const TypePtr & type_ptr)307 TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
308   py::gil_scoped_acquire acquire;
309   // Get input buffer info.
310   py::buffer_info buf = TensorPy::GetPyBufferFromPyArray(input);
311   // Check data types.
312   auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown;
313   auto buf_type = GetDataType(buf);
314   if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) {
315     MS_LOG(EXCEPTION) << "Unsupported tensor type!";
316   }
317   MS_LOG(DEBUG) << "data_type: " << data_type << ", buf_type: " << buf_type;
318   if (data_type == TypeId::kObjectTypeString || buf_type == TypeId::kObjectTypeString) {
319     return TensorPy::MakeTensorOfNumpy(input);
320   }
321   // Use buf type as data type if type_ptr not set.
322   if (data_type == TypeId::kTypeUnknown) {
323     data_type = buf_type;
324   }
325   // Convert input array to C contiguous if need.
326   std::unique_ptr<char[]> tmp_buf;
327   if (!IsCContiguous(input)) {
328     Py_buffer pybuf;
329     if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS) != 0) {
330       MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
331     }
332     tmp_buf = std::make_unique<char[]>(pybuf.len);
333     if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C') != 0) {
334       MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
335     }
336     PyBuffer_Release(&pybuf);
337     buf.ptr = tmp_buf.get();
338   }
339   // Get tensor shape.
340   ShapeVector shape(buf.shape.begin(), buf.shape.end());
341   if (data_type == buf_type) {
342     // Use memory copy if input data type is the same as the required type.
343     return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize);
344   }
345   // Create tensor with data type converted.
346   return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type);
347 }
348 
349 /// Creates a Tensor from a numpy array without copy
MakeTensorOfNumpy(const py::array & input)350 TensorPtr TensorPy::MakeTensorOfNumpy(const py::array &input) {
351   py::gil_scoped_acquire acquire;
352   // Check format.
353   if (!IsCContiguous(input)) {
354     MS_LOG(EXCEPTION) << "Array should be C contiguous.";
355   }
356   // Get input buffer info.
357   py::buffer_info buf = TensorPy::GetPyBufferFromPyArray(input);
358   // Get tensor dtype and check it.
359   auto dtype = GetDataType(buf);
360   if (dtype == TypeId::kTypeUnknown) {
361     MS_LOG(EXCEPTION) << "Unsupported data type!";
362   }
363   // Get tensor shape.
364   ShapeVector shape(buf.shape.begin(), buf.shape.end());
365   // Make a tensor with shared data with numpy array.
366   auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf));
367   return std::make_shared<Tensor>(dtype, shape, tensor_data);
368 }
369 
370 /// Creates a Tensor from a numpy array without copy, use persistent tensor data
MakePersistentDataTensorOfNumpy(const py::array & input,const py::int_ slice_num)371 TensorPtr TensorPy::MakePersistentDataTensorOfNumpy(const py::array &input, const py::int_ slice_num) {
372   py::gil_scoped_acquire acquire;
373   // Check format.
374   if (!IsCContiguous(input)) {
375     MS_LOG(EXCEPTION) << "Array should be C contiguous.";
376   }
377   // Get input buffer info.
378   py::buffer_info buf = TensorPy::GetPyBufferFromPyArray(input);
379   // Get tensor dtype and check it.
380   auto dtype = GetDataType(buf);
381   if (dtype == TypeId::kTypeUnknown) {
382     MS_LOG(EXCEPTION) << "Unsupported data type!";
383   }
384   // Get tensor shape.
385   ShapeVector shape(buf.shape.begin(), buf.shape.end());
386   // Make a tensor with shared data with numpy array.
387   auto tensor_data = std::make_shared<PersistentTensorDataNumpy>(std::move(buf), static_cast<int>(slice_num));
388   return std::make_shared<Tensor>(dtype, shape, tensor_data);
389 }
390 
GetStrides(const std::vector<ssize_t> & shape,ssize_t item_size)391 static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) {
392   std::vector<ssize_t> strides;
393   strides.reserve(shape.size());
394   const auto ndim = shape.size();
395   for (size_t i = 0; i < ndim; ++i) {
396     auto stride = item_size;
397     for (size_t j = i + 1; j < ndim; ++j) {
398       stride *= shape[j];
399     }
400     strides.push_back(stride);
401   }
402   return strides;
403 }
404 
GetPyBufferInfo(const Tensor & tensor)405 static py::buffer_info GetPyBufferInfo(const Tensor &tensor) {
406   std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
407   std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
408   return py::buffer_info{
409     tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides};
410 }
411 
GetPyTupleShape(const Tensor & tensor)412 py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
413   auto &shape = tensor.shape();
414   py::tuple dims(shape.size());
415   for (size_t i = 0; i < dims.size(); ++i) {
416     dims[i] = py::int_(shape[i]);
417   }
418   return dims;
419 }
420 
GetPyTupleStrides(const Tensor & tensor)421 py::tuple TensorPy::GetPyTupleStrides(const Tensor &tensor) {
422   std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
423   std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
424   py::tuple py_strides(strides.size());
425   for (size_t i = 0; i < strides.size(); ++i) {
426     py_strides[i] = py::int_(strides[i]);
427   }
428   return py_strides;
429 }
430 
GetPyItemSize(const Tensor & tensor)431 py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().itemsize(); }
432 
GetPyNBytes(const Tensor & tensor)433 py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
434 
435 template <typename T>
MemCopyFromCacheToHost(void * hashmap_addr,void * host_addr,void * cache_addr,size_t host_max,size_t cache_max,size_t hashmap_size,size_t col_size)436 void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
437                             size_t hashmap_size, size_t col_size) {
438   auto host_data = static_cast<char *>(host_addr);
439   auto cache_data = static_cast<char *>(cache_addr);
440   auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
441   // default param type float
442   const size_t param_type_size = 4;
443   size_t single_col_bytes = param_type_size * col_size;
444   for (size_t i = 0; i < hashmap_size; ++i) {
445     if (!hashmap_data[i].IsEmpty()) {
446       size_t host_offset = single_col_bytes * LongToSize(hashmap_data[i].key_);
447       size_t cache_offset = single_col_bytes * LongToSize(hashmap_data[i].value_);
448       if (cache_offset + single_col_bytes <= cache_max) {
449         auto ret =
450           memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes);
451         if (ret != 0) {
452           MS_LOG(EXCEPTION) << "Memcpy failed.";
453         }
454       }
455     }
456   }
457   MS_LOG(INFO) << "Memcpy from cache to host success!";
458 }
459 
FlushFromCache(const Tensor & tensor)460 void TensorPy::FlushFromCache(const Tensor &tensor) {
461   py::gil_scoped_release gil_release;
462   tensor.data_sync();
463 
464   if (tensor.cache_enable()) {
465     MS_LOG(INFO) << tensor.ToString() << " is cache enable.";
466     auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr();
467     auto cache_tensor_ptr = tensor.cache_tensor_ptr();
468     if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) {
469       hashmap_tensor_ptr->data_sync();
470       cache_tensor_ptr->data_sync();
471       auto hashmap_size = hashmap_tensor_ptr->shape_c()[0];
472       auto host_shape = tensor.shape_c();
473       auto cache_shape = cache_tensor_ptr->shape_c();
474       if (host_shape.size() != 2 && cache_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
475         MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
476                           << "host shape:" << host_shape << ", cache shape:" << cache_shape;
477       }
478       auto host_data_max_size = static_cast<size_t>(tensor.Size());
479       auto cache_data_max_size = static_cast<size_t>(cache_tensor_ptr->Size());
480       auto hashmap_data_type = hashmap_tensor_ptr->data_type();
481       if (hashmap_data_type == TypeId::kNumberTypeInt32) {
482         MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
483                                         host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
484       } else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
485         MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
486                                         host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
487       } else {
488         MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
489       }
490     }
491   }
492 }
493 
GetBytes(const Tensor & tensor)494 py::bytes TensorPy::GetBytes(const Tensor &tensor) {
495   py::gil_scoped_acquire acquire;
496   if (tensor.get_copy_done_flag()) {
497     const_cast<Tensor &>(tensor).set_copy_done_flag(false);
498     return py::bytes(static_cast<const char *>(tensor.data_c()), tensor.Size());
499   }
500   tensor.data_sync();
501   return py::bytes(static_cast<const char *>(tensor.data_c()), tensor.Size());
502 }
503 
CopyFromBuffer(char * dst,size_t dst_size,const char * src,size_t src_size,TypeId data_type)504 void CopyFromBuffer(char *dst, size_t dst_size, const char *src, size_t src_size, TypeId data_type) {
505   bool fp16_in_fp32 = (data_type == TypeId::kNumberTypeBFloat16) && (dst_size * 2 == src_size);
506   if (fp16_in_fp32) {
507     int elem_num = static_cast<int>(src_size / sizeof(float));
508     for (int i = 0; i < elem_num; ++i) {
509       auto dst_ptr = static_cast<char *>(dst + i * sizeof(bfloat16));
510       auto src_ptr = static_cast<const char *>(src + sizeof(bfloat16) + i * sizeof(float));
511       errno_t ret = memcpy_s(dst_ptr, sizeof(bfloat16), src_ptr, sizeof(bfloat16));
512       if (ret != EOK) {
513         MS_LOG(EXCEPTION) << "Failed to copy the memory to new tensor:" << ret;
514       }
515     }
516   } else {
517     size_t remain_size = src_size;
518     auto dst_ptr = dst;
519     auto src_ptr = src;
520     while (remain_size > SECUREC_MEM_MAX_LEN) {
521       auto ret = memcpy_s(dst_ptr, SECUREC_MEM_MAX_LEN, src_ptr, SECUREC_MEM_MAX_LEN);
522       if (ret != EOK) {
523         MS_LOG(EXCEPTION) << "Failed to copy the memory to new tensor" << ret;
524       }
525       remain_size -= SECUREC_MEM_MAX_LEN;
526       dst_ptr += SECUREC_MEM_MAX_LEN;
527       src_ptr += SECUREC_MEM_MAX_LEN;
528     }
529     if (remain_size != 0U) {
530       auto ret = memcpy_s(dst_ptr, remain_size, src_ptr, remain_size);
531       if (ret != EOK) {
532         MS_LOG(EXCEPTION) << "Failed to copy the memory to new tensor" << ret;
533       }
534     }
535   }
536 }
537 
ConvertBytesToTensor(const py::bytes & bytes_obj,const py::tuple & dims,const TypePtr & type_ptr)538 TensorPtr TensorPy::ConvertBytesToTensor(const py::bytes &bytes_obj, const py::tuple &dims, const TypePtr &type_ptr) {
539   ShapeVector shape;
540   for (size_t i = 0; i < dims.size(); ++i) {
541     shape.push_back(dims[i].cast<int>());
542   }
543   TypeId data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown;
544   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(data_type, shape);
545   const char *tensor_buf = PYBIND11_BYTES_AS_STRING(bytes_obj.ptr());
546   char *tensor_data_buf = reinterpret_cast<char *>(tensor->data_c());
547   CopyFromBuffer(tensor_data_buf, tensor->Size(), tensor_buf, PYBIND11_BYTES_SIZE(bytes_obj.ptr()), data_type);
548   return tensor;
549 }
550 
SyncAsNumpy(const Tensor & tensor)551 py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
552   runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kAsnumpy);
553   {
554     py::gil_scoped_release gil_release;
555 
556     // BFloat16 may not be supported in numpy.
557     std::string numpy_version = np_dtypes::GetNumpyVersion();
558     std::string minimum_numpy_version = np_dtypes::GetMinimumSupportedNumpyVersion();
559     if (tensor.data_type() == kNumberTypeBFloat16 && !np_dtypes::NumpyVersionValid(numpy_version)) {
560       MS_EXCEPTION(TypeError) << "For asnumpy, the numpy bfloat16 data type is supported in Numpy versions "
561                               << minimum_numpy_version << " to " << minimum_numpy_version[0] << ".x.x, but got "
562                               << numpy_version << ", please upgrade numpy version.";
563     }
564 
565     if (tensor.get_copy_done_flag()) {
566       const_cast<Tensor &>(tensor).set_copy_done_flag(false);
567       if (tensor.need_release_device_mem()) {
568         const_cast<Tensor &>(tensor).set_device_address(nullptr);
569       }
570       return AsNumpy(tensor);
571     }
572     tensor.data_sync();
573 
574     // Release device address of graph output tensor.
575     if (tensor.need_release_device_mem()) {
576       const_cast<Tensor &>(tensor).set_device_address(nullptr);
577     }
578   }
579   return AsNumpy(tensor);
580 }
581 
AsNumpy(const Tensor & tensor)582 py::array TensorPy::AsNumpy(const Tensor &tensor) {
583   // Use TensorData as the owner to prevent use-after-free problem.
584   // We can NOT use Tensor as the owner since its TensorData may change
585   // by other operations such as AssignValue().
586   py::gil_scoped_acquire acquire;
587   py::object owner = py::cast(tensor.data_ptr());
588   auto data_numpy = dynamic_cast<const TensorDataNumpy *>(&tensor.data());
589   if (data_numpy != nullptr) {
590     // Return internal numpy array if tensor data is implemented base on it.
591     return data_numpy->py_array(owner);
592   }
593   // Otherwise, create numpy array by buffer protocol.
594   auto info = GetPyBufferInfo(tensor);
595   py::dtype np_dtype = (tensor.data_type() == kNumberTypeBFloat16)
596                          ? py::detail::npy_format_descriptor<bfloat16>::dtype()
597                          : py::dtype(info);
598   return py::array(np_dtype, info.shape, info.strides, info.ptr, owner);
599 }
600 
Offload(const Tensor & tensor)601 void TensorPy::Offload(const Tensor &tensor) {
602   py::gil_scoped_release gil_release;
603   tensor.data_sync();
604 
605   // Release device address of graph output tensor.
606   const_cast<Tensor &>(tensor).set_device_address(nullptr);
607 }
608 
SetDeviceAddress(const Tensor & tensor,uintptr_t addr,const ShapeVector & shape,const TypePtr type_ptr)609 void TensorPy::SetDeviceAddress(const Tensor &tensor, uintptr_t addr, const ShapeVector &shape,
610                                 const TypePtr type_ptr) {
611   if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice) {
612     MS_LOG(EXCEPTION) << "set_device_address now only support Ascend backend!";
613   }
614 
615   if (type_ptr == nullptr) {
616     MS_LOG(EXCEPTION) << "Dtype to be set is nullptr.";
617   }
618 
619   TypeId data_type = type_ptr->type_id();
620   if (data_type != tensor.data_type()) {
621     MS_LOG(EXCEPTION) << "Dtype to be set is not euqal with the tensor's, then tensor's dtype is" << tensor.data_type();
622   }
623 
624   if (shape != tensor.shape()) {
625     MS_LOG(EXCEPTION) << "Shape to be set is not euqal with the tensor's, then tensor's shape is" << tensor.shape();
626   }
627 
628   void *data = reinterpret_cast<void *>(addr);
629   size_t elem_num = 1;
630   for (size_t i = 0; i < shape.size(); ++i) {
631     elem_num *= shape[i];
632   }
633   auto data_size = elem_num * GetDataTypeSize(data_type);
634   auto device_sync_ = tensor.device_address();
635   if (device_sync_ == nullptr) {
636     auto device_address = std::make_shared<device::MbufDeviceAddress>(data, data_size);
637     const_cast<Tensor &>(tensor).set_device_address(device_address);
638   } else {
639     auto device_address = std::dynamic_pointer_cast<device::MbufDeviceAddress>(device_sync_);
640     device_address->SetData(data);
641   }
642 }
643 
MoveTo(const Tensor & self,const std::string & to,bool blocking)644 TensorPtr TensorPy::MoveTo(const Tensor &self, const std::string &to, bool blocking) {
645   py::gil_scoped_release gil_release;
646   MS_LOG(INFO) << "Try move tensor to " << to;
647   auto context = GetDeviceCtx(to);
648   MS_EXCEPTION_IF_NULL(context);
649   auto target_tensor = std::make_shared<tensor::Tensor>(self.data_type(), self.shape());
650   target_tensor->set_device_address(nullptr);
651   bool return_self = false;
652   // make sure op execute end before data copy
653   runtime::OpExecutor::GetInstance().WaitAll();
654   context->device_res_manager_->MoveTo(std::make_shared<tensor::Tensor>(self), target_tensor, to, blocking,
655                                        &return_self);
656   if (return_self) {
657     return std::make_shared<tensor::Tensor>(self);
658   }
659   return target_tensor;
660 }
661 
AsNumpyOfSlice(const Tensor & tensor,const int32_t param_key,const int slice_index)662 py::array TensorPy::AsNumpyOfSlice(const Tensor &tensor, const int32_t param_key, const int slice_index) {
663   py::gil_scoped_acquire acquire;
664   py::object owner = py::cast(tensor.data_ptr());
665   auto data_numpy = std::dynamic_pointer_cast<PersistentTensorDataNumpy>(tensor.data_ptr());
666   MS_EXCEPTION_IF_NULL(data_numpy);
667 
668   data_numpy->FillSliceData(param_key, slice_index);
669 
670   // Return internal numpy array if tensor data is implemented base on it.
671   // And persistent tensor data is only implemented base on numpy array.
672   return data_numpy->py_array(owner);
673 }
674 
GetShapeFromTuple(const py::tuple & tuple)675 static ShapeVector GetShapeFromTuple(const py::tuple &tuple) {
676   ShapeVector shape;
677   const size_t size = tuple.size();
678   shape.reserve(tuple.size());
679   for (size_t i = 0; i < size; ++i) {
680     shape.push_back(py::int_(tuple[i]));
681   }
682   return shape;
683 }
RegMetaTensor(const py::module * m)684 void RegMetaTensor(const py::module *m) {
685   // Define python MetaTensor class.
686   (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
687     .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape"))
688     .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
689     .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
690     .def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
691     .def(py::pickle(
692       [](const MetaTensor &t) {  // __getstate__
693         /* Return a tuple that fully encodes the state of the object */
694         return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
695       },
696       [](const py::tuple &t) {  // __setstate__
697         constexpr size_t expect_size = 2;
698         if (t.size() != expect_size) {
699           throw std::runtime_error("Invalid state!");
700         }
701         /* Create a new C++ instance */
702         MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<ShapeVector>());
703         return tensor;
704       }));
705   // Define TensorData as a python class so that ownership of tensor data can be managed.
706   (void)py::class_<TensorData, TensorDataPtr>(*m, "_TensorData");
707   // Define python Tensor class.
708   // dtype should define before Tensor, because Tensor init depend dtype
709   (void)py::class_<BaseTensor, MetaTensor, std::shared_ptr<BaseTensor>>(*m, "BaseTensor");
710   (void)py::class_<Tensor, BaseTensor, std::shared_ptr<Tensor>>(*m, "Tensor")
711     .def(py::init([](const Tensor &tensor) { return std::make_shared<Tensor>(tensor); }), py::arg("input"))
712     .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) {
713            TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown;
714            if (data_type == kTypeUnknown || tensor.data_type() == data_type) {
715              return std::make_shared<Tensor>(tensor);
716            }
717            return std::make_shared<Tensor>(tensor, data_type);
718          }),
719          py::arg("input"), py::arg("dtype"))
720     .def(py::init([](const BaseTensor &tensor) { return std::make_shared<Tensor>(tensor); }), py::arg("input"))
721     .def(py::init([](const BaseTensor &tensor, const TypePtr &type_ptr) {
722            TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown;
723            if (data_type == kTypeUnknown || tensor.data_type() == data_type) {
724              return std::make_shared<Tensor>(tensor);
725            }
726            return std::make_shared<Tensor>(tensor, data_type);
727          }),
728          py::arg("input"), py::arg("dtype"))
729     .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) {
730            auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
731            return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
732          }),
733          py::arg("dtype"), py::arg("shape"))
734     .def(py::init([](const TypePtr &type_ptr, const py::list &shape) {
735            auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
736            return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
737          }),
738          py::arg("dtype"), py::arg("shape"))
739     .def(
740       py::init([](const py::array &input, const TypePtr &type_ptr) { return TensorPy::MakeTensor(input, type_ptr); }),
741       py::arg("input"), py::arg("dtype") = nullptr)
742     .def(py::init([](const py::float_ input, const TypePtr &type_ptr) {
743            return TensorPy::MakeTensor(py::array(input), type_ptr);
744          }),
745          py::arg("input"), py::arg("dtype") = nullptr)
746     .def(py::init([](const py::int_ input, const TypePtr &type_ptr) {
747            return TensorPy::MakeTensor(py::array(input), type_ptr);
748          }),
749          py::arg("input"), py::arg("dtype") = nullptr)
750     .def(py::init([](const py::list &input, const TypePtr &type_ptr) {
751            return TensorPy::MakeTensor(py::array(input), type_ptr);
752          }),
753          py::arg("input"), py::arg("dtype") = nullptr)
754     .def(py::init([](const py::tuple &input, const TypePtr &type_ptr) {
755            return TensorPy::MakeTensor(py::array(input), type_ptr);
756          }),
757          py::arg("input"), py::arg("dtype") = nullptr)
758     // We only suppot array/bool_/int_/float_/list/tuple/complex pybind objects as tensor input,
759     // and array/bool_/int_/float_/list/tuple init will be matched above, other pybind objects
760     // input will raise error except complex data type.
761     .def(py::init([](const py::object &input, const TypePtr &type_ptr) {
762            if (!PyComplex_CheckExact(input.ptr())) {
763              MS_LOG(EXCEPTION) << "Unsupported tensor type: " << input.get_type();
764            }
765            return TensorPy::MakeTensor(py::array(input), type_ptr);
766          }),
767          py::arg("input"), py::arg("dtype") = nullptr)
768     .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
769     .def_property("adapter_flag", &Tensor::is_adapter, &Tensor::set_adapter_flag)
770     .def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
771                              Get the tensor's data type.
772 
773                              Returns:
774                                  type, the data type of tensor.
775 
776                              Examples:
777                                  >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
778                                  >>> data.dtype
779                                  Int32
780                              )mydelimiter")
781     .def_property("_shape", TensorPy::GetPyTupleShape, &Tensor::set_shape)
782     .def_property_readonly("_size", &Tensor::DataSize, R"mydelimiter(
783                              Get tensor's data size.
784 
785                              Returns:
786                                  size_t, the size of tensor.
787 
788                              Examples:
789                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
790                                  >>> data.size
791                                  6
792                              )mydelimiter")
793     .def_property_readonly("_itemsize", TensorPy::GetPyItemSize, R"mydelimiter(
794                              Get the tensor's length of one element in bytes.
795 
796                              Returns:
797                                  itemsize, length of one element in bytes.
798 
799                              Examples:
800                                  >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
801                                  >>> data.itemsize
802                                  4
803                              )mydelimiter")
804     .def_property_readonly("_nbytes", TensorPy::GetPyNBytes, R"mydelimiter(
805                              Get the tensor's total number of bytes.
806 
807                              Returns:
808                                  nbytes, total number of bytes taken by the tensor.
809 
810                              Examples:
811                                  >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
812                                  >>> data.nbytes
813                                  4
814                              )mydelimiter")
815     .def_property_readonly("_strides", TensorPy::GetPyTupleStrides, R"mydelimiter(
816                              Get the tensor's tuple of bytes to step in each dimension
817                              when traversing an array.
818 
819                              Returns:
820                                  tuple[int], the strides of the tensor.
821 
822                              Examples:
823                                  >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
824                                  >>> data.strides
825                                  (4, 4)
826                              )mydelimiter")
827     .def("_flatten_tensors", Tensor::FlattenTensors, py::arg("fusion_size") = 0)
828     .def("setitem_index_info", TensorIndex::SetItemIndexInfo)
829     .def("getitem_index_info", TensorIndex::GetItemIndexInfo)
830     .def("_is_flattened", Tensor::IsFlattened)
831     .def("_get_flattened_tensors", Tensor::GetFlattenedTensors)
832     .def("_get_fusion_size", Tensor::GetFusionSize)
833     .def("_is_test_stub", Tensor::CheckStub)
834     .def("from_numpy", TensorPy::MakeTensorOfNumpy, R"mydelimiter(
835                              Creates a Tensor from a numpy.ndarray without copy.
836 
837                              Arg:
838                                  array (numpy.ndarray): The input ndarray.
839 
840                              Returns:
841                                  Tensor, tensor with shared data to input ndarray.
842 
843                              Examples:
844                                  >>> a = np.ones((2, 3))
845                                  >>> t = mindspore.Tensor.from_numpy(a)
846                              )mydelimiter")
847     .def("persistent_data_from_numpy", TensorPy::MakePersistentDataTensorOfNumpy, R"mydelimiter(
848                              Creates a Tensor from a numpy.ndarray without copy.
849                              Use persistent data tensor.
850 
851                              Arg:
852                                  array (numpy.ndarray): The input ndarray.
853                                  slice_num (int): The slice num of persistent data tensor.
854 
855                              Returns:
856                                  Tensor, tensor with shared data to input ndarray.
857 
858                              Examples:
859                                  >>> a = np.ones((2, 3))
860                                  >>> t = mindspore.Tensor.persistent_data_from_numpy(a, 1)
861                              )mydelimiter")
862     .def("get_bytes", &TensorPy::GetBytes, R"mydelimiter(
863                              Get raw data of tensor with type of bytes.
864 
865                              Returns:
866                                  Bytes of tensor.
867 
868                              Examples:
869                                  >>> import mindspore as ms
870                                  >>> from mindspore import Tensor
871                                  >>> x = ms.Tensor([1, 2, 3], ms.int16)
872                                  >>> print(x.get_bytes())
873                                  b'\x01\x00\x02\x00\x03\x00'
874                              )mydelimiter")
875     .def("convert_bytes_to_tensor", &TensorPy::ConvertBytesToTensor, R"mydelimiter(
876                              Convert raw data to tensor.
877 
878                              Returns:
879                                  Tensor.
880 
881                              Examples:
882                                  >>> import mindspore as ms
883                                  >>> from mindspore import Tensor
884                                  >>> x = Tensor([1, 2, 3], ms.int16)
885                                  >>> out = Tensor.convert_bytes_to_tensor(x.get_bytes(), x.shape, x.dtype)
886                                  >>> print(x.asnumpy())
887                                  [1 2 3]
888                              )mydelimiter")
889     .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter(
890                              Convert tensor to numpy.ndarray.
891 
892                              Returns:
893                                  numpy.ndarray.
894 
895                              Examples:
896                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
897                                  >>> array = data.asnumpy()
898                                  >>> array
899                                  array([[1., 1., 1.],
900                                         [1., 1., 1.]])
901                              )mydelimiter")
902     .def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter(
903                              Flush Cache data to Host if tensor is cache enable.
904 
905                              Returns:
906                                  None.
907 
908                              Examples:
909                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
910                                  >>> data._flush_from_cache()
911                              )mydelimiter")
912     .def("is_persistent_data", &Tensor::is_persistent_data, R"mydelimiter(
913                              Check if tensor have persistent data.
914 
915                              Returns:
916                                  Bool.
917 
918                              Examples:
919                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
920                                  >>> data.is_persistent_data()
921                              )mydelimiter")
922     .def("asnumpy_of_slice_persistent_data", TensorPy::AsNumpyOfSlice, R"mydelimiter(
923                              Convert tensor to numpy.ndarray of a slice.
924 
925                              Returns:
926                                  numpy.ndarray.
927 
928                              Examples:
929                                  >>> data = mindspore.Tensor(np.ones((2000000000, 256)))
930                                  >>> data.asnumpy_of_slice_persistent_data(0, 1)
931                              )mydelimiter")
932     .def("is_init", &Tensor::is_init, R"mydelimiter(
933                              Get tensor init_flag.
934 
935                              Returns:
936                                  bool, whether the tensor init.
937 
938                              Examples:
939                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
940                                  >>> data.is_init()
941                                  False
942                              )mydelimiter")
943     .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter(
944                              Set tensor init_flag.
945 
946                              Examples:
947                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
948                                  >>> data.set_init_flag(True)
949                              )mydelimiter")
950     .def("dim", &Tensor::DataDim, R"mydelimiter(
951                              Get tensor's data dimension.
952 
953                              Returns:
954                                  int, the dimension of tensor.
955 
956                              Examples:
957                                  >>> data = mindspore.Tensor(np.ones((2, 3)))
958                                  >>> data.dim()
959                                  2
960                              )mydelimiter")
961     .def("assign_value_cpp", &Tensor::AssignValue, R"mydelimiter(
962                              Assign another tensor value to this.
963 
964                              Arg:
965                                  value (:class:`mindspore.tensor`): The value tensor.
966 
967                              Examples:
968                                  >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
969                                  >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
970                                  >>> data.assign_value(data2)
971                                  >>> data.shape
972                                  (2, 2)
973                              )mydelimiter")
974     .def("set_dtype", &Tensor::SetDtype, R"mydelimiter(
975                               Set the tensor's data type.
976 
977                               Arg:
978                                   dtype (:class:`mindspore.dtype`): The type of output tensor.
979 
980                               Examples:
981                                   >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
982                                   >>> data.set_dtype(mindspore.int32)
983                                   mindspore.int32
984                               )mydelimiter")
985     .def("offload", &Tensor::Offload, R"mydelimiter(
986                               Offload tensor data to file.
987 
988                               Arg:
989                                   str : file path to save tensor data.
990                               Returns:
991                                   bool, whether the tensor offload success.
992                               Examples:
993                                   >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
994                                   >>> data.offload('./test.data')
995                                   True
996                               )mydelimiter")
997     .def("offload_file_path", &Tensor::GetOffloadFilePath, R"mydelimiter(
998                               Offload file path for tensor.
999 
1000                               Returns:
1001                                  str, offload file path for tensor.
1002                               Examples:
1003                                   >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
1004                                   >>> ret = data.offload('./test.data')
1005                                   >>> ret = (data.offload_file_path() != '')
1006                                   True
1007                               )mydelimiter")
1008     .def("move_to", &TensorPy::MoveTo, R"mydelimiter(
1009                                Copy tensor between host and device asynchronously if blocking=False,
1010                                otherwise synchronously. if the arg `to`=`CPU`, means D2H copy;
1011                                if the arg `to`=`GPU` or `to`=`ASCEND`, means H2D copy.
1012 
1013                                Args:
1014                                    str: A string, "CPU" or "ASCEND" or "GPU".
1015                                    bool: A bool type value, Default: ``True`` .
1016 
1017                                Returns:
1018                                       Tensor, with the same type and shape as the "self".
1019 
1020                               Examples:
1021                                   >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
1022                                   >>> ret = data.move_to("CPU")
1023                               )mydelimiter")
1024     .def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr)
1025     .def("data_sync", &Tensor::data_sync)
1026     .def("wait_pipeline", &Tensor::ExecuteLazyTask)
1027     .def("is_contiguous", &Tensor::is_contiguous)
1028     .def("stride", &Tensor::stride)
1029     .def("storage_offset", &Tensor::storage_offset)
1030     .def("register_hook", &RegisterHook::RegisterTensorBackwardHook)
1031     .def("remove_hook", &RegisterHook::RemoveTensorBackwardHook)
1032     .def("__str__", &Tensor::ToString)
1033     .def("__repr__", &Tensor::ToStringRepr)
1034     .def("_offload", &TensorPy::Offload)
1035     .def("set_device_address", &TensorPy::SetDeviceAddress, py::arg("addr"), py::arg("shape"), py::arg("dtype"))
1036     .def(
1037       py::pickle(
1038         [](const Tensor &t) {  // __getstate__
1039           /* Return a tuple that fully encodes the state of the object */
1040           return py::make_tuple(TensorPy::SyncAsNumpy(t));
1041         },
1042         [](const py::tuple &t) {  // __setstate__
1043           if (t.size() != 1) {
1044             throw std::runtime_error("Invalid state!");
1045           }
1046           /* Create a new C++ instance */
1047           return TensorPy::MakeTensor(t[0].cast<py::array>());
1048         }));
1049 }
1050 
1051 template <typename T>
GetSparseTensorShape(const T & sparse_tensor)1052 py::tuple GetSparseTensorShape(const T &sparse_tensor) {
1053   auto &shape = sparse_tensor.shape();
1054   py::tuple dims(shape.size());
1055   for (size_t i = 0; i < dims.size(); ++i) {
1056     dims[i] = py::int_(shape[i]);
1057   }
1058   return dims;
1059 }
1060 
GetPyTupleShape(const CSRTensor & csr_tensor)1061 py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { return GetSparseTensorShape(csr_tensor); }
1062 
RegCSRTensor(const py::module * m)1063 void RegCSRTensor(const py::module *m) {
1064   // Define python CSRTensor class.
1065   (void)py::class_<CSRTensor, std::shared_ptr<CSRTensor>>(*m, "CSRTensor")
1066     .def(py::init(
1067            [](const BaseTensor &indptr, const BaseTensor &indices, const BaseTensor &values, const py::tuple &shape) {
1068              return std::make_shared<CSRTensor>(std::make_shared<Tensor>(indptr), std::make_shared<Tensor>(indices),
1069                                                 std::make_shared<Tensor>(values), GetShapeFromTuple(shape));
1070            }),
1071          py::arg("indptr"), py::arg("indices"), py::arg("values"), py::arg("shape"))
1072     .def(py::init([](const CSRTensor &csr_tensor) { return std::make_shared<CSRTensor>(csr_tensor); }),
1073          py::arg("input"))
1074     .def_property_readonly("_shape", CSRTensorPy::GetPyTupleShape)
1075     .def_property_readonly("_dtype", &CSRTensor::Dtype)
1076     .def_property_readonly("_indptr", &CSRTensor::GetIndptr)
1077     .def_property_readonly("_indices", &CSRTensor::GetIndices)
1078     .def_property_readonly("_values", &CSRTensor::GetValues)
1079     .def("__str__", &CSRTensor::ToString)
1080     .def("__repr__", &CSRTensor::ToString);
1081 }
1082 
GetPyTupleShape(const COOTensor & coo_tensor)1083 py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) { return GetSparseTensorShape(coo_tensor); }
1084 
RegCOOTensor(const py::module * m)1085 void RegCOOTensor(const py::module *m) {
1086   // Define python COOTensor class.
1087   (void)py::class_<COOTensor, std::shared_ptr<COOTensor>>(*m, "COOTensor")
1088     .def(py::init([](const BaseTensor &indices, const BaseTensor &values, const py::tuple &shape) {
1089            return std::make_shared<COOTensor>(std::make_shared<Tensor>(indices), std::make_shared<Tensor>(values),
1090                                               GetShapeFromTuple(shape));
1091          }),
1092          py::arg("indices"), py::arg("values"), py::arg("shape"))
1093     .def(py::init([](const COOTensor &coo_tensor) { return std::make_shared<COOTensor>(coo_tensor); }),
1094          py::arg("input"))
1095     .def_property_readonly("_shape", COOTensorPy::GetPyTupleShape)
1096     .def_property_readonly("_dtype", &COOTensor::Dtype)
1097     .def_property_readonly("_indices", &COOTensor::GetIndices)
1098     .def_property_readonly("_values", &COOTensor::GetValues)
1099     .def("__str__", &COOTensor::ToString)
1100     .def("__repr__", &COOTensor::ToString);
1101 }
1102 
GetPyTupleShape(const RowTensor & row_tensor)1103 py::tuple RowTensorPy::GetPyTupleShape(const RowTensor &row_tensor) { return GetSparseTensorShape(row_tensor); }
1104 
RegRowTensor(const py::module * m)1105 void RegRowTensor(const py::module *m) {
1106   // Define python RowTensor class.
1107   (void)py::class_<RowTensor, std::shared_ptr<RowTensor>>(*m, "RowTensor")
1108     .def(py::init([](const BaseTensor &indices, const BaseTensor &values, const py::tuple &shape) {
1109            return std::make_shared<RowTensor>(std::make_shared<Tensor>(indices), std::make_shared<Tensor>(values),
1110                                               GetShapeFromTuple(shape));
1111          }),
1112          py::arg("indices"), py::arg("values"), py::arg("shape"))
1113     .def(py::init([](const RowTensor &row_tensor) { return std::make_shared<RowTensor>(row_tensor); }),
1114          py::arg("input"))
1115     .def_property_readonly("_shape", RowTensorPy::GetPyTupleShape)
1116     .def_property_readonly("_dtype", &RowTensor::Dtype)
1117     .def_property_readonly("_indices", &RowTensor::GetIndices)
1118     .def_property_readonly("_values", &RowTensor::GetValues)
1119     .def("__str__", &RowTensor::ToString)
1120     .def("__repr__", &RowTensor::ToString);
1121 }
1122 }  // namespace tensor
1123 }  // namespace mindspore
1124