1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pybind_api/ir/tensor_py.h"
18
19 #include <utility>
20
21 #include "include/common/pybind_api/api_register.h"
22 #include "abstract/abstract_value.h"
23 #include "utils/cache_embedding_hashmap_struct.h"
24 #include "include/common/utils/python_adapter.h"
25 #include "mindspore/ccsrc/include/backend/distributed/embedding_cache/embedding_cache_utils.h"
26 #include "pybind_api/ir/tensor_index_py.h"
27 #include "pybind_api/ir/hook_py.h"
28 #include "include/common/profiler.h"
29 #include "runtime/hardware/device_context_manager.h"
30 #include "runtime/pynative/op_executor.h"
31 #include "include/backend/mbuf_device_address.h"
32
33 namespace mindspore {
34 namespace tensor {
35 namespace {
36 struct TensorToNumpyRegister {
TensorToNumpyRegistermindspore::tensor::__anonb5d003050111::TensorToNumpyRegister37 TensorToNumpyRegister() { python_adapter::PyAdapterCallback::SetTensorToNumpyHandler(tensor::TensorPy::AsNumpy); }
38 } callback_register;
39
GetDeviceCtx(const std::string & to)40 device::DeviceContext *GetDeviceCtx(const std::string &to) {
41 const auto &device = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
42 if (to != "CPU" && to != device) {
43 MS_LOG(EXCEPTION) << "The value of 'to' should be same with device, bug got to:" << to << ", device: " << device;
44 }
45 auto device_ctx = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
46 {device, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
47 MS_EXCEPTION_IF_NULL(device_ctx);
48
49 device_ctx->Initialize();
50 return device_ctx;
51 }
52 } // namespace
53 constexpr ssize_t kPyBufItemSize1 = 1;
54 constexpr ssize_t kPyBufItemSize2 = 2;
55 constexpr ssize_t kPyBufItemSize4 = 4;
56 constexpr ssize_t kPyBufItemSize8 = 8;
57
GetDataType(const py::buffer_info & buf)58 static TypeId GetDataType(const py::buffer_info &buf) {
59 if (buf.format.size() == 1) {
60 switch (buf.format.front()) {
61 case 'e':
62 case 'f':
63 case 'd':
64 switch (buf.itemsize) {
65 case kPyBufItemSize2:
66 return TypeId::kNumberTypeFloat16;
67 case kPyBufItemSize4:
68 return TypeId::kNumberTypeFloat32;
69 case kPyBufItemSize8:
70 return TypeId::kNumberTypeFloat64;
71 }
72 break;
73 case 'b':
74 case 'h':
75 case 'i':
76 case 'l':
77 case 'q':
78 switch (buf.itemsize) {
79 case kPyBufItemSize1:
80 return TypeId::kNumberTypeInt8;
81 case kPyBufItemSize2:
82 return TypeId::kNumberTypeInt16;
83 case kPyBufItemSize4:
84 return TypeId::kNumberTypeInt32;
85 case kPyBufItemSize8:
86 return TypeId::kNumberTypeInt64;
87 default:
88 break;
89 }
90 break;
91 case 'B':
92 case 'H':
93 case 'I':
94 case 'L':
95 case 'Q':
96 switch (buf.itemsize) {
97 case kPyBufItemSize1:
98 return TypeId::kNumberTypeUInt8;
99 case kPyBufItemSize2:
100 return TypeId::kNumberTypeUInt16;
101 case kPyBufItemSize4:
102 return TypeId::kNumberTypeUInt32;
103 case kPyBufItemSize8:
104 return TypeId::kNumberTypeUInt64;
105 default:
106 break;
107 }
108 break;
109 case '?':
110 return TypeId::kNumberTypeBool;
111 case 'T':
112 return TypeId::kNumberTypeBFloat16;
113 default:
114 break;
115 }
116 } else if (buf.format.size() >= 2) {
117 // Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.
118 if (buf.format.back() == 'w' || buf.format.back() == 's') {
119 return TypeId::kObjectTypeString;
120 } else if (buf.format == "Zf") {
121 return TypeId::kNumberTypeComplex64;
122 } else if (buf.format == "Zd") {
123 return TypeId::kNumberTypeComplex128;
124 }
125 }
126 MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize;
127 return TypeId::kTypeUnknown;
128 }
129
GetPyTypeFormat(TypeId data_type)130 static std::string GetPyTypeFormat(TypeId data_type) {
131 switch (data_type) {
132 case TypeId::kNumberTypeFloat16:
133 return "e";
134 case TypeId::kNumberTypeBFloat16:
135 return "T";
136 case TypeId::kNumberTypeFloat32:
137 return py::format_descriptor<float>::format();
138 case TypeId::kNumberTypeFloat64:
139 return py::format_descriptor<double>::format();
140 case TypeId::kNumberTypeUInt8:
141 return py::format_descriptor<uint8_t>::format();
142 case TypeId::kNumberTypeUInt16:
143 return py::format_descriptor<uint16_t>::format();
144 case TypeId::kNumberTypeUInt32:
145 return py::format_descriptor<uint32_t>::format();
146 case TypeId::kNumberTypeUInt64:
147 return py::format_descriptor<uint64_t>::format();
148 case TypeId::kNumberTypeInt4:
149 case TypeId::kNumberTypeInt8:
150 return py::format_descriptor<int8_t>::format();
151 case TypeId::kNumberTypeInt16:
152 return py::format_descriptor<int16_t>::format();
153 case TypeId::kNumberTypeInt:
154 case TypeId::kNumberTypeInt32:
155 return py::format_descriptor<int32_t>::format();
156 case TypeId::kNumberTypeInt64:
157 return py::format_descriptor<int64_t>::format();
158 case TypeId::kNumberTypeBool:
159 return py::format_descriptor<bool>::format();
160 case TypeId::kObjectTypeString:
161 return py::format_descriptor<uint8_t>::format();
162 case TypeId::kNumberTypeComplex64:
163 return py::format_descriptor<std::complex<float>>::format();
164 case TypeId::kNumberTypeComplex128:
165 return py::format_descriptor<std::complex<double>>::format();
166 case TypeId::kMetaTypeType:
167 case TypeId::kMetaTypeEllipsis:
168 default:
169 MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
170 return "";
171 }
172 }
173
IsCContiguous(const py::array & input)174 static bool IsCContiguous(const py::array &input) {
175 auto flags = static_cast<unsigned int>(input.flags());
176 return (flags & static_cast<unsigned int>(pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_)) != 0;
177 }
178
179 // TensorDataNumpy implements TensorData using numpy array.
180 class TensorDataNumpy : public TensorData {
181 public:
TensorDataNumpy(py::buffer_info && buffer)182 explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::make_unique<py::buffer_info>(std::move(buffer))) {}
183
~TensorDataNumpy()184 ~TensorDataNumpy() override {
185 py::gil_scoped_acquire acquire;
186 buffer_.reset();
187 }
188
189 /// Total number of elements.
size() const190 ssize_t size() const override { return buffer()->size; }
191
192 /// Byte size of a single element.
itemsize() const193 ssize_t itemsize() const override { return buffer()->itemsize; }
194
195 /// Total number of bytes.
nbytes() const196 ssize_t nbytes() const override { return buffer()->itemsize * buffer()->size; }
197
198 /// Number of dimensions.
ndim() const199 ssize_t ndim() const override { return buffer()->ndim; }
200
201 /// Data pointer.
data()202 void *data() override { return buffer_data(); }
203
const_data() const204 const void *const_data() const override { return buffer()->ptr; }
205
is_sub_data() const206 bool is_sub_data() const override { return false; }
207
has_sub_data() const208 bool has_sub_data() const override { return false; }
209
is_from_numpy() const210 bool is_from_numpy() const override { return true; }
211
shape() const212 const std::vector<ssize_t> &shape() const { return buffer()->shape; }
213
214 /// To string.
ToString(const TypeId,const ShapeVector &,bool use_comma) const215 std::string ToString(const TypeId, const ShapeVector &, bool use_comma) const override {
216 py::gil_scoped_acquire gil_acquire;
217 if (use_comma) {
218 // Call python np.array2string(data_, separator=', ') to convert string with comma.
219 py::dict kwargs;
220 kwargs["separator"] = ", ";
221 auto np = py::module::import("numpy");
222 auto array2string = np.attr("array2string");
223 return py::str(array2string(py_array(), **kwargs));
224 }
225 // without comma.
226 return py::str(py_array());
227 }
228
229 /// py::array object. by default, use py::str() as the dummy owner to prevent data copy.
py_array(const py::handle & owner=py::str ()) const230 py::array py_array(const py::handle &owner = py::str()) const {
231 py::gil_scoped_acquire acquire;
232 py::dtype np_dtype =
233 (buffer()->format == "T") ? py::detail::npy_format_descriptor<bfloat16>::dtype() : py::dtype(*buffer());
234 return py::array(np_dtype, buffer()->shape, buffer()->strides, buffer()->ptr, owner);
235 }
236
237 private:
buffer_data() const238 void *buffer_data() const { return buffer_->ptr; }
buffer() const239 std::unique_ptr<py::buffer_info> const &buffer() const {
240 MS_EXCEPTION_IF_NULL(buffer_);
241 return buffer_;
242 }
243
244 // The internal buffer.
245 std::unique_ptr<py::buffer_info> buffer_;
246 };
247
248 // This class is uesd to get huge tensor data from persistent storage. Tensor data can be got by slice.
249 // It used at extend embedding to persistent storage.
250 class PersistentTensorDataNumpy : public TensorDataNumpy {
251 public:
PersistentTensorDataNumpy(py::buffer_info && buffer,int slice_num)252 explicit PersistentTensorDataNumpy(py::buffer_info &&buffer, int slice_num)
253 : TensorDataNumpy(std::move(buffer)), slice_num_(slice_num) {}
254
255 ~PersistentTensorDataNumpy() override = default;
256
257 // Fill data with a special slice tensor data. It will read data from persistent storage.
FillSliceData(const int32_t param_key,const int slice_index)258 void FillSliceData(const int32_t param_key, const int slice_index) {
259 if (slice_index >= slice_num_) {
260 MS_LOG(ERROR) << "Slice index is out of range, index: " << slice_index;
261 return;
262 }
263 auto emb_store = embedding_storage_manager.Get(param_key);
264 MS_EXCEPTION_IF_NULL(emb_store);
265
266 size_t first_dim = (size_t)SliceDataShape()[0];
267 size_t start_key = slice_index * first_dim;
268 std::vector<int> keys(first_dim);
269 std::iota(keys.begin(), keys.end(), start_key);
270 if (!emb_store->Get({keys.data(), first_dim * sizeof(int)}, {this->data(), LongToSize(this->nbytes())})) {
271 MS_LOG(EXCEPTION) << "Failed to get data from embedding store!";
272 }
273 }
274
SliceDataShape() const275 const std::vector<ssize_t> &SliceDataShape() const { return this->shape(); }
276
277 // Get total silce num of tensor data.
slice_num() const278 int slice_num() const { return slice_num_; }
279
is_persistent_data() const280 bool is_persistent_data() const override { return true; }
281
282 private:
283 int slice_num_{1};
284 };
285
GetPyBufferFromPyArray(const py::array & input)286 py::buffer_info TensorPy::GetPyBufferFromPyArray(const py::array &input) {
287 py::buffer_info buf;
288 auto descr = py::detail::array_descriptor_proxy(py::detail::array_proxy(input.ptr())->descr);
289 // For bfloat16, modify descr->type_num to support acquiring buffer_info from numpy.
290 if (descr->type == 'T') {
291 // convert descr->type_num from T(NPY_BFLOAT16) to H(NPY_USHORT)
292 const int NPY_USHORT = 4;
293 int orig_type_num = descr->type_num;
294 descr->type_num = NPY_USHORT;
295 // acquire buffer_info with type of NPY_USHORT
296 buf = input.request();
297 // convert buffer_info.format from H(NPY_USHORT) to T(NPY_BFLOAT16)
298 buf.format = "T";
299 // change back descr->type_num
300 descr->type_num = orig_type_num;
301 } else {
302 buf = input.request();
303 }
304 return buf;
305 }
306
MakeTensor(const py::array & input,const TypePtr & type_ptr)307 TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
308 py::gil_scoped_acquire acquire;
309 // Get input buffer info.
310 py::buffer_info buf = TensorPy::GetPyBufferFromPyArray(input);
311 // Check data types.
312 auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown;
313 auto buf_type = GetDataType(buf);
314 if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) {
315 MS_LOG(EXCEPTION) << "Unsupported tensor type!";
316 }
317 MS_LOG(DEBUG) << "data_type: " << data_type << ", buf_type: " << buf_type;
318 if (data_type == TypeId::kObjectTypeString || buf_type == TypeId::kObjectTypeString) {
319 return TensorPy::MakeTensorOfNumpy(input);
320 }
321 // Use buf type as data type if type_ptr not set.
322 if (data_type == TypeId::kTypeUnknown) {
323 data_type = buf_type;
324 }
325 // Convert input array to C contiguous if need.
326 std::unique_ptr<char[]> tmp_buf;
327 if (!IsCContiguous(input)) {
328 Py_buffer pybuf;
329 if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS) != 0) {
330 MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
331 }
332 tmp_buf = std::make_unique<char[]>(pybuf.len);
333 if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C') != 0) {
334 MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
335 }
336 PyBuffer_Release(&pybuf);
337 buf.ptr = tmp_buf.get();
338 }
339 // Get tensor shape.
340 ShapeVector shape(buf.shape.begin(), buf.shape.end());
341 if (data_type == buf_type) {
342 // Use memory copy if input data type is the same as the required type.
343 return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize);
344 }
345 // Create tensor with data type converted.
346 return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type);
347 }
348
349 /// Creates a Tensor from a numpy array without copy
MakeTensorOfNumpy(const py::array & input)350 TensorPtr TensorPy::MakeTensorOfNumpy(const py::array &input) {
351 py::gil_scoped_acquire acquire;
352 // Check format.
353 if (!IsCContiguous(input)) {
354 MS_LOG(EXCEPTION) << "Array should be C contiguous.";
355 }
356 // Get input buffer info.
357 py::buffer_info buf = TensorPy::GetPyBufferFromPyArray(input);
358 // Get tensor dtype and check it.
359 auto dtype = GetDataType(buf);
360 if (dtype == TypeId::kTypeUnknown) {
361 MS_LOG(EXCEPTION) << "Unsupported data type!";
362 }
363 // Get tensor shape.
364 ShapeVector shape(buf.shape.begin(), buf.shape.end());
365 // Make a tensor with shared data with numpy array.
366 auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf));
367 return std::make_shared<Tensor>(dtype, shape, tensor_data);
368 }
369
370 /// Creates a Tensor from a numpy array without copy, use persistent tensor data
MakePersistentDataTensorOfNumpy(const py::array & input,const py::int_ slice_num)371 TensorPtr TensorPy::MakePersistentDataTensorOfNumpy(const py::array &input, const py::int_ slice_num) {
372 py::gil_scoped_acquire acquire;
373 // Check format.
374 if (!IsCContiguous(input)) {
375 MS_LOG(EXCEPTION) << "Array should be C contiguous.";
376 }
377 // Get input buffer info.
378 py::buffer_info buf = TensorPy::GetPyBufferFromPyArray(input);
379 // Get tensor dtype and check it.
380 auto dtype = GetDataType(buf);
381 if (dtype == TypeId::kTypeUnknown) {
382 MS_LOG(EXCEPTION) << "Unsupported data type!";
383 }
384 // Get tensor shape.
385 ShapeVector shape(buf.shape.begin(), buf.shape.end());
386 // Make a tensor with shared data with numpy array.
387 auto tensor_data = std::make_shared<PersistentTensorDataNumpy>(std::move(buf), static_cast<int>(slice_num));
388 return std::make_shared<Tensor>(dtype, shape, tensor_data);
389 }
390
GetStrides(const std::vector<ssize_t> & shape,ssize_t item_size)391 static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) {
392 std::vector<ssize_t> strides;
393 strides.reserve(shape.size());
394 const auto ndim = shape.size();
395 for (size_t i = 0; i < ndim; ++i) {
396 auto stride = item_size;
397 for (size_t j = i + 1; j < ndim; ++j) {
398 stride *= shape[j];
399 }
400 strides.push_back(stride);
401 }
402 return strides;
403 }
404
GetPyBufferInfo(const Tensor & tensor)405 static py::buffer_info GetPyBufferInfo(const Tensor &tensor) {
406 std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
407 std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
408 return py::buffer_info{
409 tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides};
410 }
411
GetPyTupleShape(const Tensor & tensor)412 py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
413 auto &shape = tensor.shape();
414 py::tuple dims(shape.size());
415 for (size_t i = 0; i < dims.size(); ++i) {
416 dims[i] = py::int_(shape[i]);
417 }
418 return dims;
419 }
420
GetPyTupleStrides(const Tensor & tensor)421 py::tuple TensorPy::GetPyTupleStrides(const Tensor &tensor) {
422 std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
423 std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
424 py::tuple py_strides(strides.size());
425 for (size_t i = 0; i < strides.size(); ++i) {
426 py_strides[i] = py::int_(strides[i]);
427 }
428 return py_strides;
429 }
430
GetPyItemSize(const Tensor & tensor)431 py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().itemsize(); }
432
GetPyNBytes(const Tensor & tensor)433 py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
434
435 template <typename T>
MemCopyFromCacheToHost(void * hashmap_addr,void * host_addr,void * cache_addr,size_t host_max,size_t cache_max,size_t hashmap_size,size_t col_size)436 void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
437 size_t hashmap_size, size_t col_size) {
438 auto host_data = static_cast<char *>(host_addr);
439 auto cache_data = static_cast<char *>(cache_addr);
440 auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
441 // default param type float
442 const size_t param_type_size = 4;
443 size_t single_col_bytes = param_type_size * col_size;
444 for (size_t i = 0; i < hashmap_size; ++i) {
445 if (!hashmap_data[i].IsEmpty()) {
446 size_t host_offset = single_col_bytes * LongToSize(hashmap_data[i].key_);
447 size_t cache_offset = single_col_bytes * LongToSize(hashmap_data[i].value_);
448 if (cache_offset + single_col_bytes <= cache_max) {
449 auto ret =
450 memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes);
451 if (ret != 0) {
452 MS_LOG(EXCEPTION) << "Memcpy failed.";
453 }
454 }
455 }
456 }
457 MS_LOG(INFO) << "Memcpy from cache to host success!";
458 }
459
FlushFromCache(const Tensor & tensor)460 void TensorPy::FlushFromCache(const Tensor &tensor) {
461 py::gil_scoped_release gil_release;
462 tensor.data_sync();
463
464 if (tensor.cache_enable()) {
465 MS_LOG(INFO) << tensor.ToString() << " is cache enable.";
466 auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr();
467 auto cache_tensor_ptr = tensor.cache_tensor_ptr();
468 if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) {
469 hashmap_tensor_ptr->data_sync();
470 cache_tensor_ptr->data_sync();
471 auto hashmap_size = hashmap_tensor_ptr->shape_c()[0];
472 auto host_shape = tensor.shape_c();
473 auto cache_shape = cache_tensor_ptr->shape_c();
474 if (host_shape.size() != 2 && cache_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
475 MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
476 << "host shape:" << host_shape << ", cache shape:" << cache_shape;
477 }
478 auto host_data_max_size = static_cast<size_t>(tensor.Size());
479 auto cache_data_max_size = static_cast<size_t>(cache_tensor_ptr->Size());
480 auto hashmap_data_type = hashmap_tensor_ptr->data_type();
481 if (hashmap_data_type == TypeId::kNumberTypeInt32) {
482 MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
483 host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
484 } else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
485 MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
486 host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
487 } else {
488 MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
489 }
490 }
491 }
492 }
493
GetBytes(const Tensor & tensor)494 py::bytes TensorPy::GetBytes(const Tensor &tensor) {
495 py::gil_scoped_acquire acquire;
496 if (tensor.get_copy_done_flag()) {
497 const_cast<Tensor &>(tensor).set_copy_done_flag(false);
498 return py::bytes(static_cast<const char *>(tensor.data_c()), tensor.Size());
499 }
500 tensor.data_sync();
501 return py::bytes(static_cast<const char *>(tensor.data_c()), tensor.Size());
502 }
503
CopyFromBuffer(char * dst,size_t dst_size,const char * src,size_t src_size,TypeId data_type)504 void CopyFromBuffer(char *dst, size_t dst_size, const char *src, size_t src_size, TypeId data_type) {
505 bool fp16_in_fp32 = (data_type == TypeId::kNumberTypeBFloat16) && (dst_size * 2 == src_size);
506 if (fp16_in_fp32) {
507 int elem_num = static_cast<int>(src_size / sizeof(float));
508 for (int i = 0; i < elem_num; ++i) {
509 auto dst_ptr = static_cast<char *>(dst + i * sizeof(bfloat16));
510 auto src_ptr = static_cast<const char *>(src + sizeof(bfloat16) + i * sizeof(float));
511 errno_t ret = memcpy_s(dst_ptr, sizeof(bfloat16), src_ptr, sizeof(bfloat16));
512 if (ret != EOK) {
513 MS_LOG(EXCEPTION) << "Failed to copy the memory to new tensor:" << ret;
514 }
515 }
516 } else {
517 size_t remain_size = src_size;
518 auto dst_ptr = dst;
519 auto src_ptr = src;
520 while (remain_size > SECUREC_MEM_MAX_LEN) {
521 auto ret = memcpy_s(dst_ptr, SECUREC_MEM_MAX_LEN, src_ptr, SECUREC_MEM_MAX_LEN);
522 if (ret != EOK) {
523 MS_LOG(EXCEPTION) << "Failed to copy the memory to new tensor" << ret;
524 }
525 remain_size -= SECUREC_MEM_MAX_LEN;
526 dst_ptr += SECUREC_MEM_MAX_LEN;
527 src_ptr += SECUREC_MEM_MAX_LEN;
528 }
529 if (remain_size != 0U) {
530 auto ret = memcpy_s(dst_ptr, remain_size, src_ptr, remain_size);
531 if (ret != EOK) {
532 MS_LOG(EXCEPTION) << "Failed to copy the memory to new tensor" << ret;
533 }
534 }
535 }
536 }
537
ConvertBytesToTensor(const py::bytes & bytes_obj,const py::tuple & dims,const TypePtr & type_ptr)538 TensorPtr TensorPy::ConvertBytesToTensor(const py::bytes &bytes_obj, const py::tuple &dims, const TypePtr &type_ptr) {
539 ShapeVector shape;
540 for (size_t i = 0; i < dims.size(); ++i) {
541 shape.push_back(dims[i].cast<int>());
542 }
543 TypeId data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown;
544 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(data_type, shape);
545 const char *tensor_buf = PYBIND11_BYTES_AS_STRING(bytes_obj.ptr());
546 char *tensor_data_buf = reinterpret_cast<char *>(tensor->data_c());
547 CopyFromBuffer(tensor_data_buf, tensor->Size(), tensor_buf, PYBIND11_BYTES_SIZE(bytes_obj.ptr()), data_type);
548 return tensor;
549 }
550
SyncAsNumpy(const Tensor & tensor)551 py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
552 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kAsnumpy);
553 {
554 py::gil_scoped_release gil_release;
555
556 // BFloat16 may not be supported in numpy.
557 std::string numpy_version = np_dtypes::GetNumpyVersion();
558 std::string minimum_numpy_version = np_dtypes::GetMinimumSupportedNumpyVersion();
559 if (tensor.data_type() == kNumberTypeBFloat16 && !np_dtypes::NumpyVersionValid(numpy_version)) {
560 MS_EXCEPTION(TypeError) << "For asnumpy, the numpy bfloat16 data type is supported in Numpy versions "
561 << minimum_numpy_version << " to " << minimum_numpy_version[0] << ".x.x, but got "
562 << numpy_version << ", please upgrade numpy version.";
563 }
564
565 if (tensor.get_copy_done_flag()) {
566 const_cast<Tensor &>(tensor).set_copy_done_flag(false);
567 if (tensor.need_release_device_mem()) {
568 const_cast<Tensor &>(tensor).set_device_address(nullptr);
569 }
570 return AsNumpy(tensor);
571 }
572 tensor.data_sync();
573
574 // Release device address of graph output tensor.
575 if (tensor.need_release_device_mem()) {
576 const_cast<Tensor &>(tensor).set_device_address(nullptr);
577 }
578 }
579 return AsNumpy(tensor);
580 }
581
AsNumpy(const Tensor & tensor)582 py::array TensorPy::AsNumpy(const Tensor &tensor) {
583 // Use TensorData as the owner to prevent use-after-free problem.
584 // We can NOT use Tensor as the owner since its TensorData may change
585 // by other operations such as AssignValue().
586 py::gil_scoped_acquire acquire;
587 py::object owner = py::cast(tensor.data_ptr());
588 auto data_numpy = dynamic_cast<const TensorDataNumpy *>(&tensor.data());
589 if (data_numpy != nullptr) {
590 // Return internal numpy array if tensor data is implemented base on it.
591 return data_numpy->py_array(owner);
592 }
593 // Otherwise, create numpy array by buffer protocol.
594 auto info = GetPyBufferInfo(tensor);
595 py::dtype np_dtype = (tensor.data_type() == kNumberTypeBFloat16)
596 ? py::detail::npy_format_descriptor<bfloat16>::dtype()
597 : py::dtype(info);
598 return py::array(np_dtype, info.shape, info.strides, info.ptr, owner);
599 }
600
Offload(const Tensor & tensor)601 void TensorPy::Offload(const Tensor &tensor) {
602 py::gil_scoped_release gil_release;
603 tensor.data_sync();
604
605 // Release device address of graph output tensor.
606 const_cast<Tensor &>(tensor).set_device_address(nullptr);
607 }
608
SetDeviceAddress(const Tensor & tensor,uintptr_t addr,const ShapeVector & shape,const TypePtr type_ptr)609 void TensorPy::SetDeviceAddress(const Tensor &tensor, uintptr_t addr, const ShapeVector &shape,
610 const TypePtr type_ptr) {
611 if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice) {
612 MS_LOG(EXCEPTION) << "set_device_address now only support Ascend backend!";
613 }
614
615 if (type_ptr == nullptr) {
616 MS_LOG(EXCEPTION) << "Dtype to be set is nullptr.";
617 }
618
619 TypeId data_type = type_ptr->type_id();
620 if (data_type != tensor.data_type()) {
621 MS_LOG(EXCEPTION) << "Dtype to be set is not euqal with the tensor's, then tensor's dtype is" << tensor.data_type();
622 }
623
624 if (shape != tensor.shape()) {
625 MS_LOG(EXCEPTION) << "Shape to be set is not euqal with the tensor's, then tensor's shape is" << tensor.shape();
626 }
627
628 void *data = reinterpret_cast<void *>(addr);
629 size_t elem_num = 1;
630 for (size_t i = 0; i < shape.size(); ++i) {
631 elem_num *= shape[i];
632 }
633 auto data_size = elem_num * GetDataTypeSize(data_type);
634 auto device_sync_ = tensor.device_address();
635 if (device_sync_ == nullptr) {
636 auto device_address = std::make_shared<device::MbufDeviceAddress>(data, data_size);
637 const_cast<Tensor &>(tensor).set_device_address(device_address);
638 } else {
639 auto device_address = std::dynamic_pointer_cast<device::MbufDeviceAddress>(device_sync_);
640 device_address->SetData(data);
641 }
642 }
643
MoveTo(const Tensor & self,const std::string & to,bool blocking)644 TensorPtr TensorPy::MoveTo(const Tensor &self, const std::string &to, bool blocking) {
645 py::gil_scoped_release gil_release;
646 MS_LOG(INFO) << "Try move tensor to " << to;
647 auto context = GetDeviceCtx(to);
648 MS_EXCEPTION_IF_NULL(context);
649 auto target_tensor = std::make_shared<tensor::Tensor>(self.data_type(), self.shape());
650 target_tensor->set_device_address(nullptr);
651 bool return_self = false;
652 // make sure op execute end before data copy
653 runtime::OpExecutor::GetInstance().WaitAll();
654 context->device_res_manager_->MoveTo(std::make_shared<tensor::Tensor>(self), target_tensor, to, blocking,
655 &return_self);
656 if (return_self) {
657 return std::make_shared<tensor::Tensor>(self);
658 }
659 return target_tensor;
660 }
661
AsNumpyOfSlice(const Tensor & tensor,const int32_t param_key,const int slice_index)662 py::array TensorPy::AsNumpyOfSlice(const Tensor &tensor, const int32_t param_key, const int slice_index) {
663 py::gil_scoped_acquire acquire;
664 py::object owner = py::cast(tensor.data_ptr());
665 auto data_numpy = std::dynamic_pointer_cast<PersistentTensorDataNumpy>(tensor.data_ptr());
666 MS_EXCEPTION_IF_NULL(data_numpy);
667
668 data_numpy->FillSliceData(param_key, slice_index);
669
670 // Return internal numpy array if tensor data is implemented base on it.
671 // And persistent tensor data is only implemented base on numpy array.
672 return data_numpy->py_array(owner);
673 }
674
GetShapeFromTuple(const py::tuple & tuple)675 static ShapeVector GetShapeFromTuple(const py::tuple &tuple) {
676 ShapeVector shape;
677 const size_t size = tuple.size();
678 shape.reserve(tuple.size());
679 for (size_t i = 0; i < size; ++i) {
680 shape.push_back(py::int_(tuple[i]));
681 }
682 return shape;
683 }
RegMetaTensor(const py::module * m)684 void RegMetaTensor(const py::module *m) {
685 // Define python MetaTensor class.
686 (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
687 .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape"))
688 .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
689 .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
690 .def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
691 .def(py::pickle(
692 [](const MetaTensor &t) { // __getstate__
693 /* Return a tuple that fully encodes the state of the object */
694 return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
695 },
696 [](const py::tuple &t) { // __setstate__
697 constexpr size_t expect_size = 2;
698 if (t.size() != expect_size) {
699 throw std::runtime_error("Invalid state!");
700 }
701 /* Create a new C++ instance */
702 MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<ShapeVector>());
703 return tensor;
704 }));
705 // Define TensorData as a python class so that ownership of tensor data can be managed.
706 (void)py::class_<TensorData, TensorDataPtr>(*m, "_TensorData");
707 // Define python Tensor class.
708 // dtype should define before Tensor, because Tensor init depend dtype
709 (void)py::class_<BaseTensor, MetaTensor, std::shared_ptr<BaseTensor>>(*m, "BaseTensor");
710 (void)py::class_<Tensor, BaseTensor, std::shared_ptr<Tensor>>(*m, "Tensor")
711 .def(py::init([](const Tensor &tensor) { return std::make_shared<Tensor>(tensor); }), py::arg("input"))
712 .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) {
713 TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown;
714 if (data_type == kTypeUnknown || tensor.data_type() == data_type) {
715 return std::make_shared<Tensor>(tensor);
716 }
717 return std::make_shared<Tensor>(tensor, data_type);
718 }),
719 py::arg("input"), py::arg("dtype"))
720 .def(py::init([](const BaseTensor &tensor) { return std::make_shared<Tensor>(tensor); }), py::arg("input"))
721 .def(py::init([](const BaseTensor &tensor, const TypePtr &type_ptr) {
722 TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown;
723 if (data_type == kTypeUnknown || tensor.data_type() == data_type) {
724 return std::make_shared<Tensor>(tensor);
725 }
726 return std::make_shared<Tensor>(tensor, data_type);
727 }),
728 py::arg("input"), py::arg("dtype"))
729 .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) {
730 auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
731 return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
732 }),
733 py::arg("dtype"), py::arg("shape"))
734 .def(py::init([](const TypePtr &type_ptr, const py::list &shape) {
735 auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
736 return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
737 }),
738 py::arg("dtype"), py::arg("shape"))
739 .def(
740 py::init([](const py::array &input, const TypePtr &type_ptr) { return TensorPy::MakeTensor(input, type_ptr); }),
741 py::arg("input"), py::arg("dtype") = nullptr)
742 .def(py::init([](const py::float_ input, const TypePtr &type_ptr) {
743 return TensorPy::MakeTensor(py::array(input), type_ptr);
744 }),
745 py::arg("input"), py::arg("dtype") = nullptr)
746 .def(py::init([](const py::int_ input, const TypePtr &type_ptr) {
747 return TensorPy::MakeTensor(py::array(input), type_ptr);
748 }),
749 py::arg("input"), py::arg("dtype") = nullptr)
750 .def(py::init([](const py::list &input, const TypePtr &type_ptr) {
751 return TensorPy::MakeTensor(py::array(input), type_ptr);
752 }),
753 py::arg("input"), py::arg("dtype") = nullptr)
754 .def(py::init([](const py::tuple &input, const TypePtr &type_ptr) {
755 return TensorPy::MakeTensor(py::array(input), type_ptr);
756 }),
757 py::arg("input"), py::arg("dtype") = nullptr)
758 // We only suppot array/bool_/int_/float_/list/tuple/complex pybind objects as tensor input,
759 // and array/bool_/int_/float_/list/tuple init will be matched above, other pybind objects
760 // input will raise error except complex data type.
761 .def(py::init([](const py::object &input, const TypePtr &type_ptr) {
762 if (!PyComplex_CheckExact(input.ptr())) {
763 MS_LOG(EXCEPTION) << "Unsupported tensor type: " << input.get_type();
764 }
765 return TensorPy::MakeTensor(py::array(input), type_ptr);
766 }),
767 py::arg("input"), py::arg("dtype") = nullptr)
768 .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
769 .def_property("adapter_flag", &Tensor::is_adapter, &Tensor::set_adapter_flag)
770 .def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
771 Get the tensor's data type.
772
773 Returns:
774 type, the data type of tensor.
775
776 Examples:
777 >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
778 >>> data.dtype
779 Int32
780 )mydelimiter")
781 .def_property("_shape", TensorPy::GetPyTupleShape, &Tensor::set_shape)
782 .def_property_readonly("_size", &Tensor::DataSize, R"mydelimiter(
783 Get tensor's data size.
784
785 Returns:
786 size_t, the size of tensor.
787
788 Examples:
789 >>> data = mindspore.Tensor(np.ones((2, 3)))
790 >>> data.size
791 6
792 )mydelimiter")
793 .def_property_readonly("_itemsize", TensorPy::GetPyItemSize, R"mydelimiter(
794 Get the tensor's length of one element in bytes.
795
796 Returns:
797 itemsize, length of one element in bytes.
798
799 Examples:
800 >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
801 >>> data.itemsize
802 4
803 )mydelimiter")
804 .def_property_readonly("_nbytes", TensorPy::GetPyNBytes, R"mydelimiter(
805 Get the tensor's total number of bytes.
806
807 Returns:
808 nbytes, total number of bytes taken by the tensor.
809
810 Examples:
811 >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
812 >>> data.nbytes
813 4
814 )mydelimiter")
815 .def_property_readonly("_strides", TensorPy::GetPyTupleStrides, R"mydelimiter(
816 Get the tensor's tuple of bytes to step in each dimension
817 when traversing an array.
818
819 Returns:
820 tuple[int], the strides of the tensor.
821
822 Examples:
823 >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
824 >>> data.strides
825 (4, 4)
826 )mydelimiter")
827 .def("_flatten_tensors", Tensor::FlattenTensors, py::arg("fusion_size") = 0)
828 .def("setitem_index_info", TensorIndex::SetItemIndexInfo)
829 .def("getitem_index_info", TensorIndex::GetItemIndexInfo)
830 .def("_is_flattened", Tensor::IsFlattened)
831 .def("_get_flattened_tensors", Tensor::GetFlattenedTensors)
832 .def("_get_fusion_size", Tensor::GetFusionSize)
833 .def("_is_test_stub", Tensor::CheckStub)
834 .def("from_numpy", TensorPy::MakeTensorOfNumpy, R"mydelimiter(
835 Creates a Tensor from a numpy.ndarray without copy.
836
837 Arg:
838 array (numpy.ndarray): The input ndarray.
839
840 Returns:
841 Tensor, tensor with shared data to input ndarray.
842
843 Examples:
844 >>> a = np.ones((2, 3))
845 >>> t = mindspore.Tensor.from_numpy(a)
846 )mydelimiter")
847 .def("persistent_data_from_numpy", TensorPy::MakePersistentDataTensorOfNumpy, R"mydelimiter(
848 Creates a Tensor from a numpy.ndarray without copy.
849 Use persistent data tensor.
850
851 Arg:
852 array (numpy.ndarray): The input ndarray.
853 slice_num (int): The slice num of persistent data tensor.
854
855 Returns:
856 Tensor, tensor with shared data to input ndarray.
857
858 Examples:
859 >>> a = np.ones((2, 3))
860 >>> t = mindspore.Tensor.persistent_data_from_numpy(a, 1)
861 )mydelimiter")
862 .def("get_bytes", &TensorPy::GetBytes, R"mydelimiter(
863 Get raw data of tensor with type of bytes.
864
865 Returns:
866 Bytes of tensor.
867
868 Examples:
869 >>> import mindspore as ms
870 >>> from mindspore import Tensor
871 >>> x = ms.Tensor([1, 2, 3], ms.int16)
872 >>> print(x.get_bytes())
873 b'\x01\x00\x02\x00\x03\x00'
874 )mydelimiter")
875 .def("convert_bytes_to_tensor", &TensorPy::ConvertBytesToTensor, R"mydelimiter(
876 Convert raw data to tensor.
877
878 Returns:
879 Tensor.
880
881 Examples:
882 >>> import mindspore as ms
883 >>> from mindspore import Tensor
884 >>> x = Tensor([1, 2, 3], ms.int16)
885 >>> out = Tensor.convert_bytes_to_tensor(x.get_bytes(), x.shape, x.dtype)
886 >>> print(x.asnumpy())
887 [1 2 3]
888 )mydelimiter")
889 .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter(
890 Convert tensor to numpy.ndarray.
891
892 Returns:
893 numpy.ndarray.
894
895 Examples:
896 >>> data = mindspore.Tensor(np.ones((2, 3)))
897 >>> array = data.asnumpy()
898 >>> array
899 array([[1., 1., 1.],
900 [1., 1., 1.]])
901 )mydelimiter")
902 .def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter(
903 Flush Cache data to Host if tensor is cache enable.
904
905 Returns:
906 None.
907
908 Examples:
909 >>> data = mindspore.Tensor(np.ones((2, 3)))
910 >>> data._flush_from_cache()
911 )mydelimiter")
912 .def("is_persistent_data", &Tensor::is_persistent_data, R"mydelimiter(
913 Check if tensor have persistent data.
914
915 Returns:
916 Bool.
917
918 Examples:
919 >>> data = mindspore.Tensor(np.ones((2, 3)))
920 >>> data.is_persistent_data()
921 )mydelimiter")
922 .def("asnumpy_of_slice_persistent_data", TensorPy::AsNumpyOfSlice, R"mydelimiter(
923 Convert tensor to numpy.ndarray of a slice.
924
925 Returns:
926 numpy.ndarray.
927
928 Examples:
929 >>> data = mindspore.Tensor(np.ones((2000000000, 256)))
930 >>> data.asnumpy_of_slice_persistent_data(0, 1)
931 )mydelimiter")
932 .def("is_init", &Tensor::is_init, R"mydelimiter(
933 Get tensor init_flag.
934
935 Returns:
936 bool, whether the tensor init.
937
938 Examples:
939 >>> data = mindspore.Tensor(np.ones((2, 3)))
940 >>> data.is_init()
941 False
942 )mydelimiter")
943 .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter(
944 Set tensor init_flag.
945
946 Examples:
947 >>> data = mindspore.Tensor(np.ones((2, 3)))
948 >>> data.set_init_flag(True)
949 )mydelimiter")
950 .def("dim", &Tensor::DataDim, R"mydelimiter(
951 Get tensor's data dimension.
952
953 Returns:
954 int, the dimension of tensor.
955
956 Examples:
957 >>> data = mindspore.Tensor(np.ones((2, 3)))
958 >>> data.dim()
959 2
960 )mydelimiter")
961 .def("assign_value_cpp", &Tensor::AssignValue, R"mydelimiter(
962 Assign another tensor value to this.
963
964 Arg:
965 value (:class:`mindspore.tensor`): The value tensor.
966
967 Examples:
968 >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
969 >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
970 >>> data.assign_value(data2)
971 >>> data.shape
972 (2, 2)
973 )mydelimiter")
974 .def("set_dtype", &Tensor::SetDtype, R"mydelimiter(
975 Set the tensor's data type.
976
977 Arg:
978 dtype (:class:`mindspore.dtype`): The type of output tensor.
979
980 Examples:
981 >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
982 >>> data.set_dtype(mindspore.int32)
983 mindspore.int32
984 )mydelimiter")
985 .def("offload", &Tensor::Offload, R"mydelimiter(
986 Offload tensor data to file.
987
988 Arg:
989 str : file path to save tensor data.
990 Returns:
991 bool, whether the tensor offload success.
992 Examples:
993 >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
994 >>> data.offload('./test.data')
995 True
996 )mydelimiter")
997 .def("offload_file_path", &Tensor::GetOffloadFilePath, R"mydelimiter(
998 Offload file path for tensor.
999
1000 Returns:
1001 str, offload file path for tensor.
1002 Examples:
1003 >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
1004 >>> ret = data.offload('./test.data')
1005 >>> ret = (data.offload_file_path() != '')
1006 True
1007 )mydelimiter")
1008 .def("move_to", &TensorPy::MoveTo, R"mydelimiter(
1009 Copy tensor between host and device asynchronously if blocking=False,
1010 otherwise synchronously. if the arg `to`=`CPU`, means D2H copy;
1011 if the arg `to`=`GPU` or `to`=`ASCEND`, means H2D copy.
1012
1013 Args:
1014 str: A string, "CPU" or "ASCEND" or "GPU".
1015 bool: A bool type value, Default: ``True`` .
1016
1017 Returns:
1018 Tensor, with the same type and shape as the "self".
1019
1020 Examples:
1021 >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
1022 >>> ret = data.move_to("CPU")
1023 )mydelimiter")
1024 .def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr)
1025 .def("data_sync", &Tensor::data_sync)
1026 .def("wait_pipeline", &Tensor::ExecuteLazyTask)
1027 .def("is_contiguous", &Tensor::is_contiguous)
1028 .def("stride", &Tensor::stride)
1029 .def("storage_offset", &Tensor::storage_offset)
1030 .def("register_hook", &RegisterHook::RegisterTensorBackwardHook)
1031 .def("remove_hook", &RegisterHook::RemoveTensorBackwardHook)
1032 .def("__str__", &Tensor::ToString)
1033 .def("__repr__", &Tensor::ToStringRepr)
1034 .def("_offload", &TensorPy::Offload)
1035 .def("set_device_address", &TensorPy::SetDeviceAddress, py::arg("addr"), py::arg("shape"), py::arg("dtype"))
1036 .def(
1037 py::pickle(
1038 [](const Tensor &t) { // __getstate__
1039 /* Return a tuple that fully encodes the state of the object */
1040 return py::make_tuple(TensorPy::SyncAsNumpy(t));
1041 },
1042 [](const py::tuple &t) { // __setstate__
1043 if (t.size() != 1) {
1044 throw std::runtime_error("Invalid state!");
1045 }
1046 /* Create a new C++ instance */
1047 return TensorPy::MakeTensor(t[0].cast<py::array>());
1048 }));
1049 }
1050
1051 template <typename T>
GetSparseTensorShape(const T & sparse_tensor)1052 py::tuple GetSparseTensorShape(const T &sparse_tensor) {
1053 auto &shape = sparse_tensor.shape();
1054 py::tuple dims(shape.size());
1055 for (size_t i = 0; i < dims.size(); ++i) {
1056 dims[i] = py::int_(shape[i]);
1057 }
1058 return dims;
1059 }
1060
GetPyTupleShape(const CSRTensor & csr_tensor)1061 py::tuple CSRTensorPy::GetPyTupleShape(const CSRTensor &csr_tensor) { return GetSparseTensorShape(csr_tensor); }
1062
RegCSRTensor(const py::module * m)1063 void RegCSRTensor(const py::module *m) {
1064 // Define python CSRTensor class.
1065 (void)py::class_<CSRTensor, std::shared_ptr<CSRTensor>>(*m, "CSRTensor")
1066 .def(py::init(
1067 [](const BaseTensor &indptr, const BaseTensor &indices, const BaseTensor &values, const py::tuple &shape) {
1068 return std::make_shared<CSRTensor>(std::make_shared<Tensor>(indptr), std::make_shared<Tensor>(indices),
1069 std::make_shared<Tensor>(values), GetShapeFromTuple(shape));
1070 }),
1071 py::arg("indptr"), py::arg("indices"), py::arg("values"), py::arg("shape"))
1072 .def(py::init([](const CSRTensor &csr_tensor) { return std::make_shared<CSRTensor>(csr_tensor); }),
1073 py::arg("input"))
1074 .def_property_readonly("_shape", CSRTensorPy::GetPyTupleShape)
1075 .def_property_readonly("_dtype", &CSRTensor::Dtype)
1076 .def_property_readonly("_indptr", &CSRTensor::GetIndptr)
1077 .def_property_readonly("_indices", &CSRTensor::GetIndices)
1078 .def_property_readonly("_values", &CSRTensor::GetValues)
1079 .def("__str__", &CSRTensor::ToString)
1080 .def("__repr__", &CSRTensor::ToString);
1081 }
1082
GetPyTupleShape(const COOTensor & coo_tensor)1083 py::tuple COOTensorPy::GetPyTupleShape(const COOTensor &coo_tensor) { return GetSparseTensorShape(coo_tensor); }
1084
RegCOOTensor(const py::module * m)1085 void RegCOOTensor(const py::module *m) {
1086 // Define python COOTensor class.
1087 (void)py::class_<COOTensor, std::shared_ptr<COOTensor>>(*m, "COOTensor")
1088 .def(py::init([](const BaseTensor &indices, const BaseTensor &values, const py::tuple &shape) {
1089 return std::make_shared<COOTensor>(std::make_shared<Tensor>(indices), std::make_shared<Tensor>(values),
1090 GetShapeFromTuple(shape));
1091 }),
1092 py::arg("indices"), py::arg("values"), py::arg("shape"))
1093 .def(py::init([](const COOTensor &coo_tensor) { return std::make_shared<COOTensor>(coo_tensor); }),
1094 py::arg("input"))
1095 .def_property_readonly("_shape", COOTensorPy::GetPyTupleShape)
1096 .def_property_readonly("_dtype", &COOTensor::Dtype)
1097 .def_property_readonly("_indices", &COOTensor::GetIndices)
1098 .def_property_readonly("_values", &COOTensor::GetValues)
1099 .def("__str__", &COOTensor::ToString)
1100 .def("__repr__", &COOTensor::ToString);
1101 }
1102
GetPyTupleShape(const RowTensor & row_tensor)1103 py::tuple RowTensorPy::GetPyTupleShape(const RowTensor &row_tensor) { return GetSparseTensorShape(row_tensor); }
1104
RegRowTensor(const py::module * m)1105 void RegRowTensor(const py::module *m) {
1106 // Define python RowTensor class.
1107 (void)py::class_<RowTensor, std::shared_ptr<RowTensor>>(*m, "RowTensor")
1108 .def(py::init([](const BaseTensor &indices, const BaseTensor &values, const py::tuple &shape) {
1109 return std::make_shared<RowTensor>(std::make_shared<Tensor>(indices), std::make_shared<Tensor>(values),
1110 GetShapeFromTuple(shape));
1111 }),
1112 py::arg("indices"), py::arg("values"), py::arg("shape"))
1113 .def(py::init([](const RowTensor &row_tensor) { return std::make_shared<RowTensor>(row_tensor); }),
1114 py::arg("input"))
1115 .def_property_readonly("_shape", RowTensorPy::GetPyTupleShape)
1116 .def_property_readonly("_dtype", &RowTensor::Dtype)
1117 .def_property_readonly("_indices", &RowTensor::GetIndices)
1118 .def_property_readonly("_values", &RowTensor::GetValues)
1119 .def("__str__", &RowTensor::ToString)
1120 .def("__repr__", &RowTensor::ToString);
1121 }
1122 } // namespace tensor
1123 } // namespace mindspore
1124