1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_IR_BASE_TENSOR_H_
18 #define MINDSPORE_CORE_IR_BASE_TENSOR_H_
19
20 #include <future>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <numeric>
25 #include <mutex>
26 #include <algorithm>
27 #include <condition_variable>
28 #include <utility>
29 #include <iomanip>
30 #include "ir/device_sync.h"
31 #include "ir/meta_tensor.h"
32 #include "utils/log_adapter.h"
33 #include "base/float16.h"
34 #include "base/bfloat16.h"
35 #include "utils/shape_utils.h"
36 #include "utils/ms_exception.h"
37 #include "ir/device_event.h"
38 #include "utils/os.h"
39 #include "ir/meta_grad_data.h"
40 #include "ir/tensor_data.h"
41 #include "utils/ms_utils_secure.h"
42 #include "mindspore/core/base/complex_storage.h"
43 #include "utils/temp_file_manager.h"
44 #include "utils/system/env.h"
45
46 // brief mindspore namespace.
47 //
48 // mindspore namespace is the top level namespace of MindSpore project.
49 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
50 namespace mindspore {
51 // brief mindspore::tensor namespace
52 enum TensorSyncStatus {
53 kNoNeedSync,
54 kNeedSyncHostToDevice,
55 kNeedSyncHostToDeviceImmediately,
56 kNeedSyncDeviceToHost,
57 kNeedSyncDeviceToHostImmediately
58 };
59
60 enum TensorCompressionType {
61 kNoCompression = 0,
62 kIndexing = 1,
63 kSparse = 2,
64 kFSE = 3,
65 kBitPacking = 4,
66 kFSEInt = 5,
67 kFSEInfer = 6
68 };
69
70 // A sub namespace in ME to support tensor related definition.
71 namespace tensor {
72 constexpr auto kEllipsis = "...";
73 constexpr auto kThreshold = 6;
74 constexpr auto kThreshold1D = 1000;
75
76 constexpr auto kThreshold1DFloat = kThreshold * 2;
77 constexpr auto kThreshold1DInt = kThreshold * 4;
78 constexpr auto kThreshold1DBool = kThreshold * 2;
79 template <typename T, typename U>
NewData(const U * input,size_t size)80 std::unique_ptr<T[]> NewData(const U *input, size_t size) {
81 if (input == nullptr || size == 0) {
82 return nullptr;
83 }
84 if (size > INT32_MAX) {
85 MS_LOG(WARNING) << "Try to alloca a large memory, size is:" << size * sizeof(T);
86 }
87
88 auto data = std::make_unique<T[]>(size);
89 if constexpr (!std::is_same<T, U>::value &&
90 (std::is_same<T, float16>::value || std::is_same<U, float16>::value ||
91 std::is_same<T, bfloat16>::value || std::is_same<U, bfloat16>::value ||
92 std::is_same<T, ComplexStorage<float>>::value || std::is_same<U, ComplexStorage<float>>::value ||
93 std::is_same<T, ComplexStorage<double>>::value || std::is_same<U, ComplexStorage<double>>::value)) {
94 // Because float16 and bfloat16 do not support implicit cast from/to other types,
95 // We can not use std::copy() on array of float16 and bfloat16, use a loop here.
96 for (size_t i = 0; i < size; ++i) {
97 data[i] = static_cast<T>(input[i]);
98 }
99 } else {
100 // otherwise, use std::copy for better performance.
101 std::copy(input, input + size, data.get());
102 }
103 return data;
104 }
105
106 template <typename T, typename Scalar>
NewData(Scalar scalar)107 std::unique_ptr<T[]> NewData(Scalar scalar) {
108 auto data = std::make_unique<T[]>(1);
109 data[0] = static_cast<T>(scalar);
110 return data;
111 }
112
113 template <typename T>
CopyData(const ShapeVector & shape,void * const data,TypeId data_type)114 std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId data_type) {
115 const size_t size = SizeOf(shape);
116 switch (data_type) {
117 case kNumberTypeBool: {
118 auto buf = static_cast<bool *>(data);
119 return NewData<T>(buf, size);
120 }
121 case kNumberTypeUInt8: {
122 auto buf = static_cast<uint8_t *>(data);
123 return NewData<T>(buf, size);
124 }
125 case kNumberTypeInt4: {
126 auto buf = static_cast<int8_t *>(data);
127 return NewData<T>(buf, size);
128 }
129 case kNumberTypeInt8: {
130 auto buf = static_cast<int8_t *>(data);
131 return NewData<T>(buf, size);
132 }
133 case kNumberTypeInt16: {
134 auto buf = static_cast<int16_t *>(data);
135 return NewData<T>(buf, size);
136 }
137 case kNumberTypeInt32: {
138 auto buf = static_cast<int32_t *>(data);
139 return NewData<T>(buf, size);
140 }
141 case kNumberTypeInt64: {
142 auto buf = static_cast<int64_t *>(data);
143 return NewData<T>(buf, size);
144 }
145 case kNumberTypeUInt16: {
146 auto buf = static_cast<uint16_t *>(data);
147 return NewData<T>(buf, size);
148 }
149 case kNumberTypeUInt32: {
150 auto buf = static_cast<uint32_t *>(data);
151 return NewData<T>(buf, size);
152 }
153 case kNumberTypeUInt64: {
154 auto buf = static_cast<uint64_t *>(data);
155 return NewData<T>(buf, size);
156 }
157 case kNumberTypeFloat16: {
158 auto buf = static_cast<float16 *>(data);
159 return NewData<T>(buf, size);
160 }
161 case kNumberTypeFloat32: {
162 auto buf = static_cast<float *>(data);
163 return NewData<T>(buf, size);
164 }
165 case kNumberTypeFloat64: {
166 auto buf = static_cast<double *>(data);
167 return NewData<T>(buf, size);
168 }
169 #ifndef KERNEL_EXECUTOR_ANDROID
170 case kNumberTypeBFloat16: {
171 auto buf = static_cast<bfloat16 *>(data);
172 return NewData<T>(buf, size);
173 }
174 #endif
175 case kNumberTypeComplex64: {
176 auto buf = static_cast<ComplexStorage<float> *>(data);
177 return NewData<T>(buf, size);
178 }
179 case kNumberTypeComplex128: {
180 auto buf = static_cast<ComplexStorage<double> *>(data);
181 return NewData<T>(buf, size);
182 }
183 case kObjectTypeString: {
184 auto buf = static_cast<uint8_t *>(data);
185 return NewData<T>(buf, size);
186 }
187 default:
188 break;
189 }
190 MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
191 }
192
193 template <typename T>
CopyData(const ShapeVector & shape,void * const data,size_t data_len)194 std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
195 size_t size = SizeOf(shape);
196 if (size * sizeof(T) != data_len) {
197 MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
198 << " item size " << sizeof(T);
199 }
200 auto buf = static_cast<T *>(data);
201 return NewData<T>(buf, size);
202 }
203
204 // TensorStringifier provide methods to convert tensor data to its string representation.
205 template <typename T>
206 class TensorStringifier {
207 public:
TensorStringifier(const T * data,size_t data_size,size_t ndim)208 TensorStringifier(const T *data, size_t data_size, size_t ndim) : data_(data), data_size_(data_size), ndim_(ndim) {}
209 ~TensorStringifier() = default;
210
ToString(TypeId,const ShapeVector & shape,bool use_comma)211 std::string ToString(TypeId, const ShapeVector &shape, bool use_comma) const {
212 constexpr auto valid =
213 std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
214 std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
215 std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
216 std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value ||
217 std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value ||
218 std::is_same<T, bfloat16>::value;
219 static_assert(valid, "Type is invalid");
220 if (data_size_ == 0) {
221 return "";
222 }
223 if (data_ == nullptr) {
224 return "<uninitialized>";
225 }
226
227 std::ostringstream ss;
228 if (data_size_ == 1 && ndim_ == 0) { // Scalar
229 int max = 0;
230 OutputDataString(ss, 0, 0, 1, false, &max);
231 return ss.str();
232 }
233
234 int num_width = 0;
235 ssize_t cursor = 0;
236 SummaryStringRecursive(ss, shape, &cursor, 0, use_comma, &num_width);
237 return ProcessPlaceholder(ss, num_width);
238 }
239
240 private:
OutputFloatDataString(std::ostringstream & ss,bool isScalar,const T & value)241 static void OutputFloatDataString(std::ostringstream &ss, bool isScalar, const T &value) {
242 if (isScalar) {
243 ss << value;
244 } else {
245 // The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
246 const int width = std::is_same<T, float16>::value ? 11 : 15;
247 // The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
248 const int precision = std::is_same<T, float16>::value ? 4 : 8;
249 ss << std::setw(width) << std::setprecision(precision) << std::setiosflags(std::ios::scientific | std::ios::right)
250 << value;
251 }
252 }
253
OutputBoolDataString(std::ostringstream & ss,bool isScalar,const T & value)254 static void OutputBoolDataString(std::ostringstream &ss, bool isScalar, const T &value) {
255 if (isScalar) {
256 ss << (value ? "True" : "False");
257 } else {
258 constexpr int bool_max_width = sizeof("False") - 1;
259 ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
260 }
261 }
262
OutputOtherDataString(std::ostringstream & ss,bool isScalar,const T & value,int * max_width)263 static void OutputOtherDataString(std::ostringstream &ss, bool isScalar, const T &value, int *max_width) {
264 std::ostringstream value_ss;
265 if constexpr (std::is_same<T, uint8_t>::value) {
266 value_ss << static_cast<uint16_t>(value);
267 } else if constexpr (std::is_same<T, int8_t>::value) {
268 value_ss << static_cast<int16_t>(value);
269 } else {
270 value_ss << value;
271 }
272 auto value_str = value_ss.str();
273 if (!isScalar) {
274 const int width = static_cast<int>(value_str.size());
275 *max_width = std::max(*max_width, width);
276 // Add a padding string before the number, such as "###123", for subsequent replacement.
277 std::string pad(width, '#');
278 ss << pad;
279 }
280 ss << value_str;
281 }
282
ProcessPlaceholder(const std::ostringstream & ss,int max_width)283 static std::string ProcessPlaceholder(const std::ostringstream &ss, int max_width) {
284 std::string str = ss.str();
285 if constexpr (std::is_same<T, bool>::value || std::is_same<T, float16>::value || std::is_same<T, float>::value ||
286 std::is_same<T, double>::value) {
287 return str;
288 }
289 // Replace # with placeholder.
290 size_t index = str.find('#');
291 while (index != std::string::npos) {
292 size_t pos = index;
293 while (str[pos] == '#') {
294 pos++;
295 }
296 size_t len = pos - index;
297 std::string space(max_width - SizeToInt(len), ' ');
298 str = str.replace(index, len, space);
299 index = str.find('#', index);
300 }
301 return str;
302 }
303
OutputDataString(std::ostringstream & ss,ssize_t cursor,ssize_t start,ssize_t end,bool use_comma,int * max_width)304 void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma,
305 int *max_width) const {
306 const bool isScalar = ndim_ == 0 && end - start == 1;
307 constexpr auto isBool = std::is_same<T, bool>::value;
308 constexpr auto isFloat =
309 std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
310 constexpr auto isComplex =
311 std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value;
312 constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
313 for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
314 const auto value = data_[cursor + i];
315 if constexpr (isComplex) {
316 ss << value;
317 } else if constexpr (isFloat) {
318 OutputFloatDataString(ss, isScalar, value);
319 } else if (isBool) {
320 OutputBoolDataString(ss, isScalar, value);
321 } else {
322 OutputOtherDataString(ss, isScalar, value, max_width);
323 }
324 if (!isScalar && i != end - 1) {
325 if (use_comma) {
326 ss << ',';
327 }
328 ss << ' ';
329 }
330 if (!isScalar && ndim_ == 1 && end - start > (kThreshold >> 1) && (i + 1) % linefeedThreshold == 0) {
331 // Add a line feed every {threshold of type} for 1D tensor.
332 ss << '\n' << ' ';
333 }
334 }
335 }
336
SummaryStringRecursive(std::ostringstream & ss,const ShapeVector & shape,ssize_t * cursor,ssize_t depth,bool use_comma,int * max_width)337 void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth,
338 bool use_comma, int *max_width) const {
339 if (depth >= static_cast<ssize_t>(ndim_)) {
340 return;
341 }
342 ss << '[';
343 if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension
344 ssize_t num = shape[depth];
345 if ((num > kThreshold && ndim_ > 1) || (num > kThreshold1D && ndim_ == 1)) {
346 OutputDataString(ss, *cursor, 0, kThreshold >> 1, use_comma, max_width);
347 ss << ' ' << kEllipsis << ' ';
348 OutputDataString(ss, *cursor, num - (kThreshold >> 1), num, use_comma, max_width);
349 } else {
350 OutputDataString(ss, *cursor, 0, num, use_comma, max_width);
351 }
352 *cursor += num;
353 } else { // Middle dimension
354 ssize_t num = shape[depth];
355 // Handle the first half.
356 for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold >> 1), num); i++) {
357 if (i > 0) {
358 if (use_comma) {
359 ss << ',';
360 }
361 ss << '\n';
362 ss << std::setw(depth + 1) << ' '; // Add the indent.
363 }
364 SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma, max_width);
365 }
366 // Handle the ignored part.
367 if (num > kThreshold) {
368 if (use_comma) {
369 ss << ',';
370 }
371 ss << '\n';
372 ss << std::setw(depth + 1) << ' '; // Add the indent.
373 ss << kEllipsis;
374 // Ignored at this layer.
375 ssize_t ignored = shape[depth + 1];
376 const size_t offset = 2;
377 for (ssize_t i = depth + offset; i < static_cast<ssize_t>(ndim_); i++) {
378 ignored *= shape[i];
379 }
380 // Multiple with ignored layers number.
381 ignored *= (num - kThreshold);
382 *cursor += ignored;
383 }
384 // Handle the second half.
385 if (num > (kThreshold >> 1)) {
386 ssize_t iter_times =
387 std::min(static_cast<ssize_t>(num - (kThreshold >> 1)), static_cast<ssize_t>(kThreshold >> 1));
388 for (ssize_t i = 0; i < iter_times; i++) {
389 if (use_comma && (i != 0 || num <= kThreshold)) { // Not just after ignored part || Not handle ignored part
390 ss << ',';
391 }
392 ss << '\n';
393 ss << std::setw(depth + 1) << ' '; // Add the indent.
394 SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma, max_width);
395 }
396 }
397 }
398 ss << ']';
399 }
400
401 const T *data_;
402 const size_t data_size_;
403 const size_t ndim_;
404 };
405 // Tensor data implementation.
406 template <typename T>
407 class TensorDataImpl : public TensorData {
408 public:
TensorDataImpl(const ShapeVector & shape)409 explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
~TensorDataImpl()410 ~TensorDataImpl() override {
411 try {
412 RemoveOffloadFile();
413 } catch (const std::exception &e) {
414 MS_LOG(ERROR) << "Exception occurred when cleaning tensor. Error info " << e.what();
415 } catch (...) {
416 MS_LOG(ERROR) << "Exception occurred when cleaning tensor.";
417 }
418 }
419
TensorDataImpl(const ShapeVector & shape,void * data,size_t data_len)420 TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
421 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
422
TensorDataImpl(const ShapeVector & shape,void * data,TypeId data_type)423 TensorDataImpl(const ShapeVector &shape, void *data, TypeId data_type)
424 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
425
426 template <typename U>
TensorDataImpl(const ShapeVector & shape,const U * input,size_t size)427 TensorDataImpl(const ShapeVector &shape, const U *input, size_t size)
428 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(input, size)) {}
429
430 template <typename Scalar>
TensorDataImpl(const ShapeVector & shape,Scalar scalar)431 TensorDataImpl(const ShapeVector &shape, Scalar scalar)
432 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(scalar)) {}
433
size()434 ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
435
itemsize()436 ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
437
nbytes()438 ssize_t nbytes() const override { return size() * itemsize(); }
439
ndim()440 ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
441
is_sub_data()442 bool is_sub_data() const override { return false; }
443
has_sub_data()444 bool has_sub_data() const override { return false; }
445
data()446 void *data() override {
447 if (data_ != nullptr) {
448 return data_.get();
449 }
450
451 if (data_size_ > INT32_MAX) {
452 MS_LOG(WARNING) << "Try to alloca a large memory, size is:" << data_size_ * sizeof(T);
453 }
454 // Lazy allocation.
455 data_ = std::make_unique<T[]>(data_size_);
456
457 // Load data from file
458 if (!file_path_.empty()) {
459 auto fs = mindspore::system::Env::GetFileSystem();
460 MS_EXCEPTION_IF_NULL(fs);
461 if (fs->FileExist(file_path_)) {
462 auto file = fs->CreateWriteFile(file_path_, "r+");
463 MS_EXCEPTION_IF_NULL(file);
464 bool success = file->PRead(data_.get(), data_size_ * sizeof(T), 0);
465 if (!success) {
466 MS_LOG(WARNING) << "Tensor load data from file: " << file_path_ << " failed!";
467 }
468 if (!file->Close()) {
469 MS_LOG(WARNING) << "Close tensor file: " << file_path_ << " failed!";
470 }
471 } else {
472 MS_LOG(WARNING) << "Invalid tensor file path: " << file_path_;
473 }
474 }
475 return data_.get();
476 }
477
set_file_path(const std::string & file_path)478 void set_file_path(const std::string &file_path) override { file_path_ = file_path; }
479
file_path()480 const std::string file_path() const override { return file_path_; }
481
const_data()482 const void *const_data() const override {
483 // May return nullptr if data not initialized.
484 return data_.get();
485 }
486
equals(const TensorDataImpl<T> & other)487 virtual bool equals(const TensorDataImpl<T> &other) const {
488 auto ptr = &other;
489 if (ptr == this) {
490 return true;
491 }
492 if (data_ == nullptr || ptr->data_ == nullptr) {
493 return false;
494 }
495 return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
496 std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
497 }
498
equals(const TensorData & other)499 bool equals(const TensorData &other) const override {
500 // Not same type, compare data byte by byte.
501 return TensorData::equals(other);
502 }
503
ToString(TypeId type,const ShapeVector & shape,bool use_comma)504 std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override {
505 TensorStringifier<T> stringifier{data_.get(), data_size_, ndim_};
506 return stringifier.ToString(type, shape, use_comma);
507 }
508
509 private:
RemoveOffloadFile()510 void RemoveOffloadFile() {
511 if (!file_path_.empty()) {
512 TempFileManager::GetInstance().RemoveFile(file_path_);
513 TempFileManager::GetInstance().UnRegister(file_path_);
514 file_path_ = "";
515 }
516 }
517
518 size_t ndim_{0};
519 size_t data_size_{0};
520 std::unique_ptr<T[]> data_;
521 std::string file_path_{""};
522 };
523 template <template <class> class ImplClass = TensorDataImpl, typename... Args>
MakeTensorData(TypeId data_type,Args &&...args)524 TensorDataPtr MakeTensorData(TypeId data_type, Args &&... args) {
525 switch (data_type) {
526 case kNumberTypeBool:
527 return std::make_shared<ImplClass<bool>>(std::forward<Args>(args)...);
528 case kNumberTypeUInt8:
529 return std::make_shared<ImplClass<uint8_t>>(std::forward<Args>(args)...);
530 case kNumberTypeInt4:
531 return std::make_shared<ImplClass<int8_t>>(std::forward<Args>(args)...);
532 case kNumberTypeInt8:
533 return std::make_shared<ImplClass<int8_t>>(std::forward<Args>(args)...);
534 case kNumberTypeInt16:
535 return std::make_shared<ImplClass<int16_t>>(std::forward<Args>(args)...);
536 case kNumberTypeInt:
537 case kNumberTypeInt32:
538 return std::make_shared<ImplClass<int32_t>>(std::forward<Args>(args)...);
539 case kNumberTypeInt64:
540 return std::make_shared<ImplClass<int64_t>>(std::forward<Args>(args)...);
541 case kNumberTypeUInt16:
542 return std::make_shared<ImplClass<uint16_t>>(std::forward<Args>(args)...);
543 case kNumberTypeUInt32:
544 return std::make_shared<ImplClass<uint32_t>>(std::forward<Args>(args)...);
545 case kNumberTypeUInt64:
546 return std::make_shared<ImplClass<uint64_t>>(std::forward<Args>(args)...);
547 case kNumberTypeFloat16:
548 return std::make_shared<ImplClass<float16>>(std::forward<Args>(args)...);
549 case kNumberTypeFloat:
550 return std::make_shared<ImplClass<float>>(std::forward<Args>(args)...);
551 case kNumberTypeFloat32:
552 return std::make_shared<ImplClass<float>>(std::forward<Args>(args)...);
553 case kNumberTypeFloat64:
554 return std::make_shared<ImplClass<double>>(std::forward<Args>(args)...);
555 #ifndef KERNEL_EXECUTOR_ANDROID
556 case kNumberTypeBFloat16:
557 return std::make_shared<ImplClass<bfloat16>>(std::forward<Args>(args)...);
558 #endif
559 case kNumberTypeComplex64:
560 return std::make_shared<ImplClass<ComplexStorage<float>>>(std::forward<Args>(args)...);
561 case kNumberTypeComplex128:
562 return std::make_shared<ImplClass<ComplexStorage<double>>>(std::forward<Args>(args)...);
563 case kObjectTypeString:
564 return std::make_shared<ImplClass<uint8_t>>(std::forward<Args>(args)...);
565 case kObjectTypeTensorType:
566 case kObjectTypeMapTensorType:
567 return std::make_shared<ImplClass<int>>(std::forward<Args>(args)...);
568 default:
569 break;
570 }
571 MS_LOG(ERROR) << "Cannot construct Tensor because of unsupported data type: " << TypeIdToString(data_type) << ".";
572 return nullptr;
573 }
574 class BaseTensor;
575 using BaseTensorPtr = std::shared_ptr<BaseTensor>;
576 using BaseTensorPtrList = std::vector<std::shared_ptr<BaseTensor>>;
577
578 // BaseTensor entity class
579 class MS_CORE_API BaseTensor : public MetaTensor {
580 public:
581 BaseTensor() = default;
582
583 /// \brief Create base tensor from another base tensor, data is shared.
584 ///
585 /// \param[in] tensor [BaseTensor] The input base tensor.
586 explicit BaseTensor(const BaseTensor &tensor);
587
588 /// \brief Create base tensor with given data type from another tensor.
589 ///
590 /// \param[in] tensor [BaseTensor] The input tensor.
591 /// \param[in] data_type [TypeId] The new tensor data type.
592 BaseTensor(const BaseTensor &tensor, TypeId data_type);
593
594 /// \brief Create base tensor with the given shared tensor data.
595 ///
596 /// \param[in] data_type [TypeId] Data type of the tensor.
597 /// \param[in] shape The shape represented by ShapeVector of the tensor.
598 /// \param[in] data The shared tensor data.
599 BaseTensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data);
600
601 /// \brief Create a lazy allocated tensor.
602 ///
603 /// \param[in] data_type [TypeId] Data type of the tensor.
604 /// \param[in] shape The shape represented by ShapeVector of the tensor.
605 BaseTensor(TypeId data_type, const ShapeVector &shape);
606
607 /// \brief Create a tensor with input data buffer.
608 ///
609 /// \param[in] data_type [TypeId] Data type of the tensor.
610 /// \param[in] shape The shape represented by ShapeVector of the tensor.
611 /// \param[in] data The input data to be copied into tensor.
612 /// \param[in] data_len The length of data in bytes.
613 BaseTensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);
614
615 /// \brief Create a tensor with input data buffer and given source data type.
616 ///
617 /// \param[in] data_type [TypeId] Data type of the tensor.
618 /// \param[in] shape The shape represented by ShapeVector of the tensor.
619 /// \param[in] data The input data to be copied into tensor.
620 /// \param[in] src_data_type The source data type.
621 BaseTensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type);
622
623 /// \brief Create 1 dimension tensor from an int vector.
624 ///
625 /// \param[in] input [std::vector<int64_t>] the data for tensor.
626 /// \param[in] data_type [TypeId] data type.
627 explicit BaseTensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr);
628
629 /// \brief Create 1 dimension tensor from an int vector.
630 ///
631 /// \param[in] input [std::vector<int32_t>] the data for tensor.
632 /// \param[in] data_type [TypeId] data type.
633 explicit BaseTensor(const std::vector<int32_t> &input, const TypePtr &data_type = nullptr);
634
635 /// \brief Create 1 dimension tensor from a float vector.
636 ///
637 /// \param[in] input [std::vector<double>] the data for tensor.
638 /// \param[in] data_type [TypeId] data type.
639 explicit BaseTensor(const std::vector<double> &input, const TypePtr &data_type = nullptr);
640
641 /// \brief Create 1 dimension tensor from a float vector.
642 ///
643 /// \param[in] input [std::vector<float>] the data for tensor.
644 /// \param[in] data_type [TypeId] data type.
645 explicit BaseTensor(const std::vector<float> &input, const TypePtr &data_type = nullptr);
646
647 /// \brief Create 0 dimension tensor from an int64_t scalar.
648 ///
649 /// \param[in] input [int64] the data for tensor.
650 /// \param[in] data_type [TypeId] data type.
651 explicit BaseTensor(int64_t input, const TypePtr &data_type = nullptr);
652
653 /// \brief Create 0 dimension tensor from an int32_t scalar.
654 ///
655 /// \param[in] input [int32] the data for tensor.
656 /// \param[in] data_type [TypeId] data type.
657 explicit BaseTensor(int32_t input, const TypePtr &data_type = nullptr);
658
659 /// \brief Create 0 dimension tensor from an int16_t scalar.
660 ///
661 /// \param[in] input [int16] the data for tensor.
662 /// \param[in] data_type [TypeId] data type.
663 explicit BaseTensor(int16_t input, const TypePtr &data_type = nullptr);
664
665 /// \brief Create 0 dimension tensor from an int8_t scalar.
666 ///
667 /// \param[in] input [int8] the data for tensor.
668 /// \param[in] data_type [TypeId] data type.
669 explicit BaseTensor(int8_t input, const TypePtr &data_type = nullptr);
670
671 /// \brief Create 0 dimension tensor from a double scalar.
672 ///
673 /// \param[in] input [double] the data for tensor.
674 /// \param[in] data_type [TypeId] data type.
675 explicit BaseTensor(double input, const TypePtr &data_type = nullptr);
676
677 /// \brief Create 0 dimension tensor from a float scalar.
678 ///
679 /// \param[in] input [float] the data for tensor.
680 /// \param[in] data_type [TypeId] data type.
681 explicit BaseTensor(float input, const TypePtr &data_type = nullptr);
682
683 /// \brief Create 0 dimension tensor from a float16 scalar.
684 ///
685 /// \param[in] input [float16] the data for tensor.
686 /// \param[in] data_type [TypeId] data type.
687 explicit BaseTensor(float16 input, const TypePtr &data_type = nullptr);
688
689 /// \brief Create 0 dimension tensor from a bfloat16 scalar.
690 ///
691 /// \param[in] input [bfloat16] the data for tensor.
692 /// \param[in] data_type [TypeId] data type.
693 explicit BaseTensor(bfloat16 input, const TypePtr &data_type = nullptr);
694
695 /// \brief Create 0 dimension tensor from a uint64 scalar.
696 ///
697 /// \param[in] input [uint64] the data for tensor.
698 /// \param[in] data_type [TypeId] data type.
699 explicit BaseTensor(uint64_t input, const TypePtr &data_type = nullptr);
700
701 /// \brief Create 0 dimension tensor from a uint32 scalar.
702 ///
703 /// \param[in] input [uint32] the data for tensor.
704 /// \param[in] data_type [TypeId] data type.
705 explicit BaseTensor(uint32_t input, const TypePtr &data_type = nullptr);
706
707 /// \brief Create 0 dimension tensor from a uint16 scalar.
708 ///
709 /// \param[in] input [uint16] the data for tensor.
710 /// \param[in] data_type [TypeId] data type.
711 explicit BaseTensor(uint16_t input, const TypePtr &data_type = nullptr);
712
713 /// \brief Create 0 dimension tensor from a uint8 scalar.
714 ///
715 /// \param[in] input [uint8] the data for tensor.
716 /// \param[in] data_type [TypeId] data type.
717 explicit BaseTensor(uint8_t input, const TypePtr &data_type = nullptr);
718
719 /// \brief Create 0 dimension tensor from a bool scalar.
720 ///
721 /// \param[in] input [bool] the data for tensor.
722 /// \param[in] data_type [TypeId] data type.
723 explicit BaseTensor(bool input, const TypePtr &data_type = nullptr);
724
725 /// \brief Create a chunk tensor with the given data size.
726 ///
727 /// \param[in] data_type [TypeId] Data type of the tensor.
728 /// \param[in] data_size The tensor chunk data size in number of elements.
729 BaseTensor(TypeId data_type, size_t data_size);
730
731 /// \brief Create a Tensor which shape and size may be inconsistent, such as Tensor with compression data.
732 ///
733 /// \param[in] origin_data_type [TypeId] Data type of the origin tensor.
734 /// \param[in] shape The shape represented by ShapeVector of the tensor.
735 /// \param[in] compression_data_size The compression data buffer size.
736 /// \param[in] TensorCompressionType The tensor compression type.
737 BaseTensor(TypeId origin_data_type, const ShapeVector &shape, size_t compression_data_size,
738 TensorCompressionType compression_type);
739
740 BaseTensor &operator=(const BaseTensor &tensor);
741
742 /// Destructor of BaseTensor.
743 ~BaseTensor() override = default;
744
745 MS_DECLARE_PARENT(BaseTensor, MetaTensor);
746
747 /// \brief Assign value to this tensor.
748 ///
749 /// \param[in] tensor The input tensor.
750 /// \return Tensor with new value.
751 BaseTensor &AssignValue(const BaseTensor &tensor);
752
753 /// \brief Compare two tensor objects to see if they have same data type, shape and data address.
754 ///
755 /// \param[in] tensor The BaseTensor object to be compared.
756 /// \return True if having same type, shape and data address, otherwise false.
757 bool operator==(const BaseTensor &tensor) const;
758
759 /// \brief Create Abstract for Tensor.
760 ///
761 /// \return Abstract of Tensor.
762 abstract::AbstractBasePtr ToAbstract() override;
763
764 /// \brief Get Abstract cache. The value of the abstract is null.
765 /// Only used by InferShape in PyNative mode.
766 ///
767 /// \return Abstract of tensor.
768 abstract::AbstractBasePtr GetAbstractCache();
769
770 /// \brief It is different from 'operator==' which just compares shape/type/address,
771 /// it does real value comparison.
772 ///
773 /// \param[in] tensor The BaseTensor object to be compared.
774 /// \return True if it has the same value, otherwise false.
775 bool ValueEqual(const BaseTensor &tensor) const;
776
777 bool operator==(const Value &other) const override {
778 if (other.isa<BaseTensor>()) {
779 auto &other_ = static_cast<const BaseTensor &>(other);
780 return *this == other_;
781 }
782 return false;
783 }
784
785 /// \brief Gets tensor's dimension.
786 ///
787 /// \return The number of dimensions of the tensor data.
DataDim()788 int DataDim() const { return static_cast<int>(data().ndim()); }
789
790 /// \brief Getting tensor data size.
791 ///
792 /// \return The total number of elements of the tensor data.
DataSize()793 size_t DataSize() const { return data().size(); }
794
795 /// \brief Get the data type of the tensor for C++
796 ///
797 /// \return [int] The tensor's data type will be cast to int to return.
data_type_c()798 int data_type_c() const { return static_cast<int>(data_type_); }
799
800 /// \brief Get the tensor's shape for C++
801 ///
802 /// \return [ShapeVector]
shape_c(void)803 ShapeVector shape_c(void) const { return shape(); }
804
805 /// \brief Get BaseTensor data pointer for c++ type
806 ///
807 /// \return The pointer to the object
data_c()808 void *data_c() { return data().data(); }
809
810 /// \brief Get BaseTensor data byte-size for c++ type
811 ///
812 /// \return byte size of BaseTensor data
Size()813 size_t Size() const { return static_cast<size_t>(data().nbytes()); }
814
815 /// \brief The pointer to the object
data_c()816 void *data_c() const { return data_->data(); }
817
818 /// \brief To synchronize data with the device, you need to wait for the data to be valid.
819 ///
820 void data_sync(bool need_wait = true) const;
821
822 /// \brief Get the internal data object.
823 ///
824 /// \return The reference to internal data object.
data()825 TensorData &data() {
826 MS_EXCEPTION_IF_NULL(data_);
827 return *data_;
828 }
829
830 /// \brief Get the internal data shared pointer.
831 ///
832 /// return The reference to internal data object.
data_ptr()833 const TensorDataPtr &data_ptr() const { return data_; }
834
835 /// \brief Get the internal data object.
836 ///
837 /// \return The reference to internal data object.
data()838 const TensorData &data() const { return *data_; }
839
set_data(const TensorDataPtr & data)840 void set_data(const TensorDataPtr &data) { data_ = data; }
841
842 TypeId set_data_type(TypeId data_type) override;
843
844 size_t set_shape(const ShapeVector &shape) override;
845
846 /// \brief Get information about shape and data type.
847 ///
848 /// \return Information about shape and data type.
849 std::string GetShapeAndDataTypeInfo() const;
850
851 /// \brief Get display information of limit size.
852 ///
853 /// \param[in] limit_size The limit size.
854 /// \return The display information of limit size.
855 std::string ToStringInternal(size_t limit_size) const;
856
857 /// \brief Get display information with unlimited size.
858 ///
859 /// \return The display information with unlimited size.
860 std::string ToStringNoLimit() const;
861
862 /// \brief Get display information of this BaseTensor.
863 ///
864 /// \return The display information of this BaseTensor.
865 std::string ToString() const override;
866
867 /// \brief Get display information in repr form.
868 ///
869 /// \return The display information in repr form.
870 std::string ToStringRepr() const;
871
872 /// \brief Check if this BaseTensor is forward output.
873 ///
874 /// \return Whether this BaseTensor is forward output.
is_forward_output()875 bool is_forward_output() const { return is_forward_output_; }
876
877 /// \brief Set the forward output flag of this BaseTensor.
878 ///
879 /// \param[in] is_forward_output Whether this BaseTensor is forward output.
set_is_forward_output(bool is_forward_output)880 void set_is_forward_output(bool is_forward_output) { is_forward_output_ = is_forward_output; }
881
882 /// \brief Get the device address.
883 ///
884 /// \return The device address.
885 DeviceSyncPtr device_address() const;
886
887 /// \brief Set the device address.
888 ///
889 /// \param[in] device_sync The input Device synchronization.
890 /// \param[in] need_update_ref_count If need_update_ref_count is true, the device address cannot be released and
891 /// reused, so the feature map should set false when set device address of tensor.
892 void set_device_address(const DeviceSyncPtr &device_sync, bool need_update_ref_count = true);
893
894 /// \brief Get the id of this BaseTensor.
895 ///
896 /// \return The id of this BaseTensor.
id()897 std::string id() const { return id_; }
898
899 /// \brief Set lazy callback function to this Tensor
900 ///
901 /// \param[in] lazy_callback Wait for async tasks finish before data_sync.
RegisterLazyCallback(const std::function<void (void)> & lazy_callback)902 static void RegisterLazyCallback(const std::function<void(void)> &lazy_callback) { lazy_callback_ = lazy_callback; }
903
904 /// \brief Set contiguous callback function to this BaseTensor
905 ///
906 /// \param[in] contiguous_callback The callback from backend when need to make tensor contiguous.
set_contiguous_callback(const std::function<DeviceSyncPtr (const DeviceSyncPtr &)> & contiguous_callback)907 void set_contiguous_callback(const std::function<DeviceSyncPtr(const DeviceSyncPtr &)> &contiguous_callback) {
908 contiguous_callback_ = contiguous_callback;
909 }
910
911 /// @brief Get Pynative auto_grad meta data.
912 /// @return Auto grad meta data
auto_grad_meta_data()913 const AutoGradMetaDataPtr &auto_grad_meta_data() const { return auto_grad_meta_data_; }
914
915 /// @brief Set Pynative auto_grad meta data.
916 /// @param auto_grad_meta_data
set_auto_grad_meta_data(const AutoGradMetaDataPtr & auto_grad_meta_data)917 void set_auto_grad_meta_data(const AutoGradMetaDataPtr &auto_grad_meta_data) {
918 auto_grad_meta_data_ = auto_grad_meta_data;
919 }
920
921 /// \brief Get tensor storage info.
922 ///
923 /// \return BaseTensor storage info, the value is nullptr default.
924 const TensorStorageInfoPtr storage_info() const;
925
926 /// \brief Set tensor abstract.
927 ///
928 /// \param[in] abstract The abstract of tensor.
set_abstract(const std::weak_ptr<abstract::AbstractBase> & abstract)929 void set_abstract(const std::weak_ptr<abstract::AbstractBase> &abstract) { abstract_ = abstract; }
930
931 /// \brief Set synchronization status.
932 ///
933 /// \param[in] sync_status The input synchronization status.
set_sync_status(TensorSyncStatus sync_status)934 void set_sync_status(TensorSyncStatus sync_status) const { sync_status_ = sync_status; }
935
936 /// \brief Get synchronization status.
937 ///
938 /// \return The synchronization status.
sync_status()939 TensorSyncStatus sync_status() const { return sync_status_; }
940
941 /// \brief Check the value of sync_status_.
942 ///
943 /// \return Ture if sync_status_ is kNeedSyncDeviceToHostImmediately.
NeedSyncDeviceToHostImmediately()944 bool NeedSyncDeviceToHostImmediately() const { return sync_status_ == kNeedSyncDeviceToHostImmediately; }
945
946 /// \brief Check the value of sync_status_.
947 ///
948 /// \return Ture if sync_status_ is kNeedSyncDeviceToHost.
NeedSyncDeviceToHost()949 bool NeedSyncDeviceToHost() const { return sync_status_ == kNeedSyncDeviceToHost; }
950
951 /// \brief Check the value of sync_status_.
952 ///
953 /// \return Ture if sync_status_ is kNeedSyncHostToDevice.
NeedSyncHostToDevice()954 bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
955
956 /// \brief Check the value of sync_status_.
957 ///
958 /// \return Ture if sync_status_ is kNeedSyncHostToDeviceImmediately.
NeedSyncHostToDeviceImmediately()959 bool NeedSyncHostToDeviceImmediately() const { return sync_status_ == kNeedSyncHostToDeviceImmediately; }
960
961 /// \brief Get tensor's BaseShape.
962 ///
963 /// \return The BaseShape of this tensor.
base_shape_ptr()964 const BaseShapePtr &base_shape_ptr() const { return base_shape_ptr_; }
965
966 /// \brief Set tensor's BaseShape.
967 ///
968 /// \param[in] BaseShapePtr The tensor's BaseShape.
set_base_shape(const BaseShapePtr & base_shape)969 void set_base_shape(const BaseShapePtr &base_shape) { base_shape_ptr_ = base_shape; }
970
971 /// \brief Determines whether the memory of tensor is contiguous.
972 ///
973 /// \return True if tensor memory is contiguous, false otherwise.
974 bool is_contiguous() const;
975
976 /// \brief Get tensor storage stride.
977 ///
978 /// \return storage stride.
979 std::vector<int64_t> stride() const;
980
981 /// \brief Get tensor storage offset.
982 ///
983 /// \return storage offset.
984 const int64_t storage_offset() const;
985
set_need_pipeline_sync(bool need_pipeline_sync)986 void set_need_pipeline_sync(bool need_pipeline_sync) { need_pipeline_sync_ = need_pipeline_sync; }
987
988 /// \brief Execute lazy task.
989 ///
990 void ExecuteLazyTask() const;
991
992 protected:
993 bool is_forward_output_{false};
994 bool need_pipeline_sync_{false};
995 std::string id_{""};
996 mutable DeviceSyncPtr device_sync_{nullptr};
997 mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
998 AutoGradMetaDataPtr auto_grad_meta_data_{nullptr};
999 std::weak_ptr<abstract::AbstractBase> abstract_;
1000 TensorDataPtr data_{nullptr};
1001 // Tensor base shape which contain dynamic shape info.
1002 BaseShapePtr base_shape_ptr_{nullptr};
1003 inline static std::function<void(void)> lazy_callback_{nullptr};
1004 std::function<DeviceSyncPtr(const DeviceSyncPtr &)> contiguous_callback_{nullptr};
1005 };
1006
1007 // Convert shape vector to string.
1008 MS_CORE_API std::string ShapeToString(const ShapeVector &shape);
1009
CopyTensorData(const TensorDataPtr & dest,const TensorDataPtr & src)1010 inline static void CopyTensorData(const TensorDataPtr &dest, const TensorDataPtr &src) {
1011 auto dest_bytes = dest->nbytes();
1012 auto src_bytes = src->nbytes();
1013 auto err = common::huge_memcpy(static_cast<uint8_t *>(dest->data()), dest_bytes,
1014 static_cast<const uint8_t *>(src->const_data()), src_bytes);
1015 if (err != EOK) {
1016 MS_LOG(INTERNAL_EXCEPTION) << "Copy tensor data failed! bytes: " << dest_bytes << "/" << src_bytes << ".";
1017 }
1018 }
1019 } // namespace tensor
1020 } // namespace mindspore
1021
1022 #endif // MINDSPORE_CORE_IR_BASE_TENSOR_H_
1023